class Torchrb::ModelBase

Constants

REQUIRED_OPTIONS

Public Class Methods

error_rate() click to toggle source
# File lib/torchrb/model_base.rb, line 33
def error_rate
  torch.error_rate
end
predict(sample) click to toggle source
# File lib/torchrb/model_base.rb, line 59
def predict sample
  torch.predict sample, network_storage_path
end
progress_callback(progress: nil, message: nil, error_rate: Float::NAN) click to toggle source
# File lib/torchrb/model_base.rb, line 5
def progress_callback progress: nil, message: nil, error_rate: Float::NAN
  raise NotImplementedError.new("Implement this method in your Model")
end
setup_nn(options={}) click to toggle source
# File lib/torchrb/model_base.rb, line 9
def setup_nn options={}
  check_options(options)
  {
      net: Torchrb::NN::Basic,
      trainer: Torchrb::NN::TrainerDefault,
      tensor_type: "DoubleTensor",
      dimensions: [0],
      classes: [],
      dataset_split: [80, 10, 10],
      normalize: false,
      enable_cuda: false,
      auto_store_trained_network: true,
      network_storage_path: "tmp/cache/torchrb",
      debug: false,
  }.merge!(options).each do |option, default|
    cattr_reader(option)
    class_variable_set(:"@@#{option}", default)
  end
  cattr_reader(:torch) { Torchrb::Torch.new options }

  @net_options = load_extension(options[:net])
  @trainer_options = load_extension(options[:trainer])
end
train() click to toggle source
# File lib/torchrb/model_base.rb, line 37
def train
  progress_callback message: 'Loading data'
  load_model_data

  torch.iteration_callback= method(:progress_callback)

  define_nn @net_options
  define_trainer @trainer_options

  torch.cudify if enable_cuda

  progress_callback message: 'Start training'
  torch.train
  progress_callback message: 'Done'

  torch.print_results
  torch.store_network network_storage_path if auto_store_trained_network

  after_training if respond_to?(:after_training)
  torch.error_rate
end

Private Class Methods

check_options(options) click to toggle source
# File lib/torchrb/model_base.rb, line 65
def check_options(options)
  REQUIRED_OPTIONS.each do |required_option|
    raise "Option '#{required_option}' is required." unless options.has_key?(required_option)
  end
end
load_dataset(set_name, collection) click to toggle source
# File lib/torchrb/model_base.rb, line 96
def load_dataset set_name, collection
  progress_callback progress: @progress, message: "Loading #{set_name.to_s.humanize} with #{collection.size} element(s)."

  set = Torchrb::DataSet.new set_name, self, collection
  set.load do
    @progress += 0.333 / collection.size
    progress_callback progress: @progress
  end
  set.normalize! if normalize && set.is_trainset?
end
load_extension(extension) click to toggle source
# File lib/torchrb/model_base.rb, line 86
def load_extension(extension)
  if extension.is_a?(Hash)
    extend extension.keys.first
    extension.values.inject(&:merge)
  else
    extend extension
    {}
  end
end
load_model_data() click to toggle source
# File lib/torchrb/model_base.rb, line 71
def load_model_data
  raise "#{self} needs to implement '#to_tensor(var_name, data)' and '#prediction_class' method." unless respond_to?(:to_tensor, :prediction_class)
  @progress = 0
  start = 0
  all_ids = data_model.ids.shuffle
  [:train_set, :test_set, :validation_set].zip(dataset_split).map do |set, split|
    next if split.nil?
    size = all_ids.count * split.to_f / 100.0
    offset = start
    start = start + size
    collection = data_model.where(id: all_ids.slice(offset, size))
    load_dataset set, collection
  end
end