class Tensorflow::Graph::OperationDescription

Attributes

graph[R]
name[R]
op_def[R]

Public Class Methods

new(graph, op_type, inputs, attrs) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 6
def initialize(graph, op_type, inputs, attrs)
  @graph = graph
  @op_def = case op_type
              when Function
                op_type.function_def.signature
              else
                self.graph.op_def(op_type)
            end
  raise(Error::InvalidArgumentError, "Invalid op type: #{op_type}") unless @op_def

  raw_name = attrs.delete(:name)&.to_s || self.op_def.name
  @name = self.graph.scoped_name(raw_name)
  @pointer = FFI.TF_NewOperation(graph, self.op_def.name, @name)

  inputs = Array(inputs)
  setup_inputs(inputs, attrs)
  setup_control_inputs(graph.control_inputs)
  setup_attrs(**attrs)
end

Public Instance Methods

add_input(operation) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 174
def add_input(operation)
  # Check to see if the operation has multiple outputs, and if it does, we need to pack them together
  # to fit into one input
  if operation.is_a?(OperationOutput)
    FFI.TF_AddInput(self, operation)
  elsif operation.num_outputs > 1
    packed = Tensorflow.pack(operation, n: operation.num_outputs)
    FFI.TF_AddInput(self, packed.outputs.first)
  else
    FFI.TF_AddInput(self, operation.outputs.first)
  end
end
add_input_list(operations) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 187
def add_input_list(operations)
  # Operation can represent multiple operations *or* one operation with multiple outputs (like SPLIT)
  outputs = Array(operations).map(&:outputs).flatten
  outputs_ptr = FFI::Output.array_to_ptr(outputs.map(&:output))
  FFI.TF_AddInputList(self, outputs_ptr, outputs.length)
end
capture(operation) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 112
def capture(operation)
  if self.op_def.is_stateful
    raise(Error::InvalidArgumentError, "Cannot capture a stateful node (name: #{operation.name}, type: #{operation.op_type})")
  elsif operation.op_type == "Placeholder"
    raise(Error::InvalidArgumentError, "Cannot capture a placeholder by value (name: #{operation.name}, type: #{operation.op_type})")
  end

  attrs = operation.attributes.reduce(Hash.new) do |hash, attr|
    hash[attr.name.to_sym] = attr.value
    hash
  end
  attrs[:name] = operation.name

  captured_inputs = self.capture_inputs(operation, attrs)
  self.graph.create_operation(operation.op_type, captured_inputs, **attrs)
end
capture_inputs(operation, attrs) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 79
def capture_inputs(operation, attrs)
  # First capture the inputs
  inputs = operation.inputs.map do |input|
    self.capture(input.operation)
  end

  # We now have to group the inputs together. For example, a TensorSlice dataset has 1 input argument
  # which a list. But the number of inputs returned by the operation is actually the number of items in
  # the list, so its usually more than one. We need to group them into one array to be able to call
  # the operation to create a captured copy.
  i = 0
  operation.op_def.input_arg.reduce(Array.new) do |result, input_arg|
    if !input_arg.number_attr.empty?
      input_len = attrs[input_arg.number_attr.to_sym]
      is_sequence = true
    elsif !input_arg.type_list_attr.empty?
      input_len = attrs[input_arg.type_list_attr.to_sym].length
      is_sequence = true
    else
      input_len = 1
      is_sequence = false
    end

    if is_sequence
      result << inputs[i..i+input_len]
    else
      result << inputs[i]
    end
    i += input_len
    result
  end
end
check_input(arg_def, input, dtype) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 129
def check_input(arg_def, input, dtype)
  case input
    when Operation
      self.graph.equal?(input.graph) ? input : capture(input)
    when OperationOutput
      input
    when Variable
      arg_def.type == :DT_RESOURCE ? input.handle : input.value_handle
    else
      input_name = "#{self.name}/#{arg_def.name}"
      Tensorflow.constant(input, name: input_name, dtype: dtype)
  end
end
device=(value) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 56
def device=(value)
  FFI.TF_SetDevice(self, value)
end
figure_dtype(attrs, inputs) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 26
def figure_dtype(attrs, inputs)
  attr_def = self.op_def.attr.detect do |attr_def|
    attr_def.type == 'type'
  end

  result = attr_def ? attrs[attr_def.name.to_sym] : nil
  unless result
    inputs.each do |input|
      case input
        when Operation
          return input.output_types.first
        when Variable
          return input.dtype
      end
    end
  end
  result
