class Secryst::Trainer

Public Class Methods

new( model:, batch_size:, lr:, data_input:, data_target:, hyperparameters:, max_epochs: nil, log_interval: 1, checkpoint_every:, checkpoint_dir:, scheduler_step_size:, gamma: ) click to toggle source
# File lib/secryst/trainer.rb, line 4
def initialize(
  model:,
  batch_size:,
  lr:,
  data_input:,
  data_target:,
  hyperparameters:,
  max_epochs: nil,
  log_interval: 1,
  checkpoint_every:,
  checkpoint_dir:,
  scheduler_step_size:,
  gamma:
)
  @data_input = File.readlines(data_input, chomp: true)
  @data_target = File.readlines(data_target, chomp: true)

  @device = "cpu"
  @lr = lr
  @scheduler_step_size = scheduler_step_size
  @gamma = gamma
  @batch_size = batch_size
  @model_name = model
  @max_epochs = max_epochs
  @log_interval = log_interval
  @checkpoint_every = checkpoint_every
  @checkpoint_dir = checkpoint_dir
  FileUtils.mkdir_p(@checkpoint_dir)
  generate_vocabs_and_data
  save_vocabs

  case model
  when 'transformer'
    @model = Secryst::Transformer.new(hyperparameters.merge({
      input_vocab_size: @input_vocab.length,
      target_vocab_size: @target_vocab.length,
    }))
  else
    raise ArgumentError, 'Only transformer model is currently supported'
  end
end

Public Instance Methods

train() click to toggle source
# File lib/secryst/trainer.rb, line 46
def train
  best_model = nil
  best_val_loss = 1.0/0.0 # infinity

  return unless @model_name == 'transformer'

  criterion = Torch::NN::CrossEntropyLoss.new(ignore_index: index_of('<pad>')).to(@device)
  optimizer = Torch::Optim::SGD.new(@model.parameters, lr: @lr)
  scheduler = Torch::Optim::LRScheduler::StepLR.new(optimizer, step_size: @scheduler_step_size, gamma: @gamma)

  total_loss = 0.0
  start_time = Time.now
  ntokens = @target_vocab.length
  epoch = 0

  loop do
    epoch_start_time = Time.now
    @model.train
    @train_data.each.with_index do |batch, i|
      inputs, targets, decoder_inputs, src_mask, tgt_mask, memory_mask = batch
      inputs = Torch.tensor(inputs).t
      decoder_inputs = Torch.tensor(decoder_inputs).t
      targets = Torch.tensor(targets).t
      src_key_padding_mask = inputs.t.eq(1)
      tgt_key_padding_mask = decoder_inputs.t.eq(1)

      optimizer.zero_grad
      opts = {
        # src_mask: src_mask,
        tgt_mask: tgt_mask,
        # memory_mask: memory_mask,
        src_key_padding_mask: src_key_padding_mask,
        tgt_key_padding_mask: tgt_key_padding_mask,
        memory_key_padding_mask: src_key_padding_mask,
      }
      output = @model.call(inputs, decoder_inputs, opts)
      loss = criterion.call(output.transpose(0,1).reshape(-1, ntokens), targets.t.view(-1))
      loss.backward
      ClipGradNorm.clip_grad_norm(@model.parameters, max_norm: 0.5)
      optimizer.step

      # puts "i[#{i}] loss: #{loss}"
      total_loss += loss.item()
      if ( (i + 1) % @log_interval == 0 )
        cur_loss = total_loss / @log_interval
        elapsed = Time.now - start_time
        puts "| epoch #{epoch} | #{i + 1}/#{@train_data.length} batches | "\
              "lr #{scheduler.get_lr()[0].round(4)} | ms/batch #{(1000*elapsed.to_f / @log_interval).round} | "\
              "loss #{cur_loss.round(5)} | ppl #{Math.exp(cur_loss).round(5)}"
        total_loss = 0
        start_time = Time.now
      end
    end

    if epoch > 0 && epoch % @checkpoint_every == 0
      puts ">> Saving checkpoint '#{@checkpoint_dir}/checkpoint-#{epoch}.pth'"
      Torch.save(@model.state_dict, "#{@checkpoint_dir}/checkpoint-#{epoch}.pth")
    end

    # Evaluate
    @model.eval()
    total_loss = 0.0
    Torch.no_grad do
      @eval_data.each.with_index do |batch, i|
        inputs, targets, decoder_inputs, src_mask, tgt_mask, memory_mask = batch
        inputs = Torch.tensor(inputs).t
        decoder_inputs = Torch.tensor(decoder_inputs).t
        targets = Torch.tensor(targets).t
        src_key_padding_mask = inputs.t.eq(1)
        tgt_key_padding_mask = decoder_inputs.t.eq(1)

        opts = {
          # src_mask: src_mask,
          tgt_mask: tgt_mask,
          # memory_mask: memory_mask,
          src_key_padding_mask: src_key_padding_mask,
          tgt_key_padding_mask: tgt_key_padding_mask,
          memory_key_padding_mask: src_key_padding_mask,
        }
        output = @model.call(inputs, decoder_inputs, **opts)
        output_flat = output.transpose(0,1).reshape(-1, ntokens)

        total_loss += criterion.call(output_flat, targets.t.view(-1)).item
      end
      total_loss = total_loss / @eval_data.length
      puts('-' * 89)
      puts "| end of epoch #{epoch} | time: #{(Time.now - epoch_start_time).round(3)}s | "\
              " valid loss #{total_loss.round(5)} | valid ppl #{Math.exp(total_loss).round(5)} "
      puts('-' * 89)
      if total_loss < best_val_loss
        best_model = @model
        best_val_loss = total_loss
      end
    end
    scheduler.step

    epoch += 1
    break if @max_epochs && @max_epochs < epoch
  end
