class Chainer::Training::Trainer
Attributes
observation[RW]
out[RW]
stop_trigger[RW]
updater[RW]
Public Class Methods
new(updater, stop_trigger: nil, out: 'result')
click to toggle source
# File lib/chainer/training/trainer.rb, line 16 def initialize(updater, stop_trigger: nil, out: 'result') @updater = updater @stop_trigger = Chainer::Training::Util.get_trigger(stop_trigger) @observation = {} @out = out reporter = Reporter.new updater.get_all_optimizers().each do |(name, optimizer)| reporter.add_observer(name, optimizer.target) optimizer.target.namedlinks(skipself: true) do |suffix, observer| observer_name = name.to_s + suffix reporter.add_observer(observer_name, observer) end end @reporter = reporter @done = false @extensions = {} @start_at = nil @snapshot_elapsed_time = 0.0 @final_elapsed_time = nil updater.connect_trainer(self) end
Public Instance Methods
elapsed_time()
click to toggle source
# File lib/chainer/training/trainer.rb, line 42 def elapsed_time return @final_elapsed_time if @done raise "training has not been started yet" if @start_at.nil? Time.now.to_f - @start_at + @snapshot_elapsed_time.to_f end
extend(extension, name: nil, trigger: nil, priority: nil)
click to toggle source
# File lib/chainer/training/trainer.rb, line 49 def extend(extension, name: nil, trigger: nil, priority: nil) if name.nil? name = if extension.name extension.name elsif extension.default_name extension.default_name else raise ArgumentError, 'name is not given for the extension' end end raise 'the name "training" is prohibited as an extension name' if name == 'training' if trigger.nil? trigger = extension.methods.include?(:trigger) ? extension.trigger : [1, 'iteration'] end trigger = Chainer::Training::Util.get_trigger(trigger) if priority.nil? priority = extension.methods.include?(:priority) ? extension.priority : Extension::PRIORITY_READER end modified_name = name ordinal = 0 @extensions.each do |modified_name| ordinal += 1 modified_name = "#{name}_#{ordinal}" end extension.name = modified_name @extensions[modified_name] = ExtensionEntry.new(extension, priority, trigger) end
get_extension(name)
click to toggle source
# File lib/chainer/training/trainer.rb, line 83 def get_extension(name) if @extensions.keys.include?(name) @extensions[name].extension else raise "extension #{name} not found" end end
run()
click to toggle source
# File lib/chainer/training/trainer.rb, line 91 def run raise 'cannot run training loop multiple times' if @done FileUtils.mkdir_p(@out) extensions = @extensions.sort_by { |(_, e)| -e.priority }.map { |(name, extension)| [name, extension] } @start_at = Time.now.to_f extensions.each do |(_, entry)| initializer = entry.extension.methods.include?(:init) ? entry.extension.method(:init) : nil initializer.call(self) if initializer end update = @updater.method(:update) reporter = @reporter stop_trigger = @stop_trigger begin until stop_trigger.(self) do @observation = {} reporter.scope(@observation) do update.call extensions.each do |(name, entry)| entry.extension.(self) if entry.trigger.(self) end end end ensure extensions.each do |(_, entry)| finalize = entry.extension.methods.include?(:finalize) ? entry.extension.method(:finalize) : nil finalize.() if finalize end @updater.finalize() end @final_elapsed_time = @elapsed_time @done = true end
serialize(serializer)
click to toggle source
# File lib/chainer/training/trainer.rb, line 130 def serialize(serializer) updater.serialize(serializer['updater']) if @stop_trigger.respond_to?(:serialize) @stop_trigger.serialize(serializer['stop_trigger']) end s = serializer['extensions'] t = serializer['extension_triggers'] @extensions.each do |name, entry| if entry.extension.respond_to?(:serialize) entry.extension.serialize(s[name]) end if entry.trigger.respond_to?(:serialize) entry.trigger.serialize(t[name]) end end if serializer.is_a?(Chainer::Serializer) serializer.('_snapshot_elapsed_time', elapsed_time) else @snapshot_elapsed_time = serializer.('_snapshot_elapsed_time', 0.0) end end