module Chainer::Datasets::MNIST
Public Class Methods
get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil)
click to toggle source
# File lib/chainer/datasets/mnist.rb, line 6 def self.get_mnist(withlabel: true, ndim: 1, scale: 1.0, dtype: nil, label_dtype: nil) xm = Chainer::Device.default.xm dtype ||= xm::SFloat label_dtype ||= xm::Int32 train_raw = retrieve_mnist(type: :train) train = preprocess_mnist(train_raw, withlabel, ndim, scale, dtype, label_dtype) test_raw = retrieve_mnist(type: :test) test = preprocess_mnist(test_raw, withlabel, ndim, scale, dtype, label_dtype) [train, test] end
preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype)
click to toggle source
# File lib/chainer/datasets/mnist.rb, line 19 def self.preprocess_mnist(raw, withlabel, ndim, scale, image_dtype, label_dtype) images = raw[:x] if ndim == 2 images = images.reshape(true, 28, 28) elsif ndim == 3 images = images.reshape(true, 1, 28, 28) elsif ndim != 1 raise "invalid ndim for MNIST dataset" end images = images.cast_to(image_dtype) images *= scale / 255.0 if withlabel labels = raw[:y].cast_to(label_dtype) TupleDataset.new(images, labels) else images end end
retrieve_mnist(type:)
click to toggle source
# File lib/chainer/datasets/mnist.rb, line 40 def self.retrieve_mnist(type:) train_table = ::Datasets::MNIST.new(type: type).to_table xm = Chainer::Device.default.xm { x: xm::UInt8[*train_table[:pixels]], y: xm::UInt8[*train_table[:label]] } end