class DNN::Optimizers::Optimizer
Super class of all optimizer classes.
Attributes
clip_norm[RW]
status[R]
Public Class Methods
from_hash(hash)
click to toggle source
# File lib/dnn/core/optimizers.rb, line 9 def self.from_hash(hash) return nil unless hash optimizer_class = DNN.const_get(hash[:class]) optimizer = optimizer_class.allocate raise DNNError, "#{optimizer.class} is not an instance of #{self} class." unless optimizer.is_a?(self) optimizer.load_hash(hash) optimizer end
new(clip_norm: nil)
click to toggle source
@param [Float | NilClass] clip_norm
Gradient clip norm.
# File lib/dnn/core/optimizers.rb, line 19 def initialize(clip_norm: nil) @clip_norm = clip_norm end
Public Instance Methods
load_hash(hash)
click to toggle source
# File lib/dnn/core/optimizers.rb, line 59 def load_hash(hash) initialize(clip_norm: hash[:clip_norm]) end
to_hash(merge_hash = nil)
click to toggle source
# File lib/dnn/core/optimizers.rb, line 39 def to_hash(merge_hash = nil) hash = { class: self.class.name, clip_norm: @clip_norm } hash.merge!(merge_hash) if merge_hash hash end
update(params)
click to toggle source
# File lib/dnn/core/optimizers.rb, line 23 def update(params) clip_grads(params) if @clip_norm update_params(params) params.each do |param| param.grad = Xumo::SFloat[0] end end
update_layers(layers)
click to toggle source
Update layers has params.
# File lib/dnn/core/optimizers.rb, line 32 def update_layers(layers) target_params = layers.select { |layer| layer.is_a?(Layers::TrainableLayer) && layer.trainable } .map { |layer| layer.get_params.values }.flatten.compact .select(&:grad) update(target_params) end
Private Instance Methods
clip_grads(params)
click to toggle source
# File lib/dnn/core/optimizers.rb, line 50 def clip_grads(params) norm = Math.sqrt(params.reduce(0) { |total, param| total + (param.grad**2).sum.to_f }) return if norm <= @clip_norm rate = @clip_norm / (norm + 1e-7) params.each do |param| param.grad *= rate end end
update_params(params)
click to toggle source
Update params.
# File lib/dnn/core/optimizers.rb, line 46 def update_params(params) raise NotImplementedError, "Class '#{self.class.name}' has implement method 'update_params'" end