class Chainer::Functions::Pooling::MaxPooling2DWithIndexes

Public Class Methods

new(mpool2d) click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 81
def initialize(mpool2d)
  @kh = mpool2d.kh
  @kw = mpool2d.kw
  @sy = mpool2d.sy
  @sx = mpool2d.sx
  @ph = mpool2d.ph
  @pw = mpool2d.pw
  @cover_all = mpool2d.cover_all
  @indexes = mpool2d.indexes
end

Public Instance Methods

forward(x) click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 92
def forward(x)
  col = Chainer::Utils::Conv.im2col(x[0], @kh, @kw, @sy, @sx, @ph, @pw, pval: -Float::INFINITY, cover_all: @cover_all)
  n, c, kh, kw, out_h, out_w = col.shape
  col = col.reshape(n, c, kh * kw, out_h, out_w)
  col = col.transpose(0, 1, 3, 4, 2).reshape(nil, kh * kw)

  indexes = @indexes.flatten.dup

  # TODO: col = col[numpy.arange(len(indexes)), indexes]
  new_col = col.class.zeros(indexes.size)
  x[0].class.new(indexes.size).seq.each_with_index do |v, i|
    new_col[i] = col[v, indexes[i]]
  end
  col = new_col

  [col.reshape(n, c, out_h, out_w)]
end