class Mnist::Playground

Attributes

logger[RW]
strategies[RW]

Public Class Methods

new() click to toggle source
# File lib/mnist/playground.rb, line 10
def initialize
  @strategies = Mnist::Strategies::Base.descendants.map(&:new)
  @results = Hash.new { |hash, key| hash[key] = [] }
  @logger = Logger.new(STDOUT)

  log_level(Logger::INFO)
end

Public Instance Methods

guess(dataset = Mnist::Dataset::TEST_DUMMY) click to toggle source
# File lib/mnist/playground.rb, line 30
def guess(dataset = Mnist::Dataset::TEST_DUMMY)
  with_dataset(dataset) do |value, data|
    strategies.each do |strategy|
      guess = strategy.guess(data)
      @results[strategy] << { value: value, guess: guess, result: guess == value }
    end
  end
end
log_level(level) click to toggle source
# File lib/mnist/playground.rb, line 18
def log_level(level)
  @logger.level = level
end
results() click to toggle source
# File lib/mnist/playground.rb, line 39
def results
  @results.map do |strategy, results|
    success, failure = results.partition { |result| result[:result] }
    [strategy.name, { success: success.size, failure: failure.size }]
  end.to_h
end
train(dataset = Mnist::Dataset::TRAIN_DUMMY) click to toggle source
# File lib/mnist/playground.rb, line 22
def train(dataset = Mnist::Dataset::TRAIN_DUMMY)
  with_dataset(dataset) do |value, data|
    strategies.each do |strategy|
      strategy.train(value, data)
    end
  end
end

Private Instance Methods

with_dataset(dataset) { |value, data, index| ... } click to toggle source
# File lib/mnist/playground.rb, line 48
def with_dataset(dataset)
  lines = File.open(dataset).readlines

  logger.debug("Filename: #{dataset}")
  logger.debug("Dataset: #{lines.size} entries")

  lines.each.with_index do |line, index|
    logger.debug("#{((index.to_f + 1) / lines.size * 100).round(2)}%")
    value, *data = line.split(',').map(&:to_i)
    yield(value, data, index)
  end
end