module RubyBrain::DataSet::Mnist

Public Class Methods

data() click to toggle source
# File lib/ruby_brain/dataset/mnist/data.rb, line 14
def data
  train_images_path = Dir.pwd + '/train-images-idx3-ubyte.gz'
  train_labels_path = Dir.pwd + '/train-labels-idx1-ubyte.gz'
  test_images_path = Dir.pwd + '/t10k-images-idx3-ubyte.gz'
  test_labels_path = Dir.pwd + '/t10k-labels-idx1-ubyte.gz'

  unless File.exist?(train_images_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', train_images_path) 
  end
  
  unless File.exist?(train_labels_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', train_labels_path) 
  end

  unless File.exist?(test_images_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', test_images_path) 
  end
  
  unless File.exist?(test_labels_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', test_labels_path) 
  end

  train_images = Mnist.load_images(train_images_path)
  train_labels = Mnist.load_labels(train_labels_path)
  test_images = Mnist.load_images(test_images_path)
  test_labels = Mnist.load_labels(test_labels_path)

  input_training_set = train_images[2].map do |image|
    image.unpack('C*').map {|e| e / 255.0}
  end
  
  output_training_set = train_labels.map do |label|
    one_hot_vector = Array.new(10, 0)
    one_hot_vector[label] = 1
    one_hot_vector
  end

  input_test_set = test_images[2].map do |image|
    image.unpack('C*').map {|e| e / 255.0}
  end
  
  output_test_set = test_labels.map do |label|
    one_hot_vector = Array.new(10, 0)
    one_hot_vector[label] = 1
    one_hot_vector
  end

  # puts train_images[0].class
  # puts train_images[1].class
  # puts train_images[2].size
  # puts train_images[2][0].size
  # puts train_images[2][59999][783].class
  # puts train_images[2][59999].class
  # puts "------------------------------"

  # 10.times do |j|
  #   train_images[2][j].unpack('C*').each_with_index do |e, i|
  #     print(e > 50 ? 'x' : ' ')
  #     puts if (i % 28) == 0
  #   end
  #   puts
  #   puts train_labels[j]
  # end
  
  [{input: input_training_set, output: output_training_set}, {input: input_test_set, output: output_test_set}]
end
download_file(target_url, dest_path) click to toggle source
# File lib/ruby_brain/dataset/mnist/data.rb, line 6
def download_file(target_url, dest_path)
  File.open(dest_path, "wb") do |saved_file|
    open(target_url, "rb") do |read_file|
      saved_file.write(read_file.read)
    end
  end
end

Private Instance Methods

data() click to toggle source
# File lib/ruby_brain/dataset/mnist/data.rb, line 14
def data
  train_images_path = Dir.pwd + '/train-images-idx3-ubyte.gz'
  train_labels_path = Dir.pwd + '/train-labels-idx1-ubyte.gz'
  test_images_path = Dir.pwd + '/t10k-images-idx3-ubyte.gz'
  test_labels_path = Dir.pwd + '/t10k-labels-idx1-ubyte.gz'

  unless File.exist?(train_images_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', train_images_path) 
  end
  
  unless File.exist?(train_labels_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', train_labels_path) 
  end

  unless File.exist?(test_images_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz', test_images_path) 
  end
  
  unless File.exist?(test_labels_path)
    puts 'downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz ...'
    download_file('http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', test_labels_path) 
  end

  train_images = Mnist.load_images(train_images_path)
  train_labels = Mnist.load_labels(train_labels_path)
  test_images = Mnist.load_images(test_images_path)
  test_labels = Mnist.load_labels(test_labels_path)

  input_training_set = train_images[2].map do |image|
    image.unpack('C*').map {|e| e / 255.0}
  end
  
  output_training_set = train_labels.map do |label|
    one_hot_vector = Array.new(10, 0)
    one_hot_vector[label] = 1
    one_hot_vector
  end

  input_test_set = test_images[2].map do |image|
    image.unpack('C*').map {|e| e / 255.0}
  end
  
  output_test_set = test_labels.map do |label|
    one_hot_vector = Array.new(10, 0)
    one_hot_vector[label] = 1
    one_hot_vector
  end

  # puts train_images[0].class
  # puts train_images[1].class
  # puts train_images[2].size
  # puts train_images[2][0].size
  # puts train_images[2][59999][783].class
  # puts train_images[2][59999].class
  # puts "------------------------------"

  # 10.times do |j|
  #   train_images[2][j].unpack('C*').each_with_index do |e, i|
  #     print(e > 50 ? 'x' : ' ')
  #     puts if (i % 28) == 0
  #   end
  #   puts
  #   puts train_labels[j]
  # end
  
  [{input: input_training_set, output: output_training_set}, {input: input_test_set, output: output_test_set}]
end
download_file(target_url, dest_path) click to toggle source
# File lib/ruby_brain/dataset/mnist/data.rb, line 6
def download_file(target_url, dest_path)
  File.open(dest_path, "wb") do |saved_file|
    open(target_url, "rb") do |read_file|
      saved_file.write(read_file.read)
    end
  end
end