class Chainer::Training::StandardUpdater

Attributes

iteration[RW]

Public Class Methods

new(iterator, optimizer, converter: nil, device: nil, loss_func: nil) click to toggle source
# File lib/chainer/training/standard_updater.rb, line 6
def initialize(iterator, optimizer, converter: nil, device: nil, loss_func: nil)
  if iterator.kind_of?(Dataset::Iterator)
    iterator = { main: iterator }
  end
  @iterators = iterator

  unless optimizer.kind_of?(Hash)
    optimizer = { main: optimizer }
  end
  @optimizers = optimizer

  @converter = converter || Dataset::Convert.method(:concat_examples)
  @loss_func = loss_func
  @device = device
  @iteration = 0
end

Public Instance Methods

epoch() click to toggle source
# File lib/chainer/training/standard_updater.rb, line 36
def epoch
  @iterators[:main].epoch
end
epoch_detail() click to toggle source
# File lib/chainer/training/standard_updater.rb, line 40
def epoch_detail
  @iterators[:main].epoch_detail
end
finalize() click to toggle source
# File lib/chainer/training/standard_updater.rb, line 60
def finalize
  @iterators.each do |(_, iterator)|
    iterator.finalize
  end
end
get_all_optimizers() click to toggle source
# File lib/chainer/training/standard_updater.rb, line 27
def get_all_optimizers
  @optimizers.to_h
end
get_optimizer(name) click to toggle source
# File lib/chainer/training/standard_updater.rb, line 23
def get_optimizer(name)
  @optimizers[name]
end
serialize(serializer) click to toggle source
# File lib/chainer/training/standard_updater.rb, line 66
def serialize(serializer)
  @iterators.each do |name, iterator|
    iterator.serialize(serializer["iterator:#{name}"])
  end
  @optimizers.each do |name, optimizer|
    optimizer.serialize(serializer["optimizer:#{name}"])
    optimizer.target.serialize(serializer["model:#{name}"])
  end

  @iteration = serializer.('iteration', @iteration)
end
update() click to toggle source
# File lib/chainer/training/standard_updater.rb, line 31
def update
  update_core
  @iteration += 1
end
update_core() click to toggle source
# File lib/chainer/training/standard_updater.rb, line 44
def update_core
  batch = @iterators[:main].next
  in_arrays = @converter.call(batch, device: @device)

  optimizer = @optimizers[:main]
  loss_func = @loss_func || optimizer.target

  if in_arrays.kind_of?(Array)
    optimizer.update(loss_func, *in_arrays)
  elsif in_arrays.kind_of?(Hash)
    optimizer.update(loss_func, **in_arrays)
  else
    optimizer.update(loss_func, in_arrays)
  end
end