class DNN::Layers::Split
Attributes
axis[R]
dim[R]
Public Class Methods
new(axis: 1, dim: nil)
click to toggle source
Calls superclass method
DNN::Layers::Layer::new
# File lib/dnn/core/layers/split_layers.rb, line 10 def initialize(axis: 1, dim: nil) super() raise DNNError, "dim is nil" if dim == nil @axis = axis @dim = dim end
Public Instance Methods
backward_node(dy1, dy2)
click to toggle source
# File lib/dnn/core/layers/split_layers.rb, line 25 def backward_node(dy1, dy2) dy1.concatenate(dy2, axis: @axis) end
forward_node(x)
click to toggle source
# File lib/dnn/core/layers/split_layers.rb, line 17 def forward_node(x) x1_dim = @dim x2_dim = x.shape[@axis] - @dim y1, y2others = x.split([x1_dim, x1_dim + x2_dim], axis: @axis) y2 = y2others.is_a?(Array) ? y2others[0].concatenate(y2others[1..-1], axis: @axis) : y2others [y1, y2] end
load_hash(hash)
click to toggle source
# File lib/dnn/core/layers/split_layers.rb, line 33 def load_hash(hash) initialize(axis: hash[:axis], dim: hash[:dim]) end
to_hash()
click to toggle source
Calls superclass method
DNN::Layers::Layer#to_hash
# File lib/dnn/core/layers/split_layers.rb, line 29 def to_hash super(axis: @axis, dim: @dim) end