class Chainer::Functions::Pooling::MaxPooling2DWithIndexes
Public Class Methods
new(mpool2d)
click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 81 def initialize(mpool2d) @kh = mpool2d.kh @kw = mpool2d.kw @sy = mpool2d.sy @sx = mpool2d.sx @ph = mpool2d.ph @pw = mpool2d.pw @cover_all = mpool2d.cover_all @indexes = mpool2d.indexes end
Public Instance Methods
forward(x)
click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 92 def forward(x) col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all) n, c, kh, kw, out_h, out_w = col.shape col = col.reshape(n, c, kh * kw, out_h, out_w) col = col.transpose(0, 1, 3, 4, 2).reshape(nil, kh * kw) indexes = @indexes.flatten.dup # TODO: col = col[numpy.arange(len(indexes)), indexes] new_col = col.class.zeros(indexes.size) x[0].class.new(indexes.size).seq.each_with_index do |v, i| new_col[i] = col[v, indexes[i]] end col = new_col [col.reshape(n, c, out_h, out_w)] end