class Chainer::Functions::Connection::Convolution2DGradW
Public Class Methods
new(conv2d)
click to toggle source
# File lib/chainer/functions/connection/convolution_2d_grad_w.rb, line 5 def initialize(conv2d) w_node = conv2d.inputs[1] @kh, @kw = w_node.shape[2..-1] @sy = conv2d.sy @sx = conv2d.sx @ph = conv2d.ph @pw = conv2d.pw @cover_all = conv2d.cover_all @w_dtype = w_node.dtype end
Public Instance Methods
backward(indexes, grad_outputs)
click to toggle source
# File lib/chainer/functions/connection/convolution_2d_grad_w.rb, line 26 def backward(indexes, grad_outputs) x, gy = get_retained_inputs ggw = grad_outputs.first ret = [] if indexes.include?(0) xh, xw = x.shape[2..-1] gx = Deconvolution2DFunction.deconvolution_2d(gy, ggw, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw]) ret << gx end if indexes.include?(1) ggy = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(x, ggw, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all) ret << ggy end ret end
forward(inputs)
click to toggle source
# File lib/chainer/functions/connection/convolution_2d_grad_w.rb, line 17 def forward(inputs) retain_inputs([0, 1]) x, gy = inputs col = Chainer::Utils::Conv.im2col(x, @kh, @kw, @sy, @sx, @ph, @pw, cover_all: @cover_all) gw = Chainer::Utils::Math.tensordot(gy, col, [[0, 2, 3], [0, 4, 5]]).cast_to(@w_dtype) [gw] end