class DNN::Link
Attributes
layer_node[RW]
next[RW]
num_outputs[R]
prevs[RW]
Public Class Methods
new(prevs: nil, layer_node: nil, num_outputs: 1)
click to toggle source
# File lib/dnn/core/link.rb, line 8 def initialize(prevs: nil, layer_node: nil, num_outputs: 1) @prevs = prevs @layer_node = layer_node @num_outputs = num_outputs @next = nil @hold = [] end
Public Instance Methods
backward(dy = Xumo::SFloat[1])
click to toggle source
# File lib/dnn/core/link.rb, line 24 def backward(dy = Xumo::SFloat[1]) @hold << dy return if @hold.length < @num_outputs dys = @layer_node.backward_node(*@hold) @hold = [] if dys.is_a?(Array) dys.each.with_index do |dy, i| @prevs[i]&.backward(dy) end else @prevs.first&.backward(dys) end end
forward(x)
click to toggle source
# File lib/dnn/core/link.rb, line 16 def forward(x) @hold << x return if @hold.length < @prevs.length x = @layer_node.(*@hold) @hold = [] @next ? @next.forward(x) : x end