end

Private Instance Methods

batchify(data) click to toggle source
# File lib/secryst/trainer.rb, line 205
def batchify(data)
  batches = []

  (1 + data.length / @batch_size).times do |i|
    input_data = data[i*@batch_size, @batch_size].transpose[0]
    decoder_input_data = data[i*@batch_size, @batch_size].transpose[1]
    target_data = data[i*@batch_size, @batch_size].transpose[1]
    max_input_seq_length = input_data.max_by(&:length).length + 2
    max_target_seq_length = target_data.max_by(&:length).length + 1
    src_mask = Torch.triu(Torch.ones(max_input_seq_length,max_input_seq_length)).eq(0).transpose(0,1)
    tgt_mask = Torch.triu(Torch.ones(max_target_seq_length,max_target_seq_length)).eq(0).transpose(0,1)
    memory_mask = Torch.triu(Torch.ones(max_input_seq_length,max_target_seq_length)).eq(0).transpose(0,1)
    batches << [
      input_data.map {|line| pad(line.chars, max_input_seq_length).map {|c| @input_vocab[c]} },
      target_data.map {|line| pad(line.chars, max_target_seq_length, no_sos: true).map {|c| @target_vocab[c]} },
      decoder_input_data.map {|line| pad(line.chars, max_target_seq_length, no_eos: true).map {|c| @target_vocab[c]} },
      src_mask,
      tgt_mask,
      memory_mask
    ]
  end

  batches
end
generate_vocabs_and_data() click to toggle source
# File lib/secryst/trainer.rb, line 149
def generate_vocabs_and_data
  input_texts = []
  target_texts = []
  input_vocab_counter = Hash.new(0)
  target_vocab_counter = Hash.new(0)

  @data_input.each do |input_text|
    input_text.strip!
    input_texts.push(input_text)
    input_text.each_char do |char|
      input_vocab_counter[char] += 1
    end
  end

  @data_target.each do |target_text|
    target_text.strip!
    target_texts.push(target_text)
    target_text.each_char do |char|
      target_vocab_counter[char] += 1
    end
  end

  @input_vocab = Vocab.new(input_vocab_counter)
  @target_vocab = Vocab.new(target_vocab_counter)

  # Generate train, eval, and test batches
  seed = 1
  zipped_texts = input_texts.zip(target_texts)
  zipped_texts = zipped_texts.shuffle(random: Random.new(seed))

  # train - 90%, eval - 7%, test - 3%
  train_texts = zipped_texts[0..(zipped_texts.length*0.9).to_i]
  eval_texts = zipped_texts[(zipped_texts.length*0.9).to_i + 1..(zipped_texts.length*0.97).to_i]
  test_texts = zipped_texts[(zipped_texts.length*0.97).to_i+1..-1]

  # prepare batches
  @train_data = batchify(train_texts)
  @eval_data = batchify(eval_texts)
  @test_data = batchify(test_texts)

end
index_of(token) click to toggle source
# File lib/secryst/trainer.rb, line 201
def index_of(token)
  @target_vocab.stoi[token]
end
pad(arr, length, no_eos:false, no_sos:false) click to toggle source
# File lib/secryst/trainer.rb, line 191
def pad(arr, length, no_eos:false, no_sos:false)
  if !no_eos
    arr = arr + ["<eos>"]
  end
  if !no_sos
    arr = ["<sos>"] + arr
  end
  arr.fill("<pad>", arr.length...length)
end
save_vocabs() click to toggle source
# File lib/secryst/trainer.rb, line 230
def save_vocabs
  File.write("#{@checkpoint_dir}/input_vocab.json", JSON.generate(@input_vocab.freqs))
  File.write("#{@checkpoint_dir}/target_vocab.json", JSON.generate(@target_vocab.freqs))
end