class Chainer::Functions::Array::BroadcastTo

Function that broadcasts an array to a new shape.

Public Class Methods

broadcast_to(x, shape) click to toggle source
# File lib/chainer/functions/array/broadcast_to.rb, line 10
def self.broadcast_to(x, shape)
  return Chainer::Variable.as_variable(x) if x.shape == shape
  self.new(shape).apply([x]).first
end
new(shape) click to toggle source
# File lib/chainer/functions/array/broadcast_to.rb, line 6
def initialize(shape)
    @shape = shape
end

Public Instance Methods

backward(indexes, grad_outputs) click to toggle source
# File lib/chainer/functions/array/broadcast_to.rb, line 20
def backward(indexes, grad_outputs)
  gx = grad_outputs.first
  shape = @inputs.first.shape
  ndim = shape.size
  lead = gx.ndim - ndim
  lead_axis = lead.times.to_a
  axis = shape.each_with_object([]).with_index do |(sx, res), i|
    next unless sx == 1
    res << i + lead
  end
  gx = Chainer::Functions::Math::Sum.sum(gx, axis: lead_axis + axis, keepdims: true)
  return [Chainer::Functions::Array::Squeeze.squeeze(gx, axis: lead_axis)] if lead > 0
  [gx]
end
forward(inputs) click to toggle source
# File lib/chainer/functions/array/broadcast_to.rb, line 15
def forward(inputs)
  x = inputs.first
  [Chainer::Utils::Array.broadcast_to(x, @shape)]
end

Private Instance Methods

backward_one(shape, dtype, g) click to toggle source
# File lib/chainer/functions/array/broadcast_to.rb, line 37
def backward_one(shape, dtype, g)
  return dtype.zeros(shape) unless g

  ndim = shape.size
  if g.ndim != ndim
    g = g.sum(axis: 0...(g.ndim - ndim))
  end

  axis = shape.each_with_index.select{|sx, i| sx == 1 }.map{|sx, i| i }
  if axis.size > 0
    g.sum(keepdims: true, axis: axis)
  else
    g
  end
end