class Chainer::Functions::Pooling::MaxPooling2DGrad
Public Class Methods
new(mpool2d)
click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 42 def initialize(mpool2d) @kh = mpool2d.kh @kw = mpool2d.kw @sy = mpool2d.sy @sx = mpool2d.sx @ph = mpool2d.ph @pw = mpool2d.pw @cover_all = mpool2d.cover_all @indexes = mpool2d.indexes @in_shape = mpool2d.in_shape @in_dtype = mpool2d.in_dtype @mpool2d = mpool2d end
Public Instance Methods
backward(indexes, ggx)
click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 75 def backward(indexes, ggx) MaxPooling2DWithIndexes.new(@mpool2d).apply(ggx) end
forward(gy)
click to toggle source
# File lib/chainer/functions/pooling/max_pooling_2d.rb, line 56 def forward(gy) n, c, out_h, out_w = gy[0].shape h, w = @in_shape[2..-1] kh, kw = @kh, @kw gcol = @in_dtype.zeros(n * c * out_h * out_w * kh * kw) indexes = @indexes.flatten indexes += indexes.class.new((indexes.size * kh * kw) / (kh * kw)).seq(0, kh * kw) gcol[indexes] = gy[0].flatten.dup gcol = gcol.reshape(n, c, out_h, out_w, kh, kw) gcol = gcol.swapaxes(2, 4) gcol = gcol.swapaxes(3, 5) gx = Chainer::Utils::Conv.col2im(gcol, @sy, @sx, @ph, @pw, h, w) [gx] end