class Chainer::Optimizers::AdamRule

Public Class Methods

new(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil) click to toggle source
Calls superclass method Chainer::UpdateRule::new
# File lib/chainer/optimizers/adam.rb, line 4
def initialize(parent_hyperparam: nil, alpha: nil, beta1: nil, beta2: nil, eps: nil)
  hyperparam = Hyperparameter.new
  hyperparam.instance_variable_set('@alpha', 0.001)
  hyperparam.instance_variable_set('@beta1', 0.9)
  hyperparam.instance_variable_set('@beta2', 0.999)
  hyperparam.instance_variable_set('@eps', 1e-8)

  super(parent_hyperparam: parent_hyperparam || hyperparam)

  @hyperparam.instance_variable_set('@alpha', alpha) if alpha
  @hyperparam.instance_variable_set('@beta1', beta1) if beta1
  @hyperparam.instance_variable_set('@beta2', beta2) if beta2
  @hyperparam.instance_variable_set('@eps', eps) if eps
end

Public Instance Methods

init_state(param) click to toggle source
# File lib/chainer/optimizers/adam.rb, line 19
def init_state(param)
  @state[:m] = param.data.new_zeros
  @state[:v] = param.data.new_zeros
end
lr() click to toggle source
# File lib/chainer/optimizers/adam.rb, line 36
def lr
  fix1 = 1.0 - @hyperparam.beta1 ** @t
  fix2 = 1.0 - @hyperparam.beta2 ** @t
  @hyperparam.alpha * Math.sqrt(fix2) / fix1
end
update_core(param) click to toggle source
# File lib/chainer/optimizers/adam.rb, line 24
def update_core(param)
  grad = param.grad
  return if grad.nil?

  hp = @hyperparam

  @state[:m] += (1 - hp.beta1) * (grad - @state[:m])
  @state[:v] += (1 - hp.beta2) * (grad * grad - @state[:v])
  xm = Chainer.get_array_module(grad)
  param.data -= lr * @state[:m] / (xm::NMath.sqrt(@state[:v]) + hp.eps)
end