class Chainer::Functions::Array::SelectItem
Select elements stored in given indices.
Public Class Methods
select_item(x, t)
click to toggle source
Select elements stored in given indices.
This function returns $t.choose(x.T)$, that means $y[i] == x[i, t[i]]$ for all $i$. @param [Chainer::Variable] x Variable storing arrays. @param [Chainer::Variable] t Variable storing index numbers. @return [Chainer::Variable] Variable that holds $t$-th element of $x$.
# File lib/chainer/functions/array/select_item.rb, line 13 def self.select_item(x, t) SelectItem.new.apply([x, t]).first end
Public Instance Methods
backward(indexes, gy)
click to toggle source
# File lib/chainer/functions/array/select_item.rb, line 33 def backward(indexes, gy) t = get_retained_inputs.first ret = [] if indexes.include?(0) ggx = Assign.new(@in_shape, @in_dtype, t).apply(gy).first ret << ggx end if indexes.include?(1) ret << nil end ret end
forward(inputs)
click to toggle source
# File lib/chainer/functions/array/select_item.rb, line 17 def forward(inputs) retain_inputs([1]) x, t = inputs @in_shape = x.shape @in_dtype = x.class # TODO: x[six.moves.range(t.size), t] new_x = x.class.zeros(t.size) t.size.times.each do |i| new_x[i] = x[i, t[i]] end x = new_x [x] end