module DNN::Layers::MathUtils
Public Instance Methods
align_ndim(shape1, shape2)
click to toggle source
# File lib/dnn/core/layers/math_layers.rb, line 6 def align_ndim(shape1, shape2) if shape1.length < shape2.length shape2.length.times do |axis| unless shape1[axis] == shape2[axis] shape1.insert(axis, 1) end end elsif shape1.length > shape2.length shape1.length.times do |axis| unless shape1[axis] == shape2[axis] shape2.insert(axis, 1) end end end [shape1, shape2] end
broadcast_to(x, target_shape)
click to toggle source
# File lib/dnn/core/layers/math_layers.rb, line 23 def broadcast_to(x, target_shape) return x if x.shape == target_shape x_shape, target_shape = align_ndim(x.shape, target_shape) x = x.reshape(*x_shape) x_shape.length.times do |axis| unless x.shape[axis] == target_shape[axis] tmp = x (target_shape[axis] - 1).times do x = x.concatenate(tmp, axis: axis) end end end x end
sum_to(x, target_shape)
click to toggle source
# File lib/dnn/core/layers/math_layers.rb, line 38 def sum_to(x, target_shape) return x if x.shape == target_shape x_shape, target_shape = align_ndim(x.shape, target_shape) x = x.reshape(*x_shape) x_shape.length.times do |axis| unless x.shape[axis] == target_shape[axis] x = x.sum(axis: axis, keepdims: true) end end x end