class Chainer::Optimizers::MomentumSGDRule
Update rule for the classical momentum SGD
Public Class Methods
new(parent_hyperparam: nil, lr: nil, mementum: nil)
click to toggle source
Calls superclass method
Chainer::UpdateRule::new
# File lib/chainer/optimizers/momentum_sgd.rb, line 5 def initialize(parent_hyperparam: nil, lr: nil, mementum: nil) hyperparam = Hyperparameter.new hyperparam.instance_variable_set('@lr', 0.01) hyperparam.instance_variable_set('@momentum', 0.9) super(parent_hyperparam: parent_hyperparam || hyperparam) @hyperparam.instance_variable_set('@lr', lr) if lr @hyperparam.instance_variable_set('@mementum', mementum) if mementum end
Public Instance Methods
init_state(param)
click to toggle source
# File lib/chainer/optimizers/momentum_sgd.rb, line 16 def init_state(param) @state[:v] = param.data.new_zeros end
update_core(param)
click to toggle source
# File lib/chainer/optimizers/momentum_sgd.rb, line 20 def update_core(param) grad = param.grad return if grad.nil? v = @state[:v] v *= @hyperparam.momentum v -= @hyperparam.lr * grad param.data += v end