class TensorStream::GraphBuilder

Attributes

graph[RW]

Public Class Methods

new(graph) click to toggle source
# File lib/tensor_stream/graph_builder.rb, line 8
def initialize(graph)
  @graph = graph
end

Public Instance Methods

build(buffer) click to toggle source
# File lib/tensor_stream/graph_builder.rb, line 12
def build(buffer)
  protobuf = TensorStream::Protobuf.new
  parsed_tree = protobuf.load_from_string(buffer)
  parsed_tree.each do |node|
    next unless node["type"] == "node"

    # puts "build #{node['name']}"
    options = protobuf.options_evaluator(node)
    options[:name] = node["name"]
    options[:__graph] = @graph
    value = options.delete("value")
    options = symbolize_keys(options)
    case node["op"]
    when "Const"
      dimension = shape_eval(value)
      rank = dimension.size
      options[:value] = value
      options[:const] = true
      TensorStream::Constant.new(options[:dtype] || options[:T], rank, dimension, options)
    when "VariableV2"
      # evaluate options
      shape = options[:shape]
      i_var(options[:dtype] || options[:T], nil, shape, nil, options)
    when "Placeholder"
      shape = options[:shape]
      TensorStream::Placeholder.new(options[:dtype] || options[:T], nil, shape, options)
    else
      op = underscore(node["op"]).to_sym
      puts "warning unsupported op #{op}" unless TensorStream::Evaluator::RubyEvaluator.ops.key?(op)

      # map input tensor
      inputs = node["input"].map { |input|
        input[0] = "" if input.start_with?("^")

        input_indexed, index = input.split(":")

        tensor = if index && index.to_i > 0
          @graph.get_tensor_by_name(input_indexed)[index.to_i]
        else
          @graph.get_tensor_by_name(input)
        end

        raise "tensor not found by name #{input}" if tensor.nil?

        tensor
      }

      options[:data_type] = options.delete(:T)
      Graph.get_default_graph.add_op!(op, *inputs, options)
    end
  end

  @graph
end