class Datasets::MNIST

Constants

BASE_URL

Public Class Methods

new(type: :train) click to toggle source
Calls superclass method Datasets::Dataset::new
# File lib/datasets/mnist.rb, line 21
def initialize(type: :train)
  unless [:train, :test].include?(type)
    raise ArgumentError, "Please set type :train or :test: #{type.inspect}"
  end

  super()

  @metadata.id = "#{dataset_name.downcase}-#{type}"
  @metadata.name = "#{dataset_name}: #{type}"
  @metadata.url = self.class::BASE_URL
  @type = type

  case type
  when :train
    @metadata.description = "a training set of 60,000 examples"
  when :test
    @metadata.description = "a test set of 10,000 examples"
  end
end

Public Instance Methods

each(&block) click to toggle source
# File lib/datasets/mnist.rb, line 41
def each(&block)
  return to_enum(__method__) unless block_given?

  image_path = cache_dir_path + target_file(:image)
  label_path = cache_dir_path + target_file(:label)
  base_url = self.class::BASE_URL

  unless image_path.exist?
    download(image_path, base_url + target_file(:image))
  end

  unless label_path.exist?
    download(label_path, base_url + target_file(:label))
  end

  open_data(image_path, label_path, &block)
end

Private Instance Methods

dataset_name() click to toggle source
# File lib/datasets/mnist.rb, line 111
def dataset_name
  "MNIST"
end
open_data(image_path, label_path) { |record| ... } click to toggle source
# File lib/datasets/mnist.rb, line 60
def open_data(image_path, label_path, &block)
  labels = parse_labels(label_path)

  Zlib::GzipReader.open(image_path) do |f|
    n_uint32s = 4
    n_bytes = n_uint32s * 4
    mnist_magic_number = 2051
    magic, n_images, n_rows, n_cols = f.read(n_bytes).unpack("N*")
    if magic != mnist_magic_number
      raise Error, "This is not #{dataset_name} image file"
    end
    n_images.times do |i|
      data = f.read(n_rows * n_cols)
      label = labels[i]
      yield Record.new(data, label)
    end
  end
end
parse_labels(file_path) click to toggle source
# File lib/datasets/mnist.rb, line 98
def parse_labels(file_path)
  Zlib::GzipReader.open(file_path) do |f|
    n_uint32s = 4
    n_bytes = n_uint32s * 2
    mnist_magic_number = 2049
    magic, n_labels = f.read(n_bytes).unpack('N2')
    if magic != mnist_magic_number
      raise Error, "This is not #{dataset_name} label file"
    end
    f.read(n_labels).unpack('C*')
  end
end
target_file(data) click to toggle source
# File lib/datasets/mnist.rb, line 79
def target_file(data)
  case @type
  when :train
    case data
    when :image
      "train-images-idx3-ubyte.gz"
    when :label
      "train-labels-idx1-ubyte.gz"
    end
  when :test
    case data
    when :image
      "t10k-images-idx3-ubyte.gz"
    when :label
      "t10k-labels-idx1-ubyte.gz"
    end
  end
end