module Torchrb::NN::Basic

Public Instance Methods

define_nn() click to toggle source
# File lib/torchrb/nn/basic.rb, line 3
  def define_nn
    input_layer = 1
    interm_layer = 80
    output_layer = model.classes.size
    torch.eval(<<-EOF, __FILE__, __LINE__).to_h
      net = nn.Sequential()
      net:add(nn.Linear(#{input_layer}, #{interm_layer}))
      net:add(nn.Linear(#{interm_layer}, #{output_layer}))
      net:add(nn.LogSoftMax())
    EOF
  end