class Secryst::MultiHeadAttentionForward
Public Class Methods
multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training: true, key_padding_mask: nil, need_weights: true, attn_mask: nil, use_separate_proj_weight: false, q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil, static_k: nil, static_v: nil)
click to toggle source
Args:
query, key, value: map a query and a set of key-value pairs to an output. See "Attention Is All You Need" for more details. embed_dim_to_check: total dimension of the model. num_heads: parallel attention heads. in_proj_weight, in_proj_bias: input projection weight and bias. bias_k, bias_v: bias of the key and value sequences to be added at dim=0. add_zero_attn: add a new batch of zeros to the key and value sequences at dim=1. dropout_p: probability of an element to be zeroed. out_proj_weight, out_proj_bias: the output projection weight and bias. training: apply dropout if is ``true``. key_padding_mask: if provided, specified padding elements in the key will be ignored by the attention. This is a binary mask. When the value is true, the corresponding value on the attention layer will be filled with -inf. need_weights: output attn_output_weights. attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch. use_separate_proj_weight: the function accept the proj. weights for query, key, and value in different forms. If false, in_proj_weight will be used, which is a combination of q_proj_weight, k_proj_weight, v_proj_weight. q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. static_k, static_v: static key and value used for attention operators.
Shape:
Inputs: - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is the embedding dimension. - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of ``true`` will be ignored while the position with the value of ``false`` will be unchanged. - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``true`` are not allowed to attend while ``false`` values will be unchanged. If a FloatTensor is provided, it will be added to the attention weight. - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. Outputs: - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is the embedding dimension. - attn_output_weights: :math:`(N, L, S)` where N is the batch size, L is the target sequence length, S is the source sequence length.
# File lib/secryst/multi_head_attention_forward.rb, line 54 def self.multi_head_attention_forward(query, key, value, embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, out_proj_bias, training: true, key_padding_mask: nil, need_weights: true, attn_mask: nil, use_separate_proj_weight: false, q_proj_weight: nil, k_proj_weight: nil, v_proj_weight: nil, static_k: nil, static_v: nil) tgt_len, bsz, embed_dim = query.size() raise ArgumentError if embed_dim != embed_dim_to_check # allow MHA to have different sizes for the feature dimension raise ArgumentError if key.size(0) != value.size(0) or key.size(1) != value.size(1) head_dim = embed_dim / num_heads raise ArgumentError, "embed_dim must be divisible by num_heads" if head_dim * num_heads != embed_dim scaling = head_dim.to_f ** -0.5 if !use_separate_proj_weight if Torch.equal(query, key) && Torch.equal(key, value) # self-attention q, k, v = linear(query, in_proj_weight, in_proj_bias).chunk(3, -1) elsif Torch.equal(key, value) # encoder-decoder attention # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight.slice(0, _start, _end) # NOTE: inc-trspl if _b _b = _b.slice(0, _start, _end) end q = linear(query, _w, _b) if !key raise ArgumentError if value k = nil v = nil else # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = nil _w = in_proj_weight.slice(0, _start) if _b _b = _b.slice(0, _start) end k, v = linear(key, _w, _b).chunk(2, -1) end else # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = 0 _end = embed_dim _w = in_proj_weight.slice(0, _start, _end) if _b _b = _b.slice(0, _start, _end) end q = linear(query, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim _end = embed_dim * 2 _w = in_proj_weight.slice(0, _start, _end) if _b _b = _b.slice(0, _start, _end) end k = linear(key, _w, _b) # This is inline in_proj function with in_proj_weight and in_proj_bias _b = in_proj_bias _start = embed_dim * 2 _end = nil _w = in_proj_weight.slice(0, _start) if _b _b = _b.slice(0, _start) end v = linear(value, _w, _b) end else q_proj_weight_non_opt = q_proj_weight len1, len2 = q_proj_weight_non_opt.size() raise ArgumentError if len1 != embed_dim || len2 != query.size(-1) k_proj_weight_non_opt = k_proj_weight len1, len2 = k_proj_weight_non_opt.size() raise ArgumentError if len1 != embed_dim || len2 != key.size(-1) v_proj_weight_non_opt = v_proj_weight len1, len2 = v_proj_weight_non_opt.size() raise ArgumentError if len1 != embed_dim || len2 != value.size(-1) if in_proj_bias q = linear(query, q_proj_weight_non_opt, in_proj_bias.slice(0,0,embed_dim)) k = linear(key, k_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim, embed_dim * 2)) v = linear(value, v_proj_weight_non_opt, in_proj_bias.slice(0, embed_dim * 2)) else q = linear(query, q_proj_weight_non_opt, in_proj_bias) k = linear(key, k_proj_weight_non_opt, in_proj_bias) v = linear(value, v_proj_weight_non_opt, in_proj_bias) end end q = q * scaling if attn_mask raise ArgumentError, 'Only float, byte, and bool types are supported for attn_mask, not %s' % attn_mask.dtype unless attn_mask.dtype == Torch.float32 || attn_mask.dtype == Torch.float64 || attn_mask.dtype == Torch.float16 || attn_mask.dtype == Torch.uint8 || attn_mask.dtype == Torch.bool if attn_mask.dtype == Torch.uint8 puts "Byte tensor for attn_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead." attn_mask = attn_mask.to(Torch.bool) end if attn_mask.dim() == 2 attn_mask = attn_mask.unsqueeze(0) raise ArgumentError, 'The size of the 2D attn_mask is not correct.' if attn_mask.size() != [1, query.size(0), key.size(0)] elsif attn_mask.dim() == 3 raise ArgumentError, 'The size of the 3D attn_mask is not correct.' if attn_mask.size() != [bsz * num_heads, query.size(0), key.size(0)] else raise ArgumentError, "attn_mask's dimension %s is not supported" % attn_mask.dim() end # attn_mask's dim is 3 now. end # convert ByteTensor key_padding_mask to bool if key_padding_mask && key_padding_mask.dtype == Torch.uint8 puts("Byte tensor for key_padding_mask in NN::MultiheadAttention is deprecated. Use bool tensor instead.") key_padding_mask = key_padding_mask.to(Torch.bool) end if bias_k && bias_v if !static_k && !static_v k = Torch.cat([k, bias_k.repeat(1, bsz, 1)]) v = Torch.cat([v, bias_v.repeat(1, bsz, 1)]) attn_mask = pad(attn_mask, [0, 1]) if attn_mask key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask else raise ArgumentError, "bias cannot be added to static key." unless !static_k raise ArgumentError, "bias cannot be added to static value." unless !static_v end else raise ArgumentError unless !bias_k raise ArgumentError unless !bias_v end q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if k v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) if v if static_k raise ArgumentError unless static_k.size(0) == bsz * num_heads raise ArgumentError unless static_k.size(2) == head_dim k = static_k end if static_v raise ArgumentError unless static_v.size(0) == bsz * num_heads raise ArgumentError unless static_v.size(2) == head_dim v = static_v end src_len = k.size(1) if key_padding_mask raise ArgumentError unless key_padding_mask.size(0) == bsz raise ArgumentError unless key_padding_mask.size(1) == src_len end if add_zero_attn src_len += 1 k_sizes = k.size() k_sizes[1] = 1 k = Torch.cat([k, Torch.zeros(k_sizes, dtype: k.dtype, device: k.device)], 1) v_sizes = v.size() v_sizes[1] = 1 v = Torch.cat([v, Torch.zeros(v_sizes, dtype: v.dtype, device: v.device)], 1) attn_mask = pad(attn_mask, [0, 1]) if attn_mask key_padding_mask = pad(key_padding_mask, [0, 1]) if key_padding_mask end attn_output_weights = Torch.bmm(q, k.transpose(1, 2)) raise ArgumentError unless attn_output_weights.size() == [bsz * num_heads, tgt_len, src_len] if attn_mask if attn_mask.dtype == Torch.bool attn_output_weights.masked_fill!(attn_mask, -1.0/0.0) else attn_output_weights += attn_mask end end if key_padding_mask attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), -1.0/0.0 ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) end attn_output_weights = softmax( attn_output_weights, dim: -1) attn_output_weights = dropout(attn_output_weights, p: dropout_p, training: training) attn_output = Torch.bmm(attn_output_weights, v) raise ArgumentError unless attn_output.size() == [bsz * num_heads, tgt_len, head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn_output = linear(attn_output, out_proj_weight, out_proj_bias) if need_weights # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(1) / num_heads else return attn_output, nil end end