class Torchrb::Wrapper

Attributes

model[R]
progress[R]

Public Class Methods

for(model_class, options={}) { |instances| ... } click to toggle source
# File lib/torchrb/wrapper.rb, line 4
def self.for model_class, options={}
  @@instances[model_class] ||= new model_class, options
  if block_given?
    yield @@instances[model_class]
  else
    @@instances[model_class]
  end
end
new(model, options={}) click to toggle source
Calls superclass method Torchrb::Torch::new
# File lib/torchrb/wrapper.rb, line 16
def initialize model, options={}
  raise "#{model} must be a class and extend Torchrb::ModelBase!" unless model.is_a?(Class) || model.class < Torchrb::ModelBase
  @model = model
  super(options)
  self.class.include model.net
  self.class.include model.trainer
  model.setup self
end

Public Instance Methods

load_model_data() click to toggle source
# File lib/torchrb/wrapper.rb, line 26
def load_model_data
  @progress = 0
  load_dataset :train_set
  load_dataset :test_set
  load_dataset :validation_set
end
predict(sample) click to toggle source
Calls superclass method Torchrb::Torch#predict
# File lib/torchrb/wrapper.rb, line 44
def predict sample
  super sample
end
train() click to toggle source
Calls superclass method Torchrb::Torch#train
# File lib/torchrb/wrapper.rb, line 33
def train
  define_nn
  define_trainer

  cudify if enable_cuda
  super
  print_results
  store_network
  error_rate
end

Private Instance Methods

engine_storage() click to toggle source
# File lib/torchrb/wrapper.rb, line 63
def engine_storage
  Visit.cache_dir + "/net.t7"
end
load_dataset(set_name) click to toggle source
# File lib/torchrb/wrapper.rb, line 49
def load_dataset set_name
  set_size = model.send(set_name).size
  model.progress_callback progress, message: "Loading #{set_name.to_s.humanize} with #{set_size} element(s)."

  set = Torchrb::DataSet.new set_name, self
  set.load do
    @progress += 0.333 / set_size
    model.progress_callback progress
  end
  set.normalize! if model.normalize? && set.is_trainset?
end