class DNN::Link

Attributes

layer_node[RW]
next[RW]
num_outputs[R]
prevs[RW]

Public Class Methods

new(prevs: nil, layer_node: nil, num_outputs: 1) click to toggle source
# File lib/dnn/core/link.rb, line 8
def initialize(prevs: nil, layer_node: nil, num_outputs: 1)
  @prevs = prevs
  @layer_node = layer_node
  @num_outputs = num_outputs
  @next = nil
  @hold = []
end

Public Instance Methods

backward(dy = Xumo::SFloat[1]) click to toggle source
# File lib/dnn/core/link.rb, line 24
def backward(dy = Xumo::SFloat[1])
  @hold << dy
  return if @hold.length < @num_outputs
  dys = @layer_node.backward_node(*@hold)
  @hold = []
  if dys.is_a?(Array)
    dys.each.with_index do |dy, i|
      @prevs[i]&.backward(dy)
    end
  else
    @prevs.first&.backward(dys)
  end
end
forward(x) click to toggle source
# File lib/dnn/core/link.rb, line 16
def forward(x)
  @hold << x
  return if @hold.length < @prevs.length
  x = @layer_node.(*@hold)
  @hold = []
  @next ? @next.forward(x) : x
end