class CooCoo::Trainer::MomentumStochastic
Constants
- DEFAULT_OPTIONS
Public Instance Methods
learn(network, input, expecting, rate, last_deltas, momentum, cost_function, hidden_state)
click to toggle source
# File lib/coo-coo/trainer/momentum_stochastic.rb, line 46 def learn(network, input, expecting, rate, last_deltas, momentum, cost_function, hidden_state) output, hidden_state = network.forward(input, hidden_state) target = expecting target = network.prep_output_target(expecting) final_output = network.final_output(output) errors = cost_function.derivative(target, final_output) deltas, hidden_state = network.backprop(input, output, errors, hidden_state) deltas = CooCoo::Sequence[deltas] * rate network.update_weights!(input, output, deltas - last_deltas * momentum) return cost_function.call(target, final_output), hidden_state, deltas end
options()
click to toggle source
Calls superclass method
CooCoo::Trainer::Base#options
# File lib/coo-coo/trainer/momentum_stochastic.rb, line 11 def options super(DEFAULT_OPTIONS) do |o, options| o.on('--momentum FLOAT', Float, 'Multiplier for the accumulated changes.') do |n| options.momentum = n end end end
train(options, &block)
click to toggle source
@option options [Float] :momentum The dampening factor on the reuse of the previous network change.
# File lib/coo-coo/trainer/momentum_stochastic.rb, line 20 def train(options, &block) options = options.to_h network = options.fetch(:network) training_data = options.fetch(:data) learning_rate = options.fetch(:learning_rate, 1/3.0) batch_size = options.fetch(:batch_size, 1024) cost_function = options.fetch(:cost_function, CostFunctions::MeanSquare) momentum = options.fetch(:momentum, 1/30.0) t = Time.now training_data.each_slice(batch_size).with_index do |batch, i| last_delta = 0.0 total_errs = batch.inject(nil) do |acc, (expecting, input)| errs, hidden_state, last_delta = learn(network, input, expecting, learning_rate, last_delta, momentum, cost_function, Hash.new) errs + (acc || 0) end if block block.call(BatchStats.new(self, i, batch_size, Time.now - t, total_errs)) end t = Time.now end end