class TensorStream::Graph

A class that defines a TensorStream graph

Attributes

collections[RW]
constants[RW]
eager_execution[RW]
node_keys[R]
nodes[RW]
random_seed[RW]

Public Class Methods

create_default() click to toggle source
# File lib/tensor_stream/graph.rb, line 75
def self.create_default
  Thread.current[:tensor_stream_current_graph] = TensorStream::Graph.new
end
get_default_graph() click to toggle source
# File lib/tensor_stream/graph.rb, line 71
def self.get_default_graph
  Thread.current[:tensor_stream_current_graph] || create_default
end
new() click to toggle source
# File lib/tensor_stream/graph.rb, line 9
def initialize
  @eager_execution = false
  @nodes = {}
  @node_keys = []
  @collections = {
    :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
    :"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
  }
  @constants = {}
end
parse_from_string(buffer) click to toggle source
# File lib/tensor_stream/graph.rb, line 270
def self.parse_from_string(buffer)
  builder = TensorStream::GraphBuilder.new(Graph.new)
  builder.build(buffer)
end

Public Instance Methods

[](name) click to toggle source
# File lib/tensor_stream/graph.rb, line 122
def [](name)
  get_node(name)
end
add_node(node, name = nil) click to toggle source
# File lib/tensor_stream/graph.rb, line 88
def add_node(node, name = nil)
  raise "Placeholder cannot be used when eager_execution is enabled" if @eager_execution && node.is_a?(Placeholder)

  if name.nil?
    node.name = if @nodes[node.name]
      uniqunify(node.name)
    else
      node.name
    end
  end

  node.device = get_device_scope
  @node_keys << node.name
  @nodes[node.name] = node
  @constants[node.name] = node if node.is_const

  node.send(:propagate_outputs)
  node.send(:propagate_consumer, node)
end
add_node!(name, node) click to toggle source
# File lib/tensor_stream/graph.rb, line 126
def add_node!(name, node)
  @nodes[name] = node
  node
end
add_op(operation, *args) click to toggle source
# File lib/tensor_stream/graph.rb, line 131
def add_op(operation, *args)
  options = if args.last.is_a?(Hash)
    args.pop || {}
  else
    {}
  end

  inputs = args.map { |i| TensorStream.convert_to_tensor(i) }.map { |i| i ? i.op : nil }

  new_op = Operation.new(self, inputs: inputs, options: options)
  new_op.source = format_source(caller_locations)
  new_op.operation = operation
  new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
  new_op.rank = new_op.shape.rank
  new_op.name = options[:internal_name] || [get_name_scope, options[:name] || set_operation_name(new_op)].compact.reject(&:empty?).join("/")
  new_op.internal = options[:internal]

  new_op.data_type = new_op.set_data_type(options[:data_type])
  new_op.is_const = new_op.infer_const

  new_op.given_name = new_op.name

  new_op
end
add_op!(operation, *args) click to toggle source
# File lib/tensor_stream/graph.rb, line 156
def add_op!(operation, *args)
  add_op(operation, *args).tap { |node| add_node(node) }
end
add_to_collection(collection_name, val) click to toggle source
# File lib/tensor_stream/graph.rb, line 83
def add_to_collection(collection_name, val)
  @collections[collection_name.to_sym] ||= []
  @collections[collection_name.to_sym] << val
end
add_variable(node, options = {}) click to toggle source
# File lib/tensor_stream/graph.rb, line 164
def add_variable(node, options = {})
  scope = _variable_scope

  raise "duplicate variable detected #{node.name} and reuse=false in current scope" if @nodes[node.name] && !scope.reuse
  return @nodes[node.name] if @nodes[node.name]
  raise "shape is not declared for #{node.name}" if node.shape.nil?

  if !options[:collections].nil? && !options[:collections].empty?
    options[:collections] = [options[:collections]] unless options[:collections].is_a?(Array)
    options[:collections].each { |coll| add_to_collection(coll, node) }
  end

  add_to_collection(GraphKeys::GLOBAL_VARIABLES, node)
  add_to_collection(GraphKeys::TRAINABLE_VARIABLES, node) if node.trainable?

  node
end
add_variable!(node, options = {}) click to toggle source
# File lib/tensor_stream/graph.rb, line 182
def add_variable!(node, options = {})
  node = add_variable(node, options)
  op = Graph.get_default_graph.add_op!(:variable_v2, var_name: node.name, shape: options[:shape], data_type: options[:data_type])
  op
end
as_default() { |self| ... } click to toggle source
# File lib/tensor_stream/graph.rb, line 36
def as_default
  Thread.current[:tensor_stream_current_graph_queue] ||= []
  Thread.current[:tensor_stream_current_graph_queue] << Graph.get_default_graph

  Thread.current[:tensor_stream_current_graph] = self
  yield(self) if block_given?
  Thread.current[:tensor_stream_current_graph] = Thread.current[:tensor_stream_current_graph_queue].pop
  self
end
as_graph_def() click to toggle source
# File lib/tensor_stream/graph.rb, line 266
def as_graph_def
  TensorStream::Pbtext.new.get_string(self)
end
control_dependencies(control_inputs = []) { || ... } click to toggle source
# File lib/tensor_stream/graph.rb, line 188
def control_dependencies(control_inputs = [])
  Thread.current["ts_graph_#{object_id}"] ||= {}
  Thread.current["ts_graph_#{object_id}"][:control_dependencies] ||= []
  Thread.current["ts_graph_#{object_id}"][:control_dependencies] << Graph.get_default_graph.add_op!(:no_op, *control_inputs)
  begin
    yield
  ensure
    Thread.current["ts_graph_#{object_id}"][:control_dependencies].pop
  end
