class Torchrb::DataSet

Attributes

collection[R]
histogram[R]
model[R]

Public Class Methods

new(set, model, collection) click to toggle source
# File lib/torchrb/data_set.rb, line 4
def initialize set, model, collection
  @set = set
  @model = model
  @collection = collection
end

Public Instance Methods

classes() click to toggle source
# File lib/torchrb/data_set.rb, line 18
def classes
  model.classes
end
dimensions() click to toggle source
# File lib/torchrb/data_set.rb, line 41
def dimensions
  torch.eval("return (##{var_name}.input):totable()", __FILE__, __LINE__).values.map(&:to_i)
end
is_trainset?() click to toggle source
# File lib/torchrb/data_set.rb, line 10
def is_trainset?
  @set == :train_set
end
load(&progress_callback) click to toggle source
# File lib/torchrb/data_set.rb, line 22
  def load &progress_callback
    @progress_callback = progress_callback

    load_classes if is_trainset?
    init_variables
    do_load

    torch.eval <<-EOF, __FILE__, __LINE__
      setmetatable(#{var_name}, {__index = function(t, i)
           return {t.input[i], t.label[i]}
       end}
      );
      function #{var_name}:size()
        return #{var_name}.input:size(1)
      end
    EOF
    self
  end
torch() click to toggle source
# File lib/torchrb/data_set.rb, line 45
def torch
  model.torch
end
var_name() click to toggle source
# File lib/torchrb/data_set.rb, line 14
def var_name
  @set.to_s
end

Private Instance Methods

cudify() click to toggle source
# File lib/torchrb/data_set.rb, line 81
  def cudify
    torch.eval <<-EOF, __FILE__, __LINE__
      #{var_name}.label:cuda()
      #{var_name}.input:cuda()
    EOF
  end
do_load() click to toggle source
# File lib/torchrb/data_set.rb, line 60
def do_load
  values = collection.each_with_index.map do |data, index|
    load_single(data, index)
  end
  @histogram = Hash[*values.group_by { |v| v }.flat_map { |k, v| [k, v.size] }]
end
init_variables() click to toggle source
# File lib/torchrb/data_set.rb, line 88
  def init_variables
    torch.eval <<-EOF, __FILE__, __LINE__
      #{var_name} = {
        label= torch.LongTensor(#{collection.count}),
        input= torch.#{model.tensor_type}(#{collection.count} , #{model.dimensions.join ", "})
      }
    EOF
  end
load_classes() click to toggle source
# File lib/torchrb/data_set.rb, line 50
  def load_classes
    torch.eval <<-EOF, __FILE__, __LINE__
      classes = {#{classes.map(&:inspect).join ", "}}
    EOF
  end
load_from_cache(cached_file) click to toggle source
# File lib/torchrb/data_set.rb, line 56
def load_from_cache(cached_file)
  torch.eval "#{var_name} = torch.load('#{cached_file}')", __FILE__, __LINE__
end
load_single(data, index) click to toggle source
# File lib/torchrb/data_set.rb, line 67
  def load_single(data, index)
    @progress_callback.call
    klass = model.prediction_class data
    label_index = classes.index(klass)
    raise "Returned class '#{klass}' is not one of #{classes}" if label_index.nil?
    label_value = label_index+1
    torch.eval <<-EOF, __FILE__, __LINE__
       #{model.to_tensor("torchrb_data", data).strip}
       #{var_name}.label[#{index+1}] = torch.LongTensor({#{label_value}})
       #{var_name}.input[#{index+1}] = torchrb_data
    EOF
    klass
  end
normalize!() click to toggle source
# File lib/torchrb/data_set.rb, line 97
  def normalize!
    if @is_trainset
      torch.eval(<<-EOF, __FILE__, __LINE__).to_h.map { |k, v| {k.humanize => v.values} }.reduce({}, :merge)
        mean = {} -- store the mean, to normalize the test set in the future
        stdv  = {} -- store the standard-deviation for the future
        for i=1,#{model.dimensions.first-1} do -- over each image channel
            mean[i] = #{var_name}.input[{ {}, {i}, {}, {}  }]:mean() -- mean estimation
            stdv[i] = #{var_name}.input[{ {}, {i}, {}, {}  }]:std() -- std estimation

            #{var_name}.input[{ {}, {i}, {}, {}  }]:add(-mean[i]) -- mean subtraction
            #{var_name}.input[{ {}, {i}, {}, {}  }]:div(stdv[i]) -- std scaling
        end
        return {mean= mean, standard_diviation= stdv}
      EOF
    else
      torch.eval <<-EOF, __FILE__, __LINE__
      for i=1,#{model.dimensions.first-1} do -- over each image channel
          #{var_name}.input[{ {}, {i}, {}, {}  }]:add(-mean[i]) -- mean subtraction
          #{var_name}.input[{ {}, {i}, {}, {}  }]:div(stdv[i]) -- std scaling
      end
      EOF
    end
  end