class Chainer::Functions::Normalization::FixedBatchNormalizationGrad

Public Class Methods

new(eps, expander, axis, inv_std, inv_var) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 220
def initialize(eps, expander, axis, inv_std, inv_var)
  @eps = eps
  @expander = expander
  @axis = axis
  @inv_std = inv_std
  @inv_var = inv_var
end

Public Instance Methods

backward(inputs, grad_outputs) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 252
def backward(inputs, grad_outputs)
  x, gamma, mean, _, gy = inputs
  ggx1, gggamma1, ggbeta1, ggmean1, ggvar1 = grad_outputs
  gx1, ggamma1, gbeta1, gmean1, gvar1 = output_data

  # Handle None in output gradients.
  xp = Chainer.get_array_module(x)
  ggx1 = zero_if_none(xp, ggx1, x.shape, x.class)
  gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class)
  ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class)
  ggmean1 = zero_if_none(xp, ggmean1, mean.shape, mean.class)
  ggvar1 = zero_if_none(xp, ggvar1, mean.shape, mean.class)

  expander = @expander
  x_hat = x_hat(x, expander.(mean), expander.(@inv_std))
  tmp = -0.5 * ggvar1

  gamma_over_var = gamma * @inv_var
  g_gamma_over_var = tmp * ggamma1

  gggamma2 = gggamma1 + tmp * gamma_over_var
  gx_hat = gy * expander.(gggamma2)
  gx2 = expander.(@inv_std) * gx_hat
  gmean2 = -@inv_std * gx_hat.sum(axis: @axis)

  g_gamma_over_std = (ggx1 * gy).sum(axis: @axis) - ggmean1 * gbeta1
  ggbeta2 = ggbeta1 - ggmean1 * @gamma_over_std
  ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(@gamma_over_std) * ggx1)

  ggamma2 = (@inv_var * g_gamma_over_var + @inv_std * g_gamma_over_std)
  gvar2 = -(ggamma2 * gamma_over_var + 0.5 * @inv_var * ((x_hat * gx_hat).sum(axis: @axis) - @gamma_over_std * g_gamma_over_std))

  [gx2, ggamma2, gmean2, gvar2, ggy2]
end
forward(inputs) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 228
def forward(inputs)
  retain_inputs([0, 1, 2, 4])
  x, gamma, mean, var, gy = inputs
  expander = @expander
  xp = Chainer.get_array_module(x)

  if @inv_std.nil? || @inv_var.nil?
    @inv_var = (var + @eps).reciprocal
    @inv_std = xp::NMath.sqrt(@inv_var)
  end

  @gamma_over_std = gamma * @inv_std
  x_hat = x_hat(x, expander.(mean), expander.(@inv_std))

  gx = expander.(@gamma_over_std) * gy
  gbeta = gy.sum(axis: @axis)
  ggamma = (x_hat * gy).sum(axis: @axis)
  gmean = -@gamma_over_std * gbeta
  gvar = -0.5 * gamma * @inv_var * ggamma

  retain_outputs([0, 1, 2, 3, 4])
  [gx, ggamma, gbeta, gmean, gvar]
end