class Chainer::Training::Extensions::PrintReport
Public Class Methods
new(entries, log_report: 'LogReport', out: STDOUT)
click to toggle source
# File lib/chainer/training/extensions/print_report.rb, line 5 def initialize(entries, log_report: 'LogReport', out: STDOUT) @entries = entries @log_report = log_report @out = out @log_len = 0 # number of observations already printed # format information entry_widths = entries.map { |s| [10, s.size].max } templates = [] header = [] entries.zip(entry_widths).each do |entry, w| header << sprintf("%-#{w}s", entry) templates << [entry, "%-#{w}g ", ' ' * (w + 2)] end @header = header.join(' ') + "\n" @templates = templates end
Public Instance Methods
call(trainer)
click to toggle source
# File lib/chainer/training/extensions/print_report.rb, line 25 def call(trainer) if @header @out.write(@header) @header = nil end if @log_report.is_a?(String) log_report = trainer.get_extension(@log_report) elsif @log_report.is_a?(LogReport) log_report.(trainer) else raise TypeError, "log report has a wrong type #{log_report.class}" end log = log_report.log while log.size > @log_len @out.write("\033[J") print(log[@log_len]) @log_len += 1 end end
serialize(serializer)
click to toggle source
# File lib/chainer/training/extensions/print_report.rb, line 47 def serialize(serializer) if @log_report.is_a?(Chainer::Training::Extensions::LogReport) @log_report.serialize(serializer['_log_report']) end end
Private Instance Methods
print(observation)
click to toggle source
# File lib/chainer/training/extensions/print_report.rb, line 55 def print(observation) @templates.each do |entry, template, empty| if observation.keys.include?(entry) @out.write(sprintf(template, observation[entry])) else @out.write(empty) end end @out.write("\n") end