class Mnist::Loader
Constants
- IMAGE_FILE_MAGIC
- LABEL_FILE_MAGIC
Attributes
filename_image[R]
filename_label[R]
Public Class Methods
new(filename_image, filename_label, one_hot)
click to toggle source
# File lib/mnist-learn.rb, line 39 def initialize(filename_image, filename_label, one_hot) @filename_image = filename_image @filename_label = filename_label @one_hot = one_hot @index = 0 end
Public Instance Methods
images()
click to toggle source
# File lib/mnist-learn.rb, line 64 def images @all_images ||= load_images[2] end
labels()
click to toggle source
# File lib/mnist-learn.rb, line 68 def labels @all_labels ||= (@one_hot ? load_labels.map { |label_data| one_hot_transform(label_data) } : load_labels) end
load_images()
click to toggle source
# File lib/mnist-learn.rb, line 48 def load_images check_magic(input_images, IMAGE_FILE_MAGIC) @total_count = read_total_count(input_images) nrows, ncols = read_image_size(input_images) images = @total_count.times.map do read_image(nrows, ncols) end [nrows, ncols, images] end
load_labels()
click to toggle source
# File lib/mnist-learn.rb, line 58 def load_labels check_magic(input_labels, LABEL_FILE_MAGIC) @total_count = read_total_count(input_labels) read_labels(input_labels, @total_count) end
next(batch_size)
click to toggle source
# File lib/mnist-learn.rb, line 72 def next(batch_size) if @index == 0 @rows, @columns, @images = load_images @labels = load_labels end images = [] labels = [] batch_size.times.each do next if @index >= @total_count image_data = @images[@index] label_data = @labels[@index] image_data.map! { |b| b.to_f / 255.0 } @index += 1 images << image_data labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f) end [images, labels] end
next_batch(batch_size, rnd: Random.new)
click to toggle source
# File lib/mnist-learn.rb, line 91 def next_batch(batch_size, rnd: Random.new) @data_set ||= begin rows, columns, images = load_images labels = load_labels Array.new(images.size) do image_data = images[@index] label_data = labels[@index] image_data.map! { |b| b.to_f / 255.0 } @index += 1 [image_data, (@one_hot ? one_hot_transform(label_data) : label_data.to_f)] end end @data_set.shuffle!(random: rnd) batch = @data_set[0...batch_size] [batch.map { |v| v[0]}, batch.map { |v| v[1]}] end
Private Instance Methods
check_magic(input_file, expected_magic)
click to toggle source
# File lib/mnist-learn.rb, line 116 def check_magic(input_file, expected_magic) actual_magic = read_magic(input_file) unless actual_magic == expected_magic raise InvalidMagic, "Expected #{expected_magic}, but #{actual_magic} is given" end end
input_images()
click to toggle source
# File lib/mnist-learn.rb, line 149 def input_images @input_images ||= File.open(filename_image) end
input_labels()
click to toggle source
# File lib/mnist-learn.rb, line 153 def input_labels @input_labels ||= File.open(filename_label) end
one_hot_transform(label)
click to toggle source
# File lib/mnist-learn.rb, line 110 def one_hot_transform(label) arr = Array.new(10) { 0.0 } arr[label] = 1.0 arr end
read_image(nrows, ncols)
click to toggle source
# File lib/mnist-learn.rb, line 145 def read_image(nrows, ncols) input_images.read(nrows * ncols).unpack("C*") end
read_image_size(input_file)
click to toggle source
# File lib/mnist-learn.rb, line 139 def read_image_size(input_file) read_uint32(input_file, 2) end
read_magic(input_file)
click to toggle source
# File lib/mnist-learn.rb, line 131 def read_magic(input_file) read_uint32(input_file).first end
read_total_count(input_file)
click to toggle source
# File lib/mnist-learn.rb, line 135 def read_total_count(input_file) read_uint32(input_file).first end
read_uint32(input_file, n=1)
click to toggle source
# File lib/mnist-learn.rb, line 127 def read_uint32(input_file, n=1) input_file.read(4 * n).unpack('N*') end
read_uint8(input_file, n=1)
click to toggle source
# File lib/mnist-learn.rb, line 123 def read_uint8(input_file, n=1) input_file.read(n).unpack('C*') end
Also aliased as: read_labels