class Torchrb::Wrapper
Attributes
model[R]
progress[R]
Public Class Methods
for(model_class, options={}) { |instances| ... }
click to toggle source
# File lib/torchrb/wrapper.rb, line 4 def self.for model_class, options={} @@instances[model_class] ||= new model_class, options if block_given? yield @@instances[model_class] else @@instances[model_class] end end
new(model, options={})
click to toggle source
Calls superclass method
Torchrb::Torch::new
# File lib/torchrb/wrapper.rb, line 16 def initialize model, options={} raise "#{model} must be a class and extend Torchrb::ModelBase!" unless model.is_a?(Class) || model.class < Torchrb::ModelBase @model = model super(options) self.class.include model.net self.class.include model.trainer model.setup self end
Public Instance Methods
load_model_data()
click to toggle source
# File lib/torchrb/wrapper.rb, line 26 def load_model_data @progress = 0 load_dataset :train_set load_dataset :test_set load_dataset :validation_set end
predict(sample)
click to toggle source
Calls superclass method
Torchrb::Torch#predict
# File lib/torchrb/wrapper.rb, line 44 def predict sample super sample end
train()
click to toggle source
Calls superclass method
Torchrb::Torch#train
# File lib/torchrb/wrapper.rb, line 33 def train define_nn define_trainer cudify if enable_cuda super print_results store_network error_rate end
Private Instance Methods
engine_storage()
click to toggle source
# File lib/torchrb/wrapper.rb, line 63 def engine_storage Visit.cache_dir + "/net.t7" end
load_dataset(set_name)
click to toggle source
# File lib/torchrb/wrapper.rb, line 49 def load_dataset set_name set_size = model.send(set_name).size model.progress_callback progress, message: "Loading #{set_name.to_s.humanize} with #{set_size} element(s)." set = Torchrb::DataSet.new set_name, self set.load do @progress += 0.333 / set_size model.progress_callback progress end set.normalize! if model.normalize? && set.is_trainset? end