class Chainer::Functions::Math::Sum

Sum of array elements over a given axis.

Public Class Methods

new(axis: nil, keepdims: false) click to toggle source
# File lib/chainer/functions/math/sum.rb, line 16
def initialize(axis: nil, keepdims: false)
  if axis.nil?
    @axis = nil
  elsif axis.is_a?(Integer)
    @axis = [axis]
  elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) }
    raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size
    @axis = axis
  else
    raise TypeError, 'nil, Integer or Array of int are required'
  end

  @keepdims = keepdims
end
sum(x, axis: nil, keepdims: false) click to toggle source

Sum of array elements over a given axis

@param [Chainer::Variable] x Elements to sum @param [nil, Integer, Array<Integer>] axis Axis which a sum is performed @param keepdims If `true`, the specified axes are remained as axes of length one @return [Chainer::Variable] Output variable

# File lib/chainer/functions/math/sum.rb, line 12
def self.sum(x, axis: nil, keepdims: false)
  Sum.new(axis: axis, keepdims: keepdims).apply([x]).first
end

Public Instance Methods

backward(indexes, grad_outputs) click to toggle source
# File lib/chainer/functions/math/sum.rb, line 38
def backward(indexes, grad_outputs)
  gy = grad_outputs.first
  ndim = @inputs.first.shape.size
  unless ndim == 0 || @axis.nil? || @keepdims
    actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim  }
    shape = gy.shape
    actual_axis.sort.each { |axis| shape.insert(axis, 1) }
    gy = Chainer::Functions::Array::Reshape.reshape(gy, shape)
  end
  [Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)]
end
forward(inputs) click to toggle source
# File lib/chainer/functions/math/sum.rb, line 31
def forward(inputs)
  x = inputs.first
  ret = x.sum(axis: @axis, keepdims: @keepdims)
  ret = x.class.cast(ret)
  [ret]
end