class Chainer::Functions::Connection::Convolution2DGradW

Public Class Methods

new(conv2d) click to toggle source
# File lib/chainer/functions/connection/convolution_2d_grad_w.rb, line 5
def initialize(conv2d)
  w_node = conv2d.inputs[1]

  @kh, @kw = w_node.shape[2..-1]
  @sy = conv2d.sy
  @sx = conv2d.sx
  @ph = conv2d.ph
  @pw = conv2d.pw
  @cover_all = conv2d.cover_all
  @w_dtype = w_node.dtype
end

Public Instance Methods

backward(indexes, grad_outputs) click to toggle source
# File lib/chainer/functions/connection/convolution_2d_grad_w.rb, line 26
def backward(indexes, grad_outputs)
  x, gy = get_retained_inputs
  ggw = grad_outputs.first

  ret = []
  if indexes.include?(0)
    xh, xw = x.shape[2..-1]
    gx = Deconvolution2DFunction.deconvolution_2d(gy, ggw, stride: [@sy, @sx], pad: [@ph, @pw], outsize: [xh, xw])
    ret << gx
  end

  if indexes.include?(1)
    ggy = Chainer::Functions::Connection::Convolution2DFunction.convolution_2d(x, ggw, stride: [@sy, @sx], pad: [@ph, @pw], cover_all: @cover_all)
    ret << ggy
  end

  ret
end
forward(inputs) click to toggle source
# File lib/chainer/functions/connection/convolution_2d_grad_w.rb, line 17
def forward(inputs)
  retain_inputs([0, 1])
  x, gy = inputs
  col = Chainer::Utils::Conv.im2col(x, @kh, @kw, @sy, @sx, @ph, @pw, cover_all: @cover_all)

  gw = Chainer::Utils::Math.tensordot(gy, col, [[0, 2, 3], [0, 4, 5]]).cast_to(@w_dtype)
  [gw]
end