class Chainer::Functions::Activation::LogSoftmaxGrad

Public Class Methods

new(x_shape, x_dtype) click to toggle source
# File lib/chainer/functions/activation/log_softmax.rb, line 78
def initialize(x_shape, x_dtype)
  @x_shape = x_shape
  @x_dtype = x_dtype
end

Public Instance Methods

backward(indexes, ggx) click to toggle source
# File lib/chainer/functions/activation/log_softmax.rb, line 92
def backward(indexes, ggx)
  y, gy = get_retained_inputs
  ret = []
  exp_y = Chainer::Functions::Math::Exp.exp(y)

  if indexes.include?(0)
    gy_sum = Chainer::Functions::Math::Sum.sum(gy, axis: 1, keepdims: true)
    gy_sum = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy_sum, gy.shape)

    g0 = -ggx.first * exp_y * gy_sum
    ret << g0
  end
  if indexes.include?(1)
    a = Chainer::Functions::Math::Sum.sum(ggx.first * exp_y, axis: 1, keepdims: true)
    a = Chainer::Functions::Array::BroadcastTo.broadcast_to(a, gy.shape)
    g1 = ggx.first - a
    ret << g1
  end

  ret
end
forward(inputs) click to toggle source
# File lib/chainer/functions/activation/log_softmax.rb, line 83
def forward(inputs)
  retain_inputs([0, 1])
  y, gy = inputs

  xm = Chainer.get_array_module(y)
  gx = gy - xm::NMath.exp(y) * gy.sum(axis: 1, keepdims: true)
  [gx]
end