class Chainer::Training::Extensions::ExponentialShift
Trainer
extension to exponentially shift an optimizer attribute.
This extension exponentially increases or decreases the specified attribute of the optimizer. The typical use case is an exponential decay of the learning rate. This extension is also called before the training loop starts by default.
Attributes
last_value[R]
Public Class Methods
new(attr, rate, init: nil, target: nil, optimizer: nil)
click to toggle source
@param [string] attr Name of the attribute to shift @param [float] rate Rate of the exponential shift. @param [float] init Initial value of the attribute. @param [float] target Target value of the attribute. @param [Chainer::Optimizer] optimizer Target optimizer to adjust the attribute.
# File lib/chainer/training/extensions/exponential_shift.rb, line 17 def initialize(attr, rate, init: nil, target: nil, optimizer: nil) @attr = attr raise 'ExponentialShift does not support negative rate' if rate < 0 @rate = rate @init = init @target = target @optimizer = optimizer @t = 0 @last_value = nil end
Public Instance Methods
call(trainer)
click to toggle source
# File lib/chainer/training/extensions/exponential_shift.rb, line 38 def call(trainer) @t += 1 optimizer = get_optimizer(trainer) value = @init * (@rate ** @t) if @target if @rate > 1 if value / @target > 1 value = @target end else if value / @target < 1 value = @target end end end update_value(optimizer, value) end
init(trainer)
click to toggle source
# File lib/chainer/training/extensions/exponential_shift.rb, line 28 def init(trainer) optimizer = get_optimizer(trainer) @init = optimizer.send(@attr) if @init.nil? if @last_value.nil? update_value(optimizer, @init) else update_value(optimizer, @last_value) end end
serialize(serializer)
click to toggle source
# File lib/chainer/training/extensions/exponential_shift.rb, line 57 def serialize(serializer) @t = serializer.('t', @t) @last_value = serializer.('last_value', @last_value) if Chainer.array?(@last_value) @last_value = @last_value[0] end end
Private Instance Methods
get_optimizer(trainer)
click to toggle source
# File lib/chainer/training/extensions/exponential_shift.rb, line 67 def get_optimizer(trainer) @optimizer || trainer.updater.get_optimizer(:main) end
update_value(optimizer, value)
click to toggle source
# File lib/chainer/training/extensions/exponential_shift.rb, line 71 def update_value(optimizer, value) optimizer.send("#{@attr}=", value) @last_value = value end