class Torchrb::Torch

Attributes

error_rate[RW]
network_loaded[RW]
network_timestamp[RW]

Public Class Methods

new(options={}) click to toggle source
Calls superclass method Torchrb::Lua::new
# File lib/torchrb/torch.rb, line 7
def initialize options={}
  super
  @network_loaded = false
  @error_rate = Float::NAN
  load_network options[:network_storage_path] unless network_loaded rescue nil
end

Public Instance Methods

cudify() click to toggle source
# File lib/torchrb/torch.rb, line 102
  def cudify
    eval <<-EOF, __FILE__, __LINE__
          -- print(sys.COLORS.red .. '==> using CUDA GPU #' .. cutorch.getDevice() .. sys.COLORS.black)
          train_set.input = train_set.input:cuda()
          train_set.label = train_set.label:cuda()
          test_set.input = test_set.input:cuda()
          test_set.label = test_set.label:cuda()
          validation_set.input = validation_set.input:cuda()
          validation_set.label = validation_set.label:cuda()

          criterion = nn.ClassNLLCriterion():cuda()
          net = cudnn.convert(net:cuda(), cudnn)
    EOF
  end
iteration_callback=(callback) click to toggle source
# File lib/torchrb/torch.rb, line 14
def iteration_callback= callback
  state.function "iteration_callback" do |trainer, iteration, currentError|
    progress = iteration / state['number_of_iterations']
    self.error_rate = currentError/100.0
    callback.call progress: progress, error_rate: error_rate
  end
end
load_network(network_storage_path) click to toggle source
# File lib/torchrb/torch.rb, line 51
  def load_network network_storage_path
    raise "Neuronal net not trained yet. Call 'Torch#update_training_data'." unless File.exist?(network_storage_path)
    metadata = eval(<<-EOF, __FILE__, __LINE__).to_ruby
        net = torch.load('#{network_storage_path}')
        metadata = torch.load('#{network_storage_path}.meta')
        classes = metadata[1]
        timestamp = metadata[3]
        return metadata[2]
    EOF
    self.error_rate = metadata
    self.network_timestamp = @state['timestamp']
    puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] loaded from #{network_storage_path} @ #{network_timestamp}" if debug
    self.network_loaded = true
  end
predict(sample, network_storage_path=nil) click to toggle source
# File lib/torchrb/torch.rb, line 36
  def predict sample, network_storage_path=nil
    load_network network_storage_path unless network_loaded

    classes = eval <<-EOF, __FILE__, __LINE__
        #{sample.to_tensor("sample_data").strip}
        local prediction = #{enable_cuda ? "net:forward(sample_data:cuda()):float()" : "net:forward(sample_data)"}
        prediction = prediction:exp()
        confidences = prediction:totable()
        return classes
    EOF
    puts "predicted #{@state['confidences'].to_h} based on network @ #{network_timestamp}" if debug
    classes = classes.to_h
    @state['confidences'].to_h.map { |k, v| {classes[k] => v} }.reduce({}, :merge)
  end
print_results() click to toggle source
store_network(network_storage_path) click to toggle source
# File lib/torchrb/torch.rb, line 66
  def store_network network_storage_path
    eval <<-EOF, __FILE__, __LINE__
        torch.save('#{network_storage_path}', net)
        torch.save('#{network_storage_path}.meta', {classes, #{error_rate}, '#{network_timestamp}}'} )
    EOF
    puts "Network with metadata [#{@state['classes'].to_h}, #{error_rate}] stored in #{network_storage_path} @ #{network_timestamp}" if debug
  end
train() click to toggle source
# File lib/torchrb/torch.rb, line 22
  def train
    eval <<-EOF, __FILE__, __LINE__
        local oldprint = print
        print = function(...)
        end

        trainer:train(train_set)

        print = oldprint
    EOF
    self.network_loaded = true
    self.network_timestamp = Time.now
  end