class Chainer::Functions::Connection::EmbedIDFunction

Public Class Methods

embed_id(x, w, ignore_label: nil) click to toggle source
# File lib/chainer/functions/connection/embed_id.rb, line 9
def self.embed_id(x, w, ignore_label: nil)
  self.new(ignore_label: ignore_label).(x, w)
end
new(ignore_label: nil) click to toggle source
# File lib/chainer/functions/connection/embed_id.rb, line 5
def initialize(ignore_label: nil)
  @ignore_label = ignore_label
end

Public Instance Methods

backward(inputs, grad_outputs) click to toggle source
# File lib/chainer/functions/connection/embed_id.rb, line 34
def backward(inputs, grad_outputs)
  (x, w) = inputs
  gy = grad_outputs[0].reshape(x.size, true)
  gw = w.class.zeros(w.shape).reshape(w.shape.take(w.shape.size - 1).reduce(&:*), true)

  x.reshape(x.size).each_with_index do |ix, i|
    next if ix == @ignore_label
    gw[ix, true] = gw[ix, true] + gy[i, true]
  end

  [nil, gw.reshape(*w.shape)]
end
forward(inputs) click to toggle source
# File lib/chainer/functions/connection/embed_id.rb, line 13
def forward(inputs)
  xm = Chainer.get_array_module(*inputs)
  (x, w) = inputs

  unless @ignore_label
    return [Chainer::Utils::Array.take(w, x, axis: 0)]
  end

  valid_x = x.ne(@ignore_label)
  if valid_x.count == x.size
    return [Chainer::Utils::Array.take(w, x, axis: 0)]
  end
  x *= valid_x
  y = Chainer::Utils::Array.take(w, x, axis: 0).dup

  y = y.reshape(y.shape.take(y.shape.size - 1).reduce(&:*), true)
  valid_x.where2.last.each {|i| y[i, true] = y.class.zeros(y.shape.last) }

  [y.reshape(*x.shape, true)]
end