class Secryst::Translator

Public Class Methods

new(model:, vocabs_dir:, hyperparameters:, model_file:) click to toggle source
# File lib/secryst/translator.rb, line 3
def initialize(model:, vocabs_dir:, hyperparameters:, model_file:)
  @device = "cpu"
  @vocabs_dir = vocabs_dir

  load_vocabs

  if model == 'transformer'
    @model = Secryst::Transformer.new(hyperparameters.merge({
      input_vocab_size: @input_vocab.length,
      target_vocab_size: @target_vocab.length,
    }))
  else
    raise ArgumentError, 'Only transformer model is currently supported'
  end

  @model.load_state_dict(Torch.load(model_file))
  @model.eval
end

Public Instance Methods

translate(phrase, max_seq_length: 100) click to toggle source
# File lib/secryst/translator.rb, line 22
def translate(phrase, max_seq_length: 100)
  input = ['<sos>'] + phrase.chars + ['<eos>']
  input = Torch.tensor([input.map {|i| @input_vocab.stoi[i]}]).t
  output = Torch.tensor([[@target_vocab.stoi['<sos>']]])
  src_key_padding_mask = input.t.eq(1)

  max_seq_length.times do |i|
    tgt_key_padding_mask = output.t.eq(1)
    tgt_mask = Torch.triu(Torch.ones(i+1,i+1)).eq(0).transpose(0,1)
    opts = {
      tgt_mask: tgt_mask,
      src_key_padding_mask: src_key_padding_mask,
      tgt_key_padding_mask: tgt_key_padding_mask,
      memory_key_padding_mask: src_key_padding_mask,
    }
    prediction = @model.call(input, output, opts).map {|i| i.argmax.item }
    break if @target_vocab.itos[prediction[i]] == '<eos>'
    output = Torch.cat([output, Torch.tensor([[prediction[i]]])])
  end

  puts "#{output[1..-1].map {|i| @target_vocab.itos[i.item]}.join('')}"
end

Private Instance Methods

load_vocabs() click to toggle source
# File lib/secryst/translator.rb, line 46
def load_vocabs
  @input_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/input_vocab.json")))
  @target_vocab = Vocab.new(JSON.parse(File.read("#{@vocabs_dir}/target_vocab.json")))
end