class Chainer::Functions::Array::SelectItem

Select elements stored in given indices.

Public Class Methods

select_item(x, t) click to toggle source

Select elements stored in given indices.

This function returns $t.choose(x.T)$, that means
$y[i] == x[i, t[i]]$ for all $i$.

@param [Chainer::Variable] x Variable storing arrays.
@param [Chainer::Variable] t Variable storing index numbers.
@return [Chainer::Variable] Variable that holds $t$-th element of $x$.
# File lib/chainer/functions/array/select_item.rb, line 13
def self.select_item(x, t)
  SelectItem.new.apply([x, t]).first
end

Public Instance Methods

backward(indexes, gy) click to toggle source
# File lib/chainer/functions/array/select_item.rb, line 33
def backward(indexes, gy)
  t = get_retained_inputs.first
  ret = []
  if indexes.include?(0)
    ggx = Assign.new(@in_shape, @in_dtype, t).apply(gy).first
    ret << ggx
  end
  if indexes.include?(1)
    ret << nil
  end
  ret
end
forward(inputs) click to toggle source
# File lib/chainer/functions/array/select_item.rb, line 17
def forward(inputs)
  retain_inputs([1])
  x, t = inputs
  @in_shape = x.shape
  @in_dtype = x.class

  # TODO: x[six.moves.range(t.size), t]
  new_x = x.class.zeros(t.size)
  t.size.times.each do |i|
    new_x[i] = x[i, t[i]]
  end
  x = new_x

  [x]
end