class TensorStream::TensorShape

class that defines a shape for TensorFlow compatibility

Attributes

rank[RW]
shape[RW]

Public Class Methods

fix_inferred_elements(shape, total_size) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 121
def self.fix_inferred_elements(shape, total_size)
  return shape if shape.empty?
  return nil if shape[0].is_a?(Tensor)

  current_size = shape.inject(1) { |product, n| n > 0 ? product * n : product }
  inferred_size = total_size.nil? ? nil : total_size / current_size
  shape.map { |s| s == -1 ? inferred_size : s }
end
infer_shape(shape_a, shape_b) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 76
def self.infer_shape(shape_a, shape_b)
  return nil if shape_a.nil? || shape_b.nil?
  return shape_a if shape_b.empty?
  return shape_b if shape_a.empty?
  return shape_a if shape_a == shape_b
  return shape_b if shape_b.size > shape_a.size
  return shape_a if shape_a.size > shape_b.size

  reversed_a = shape_a.reverse
  reversed_b = shape_b.reverse

  reversed_a.each_with_index.collect { |s, index|
    next s if index >= reversed_b.size
    next nil if s.nil? || reversed_b[index].nil?
    next nil if s.is_a?(Tensor) || reversed_b[index].is_a?(Tensor)
    next reversed_b[index] if reversed_b[index] > s

    s
  }.reverse
end
new(shape, rank = nil) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 6
def initialize(shape, rank = nil)
  @shape = shape
  @rank = rank.nil? && shape ? shape.size : rank
end
reshape(arr, new_shape) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 97
def self.reshape(arr, new_shape)
  arr = arr.is_a?(Array) ? arr.flatten : [arr]
  new_shape = new_shape.is_a?(TensorShape) ? new_shape.shape : new_shape
  new_shape = TensorShape.fix_inferred_elements(new_shape, arr.size)
  return arr[0] if arr.size == 1 && new_shape.empty?

  new_shape = new_shape.dup

  s = new_shape.shift

  if new_shape.size.zero?
    raise "reshape dimen mismatch #{arr.size} != #{s}" if arr.size != s

    return arr
  end

  dim = (arr.size / s)
  return arr if dim.zero?

  arr.each_slice(dim).collect do |slice|
    reshape(slice, new_shape.dup)
  end
end

Public Instance Methods

[](index) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 20
def [](index)
  new_shape = @shape[index]
  TensorShape.new(@shape[index])
end
as_dimension(value) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 62
def as_dimension(value)
  value.is_a?(TensorShape) ? value.shape : value
end
assert_compatible_with(other) click to toggle source

Raises an exception if `other` is not compatible with this shape.

# File lib/tensor_stream/tensor_shape.rb, line 72
def assert_compatible_with(other)
  raise TensorStream::ValueError, "Dimensions #{self} and #{other} are not compatible" unless compatible_with?(other)
end
compatible_with?(other) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 56
def compatible_with?(other)
  other = as_dimension(other)

  shape.nil? || other.nil? || shape == other
end
fully_defined?() click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 42
def fully_defined?
  known?
end
known?() click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 33
def known?
  return false if shape.nil?

  a_shape = shape.is_a?(Array) ? shape : [shape]
  a_shape.each { |s| return false if s.nil? || s < 0 }

  true
end
merge_with(other) click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 46
def merge_with(other)
  assert_compatible_with(other)

  if @shape.nil?
    TensorShape.new(other)
  else
    TensorShape.new(@shape)
  end
end
ndims() click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 25
def ndims
  shape ? shape.size : nil
end
scalar?() click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 29
def scalar?
  known? && shape.size.zero?
end
to_s() click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 11
def to_s
  return "?" if @shape.nil?

  dimensions = @shape.collect { |r|
    "Dimension(#{r})"
  }.join(",")
  "TensorShape([#{dimensions}])"
end
value() click to toggle source
# File lib/tensor_stream/tensor_shape.rb, line 66
def value
  shape
end