class Mnist::Loader

Constants

IMAGE_FILE_MAGIC
LABEL_FILE_MAGIC

Attributes

filename_image[R]
filename_label[R]

Public Class Methods

new(filename_image, filename_label, one_hot) click to toggle source
# File lib/mnist-learn.rb, line 39
def initialize(filename_image, filename_label, one_hot)
  @filename_image = filename_image
  @filename_label = filename_label
  @one_hot = one_hot
  @index = 0
end

Public Instance Methods

images() click to toggle source
# File lib/mnist-learn.rb, line 64
def images
  @all_images ||= load_images[2]
end
labels() click to toggle source
# File lib/mnist-learn.rb, line 68
def labels
  @all_labels ||= (@one_hot ? load_labels.map { |label_data| one_hot_transform(label_data) } : load_labels)
end
load_images() click to toggle source
# File lib/mnist-learn.rb, line 48
def load_images
  check_magic(input_images, IMAGE_FILE_MAGIC)
  @total_count = read_total_count(input_images)
  nrows, ncols = read_image_size(input_images)
  images = @total_count.times.map do
    read_image(nrows, ncols)
  end
  [nrows, ncols, images]
end
load_labels() click to toggle source
# File lib/mnist-learn.rb, line 58
def load_labels
  check_magic(input_labels, LABEL_FILE_MAGIC)
  @total_count = read_total_count(input_labels)
  read_labels(input_labels, @total_count)
end
next(batch_size) click to toggle source
# File lib/mnist-learn.rb, line 72
def next(batch_size)
  if @index == 0
      @rows, @columns, @images = load_images
      @labels = load_labels
  end
  images = []
  labels = []
  batch_size.times.each do
    next if @index >= @total_count
    image_data = @images[@index]
    label_data = @labels[@index]
    image_data.map! { |b| b.to_f / 255.0 }
    @index += 1
    images << image_data
    labels << (@one_hot ? one_hot_transform(label_data) : label_data.to_f)
  end
  [images, labels]
end
next_batch(batch_size, rnd: Random.new) click to toggle source
# File lib/mnist-learn.rb, line 91
def next_batch(batch_size, rnd: Random.new)
  @data_set ||= begin
    rows, columns, images = load_images
    labels = load_labels
    Array.new(images.size) do
      image_data = images[@index]
      label_data = labels[@index]
      image_data.map! { |b| b.to_f / 255.0 }
      @index += 1
      [image_data, (@one_hot ? one_hot_transform(label_data) : label_data.to_f)]
    end
  end
  @data_set.shuffle!(random: rnd)
  batch = @data_set[0...batch_size]
  [batch.map { |v| v[0]}, batch.map { |v| v[1]}]
end

Private Instance Methods

check_magic(input_file, expected_magic) click to toggle source
# File lib/mnist-learn.rb, line 116
def check_magic(input_file, expected_magic)
  actual_magic = read_magic(input_file)
  unless actual_magic == expected_magic
    raise InvalidMagic, "Expected #{expected_magic}, but #{actual_magic} is given"
  end
end
input_images() click to toggle source
# File lib/mnist-learn.rb, line 149
def input_images
  @input_images ||= File.open(filename_image)
end
input_labels() click to toggle source
# File lib/mnist-learn.rb, line 153
def input_labels
  @input_labels ||= File.open(filename_label)
end
one_hot_transform(label) click to toggle source
# File lib/mnist-learn.rb, line 110
def one_hot_transform(label)
  arr = Array.new(10) { 0.0 }
  arr[label] = 1.0
  arr
end
read_image(nrows, ncols) click to toggle source
# File lib/mnist-learn.rb, line 145
def read_image(nrows, ncols)
  input_images.read(nrows * ncols).unpack("C*")
end
read_image_size(input_file) click to toggle source
# File lib/mnist-learn.rb, line 139
def read_image_size(input_file)
  read_uint32(input_file, 2)
end
read_labels(input_file, n=1)
Alias for: read_uint8
read_magic(input_file) click to toggle source
# File lib/mnist-learn.rb, line 131
def read_magic(input_file)
  read_uint32(input_file).first
end
read_total_count(input_file) click to toggle source
# File lib/mnist-learn.rb, line 135
def read_total_count(input_file)
  read_uint32(input_file).first
end
read_uint32(input_file, n=1) click to toggle source
# File lib/mnist-learn.rb, line 127
def read_uint32(input_file, n=1)
  input_file.read(4 * n).unpack('N*')
end
read_uint8(input_file, n=1) click to toggle source
# File lib/mnist-learn.rb, line 123
def read_uint8(input_file, n=1)
  input_file.read(n).unpack('C*')
end
Also aliased as: read_labels