class Tensorflow::Variable

Attributes

dtype[R]
handle[R]
name[R]

Public Class Methods

new(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true) click to toggle source
# File lib/tensorflow/variable.rb, line 7
def initialize(initial_value = nil, dtype: nil, shape: nil, shared_name: nil, name: 'Variable', trainable: true)
  initial_value = case initial_value
                  when NilClass
                    @dtype = dtype
                    shape = []
                    initial_value
                  when Graph::Operation
                    @dtype = dtype || initial_value.dtype
                    shape = shape || initial_value.output_shapes.first
                    initial_value
                  when Tensor
                    @dtype = initial_value.dtype
                    shape = shape || initial_value.shape
                    initial_value
                  else
                    tensor = Tensor.from_value(initial_value, dtype: dtype)
                    @dtype = tensor.dtype
                    shape = tensor.shape
                    tensor
                  end

  name = name&.to_s
  shared_name = shared_name&.to_s
  unique_name = ExecutionContext.current.unique_name(name || shared_name)
  shared_name ||= unique_name
  @name = unique_name

  collections = [Graph::GraphKeys::GLOBAL_VARIABLES]
  if trainable
    collections << Graph::GraphKeys::TRAINABLE_VARIABLES
  end

  ExecutionContext.current.add_to_collections(collections, self)

  @handle = RawOps.var_handle_op(dtype: @dtype, shape: shape, shared_name: shared_name, name: unique_name)
  self.value = initial_value if initial_value
end

Public Instance Methods

assign_add(value, dtype: nil) click to toggle source
# File lib/tensorflow/variable.rb, line 101
def assign_add(value, dtype: nil)
  @value_handle = nil
  tensor = Tensor.from_value(value, dtype: dtype)
  tensor = Tensorflow.cast(tensor, self.dtype)
  RawOps.assign_add_variable_op(self.handle, value, dtype: tensor.dtype)
end
assign_sub(value) click to toggle source
# File lib/tensorflow/variable.rb, line 108
def assign_sub(value)
  @value_handle = nil
  tensor = Tensor.from_value(value, dtype: dtype)
  tensor = Tensorflow.cast(tensor, self.dtype)
  RawOps.assign_sub_variable_op(self.handle, value, dtype: tensor.dtype)
end
consumers() click to toggle source

These methods match the operation api to enable gradients and sessions

# File lib/tensorflow/variable.rb, line 71
def consumers
  self.handle.consumers
end
initialized?() click to toggle source
# File lib/tensorflow/variable.rb, line 66
def initialized?
  RawOps.var_is_initialized_op(self.handle)
end
initializer() click to toggle source
# File lib/tensorflow/variable.rb, line 62
def initializer
  @initializer
end
inspect() click to toggle source
# File lib/tensorflow/variable.rb, line 119
def inspect
  inspection = []
  inspection << ["name: #{self.handle.name}"] if self.handle.respond_to?(:name)
  inspection << ["shape: #{self.value_handle.shape}"]
  inspection << ["dtype: #{self.value_handle.dtype}"]
  "#<#{self.class} #{inspection.join(", ")}>"
end
outputs() click to toggle source

This enables executing variables to get the values in a session

# File lib/tensorflow/variable.rb, line 76
def outputs
  [Graph::OperationOutput.from_index(self.value_handle, 0)]
end
rank() click to toggle source
# File lib/tensorflow/variable.rb, line 93
def rank
  self.shape.size
end
reshape(shape) click to toggle source
# File lib/tensorflow/variable.rb, line 97
def reshape(shape)
  RawOps.reshape(self, shape)
end
shape() click to toggle source
# File lib/tensorflow/variable.rb, line 84
def shape
  self.value_handle.shape
end
tensor() click to toggle source
# File lib/tensorflow/variable.rb, line 88
def tensor
  raise(Error::UnavailableError, "Only supported in eager execution mode") if Tensorflow.execution_mode == Tensorflow::GRAPH_MODE
  self.value_handle.tensor
end
to_ptr() click to toggle source
# File lib/tensorflow/variable.rb, line 80
def to_ptr
  self.handle.to_ptr
end
to_s() click to toggle source
# File lib/tensorflow/variable.rb, line 115
def to_s
  inspect
end
value() click to toggle source
# File lib/tensorflow/variable.rb, line 49
def value
  case value_handle
    when Eager::TensorHandle
      value_handle.value
    when Graph::Operation
      value_handle
  end
end
value=(value) click to toggle source
# File lib/tensorflow/variable.rb, line 58
def value=(value)
  @initializer = RawOps.assign_variable_op(self.handle, value, dtype: @dtype)
end
value_handle() click to toggle source
# File lib/tensorflow/variable.rb, line 45
def value_handle
  @value_handle ||= RawOps.read_variable_op(self.handle, dtype: @dtype)
end