end
device(device_name) { || ... } click to toggle source

Returns a context manager that specifies the default device to use.

# File lib/tensor_stream/graph.rb, line 60
def device(device_name)
  Thread.current["ts_graph_#{object_id}"] ||= {}
  Thread.current["ts_graph_#{object_id}"][:default_device] ||= []
  Thread.current["ts_graph_#{object_id}"][:default_device] << device_name
  begin
    yield
  ensure
    Thread.current["ts_graph_#{object_id}"][:default_device].pop
  end
end
disable_eager_execution() click to toggle source
# File lib/tensor_stream/graph.rb, line 203
def disable_eager_execution
  @eager_execution = false
end
enable_eager_execution() click to toggle source
# File lib/tensor_stream/graph.rb, line 199
def enable_eager_execution
  @eager_execution = true
end
executing_eagerly?() click to toggle source
# File lib/tensor_stream/graph.rb, line 207
def executing_eagerly?
  @eager_execution
end
get_collection(name, _options = {}) click to toggle source
# File lib/tensor_stream/graph.rb, line 79
def get_collection(name, _options = {})
  @collections[name.to_sym]
end
get_const_counter() click to toggle source
# File lib/tensor_stream/graph.rb, line 238
def get_const_counter
  @const_counter ||= 0

  name = @const_counter.zero? ? "" : "_#{@const_counter}"

  @const_counter += 1
  name
end
get_dependency_scope() click to toggle source
# File lib/tensor_stream/graph.rb, line 254
def get_dependency_scope
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
  return nil if graph_thread_storage.nil? || graph_thread_storage[:control_dependencies].nil?
  graph_thread_storage[:control_dependencies].last
end
get_device_scope() click to toggle source
# File lib/tensor_stream/graph.rb, line 260
def get_device_scope
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
  return :default if graph_thread_storage.nil? || graph_thread_storage[:default_device].nil?
  graph_thread_storage[:default_device].last
end
get_name_scope() click to toggle source
# File lib/tensor_stream/graph.rb, line 247
def get_name_scope
  graph_thread_storage = Thread.current["ts_graph_#{object_id}"]
  return nil if graph_thread_storage.nil? || graph_thread_storage[:current_scope].nil?

  graph_thread_storage[:current_scope].join("/")
end
get_node(name) click to toggle source
# File lib/tensor_stream/graph.rb, line 112
def get_node(name)
  @nodes[name]
end
get_operation_counter() click to toggle source
# File lib/tensor_stream/graph.rb, line 211
def get_operation_counter
  @op_counter ||= 0

  name = @op_counter.zero? ? "" : "_#{@op_counter}"

  @op_counter += 1

  name
end
get_placeholder_counter() click to toggle source
# File lib/tensor_stream/graph.rb, line 221
def get_placeholder_counter
  @placeholder_counter ||= 0
  @placeholder_counter += 1

  return "" if @placeholder_counter == 1

  "_#{@placeholder_counter}"
end
get_tensor_by_name(name) click to toggle source
# File lib/tensor_stream/graph.rb, line 116
def get_tensor_by_name(name)
  raise TensorStream::KeyError, "#{name} not found" unless @nodes.key?(name)

  get_node(name)
end
get_var_counter() click to toggle source
# File lib/tensor_stream/graph.rb, line 230
def get_var_counter
  @var_counter ||= 0
  @var_counter += 1

  return "" if @var_counter == 1
  "_#{@var_counter}"
end
graph_def_versions() click to toggle source
# File lib/tensor_stream/graph.rb, line 275
def graph_def_versions
  "producer: 26"
end
name_scope(name = nil) { |get_name_scope| ... } click to toggle source
# File lib/tensor_stream/graph.rb, line 46
def name_scope(name = nil)
  Thread.current["ts_graph_#{object_id}"] ||= {}
  Thread.current["ts_graph_#{object_id}"][:current_scope] ||= []
  Thread.current["ts_graph_#{object_id}"][:current_scope] << name

  begin
    yield get_name_scope if block_given?
  ensure
    Thread.current["ts_graph_#{object_id}"][:current_scope].pop
  end
end
node_added?(name) click to toggle source
# File lib/tensor_stream/graph.rb, line 108
def node_added?(name)
  @nodes.key?(name)
end
reset() click to toggle source
# File lib/tensor_stream/graph.rb, line 20
def reset
  @placeholder_counter = 0
  @const_counter = 0
  @var_counter = 0
  @op_counter = 0
  @random_seed = nil
  @nodes = {}
  @node_keys = []
  @collections = {
    :"#{GraphKeys::GLOBAL_VARIABLES}" => [],
    :"#{GraphKeys::TRAINABLE_VARIABLES}" => [],
  }
  @constants = {}
  TensorStream::Evaluator.clear_storages(self)
end
set_operation_name(op) click to toggle source
# File lib/tensor_stream/graph.rb, line 160
def set_operation_name(op)
  op.operation.to_s
end

Protected Instance Methods

_variable_scope() click to toggle source
# File lib/tensor_stream/graph.rb, line 281
def _variable_scope
  return VariableScope.new(name: "", reuse: false, initializer: nil) if Thread.current[:tensor_stream_variable_scope].nil? || Thread.current[:tensor_stream_variable_scope].empty?
  scope = Thread.current[:tensor_stream_variable_scope].last
  scope
end
uniqunify(name) click to toggle source
# File lib/tensor_stream/graph.rb, line 287
def uniqunify(name)
  counter = 0
  new_name = name
  Kernel.loop do
    counter += 1
    new_name = "#{name}_#{counter}"

    break unless @nodes.key?(new_name)
  end
  new_name
end