class Chainer::Functions::Loss::MeanSquaredError
Mean squared error (a.k.a. Euclidean loss) function.
Public Class Methods
mean_squared_error(x0, x1)
click to toggle source
Mean squared error function.
This function computes mean squared error between two variables. The mean is taken over the minibatch. Note that the error is not scaled by 1/2.
@param [Chainer::Variable or Numo::NArray or Cumo::NArray] x0 Input variable. @param [Chainer::Variable or Numo::NArray or Cumo::NArray] x1 Input variable. @return [Chainer::Variable] A variable holding an array representing the mean squared error of two inputs.
# File lib/chainer/functions/loss/mean_squared_error.rb, line 15 def self.mean_squared_error(x0, x1) self.new.apply([x0, x1]).first end
Public Instance Methods
backward(indexes, gy)
click to toggle source
# File lib/chainer/functions/loss/mean_squared_error.rb, line 25 def backward(indexes, gy) x0, x1 = get_retained_inputs diff = x0 - x1 gy0 = Chainer::Functions::Array::BroadcastTo.broadcast_to(gy[0], diff.shape) gx0 = gy0 * diff * (2.0 / diff.size) ret = [] if indexes.include?(0) ret << gx0 end if indexes.include?(1) ret << -gx0 end ret end
forward(inputs)
click to toggle source
# File lib/chainer/functions/loss/mean_squared_error.rb, line 19 def forward(inputs) retain_inputs([0, 1]) diff = (inputs[0] - inputs[1]).flatten.dup [diff.class.cast(diff.dot(diff) / diff.size)] end