class DNN::Optimizers::Nesterov

Public Class Methods

new(lr: 0.01, momentum: 0.9, clip_norm: nil) click to toggle source
Calls superclass method DNN::Optimizers::SGD::new
# File lib/dnn/core/optimizers.rb, line 100
def initialize(lr: 0.01, momentum: 0.9, clip_norm: nil)
  super(lr: lr, momentum: momentum, clip_norm: clip_norm)
end

Private Instance Methods

update_params(params) click to toggle source
# File lib/dnn/core/optimizers.rb, line 104
        def update_params(params)
  params.each do |param|
    @v[param] ||= Xumo::SFloat.zeros(*param.data.shape)
    amount = param.grad * @lr
    @v[param] = @v[param] * @momentum - amount
    param.data = (param.data + @momentum**2 * @v[param]) - (1 + @momentum) * amount
  end
end