class Datasets::CIFAR

Public Class Methods

new(n_classes: 10, type: :train) click to toggle source
Calls superclass method
# File lib/datasets/cifar.rb, line 28
def initialize(n_classes: 10, type: :train)
  unless [10, 100].include?(n_classes)
    message = "Please set n_classes 10 or 100: #{n_classes.inspect}"
    raise ArgumentError, message
  end
  unless [:train, :test].include?(type)
    message = "Please set type :train or :test: #{type.inspect}"
    raise ArgumentError, message
  end

  super()

  @metadata.id = "cifar-#{n_classes}"
  @metadata.name = "CIFAR-#{n_classes}"
  @metadata.url = "https://www.cs.toronto.edu/~kriz/cifar.html"
  @metadata.description = "CIFAR-#{n_classes} is 32x32 image dataset"

  @n_classes = n_classes
  @type = type
end

Public Instance Methods

each(&block) click to toggle source
# File lib/datasets/cifar.rb, line 49
def each(&block)
  return to_enum(__method__) unless block_given?

  data_path = cache_dir_path + "cifar-#{@n_classes}.tar.gz"
  unless data_path.exist?
    data_url = "https://www.cs.toronto.edu/~kriz/cifar-#{@n_classes}-binary.tar.gz"
    download(data_path, data_url)
  end

  parse_data(data_path, &block)
end

Private Instance Methods

parse_data(data_path, &block) click to toggle source
# File lib/datasets/cifar.rb, line 63
def parse_data(data_path, &block)
  open_tar_gz(data_path) do |tar|
    target_file_names.each do |target_file_name|
      tar.seek(target_file_name) do |entry|
        parse_entry(entry, &block)
      end
    end
  end
end
parse_entry(entry) { |record10| ... } click to toggle source
# File lib/datasets/cifar.rb, line 106
def parse_entry(entry)
  case @n_classes
  when 10
    loop do
      label = entry.read(1)
      break if label.nil?
      label = label.unpack("C")[0]
      data = entry.read(3072)
      yield Record10.new(data, label)
    end
  when 100
    loop do
      coarse_label = entry.read(1)
      break if coarse_label.nil?
      coarse_label = coarse_label.unpack("C")[0]
      fine_label = entry.read(1).unpack("C")[0]
      data = entry.read(3072)
      yield Record100.new(data, coarse_label, fine_label)
    end
  end
end
target_file_names() click to toggle source
# File lib/datasets/cifar.rb, line 73
def target_file_names
  case @n_classes
  when 10
    prefix = 'cifar-10-batches-bin'
    case @type
    when :train
      [
        "#{prefix}/data_batch_1.bin",
        "#{prefix}/data_batch_2.bin",
        "#{prefix}/data_batch_3.bin",
        "#{prefix}/data_batch_4.bin",
        "#{prefix}/data_batch_5.bin",
      ]
    when :test
      [
        "#{prefix}/test_batch.bin"
      ]
    end
  when 100
    prefix = "cifar-100-binary"
    case @type
    when :train
      [
        "#{prefix}/train.bin",
      ]
    when :test
      [
        "#{prefix}/test.bin",
      ]
    end
  end
end