class Chainer::Functions::Loss::MeanSquaredError

Mean squared error (a.k.a. Euclidean loss) function.

Public Class Methods

mean_squared_error(x0, x1) click to toggle source

Mean squared error function.

This function computes mean squared error between two variables. The mean is taken over the minibatch. Note that the error is not scaled by 1/2.

@param [Chainer::Variable or Numo::NArray or Cumo::NArray] x0 Input variable. @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x1 Input variable. @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.

# File lib/chainer/functions/loss/mean_squared_error.rb, line 15
def self.mean_squared_error(x0, x1)
  self.new.apply([x0, x1]).first
end

Public Instance Methods

backward(indexes, gy) click to toggle source
# File lib/chainer/functions/loss/mean_squared_error.rb, line 25
def backward(indexes, gy)
  x0, x1 = get_retained_inputs
  diff = x0 - x1
  gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape)
  gx0 = gy0 * diff * (2.0 / diff.size)

  ret = []
  if indexes.include?(0)
    ret << gx0
  end
  if indexes.include?(1)
    ret << -gx0
  end
  ret
end
forward(inputs) click to toggle source
# File lib/chainer/functions/loss/mean_squared_error.rb, line 19
def forward(inputs)
  retain_inputs([0, 1])
  diff = (inputs[0] - inputs[1]).flatten.dup
  [diff.class.cast(diff.dot(diff) / diff.size)]
end