class Chainer::Iterators::SerialIterator
Attributes
epoch[R]
is_new_epoch[R]
Public Class Methods
new(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default)
click to toggle source
# File lib/chainer/iterators/serial_iterator.rb, line 6 def initialize(dataset, batch_size, repeat: true, shuffle: true, device: Chainer::Device.default) @dataset = dataset @batch_size = batch_size @repeat = repeat @shuffle = shuffle @device = device @xm = device.xm reset end
Public Instance Methods
epoch_detail()
click to toggle source
# File lib/chainer/iterators/serial_iterator.rb, line 56 def epoch_detail @epoch + @current_position.to_f / @dataset.size end
next()
click to toggle source
# File lib/chainer/iterators/serial_iterator.rb, line 17 def next raise StopIteration if !@repeat && @epoch > 0 @previous_epoch_detail = epoch_detail i = @current_position n = @dataset.size i_end = [i + @batch_size, n].min batch = @order[i...i_end].to_a.map { |index| @dataset[index] } if i_end >= n if @repeat rest = i_end - n unless @order.nil? @order = @order.class[*@order.to_a.shuffle] end if rest > 0 if @order.nil? batch = batch.append(@dataset[0...rest]) else batch = @dataset[0...rest].map { |index| @dataset[index] } end end @current_position = rest else @current_position = 0 end @epoch += 1 @is_new_epoch = true else @is_new_epoch = false @current_position = i_end end batch end
reset()
click to toggle source
# File lib/chainer/iterators/serial_iterator.rb, line 85 def reset if @shuffle order = @dataset.size.times.map(&:to_i).shuffle @order = @xm::Int64[*order] else order = @dataset.size.times.map(&:to_i) @order = @xm::Int64[*order] end @current_position = 0 @epoch = 0 @is_new_epoch = false @previous_epoch_detail = -1.0 end
serialize(serializer)
click to toggle source
# File lib/chainer/iterators/serial_iterator.rb, line 60 def serialize(serializer) @current_position = serializer.('current_position', @current_position) @epoch = serializer.('epoch', @epoch) @is_new_epoch = serializer.('is_new_epoch', @is_new_epoch) unless @order.nil? begin serializer.('order', @order) rescue KeyError serializer('_order', @order) end end begin @previous_epoch_detail = serializer.( 'previous_epoch_detail', @previous_epoch_detail) rescue KeyError # guess previous_epoch_detail for older version @previous_epoch_detail = @epoch + (@current_position - @batch_size) / @dataset.size if epoch_detail > 0 @previous_epoch_detail = [@previous_epoch_detail, 0.0].max else @previous_epoch_detail = -1.0 end end end