class Chainer::Functions::Normalization::BatchNormalizationGrad
Public Class Methods
new(eps, expander, axis, mean, inv_std)
click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 107 def initialize(eps, expander, axis, mean, inv_std) @eps = eps @expander = expander @axis = axis @mean = mean @inv_std = inv_std end
Public Instance Methods
backward(inputs, grad_outputs)
click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 132 def backward(inputs, grad_outputs) expander = @expander x, gamma, gy = inputs gx1, ggamma1, = output_data ggx1, gggamma1, ggbeta1 = grad_outputs xp = Chainer.get_array_module(x) # auxiliary values inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size)) r = ggx1.nil? ? 0 : (gx1 * ggx1).sum(axis: @axis) coeff = gamma * @inv_std coeff_m = coeff * inv_m x_hat = x_hat(x, expander.(@mean), expander.(@inv_std)) # handle None in output gradients ggx1 = zero_if_none(xp, ggx1, x.shape, x.class) gggamma1 = zero_if_none(xp, gggamma1, gamma.shape, gamma.class) ggbeta1 = zero_if_none(xp, ggbeta1, gamma.shape, gamma.class) gggamma2 = gggamma1 - coeff_m * (x_hat * ggx1).sum(axis: @axis) ggbeta2 = ggbeta1 - coeff_m * ggx1.sum(axis: @axis) ggamma2 = r / gamma gx_hat2 = (expander.(gggamma2) * gy - expander.(coeff_m * ggamma1) * ggx1) gstd2 = -@inv_std * (r + (x_hat * gx_hat2).sum(axis: @axis)) gmean2 = -@inv_std * gx_hat2.sum(axis: @axis) gx2 = expander.(@inv_std) * gx_hat2 + inv_m * (expander.(gmean2) + x_hat * expander.(gstd2)) ggy2 = (expander.(gggamma2) * x_hat + expander.(ggbeta2) + expander.(coeff) * ggx1) [gx2, ggamma2, ggy2] end
forward(inputs)
click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 115 def forward(inputs) retain_inputs([0, 1, 2]) x, gamma, gy = inputs expander = @expander inv_m = gamma.class.new.fill(1.0 / x.size.div(gamma.size)) xp = Chainer.get_array_module(x) gbeta = gy.sum(axis: @axis) x_hat = x_hat(x, expander.(@mean), expander.(@inv_std)) ggamma = (gy * x_hat).sum(axis: @axis) gx = expander.(gamma * @inv_std) * (gy - (x_hat * expander.(ggamma) + expander.(gbeta)) * inv_m) retain_outputs([0, 1]) [gx, ggamma, gbeta] end