class Chainer::Functions::Pooling::MaxPooling2DGrad

Public Class Methods

new(mpool2d) click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 42
def initialize(mpool2d)
  @kh = mpool2d.kh
  @kw = mpool2d.kw
  @sy = mpool2d.sy
  @sx = mpool2d.sx
  @ph = mpool2d.ph
  @pw = mpool2d.pw
  @cover_all = mpool2d.cover_all
  @indexes = mpool2d.indexes
  @in_shape = mpool2d.in_shape
  @in_dtype = mpool2d.in_dtype
  @mpool2d = mpool2d
end

Public Instance Methods

backward(indexes, ggx) click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 75
def backward(indexes, ggx)
  MaxPooling2DWithIndexes.new(@mpool2d).apply(ggx)
end
forward(gy) click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 56
def forward(gy)
  n, c, out_h, out_w = gy[0].shape
  h, w  = @in_shape[2..-1]
  kh, kw = @kh, @kw

  gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw)

  indexes = @indexes.flatten
  indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw)

  gcol[indexes] = gy[0].flatten.dup
  gcol = gcol.reshape(n, c, out_h, out_w, kh, kw)
  gcol = gcol.swapaxes(2, 4)
  gcol = gcol.swapaxes(3, 5)

  gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w)
  [gx]
end