class DNN::Layers::Split

Attributes

axis[R]
dim[R]

Public Class Methods

new(axis: 1, dim: nil) click to toggle source
Calls superclass method DNN::Layers::Layer::new
# File lib/dnn/core/layers/split_layers.rb, line 10
def initialize(axis: 1, dim: nil)
  super()
  raise DNNError, "dim is nil" if dim == nil
  @axis = axis
  @dim = dim
end

Public Instance Methods

backward_node(dy1, dy2) click to toggle source
# File lib/dnn/core/layers/split_layers.rb, line 25
def backward_node(dy1, dy2)
  dy1.concatenate(dy2, axis: @axis)
end
forward_node(x) click to toggle source
# File lib/dnn/core/layers/split_layers.rb, line 17
def forward_node(x)
  x1_dim = @dim
  x2_dim = x.shape[@axis] - @dim
  y1, y2others = x.split([x1_dim, x1_dim + x2_dim], axis: @axis)
  y2 = y2others.is_a?(Array) ? y2others[0].concatenate(y2others[1..-1], axis: @axis) : y2others
  [y1, y2]
end
load_hash(hash) click to toggle source
# File lib/dnn/core/layers/split_layers.rb, line 33
def load_hash(hash)
  initialize(axis: hash[:axis], dim: hash[:dim])
end
to_hash() click to toggle source
Calls superclass method DNN::Layers::Layer#to_hash
# File lib/dnn/core/layers/split_layers.rb, line 29
def to_hash
  super(axis: @axis, dim: @dim)
end