class Secryst::PositionalEncoding
Public Class Methods
new(d_model, dropout: 0.1, max_len: 5000)
click to toggle source
PositionalEncoding
module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.
Calls superclass method
# File lib/secryst/transformer.rb, line 347 def initialize(d_model, dropout: 0.1, max_len: 5000) super() @dropout = Torch::NN::Dropout.new(p: dropout) pe = Torch.zeros(max_len, d_model) position = Torch.arange(0, max_len, dtype: :float).unsqueeze(1) div_term = Torch.exp(Torch.arange(0, d_model, 2).float() * (-Math.log(10000.0) / d_model)) sin = Torch.sin(position * div_term).t cos = Torch.cos(position * div_term).t pe.t! pe.each.with_index do |row, i| pe[i] = sin[i / 2] if i % 2 == 0 pe[i] = cos[(i-1)/2] if i % 2 != 0 end pe.t! pe = pe.unsqueeze(0).transpose(0, 1) register_buffer('pe', pe) end
Public Instance Methods
forward(x)
click to toggle source
# File lib/secryst/transformer.rb, line 366 def forward(x) x = x + pe.narrow(0, 0, x.size(0)) return x end