module DNN::Layers::MathUtils

Public Instance Methods

align_ndim(shape1, shape2) click to toggle source
# File lib/dnn/core/layers/math_layers.rb, line 6
def align_ndim(shape1, shape2)
  if shape1.length < shape2.length
    shape2.length.times do |axis|
      unless shape1[axis] == shape2[axis]
        shape1.insert(axis, 1)
      end
    end
  elsif shape1.length > shape2.length
    shape1.length.times do |axis|
      unless shape1[axis] == shape2[axis]
        shape2.insert(axis, 1)
      end
    end
  end
  [shape1, shape2]
end
broadcast_to(x, target_shape) click to toggle source
# File lib/dnn/core/layers/math_layers.rb, line 23
def broadcast_to(x, target_shape)
  return x if x.shape == target_shape
  x_shape, target_shape = align_ndim(x.shape, target_shape)
  x = x.reshape(*x_shape)
  x_shape.length.times do |axis|
    unless x.shape[axis] == target_shape[axis]
      tmp = x
      (target_shape[axis] - 1).times do
        x = x.concatenate(tmp, axis: axis)
      end
    end
  end
  x
end
sum_to(x, target_shape) click to toggle source
# File lib/dnn/core/layers/math_layers.rb, line 38
def sum_to(x, target_shape)
  return x if x.shape == target_shape
  x_shape, target_shape = align_ndim(x.shape, target_shape)
  x = x.reshape(*x_shape)
  x_shape.length.times do |axis|
    unless x.shape[axis] == target_shape[axis]
      x = x.sum(axis: axis, keepdims: true)
    end
  end
  x
end