class Chainer::Functions::Normalization::FixedBatchNormalizationGrad
Public Class Methods
new(eps, expander, axis, inv_std, inv_var)
click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 220 def initialize(eps, expander, axis, inv_std, inv_var) @eps = eps @expander = expander @axis = axis @inv_std = inv_std @inv_var = inv_var end
Public Instance Methods
backward(inputs, grad_outputs)
click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 252 def backward(inputs, grad_outputs) x, gamma, mean, _, gy = inputs ggx1, gggamma1, ggbeta1, ggmean1, ggvar1 = grad_outputs gx1, ggamma1, gbeta1, gmean1, gvar1 = output_data # Handle None in output gradients. xp = Chainer.get_array_module(x) ggx1 = zero_if_none(xp, ggx1, x.shape, x.class) gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class) ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class) ggmean1 = zero_if_none(xp, ggmean1, mean.shape, mean.class) ggvar1 = zero_if_none(xp, ggvar1, mean.shape, mean.class) expander = @expander x_hat = x_hat(x, expander.(mean), expander.(@inv_std)) tmp = -0.5 * ggvar1 gamma_over_var = gamma * @inv_var g_gamma_over_var = tmp * ggamma1 gggamma2 = gggamma1 + tmp * gamma_over_var gx_hat = gy * expander.(gggamma2) gx2 = expander.(@inv_std) * gx_hat gmean2 = -@inv_std * gx_hat.sum(axis: @axis) g_gamma_over_std = (ggx1 * gy).sum(axis: @axis) - ggmean1 * gbeta1 ggbeta2 = ggbeta1 - ggmean1 * @gamma_over_std ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(@gamma_over_std) * ggx1) ggamma2 = (@inv_var * g_gamma_over_var + @inv_std * g_gamma_over_std) gvar2 = -(ggamma2 * gamma_over_var + 0.5 * @inv_var * ((x_hat * gx_hat).sum(axis: @axis) - @gamma_over_std * g_gamma_over_std)) [gx2, ggamma2, gmean2, gvar2, ggy2] end
forward(inputs)
click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 228 def forward(inputs) retain_inputs([0, 1, 2, 4]) x, gamma, mean, var, gy = inputs expander = @expander xp = Chainer.get_array_module(x) if @inv_std.nil? || @inv_var.nil? @inv_var = (var + @eps).reciprocal @inv_std = xp::NMath.sqrt(@inv_var) end @gamma_over_std = gamma * @inv_std x_hat = x_hat(x, expander.(mean), expander.(@inv_std)) gx = expander.(@gamma_over_std) * gy gbeta = gy.sum(axis: @axis) ggamma = (x_hat * gy).sum(axis: @axis) gmean = -@gamma_over_std * gbeta gvar = -0.5 * gamma * @inv_var * ggamma retain_outputs([0, 1, 2, 3, 4]) [gx, ggamma, gbeta, gmean, gvar] end