end
save() click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 49
def save
  Status.check do |status|
    ptr = FFI.TF_FinishOperation(self, status)
    Operation.new(self.graph, ptr)
  end
end
setup_attr(name, value) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 200
def setup_attr(name, value)
  attr_def = self.op_def.attr.detect do |attr_def|
    name.to_s == attr_def.name
  end
  unless attr_def
    raise(Error::UnknownError, "Unknown attribute: #{name}")
  end

  case attr_def.type
    when 'bool'
      FFI.TF_SetAttrBool(self, attr_def.name, value ? 1 : 0)
    when 'int'
      FFI.TF_SetAttrInt(self, attr_def.name, value)
    when 'float'
      FFI.TF_SetAttrFloat(self, attr_def.name, value)
    when 'func'
      function_name = value.is_a?(Function) ? value.name : value
      FFI.TF_SetAttrFuncName(self, attr_def.name, function_name, function_name.length)
    when 'shape'
      pointer = ::FFI::MemoryPointer.new(:int64, value.length)
      pointer.write_array_of_int64(value)
      FFI.TF_SetAttrShape(self, attr_def.name, pointer, value.length)
    when 'list(shape)'
      dims_pointer = ::FFI::MemoryPointer.new(:pointer, value.length)
      num_dims_pointer = ::FFI::MemoryPointer.new(:int32, value.length)
      value.each_with_index do |shape, i|
        dim_pointer = ::FFI::MemoryPointer.new(:int64, shape.length)
        dim_pointer.write_array_of_int64(shape)
        dims_pointer.put_pointer(i * ::FFI.type_size(:pointer), dim_pointer)
        num_dims_pointer.put_int32(i * ::FFI.type_size(:int32), shape.length)
      end
      FFI.TF_SetAttrShapeList(self, attr_def.name, dims_pointer, num_dims_pointer, value.length)
    when 'string'
      FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
    when 'list(string)'
      a = 1
      #FFI.TF_SetAttrString(self, attr_def.name, value, value.length)
    when 'tensor'
      Status.check do |status|
        FFI.TF_SetAttrTensor(self, attr_def.name, value, status)
      end
    when 'type'
      FFI.TF_SetAttrType(self, attr_def.name, value)
    when 'list(type)'
      value_ptr = ::FFI::MemoryPointer.new(FFI::DataType.native_type.size, value.count)
      value.each_with_index do |a_value, i|
        value_ptr.put_int32(i * FFI::DataType.native_type.size, FFI::DataType[a_value])
      end
      FFI.TF_SetAttrTypeList(self, attr_def.name, value_ptr, value.count)
    else
      raise(Error::UnimplementedError, "Unsupported attribute. #{self.op_def.name} - #{attr_def.name}")
  end
end
setup_attrs(**attrs) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 194
def setup_attrs(**attrs)
  attrs.each do |attr_name, attr_value|
    self.setup_attr(attr_name, attr_value)
  end
end
setup_control_input(control_input) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 66
def setup_control_input(control_input)
  control_input = case control_input
                    when Operation
                      control_input
                    when Variable
                      control_input.handle
                    else
                      raise(Error::InvalidArgumentError, "Invalid control input")
                    end

  FFI.TF_AddControlInput(self, control_input)
end
setup_control_inputs(control_inputs) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 60
def setup_control_inputs(control_inputs)
  control_inputs.each do |control_input|
    setup_control_input(control_input)
  end
end
setup_input(index, value, attrs) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 149
def setup_input(index, value, attrs)
  arg_def = self.op_def.input_arg[index]
  dtype = attrs[arg_def.type_attr.to_sym]

  # Value can be an operation with multiple outputs. For example calling PACK with an input operation of SPLIT
  checked_value = if (!arg_def.number_attr.empty? || !arg_def.type_list_attr.empty?)  && value.is_a?(Array)
                    value.map do |sub_value|
                      self.check_input(arg_def, sub_value, dtype)
                    end
                  else
                    self.check_input(arg_def, value, dtype)
                  end

  if !arg_def.type_list_attr.empty?
    # This input is a heterogeneous list
    self.add_input_list(checked_value)
  elsif !arg_def.number_attr.empty?
    # This input is a homogeneous list
    self.add_input_list(checked_value)
  else
    # This input is a single item
    self.add_input(checked_value)
  end
end
setup_inputs(inputs, attrs) click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 143
def setup_inputs(inputs, attrs)
  inputs.each_with_index do |input, index|
    self.setup_input(index, input, attrs)
  end
end
to_ptr() click to toggle source
# File lib/tensorflow/graph/operation_description.rb, line 45
def to_ptr
  @pointer
end