class Chainer::Training::StandardUpdater
Attributes
iteration[RW]
Public Class Methods
new(iterator, optimizer, converter: nil, device: nil, loss_func: nil)
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 6 def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil) if iterator.kind_of?(Dataset::Iterator) iterator = { main: iterator } end @iterators = iterator unless optimizer.kind_of?(Hash) optimizer = { main: optimizer } end @optimizers = optimizer @converter = converter || Dataset::Convert.method(:concat_examples) @loss_func = loss_func @device = device @iteration = 0 end
Public Instance Methods
epoch()
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 36 def epoch @iterators[:main].epoch end
epoch_detail()
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 40 def epoch_detail @iterators[:main].epoch_detail end
finalize()
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 60 def finalize @iterators.each do |(_, iterator)| iterator.finalize end end
get_all_optimizers()
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 27 def get_all_optimizers @optimizers.to_h end
get_optimizer(name)
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 23 def get_optimizer(name) @optimizers[name] end
serialize(serializer)
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 66 def serialize(serializer) @iterators.each do |name, iterator| iterator.serialize(serializer["iterator:#{name}"]) end @optimizers.each do |name, optimizer| optimizer.serialize(serializer["optimizer:#{name}"]) optimizer.target.serialize(serializer["model:#{name}"]) end @iteration = serializer.('iteration', @iteration) end
update()
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 31 def update update_core @iteration += 1 end
update_core()
click to toggle source
# File lib/chainer/training/standard_updater.rb, line 44 def update_core batch = @iterators[:main].next in_arrays = @converter.call(batch, device: @device) optimizer = @optimizers[:main] loss_func = @loss_func || optimizer.target if in_arrays.kind_of?(Array) optimizer.update(loss_func, *in_arrays) elsif in_arrays.kind_of?(Hash) optimizer.update(loss_func, **in_arrays) else optimizer.update(loss_func, in_arrays) end end