class DNN::Optimizers::Optimizer

Super class of all optimizer classes.

Attributes

clip_norm[RW]
status[R]

Public Class Methods

from_hash(hash) click to toggle source
# File lib/dnn/core/optimizers.rb, line 9
def self.from_hash(hash)
  return nil unless hash
  optimizer_class = DNN.const_get(hash[:class])
  optimizer = optimizer_class.allocate
  raise DNNError, "#{optimizer.class} is not an instance of #{self} class." unless optimizer.is_a?(self)
  optimizer.load_hash(hash)
  optimizer
end
new(clip_norm: nil) click to toggle source

@param [Float | NilClass] clip_norm Gradient clip norm.

# File lib/dnn/core/optimizers.rb, line 19
def initialize(clip_norm: nil)
  @clip_norm = clip_norm
end

Public Instance Methods

load_hash(hash) click to toggle source
# File lib/dnn/core/optimizers.rb, line 59
def load_hash(hash)
  initialize(clip_norm: hash[:clip_norm])
end
to_hash(merge_hash = nil) click to toggle source
# File lib/dnn/core/optimizers.rb, line 39
def to_hash(merge_hash = nil)
  hash = { class: self.class.name, clip_norm: @clip_norm }
  hash.merge!(merge_hash) if merge_hash
  hash
end
update(params) click to toggle source
# File lib/dnn/core/optimizers.rb, line 23
def update(params)
  clip_grads(params) if @clip_norm
  update_params(params)
  params.each do |param|
    param.grad = Xumo::SFloat[0]
  end
end
update_layers(layers) click to toggle source

Update layers has params.

# File lib/dnn/core/optimizers.rb, line 32
def update_layers(layers)
  target_params = layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable }
                        .map { |layer| layer.get_params.values }.flatten.compact
                        .select(&:grad)
  update(target_params)
end

Private Instance Methods

clip_grads(params) click to toggle source
# File lib/dnn/core/optimizers.rb, line 50
        def clip_grads(params)
  norm = Math.sqrt(params.reduce(0) { |total, param| total + (param.grad**2).sum.to_f })
  return if norm <= @clip_norm
  rate = @clip_norm / (norm + 1e-7)
  params.each do |param|
    param.grad *= rate
  end
end
update_params(params) click to toggle source

Update params.

# File lib/dnn/core/optimizers.rb, line 46
        def update_params(params)
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'update_params'"
end