class Chainer::Functions::Normalization::BatchNormalization

Attributes

running_mean[R]
running_var[R]

Public Class Methods

batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 30
def self.batch_normalization(x, gamma, beta, eps: 2e-5, running_mean: nil, running_var: nil, decay: 0.9)
  BatchNormalization.new(eps: eps, mean: running_mean, var: running_var, decay: decay).apply([x, gamma, beta])[0]
end
new(eps: 2e-5, mean: nil, var: nil, decay: 0.9) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 34
def initialize(eps: 2e-5, mean: nil, var: nil, decay: 0.9)
  @mean = nil
  @inv_std = nil

  @running_mean = mean
  @running_var = var
  @eps = eps
  @decay = decay
end

Public Instance Methods

backward(indexes, grad_outputs) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 88
def backward(indexes, grad_outputs)
  x, gamma = get_retained_inputs
  gy, = grad_outputs

  # hatappi debug
  #@mean = @mean.class.new(@mean.shape).seq
  #@inv_std = @inv_std.class.new(@inv_std.shape).seq
  #x.data = x.data.class.new(x.shape).seq
  #gamma.data = gamma.data.class.new(gamma.shape).seq
  #gy.data = gy.data.class.new(gy.shape).seq

  f = BatchNormalizationGrad.new(@eps, @expander, @axis, @mean, @inv_std)
  f.(x, gamma, gy)
end
forward(inputs) click to toggle source
# File lib/chainer/functions/normalization/batch_normalization.rb, line 44
def forward(inputs)
  retain_inputs([0, 1])
  x, gamma, beta = inputs
  xp = Chainer.get_array_module(x)

  if @running_mean.nil?
    @running_mean = xp::NArray[*gamma].new_zeros
    @running_var = xp::NArray[*gamma].new_zeros
  end

  # expander inserts singleton dimensions to gamma and beta so that they
  # can be broadcasted with x.
  head_ndim = gamma.ndim + 1
  # TODO: expander = (None, Ellipsis) + (None,) * (x.ndim - head_ndim)
  suffix = [1] * (x.ndim - head_ndim)
  expander = -> (arr) do
    shape = [1] + arr.shape + suffix
    arr.reshape(*shape)
  end
  @expander = expander
  @axis = [0] + (head_ndim...(x.ndim)).to_a

  gamma = expander.(gamma)
  beta = expander.(beta)
  @mean = x.mean(axis: @axis)

  # TODO: Numo::Array can not be specified standard deviation
  var = ((x - x.mean(axis: @axis, keepdims: true)) ** 2).mean(axis: @axis)

  var += @eps
  @inv_std = var ** (-0.5)

  y = apply_bn_fwd(xp, x, expander.(@mean), expander.(@inv_std), gamma, beta)
  # Update running statistics
  m = x.size.div(gamma.size)
  adjust = m / [m - 1.0, 1.0].max
  @running_mean *= @decay
  @running_mean += (1 - @decay) * @mean
  @running_var *= @decay
  @running_var += (1 - @decay) * adjust * var

  [y]
end