class Tensorflow::Tensor

Public Class Methods

finalize(pointer) click to toggle source
# File lib/tensorflow/tensor.rb, line 6
def self.finalize(pointer)
  proc do
    FFI.TF_DeleteTensor(pointer)
  end
end
from_pointer(pointer) click to toggle source
# File lib/tensorflow/tensor.rb, line 41
def self.from_pointer(pointer)
  result = self.allocate
  result.instance_variable_set(:@pointer, pointer)
  ObjectSpace.define_finalizer(result, self.finalize(pointer))
  result
end
from_proto(proto) click to toggle source
# File lib/tensorflow/tensor.rb, line 27
def self.from_proto(proto)
  proto = proto.is_a?(TensorProto) ? proto : TensorProto.decode(proto)
  shape = proto.tensor_shape.dim.map(&:size)
  dtype = FFI::DataType[DataType.resolve(proto.dtype)]
  numo_klass = TensorData::DTYPE_TO_NUMO_TYPE_MAP[dtype]
  value = if shape.empty?
            array = numo_klass.from_binary(proto.tensor_content)
            array[0]
           else
             numo_klass.from_binary(proto.tensor_content, shape)
           end
  self.new(value, dtype:dtype, shape:shape)
end
from_value(value, dtype: nil) click to toggle source
# File lib/tensorflow/tensor.rb, line 12
def self.from_value(value, dtype: nil)
  case value
    when Tensor
      value
    when Graph::Operation
      value
    when Eager::TensorHandle
      value.tensor
    when Data::Dataset
      value.variant_tensor
    else
      Tensor.new(value, dtype: dtype)
  end
end
new(value, dtype: nil, shape: []) click to toggle source
# File lib/tensorflow/tensor.rb, line 48
def initialize(value, dtype: nil, shape: [])
  value = case value
            when Numo::NArray
              value
            when Array
              # We convert all arrays to narrays. This makes it a lot easier to support multidimensional arrays
              result = Numo::NArray.cast(value)
            else
              TensorData.value_with_shape(value, dtype, shape)
          end

  tensor_data = TensorData.new(value, dtype: dtype, shape: shape)
  dtype = tensor_data.dtype
  shape = tensor_data.shape

  if shape && shape.size > 0
    dims_ptr = ::FFI::MemoryPointer.new(:int64, shape.size)
    dims_ptr.write_array_of_int64(shape)
  else
    dims_ptr = nil
  end

  @pointer = FFI.TF_NewTensor(dtype,
                              dims_ptr, shape ? shape.size : 0,
                              tensor_data, tensor_data.byte_size,
                              TensorData::Deallocator, nil)

  ObjectSpace.define_finalizer(self, self.class.finalize(@pointer))
end

Public Instance Methods

byte_size() click to toggle source
# File lib/tensorflow/tensor.rb, line 94
def byte_size
  FFI.TF_TensorByteSize(self)
end
data() click to toggle source
# File lib/tensorflow/tensor.rb, line 103
def data
  TensorData.from_pointer(FFI.TF_TensorData(self), self.byte_size, self.dtype, self.shape)
end
dtype() click to toggle source
# File lib/tensorflow/tensor.rb, line 82
def dtype
  FFI.TF_TensorType(self)
end
inspect() click to toggle source
# File lib/tensorflow/tensor.rb, line 98
def inspect
  inspection = %w(numo shape dtype).map { |v| "#{v}: #{send(v).inspect}"}
  "#<#{self.class} #{inspection.join(", ")}>"
end
to_ptr() click to toggle source
# File lib/tensorflow/tensor.rb, line 90
def to_ptr
  @pointer
end
to_s() click to toggle source
# File lib/tensorflow/tensor.rb, line 86
def to_s
  inspect
end
value() click to toggle source
# File lib/tensorflow/tensor.rb, line 78
def value
  self.data.value
end

Private Instance Methods

calculate_shape(value) click to toggle source
# File lib/tensorflow/tensor.rb, line 121
def calculate_shape(value)
  return value.shape if value.respond_to?(:shape)

  shape = []
  d = value
  while d.is_a?(Array)
    shape << d.size
    d = d.first
  end
  shape
end
dim(index) click to toggle source
# File lib/tensorflow/tensor.rb, line 113
def dim(index)
  FFI.TF_Dim(self, index)
end
element_count() click to toggle source
# File lib/tensorflow/tensor.rb, line 117
def element_count
  FFI.TF_TensorElementCount(self)
end
num_dims() click to toggle source
# File lib/tensorflow/tensor.rb, line 109
def num_dims
  FFI.TF_NumDims(self)
end