class Chainer::Optimizers::MomentumSGDRule

Update rule for the classical momentum SGD

Public Class Methods

new(parent_hyperparam: nil, lr: nil, mementum: nil) click to toggle source
Calls superclass method Chainer::UpdateRule::new
# File lib/chainer/optimizers/momentum_sgd.rb, line 5
def initialize(parent_hyperparam: nil, lr: nil, mementum: nil)
  hyperparam = Hyperparameter.new
  hyperparam.instance_variable_set('@lr', 0.01)
  hyperparam.instance_variable_set('@momentum', 0.9)

  super(parent_hyperparam: parent_hyperparam || hyperparam)
  
  @hyperparam.instance_variable_set('@lr', lr) if lr
  @hyperparam.instance_variable_set('@mementum', mementum) if mementum
end

Public Instance Methods

init_state(param) click to toggle source
# File lib/chainer/optimizers/momentum_sgd.rb, line 16
def init_state(param)
  @state[:v] = param.data.new_zeros
end
update_core(param) click to toggle source
# File lib/chainer/optimizers/momentum_sgd.rb, line 20
def update_core(param)
  grad = param.grad
  return if grad.nil?
    
  v = @state[:v]
  v *= @hyperparam.momentum
  v -= @hyperparam.lr * grad
  param.data += v
end