class Secryst::PositionalEncoding

Public Class Methods

new(d_model, dropout: 0.1, max_len: 5000) click to toggle source

PositionalEncoding module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.

Calls superclass method
# File lib/secryst/transformer.rb, line 347
def initialize(d_model, dropout: 0.1, max_len: 5000)
  super()
  @dropout = Torch::NN::Dropout.new(p: dropout)

  pe = Torch.zeros(max_len, d_model)
  position = Torch.arange(0, max_len, dtype: :float).unsqueeze(1)
  div_term = Torch.exp(Torch.arange(0, d_model, 2).float() * (-Math.log(10000.0) / d_model))
  sin = Torch.sin(position * div_term).t
  cos = Torch.cos(position * div_term).t
  pe.t!
  pe.each.with_index do |row, i|
    pe[i] = sin[i / 2] if i % 2 == 0
    pe[i] = cos[(i-1)/2] if i % 2 != 0
  end
  pe.t!
  pe = pe.unsqueeze(0).transpose(0, 1)
  register_buffer('pe', pe)
end

Public Instance Methods

forward(x) click to toggle source
# File lib/secryst/transformer.rb, line 366
def forward(x)
  x = x + pe.narrow(0, 0, x.size(0))
  return x
end