class Chainer::GradientMethod
Public Class Methods
new()
click to toggle source
Calls superclass method
# File lib/chainer/gradient_method.rb, line 3 def initialize super() @hyperparam = Hyperparameter.new end
Public Instance Methods
call_hooks()
click to toggle source
# File lib/chainer/gradient_method.rb, line 24 def call_hooks @hooks.values.each do |hook| _call_hook(hook) reallocate_cleared_grads end end
create_update_rule()
click to toggle source
# File lib/chainer/gradient_method.rb, line 60 def create_update_rule raise NotImplementedError end
reallocate_cleared_grads()
click to toggle source
# File lib/chainer/gradient_method.rb, line 15 def reallocate_cleared_grads @target.namedparams(include_uninit: false) do |(name, param)| if param.grad.nil? xm = Chainer.get_array_module(param.data) param.grad = xm::NArray.[](*param.data).new_zeros end end end
setup(link)
click to toggle source
Calls superclass method
Chainer::Optimizer#setup
# File lib/chainer/gradient_method.rb, line 8 def setup(link) super(link) link.params do |param| param.update_rule = create_update_rule end end
update(lossfun=nil, *args, **kwds)
click to toggle source
# File lib/chainer/gradient_method.rb, line 31 def update(lossfun=nil, *args, **kwds) if lossfun use_cleargrads = self.methods.include?(:use_cleargrads) ? self.use_cleargrads : true if args.size > 0 && kwds.keys.size > 0 loss = lossfun.(*args, **kwds) elsif args.size > 0 loss = lossfun.(*args) elsif kwds.keys.size > 0 loss = lossfun.(**kwds) end if use_cleargrads @target.cleargrads() else @target.zerograds() end loss.backward() end reallocate_cleared_grads call_hooks @t += 1 @target.params do |param| param.update end end