class Chainer::Training::Extensions::ProgressBar
Public Class Methods
new(training_length: nil, update_interval: 100, bar_length: 50, out: STDOUT)
click to toggle source
# File lib/chainer/training/extensions/progress_bar.rb, line 8 def initialize(training_length: nil, update_interval: 100, bar_length: 50, out: STDOUT) @training_length = training_length @status_template = nil @update_interval = update_interval @bar_length = bar_length @out = out @out.sync = true @recent_timing = [] end
Public Instance Methods
call(trainer)
click to toggle source
# File lib/chainer/training/extensions/progress_bar.rb, line 18 def call(trainer) if @training_length.nil? t = trainer.stop_trigger raise TypeError, "cannot retrieve the training length #{t.class}" unless t.is_a?(Chainer::Training::Triggers::IntervalTrigger) @training_length = [t.period, t.unit] end if @status_template.nil? @status_template = ERB.new("<%= sprintf('%10d', self.iteration) %> iter, <%= self.epoch %> epoch / #{@training_length[0]} #{@training_length[1]}s\n") end length, unit = @training_length iteration = trainer.updater.iteration # print the progress bar according to interval return unless iteration % @update_interval == 0 epoch = trainer.updater.epoch_detail now = Time.now.to_f @recent_timing << [iteration, epoch, now] @out.write("\033[J") if unit == 'iteration' rate = iteration.to_f / length else rate = epoch.to_f / length end marks = '#' * (rate * @bar_length).to_i @out.write(sprintf(" total [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), rate * 100)) epoch_rate = epoch - epoch.to_i marks = '#' * (epoch_rate * @bar_length).to_i @out.write(sprintf("this epoch [%s%s] %6.2f%\n", marks, '.' * (@bar_length - marks.size), epoch_rate * 100)) status = @status_template.result(trainer.updater.bind) @out.write(status) old_t, old_e, old_sec = @recent_timing[0] span = now - old_sec if span.zero? speed_t = Float::INFINITY speed_e = Float::INFINITY else speed_t = (iteration - old_t) / span speed_e = (epoch - old_e) / span end if unit == 'iteration' estimated_time = (length - iteration) / speed_t else estimated_time = (length - epoch) / speed_e end @out.write(sprintf("%10.5g iters/sec. Estimated time to finish: %s.\n", speed_t, (Time.parse("1991/01/01") + (estimated_time)).strftime("%H:%m:%S"))) # move the cursor to the head of the progress bar @out.write("\033[4A") # TODO: Support Windows @out.flush @recent_timing.delete_at(0) if @recent_timing.size > 100 end
finalize()
click to toggle source
# File lib/chainer/training/extensions/progress_bar.rb, line 83 def finalize @out.write("\033[J") # TODO: Support Windows @out.flush end