class Chainer::Functions::Math::Sum
Sum
of array elements over a given axis.
Public Class Methods
new(axis: nil, keepdims: false)
click to toggle source
# File lib/chainer/functions/math/sum.rb, line 16 def initialize(axis: nil, keepdims: false) if axis.nil? @axis = nil elsif axis.is_a?(Integer) @axis = [axis] elsif axis.is_a?(::Array) && axis.all? { |e| e.is_a?(Integer) } raise ArgumentError, "duplicate value in axis: #{axis}" unless axis.uniq.size == axis.size @axis = axis else raise TypeError, 'nil, Integer or Array of int are required' end @keepdims = keepdims end
sum(x, axis: nil, keepdims: false)
click to toggle source
Sum
of array elements over a given axis
@param [Chainer::Variable] x Elements to sum @param [nil, Integer, Array
<Integer>] axis Axis which a sum is performed @param keepdims If `true`, the specified axes are remained as axes of length one @return [Chainer::Variable] Output variable
# File lib/chainer/functions/math/sum.rb, line 12 def self.sum(x, axis: nil, keepdims: false) Sum.new(axis: axis, keepdims: keepdims).apply([x]).first end
Public Instance Methods
backward(indexes, grad_outputs)
click to toggle source
# File lib/chainer/functions/math/sum.rb, line 38 def backward(indexes, grad_outputs) gy = grad_outputs.first ndim = @inputs.first.shape.size unless ndim == 0 || @axis.nil? || @keepdims actual_axis = @axis.map { |axis| axis >= 0 ? axis : axis + ndim } shape = gy.shape actual_axis.sort.each { |axis| shape.insert(axis, 1) } gy = Chainer::Functions::Array::Reshape.reshape(gy, shape) end [Chainer::Functions::Array::BroadcastTo.broadcast_to(gy, @inputs.first.shape)] end
forward(inputs)
click to toggle source
# File lib/chainer/functions/math/sum.rb, line 31 def forward(inputs) x = inputs.first ret = x.sum(axis: @axis, keepdims: @keepdims) ret = x.class.cast(ret) [ret] end