class CooCoo::Trainer::MomentumStochastic

Constants

DEFAULT_OPTIONS

Public Instance Methods

learn(network, input, expecting, rate, last_deltas, momentum, cost_function, hidden_state) click to toggle source
# File lib/coo-coo/trainer/momentum_stochastic.rb, line 46
def learn(network, input, expecting, rate, last_deltas, momentum, cost_function, hidden_state)
  output, hidden_state = network.forward(input, hidden_state)
  target = expecting
  target = network.prep_output_target(expecting)
  final_output = network.final_output(output)
  errors = cost_function.derivative(target, final_output)
  deltas, hidden_state = network.backprop(input, output, errors, hidden_state)
  deltas = CooCoo::Sequence[deltas] * rate
  network.update_weights!(input, output, deltas - last_deltas * momentum)
  return cost_function.call(target, final_output), hidden_state, deltas
end
options() click to toggle source
Calls superclass method CooCoo::Trainer::Base#options
# File lib/coo-coo/trainer/momentum_stochastic.rb, line 11
def options
  super(DEFAULT_OPTIONS) do |o, options|
    o.on('--momentum FLOAT', Float, 'Multiplier for the accumulated changes.') do |n|
      options.momentum = n
    end
  end
end
train(options, &block) click to toggle source

@option options [Float] :momentum The dampening factor on the reuse of the previous network change.

# File lib/coo-coo/trainer/momentum_stochastic.rb, line 20
def train(options, &block)
  options = options.to_h
  network = options.fetch(:network)
  training_data = options.fetch(:data)
  learning_rate = options.fetch(:learning_rate, 1/3.0)
  batch_size = options.fetch(:batch_size, 1024)
  cost_function = options.fetch(:cost_function, CostFunctions::MeanSquare)
  momentum = options.fetch(:momentum, 1/30.0)

  t = Time.now
  
  training_data.each_slice(batch_size).with_index do |batch, i|
    last_delta = 0.0
    total_errs = batch.inject(nil) do |acc, (expecting, input)|
      errs, hidden_state, last_delta = learn(network, input, expecting, learning_rate, last_delta, momentum, cost_function, Hash.new)
      errs + (acc || 0)
    end

    if block
      block.call(BatchStats.new(self, i, batch_size, Time.now - t, total_errs))
    end
    
    t = Time.now
  end
end