class DNN::Optimizers::Adam
Attributes
alpha[RW]
amsgrad[R]
beta1[RW]
beta2[RW]
eps[RW]
Public Class Methods
new(alpha: 0.001, beta1: 0.9, beta2: 0.999, eps: 1e-7, amsgrad: false, clip_norm: nil)
click to toggle source
@param [Float] alpha Value used to calculate learning rate. @param [Float] beta1 Moving average index of beta1. @param [Float] beta2 Moving average index of beta2. @param [Float] eps Value to avoid division by zero. @param [Boolean] amsgrad Setting the true enable amsgrad.
Calls superclass method
DNN::Optimizers::Optimizer::new
# File lib/dnn/core/optimizers.rb, line 263 def initialize(alpha: 0.001, beta1: 0.9, beta2: 0.999, eps: 1e-7, amsgrad: false, clip_norm: nil) super(clip_norm: clip_norm) @alpha = alpha @beta1 = beta1 @beta2 = beta2 @eps = eps @amsgrad = amsgrad @t = 0 @m = {} @v = {} @s = amsgrad ? {} : nil @status = { m: @m, v: @v, s: @s } end
Public Instance Methods
load_hash(hash)
click to toggle source
# File lib/dnn/core/optimizers.rb, line 302 def load_hash(hash) initialize(alpha: hash[:alpha], beta1: hash[:beta1], beta2: hash[:beta2], eps: hash[:eps], amsgrad: hash[:amsgrad], clip_norm: hash[:clip_norm]) end
to_hash()
click to toggle source
# File lib/dnn/core/optimizers.rb, line 277 def to_hash { class: self.class.name, alpha: @alpha, beta1: @beta1, beta2: @beta2, eps: @eps, amsgrad: @amsgrad, clip_norm: @clip_norm } end
Private Instance Methods
update_params(params)
click to toggle source
# File lib/dnn/core/optimizers.rb, line 284 def update_params(params) @t += 1 lr = @alpha * Math.sqrt(1 - @beta2**@t) / (1 - @beta1**@t) params.each do |param| @m[param] ||= Xumo::SFloat.zeros(*param.data.shape) @v[param] ||= Xumo::SFloat.zeros(*param.data.shape) @m[param] += (1 - @beta1) * (param.grad - @m[param]) @v[param] += (1 - @beta2) * (param.grad**2 - @v[param]) if @amsgrad @s[param] ||= Xumo::SFloat.zeros(*param.data.shape) @s[param] = Xumo::SFloat.maximum(@s[param], @v[param]) param.data -= lr * @m[param] / Xumo::NMath.sqrt(@s[param] + @eps) else param.data -= lr * @m[param] / Xumo::NMath.sqrt(@v[param] + @eps) end end end