module CIFAR10

Public Class Methods

categorical(y_data) click to toggle source
# File lib/nn/cifar10.rb, line 42
def self.categorical(y_data)
  y_data = y_data.map do |label|
    classes = Array.new(10, 0)
    classes[label] = 1
    classes
  end
end
dir() click to toggle source
# File lib/nn/cifar10.rb, line 50
def self.dir
  "cifar-10-batches-bin"
end
load_test() click to toggle source
# File lib/nn/cifar10.rb, line 22
def self.load_test
  if File.exist?("CIFAR-10-test.marshal")
    marshal = File.binread("CIFAR-10-test.marshal")
    return Marshal.load(marshal)
  end
  bin = File.binread("#{dir}/test_batch.bin")
  datasets = bin.unpack("C*")
  x_test = []
  y_test = []
  loop do
    label = datasets.shift
    break unless label
    x_test << datasets.slice!(0, 3072)
    y_test << label
  end
  test = [x_test, y_test]
  File.binwrite("CIFAR-10-test.marshal", Marshal.dump(test))
  test
end
load_train(index) click to toggle source
# File lib/nn/cifar10.rb, line 2
def self.load_train(index)
  if File.exist?("CIFAR-10-train#{index}.marshal")
    marshal = File.binread("CIFAR-10-train#{index}.marshal")
    return Marshal.load(marshal)
  end
  bin = File.binread("#{dir}/data_batch_#{index}.bin")
  datasets = bin.unpack("C*")
  x_train = []
  y_train = []
  loop do
    label = datasets.shift
    break unless label
    x_train << datasets.slice!(0, 3072)
    y_train << label
  end
  train = [x_train, y_train]
  File.binwrite("CIFAR-10-train#{index}.marshal", Marshal.dump(train))
  train
end