class Chainer::Functions::Normalization::BatchNormalizationGrad

Public Class Methods

new(eps, expander, axis, mean, inv_std) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 107
def initialize(eps, expander, axis, mean, inv_std)
  @eps = eps
  @expander = expander
  @axis = axis
  @mean = mean
  @inv_std = inv_std
end

Public Instance Methods

backward(inputs, grad_outputs) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 132
def backward(inputs, grad_outputs)
  expander = @expander

  x, gamma, gy = inputs
  gx1, ggamma1, = output_data
  ggx1, gggamma1, ggbeta1 = grad_outputs
  xp = Chainer.get_array_module(x)

  # auxiliary values
  inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
  r = ggx1.nil? ? 0 : (gx1 * ggx1).sum(axis: @axis)
  coeff = gamma * @inv_std
  coeff_m = coeff * inv_m
  x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))

  # handle None in output gradients
  ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
  gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
  ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)

  gggamma2 = gggamma1 - coeff_m * (x_hat * ggx1).sum(axis: @axis)
        ggbeta2 = ggbeta1 - coeff_m * ggx1.sum(axis: @axis)

  ggamma2 = r / gamma

  gx_hat2 = (expander.(gggamma2) * gy - expander.(coeff_m * ggamma1) * ggx1)
  gstd2 = -@inv_std * (r + (x_hat * gx_hat2).sum(axis: @axis))
  gmean2 = -@inv_std * gx_hat2.sum(axis: @axis)
  gx2 = expander.(@inv_std) * gx_hat2 + inv_m * (expander.(gmean2) + x_hat * expander.(gstd2))
  ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(coeff) * ggx1)

  [gx2, ggamma2, ggy2]
end
forward(inputs) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 115
def forward(inputs)
  retain_inputs([0, 1, 2])
  x, gamma, gy = inputs
  expander = @expander

  inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size))
  xp = Chainer.get_array_module(x)

  gbeta = gy.sum(axis: @axis)
  x_hat = x_hat(x, expander.(@mean), expander.(@inv_std))
  ggamma = (gy * x_hat).sum(axis: @axis)
  gx = expander.(gamma * @inv_std) * (gy - (x_hat * expander.(ggamma) + expander.(gbeta)) * inv_m)

  retain_outputs([0, 1])
  [gx, ggamma, gbeta]
end