class Chainer::Training::Extensions::Snapshot

Attributes

filename_proc[RW]
save_class[RW]
target[RW]

Public Class Methods

new(save_class: nil, filename_proc: nil, target: nil) click to toggle source
# File lib/chainer/training/extensions/snapshot.rb, line 15
def initialize(save_class: nil, filename_proc: nil, target: nil)
  @priority = -100
  @trigger = [1, 'epoch']
  @save_class = save_class || Chainer::Serializers::MarshalSerializer
  @filename_proc = filename_proc || Proc.new { |trainer| "snapshot_iter_#{trainer.updater.iteration}" }
  @target = target
end
snapshot(save_class: nil, &block) click to toggle source
# File lib/chainer/training/extensions/snapshot.rb, line 11
def self.snapshot(save_class: nil, &block)
  self.new(save_class: save_class, filename_proc: block)
end
snapshot_object(target:, save_class:, &block) click to toggle source
# File lib/chainer/training/extensions/snapshot.rb, line 7
def self.snapshot_object(target:, save_class:, &block)
  self.new(save_class: save_class, filename_proc: block, target: target)
end

Public Instance Methods

call(trainer) click to toggle source
# File lib/chainer/training/extensions/snapshot.rb, line 23
def call(trainer)
  target = @target || trainer
  filename = filename_proc.call(trainer)
  prefix = "tmp#{filename}"
  temp_file = Tempfile.create(basename: prefix, tmpdir: trainer.out)
  save_class.save_file(temp_file.path, trainer)
  FileUtils.move(temp_file.path, File.join(trainer.out, filename))
end