class Chainer::UpdateRule

Attributes

state[R]

Public Class Methods

new(parent_hyperparam:) click to toggle source
# File lib/chainer/optimizer.rb, line 57
def initialize(parent_hyperparam:)
  @hooks = {}  
  @state = nil
  @enabled = true
  @hyperparam = Chainer::Hyperparameter.new(parent: parent_hyperparam)
  @t = 0
end

Public Instance Methods

init_state(param) click to toggle source
# File lib/chainer/optimizer.rb, line 95
def init_state(param)
  raise NotImplementedError
end
serialize(serializer) click to toggle source

Serializes the update rule state. Be careful that this method only saves/loads the state of the update rule. The parameters of the target link is not saved/loaded by this method, and so you need to serialize the target link separately if you want to fully recover the training state including parameters.

@param [Chainer::AbstractSerializer] serializer Serializer object.

# File lib/chainer/optimizer.rb, line 107
def serialize(serializer)
  if @state.nil?
    if serializer.is_a?(Chainer::Deserializer)
      # try to initialize the state to retrieve state entries
      @state = {}
      self_copy = self.dup
      # TODO(sonots): pass device from outside
      xm = Chainer::Device.default.xm
      arr = xm::SFloat.new(1)
      self_copy.init_state(Chainer::Variable.new(arr, grad: arr))
      @state.keys.each do |key|
        @state[key] = serializer.(key.to_s, nil)
      end
    end
  else
    @state.each do |key, val|
      @state[key] = serializer.(key.to_s, val)
    end
  end                                                                                 
end
update(param) click to toggle source
# File lib/chainer/optimizer.rb, line 65
def update(param)
  return unless @enabled

  @t += 1
  unless param.data.nil?
    prepare(param)
  end
  @hooks.values.each do |hook|
    hook.call(param)
  end
  update_core(param)
end
update_core(param) click to toggle source
# File lib/chainer/optimizer.rb, line 78
def update_core(param)
  xm = Chainer.get_array_module(param)
  if xm == Cumo
    update_core_gpu(param)
  else
    update_core_cpu(param)
  end
end
update_core_cpu() click to toggle source
# File lib/chainer/optimizer.rb, line 87
def update_core_cpu
  raise NotImplementedError
end
update_core_gpu() click to toggle source
# File lib/chainer/optimizer.rb, line 91
def update_core_gpu
  raise NotImplementedError
end

Private Instance Methods

prepare(param) click to toggle source
# File lib/chainer/optimizer.rb, line 130
def prepare(param)
  if @state.nil?
    @state = {}
    init_state(param)
  end
  @state.select! { |_, v| Chainer.array?(v) }
end