class Torchrb::DataSet
Attributes
collection[R]
histogram[R]
model[R]
Public Class Methods
new(set, model, collection)
click to toggle source
# File lib/torchrb/data_set.rb, line 4 def initialize set, model, collection @set = set @model = model @collection = collection end
Public Instance Methods
classes()
click to toggle source
# File lib/torchrb/data_set.rb, line 18 def classes model.classes end
dimensions()
click to toggle source
# File lib/torchrb/data_set.rb, line 41 def dimensions torch.eval("return (##{var_name}.input):totable()", __FILE__, __LINE__).values.map(&:to_i) end
is_trainset?()
click to toggle source
# File lib/torchrb/data_set.rb, line 10 def is_trainset? @set == :train_set end
load(&progress_callback)
click to toggle source
# File lib/torchrb/data_set.rb, line 22 def load &progress_callback @progress_callback = progress_callback load_classes if is_trainset? init_variables do_load torch.eval <<-EOF, __FILE__, __LINE__ setmetatable(#{var_name}, {__index = function(t, i) return {t.input[i], t.label[i]} end} ); function #{var_name}:size() return #{var_name}.input:size(1) end EOF self end
torch()
click to toggle source
# File lib/torchrb/data_set.rb, line 45 def torch model.torch end
var_name()
click to toggle source
# File lib/torchrb/data_set.rb, line 14 def var_name @set.to_s end
Private Instance Methods
cudify()
click to toggle source
# File lib/torchrb/data_set.rb, line 81 def cudify torch.eval <<-EOF, __FILE__, __LINE__ #{var_name}.label:cuda() #{var_name}.input:cuda() EOF end
do_load()
click to toggle source
# File lib/torchrb/data_set.rb, line 60 def do_load values = collection.each_with_index.map do |data, index| load_single(data, index) end @histogram = Hash[*values.group_by { |v| v }.flat_map { |k, v| [k, v.size] }] end
init_variables()
click to toggle source
# File lib/torchrb/data_set.rb, line 88 def init_variables torch.eval <<-EOF, __FILE__, __LINE__ #{var_name} = { label= torch.LongTensor(#{collection.count}), input= torch.#{model.tensor_type}(#{collection.count} , #{model.dimensions.join ", "}) } EOF end
load_classes()
click to toggle source
# File lib/torchrb/data_set.rb, line 50 def load_classes torch.eval <<-EOF, __FILE__, __LINE__ classes = {#{classes.map(&:inspect).join ", "}} EOF end
load_from_cache(cached_file)
click to toggle source
# File lib/torchrb/data_set.rb, line 56 def load_from_cache(cached_file) torch.eval "#{var_name} = torch.load('#{cached_file}')", __FILE__, __LINE__ end
load_single(data, index)
click to toggle source
# File lib/torchrb/data_set.rb, line 67 def load_single(data, index) @progress_callback.call klass = model.prediction_class data label_index = classes.index(klass) raise "Returned class '#{klass}' is not one of #{classes}" if label_index.nil? label_value = label_index+1 torch.eval <<-EOF, __FILE__, __LINE__ #{model.to_tensor("torchrb_data", data).strip} #{var_name}.label[#{index+1}] = torch.LongTensor({#{label_value}}) #{var_name}.input[#{index+1}] = torchrb_data EOF klass end
normalize!()
click to toggle source
# File lib/torchrb/data_set.rb, line 97 def normalize! if @is_trainset torch.eval(<<-EOF, __FILE__, __LINE__).to_h.map { |k, v| {k.humanize => v.values} }.reduce({}, :merge) mean = {} -- store the mean, to normalize the test set in the future stdv = {} -- store the standard-deviation for the future for i=1,#{model.dimensions.first-1} do -- over each image channel mean[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:mean() -- mean estimation stdv[i] = #{var_name}.input[{ {}, {i}, {}, {} }]:std() -- std estimation #{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction #{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling end return {mean= mean, standard_diviation= stdv} EOF else torch.eval <<-EOF, __FILE__, __LINE__ for i=1,#{model.dimensions.first-1} do -- over each image channel #{var_name}.input[{ {}, {i}, {}, {} }]:add(-mean[i]) -- mean subtraction #{var_name}.input[{ {}, {i}, {}, {} }]:div(stdv[i]) -- std scaling end EOF end end