class Chainer::Functions::Connection::LinearFunction

Public Class Methods

linear(x, w, b=nil) click to toggle source
# File lib/chainer/functions/connection/linear.rb, line 5
def self.linear(x, w, b=nil)
  if x.ndim > 2
    x = x.reshape(x.shape.first, -1)
  end

  if b.nil?
    args = x, w
  else
    args = x, w, b
  end

  self.new.apply(args).first
end

Public Instance Methods

backward(indexes, grad_outputs) click to toggle source
# File lib/chainer/functions/connection/linear.rb, line 33
def backward(indexes, grad_outputs)
  x, w = get_retained_inputs
  gy = grad_outputs.first

  ret = []
  if indexes.include?(0)
    gx = LinearFunction.linear(gy, w.transpose)
    ret << Chainer::Functions::Array::Cast.cast(gx, x.dtype)
  end
  if indexes.include?(1)
    gw = LinearFunction.linear(gy.transpose, x.transpose)
    ret << Chainer::Functions::Array::Cast.cast(gw, w.dtype)
  end
  if indexes.include?(2)
    gb = Chainer::Functions::Math::Sum.sum(gy, axis: 0)
    ret << gb
  end
  ret
end
forward(inputs) click to toggle source
# File lib/chainer/functions/connection/linear.rb, line 19
def forward(inputs)
  x = inputs[0]
  w = inputs[1]

  y = x.dot(w.transpose).cast_to(x.class)
  if inputs.size == 3
    b = inputs[2]
    y += b
  end

  retain_inputs([0, 1])
  return [y]
end