class TensorStream::Protobuf

A .pb graph deserializer

Constants

UNESCAPES

Public Class Methods

new() click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 6
def initialize
end

Public Instance Methods

evaluate_tensor_node(node) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 31
def evaluate_tensor_node(node)
  if !node["shape"].empty? && node["tensor_content"]
    content = node["tensor_content"]
    unpacked = eval(%("#{content}"))

    if node["dtype"] == "DT_FLOAT"
      TensorShape.reshape(unpacked.unpack("f*"), node["shape"])
    elsif node["dtype"] == "DT_INT32"
      TensorShape.reshape(unpacked.unpack("l*"), node["shape"])
    elsif node["dtype"] == "DT_STRING"
      node["string_val"]
    else
      raise "unknown dtype #{node["dtype"]}"
    end
  else

    val = if node["dtype"] == "DT_FLOAT"
      node["float_val"] ? node["float_val"].to_f : []
    elsif node["dtype"] == "DT_INT32"
      node["int_val"] ? node["int_val"].to_i : []
    elsif node["dtype"] == "DT_STRING"
      node["string_val"]
    else
      raise "unknown dtype #{node["dtype"]}"
    end

    if node["shape"] == [1]
      [val]
    else
      val
    end
  end
end
load(pbfile) click to toggle source

parsers a protobuf file and spits out a ruby hash

# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 16
def load(pbfile)
  f = File.new(pbfile, "r")
  lines = []
  while !f.eof? && (str = f.readline.strip)
    lines << str
  end
  evaluate_lines(lines)
end
load_from_string(buffer) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 9
def load_from_string(buffer)
  evaluate_lines(buffer.split("\n").map(&:strip))
end
map_type_to_ts(attr_value) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 65
def map_type_to_ts(attr_value)
  case attr_value
  when "DT_FLOAT"
    :float32
  when "DT_INT32"
    :int32
  when "DT_INT64"
    :int64
  when "DT_STRING"
    :string
  when "DT_BOOL"
    :boolean
  else
    raise "unknown type #{attr_value}"
  end
end
options_evaluator(node) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 82
def options_evaluator(node)
  return {} if node["attributes"].nil?

  node["attributes"].map { |attribute|
    attr_type, attr_value = attribute["value"].flat_map { |k, v| [k, v] }

    if attr_type == "tensor"
      attr_value = evaluate_tensor_node(attr_value)
    elsif attr_type == "type"
      attr_value = map_type_to_ts(attr_value)
    elsif attr_type == "b"
      attr_value = attr_value == "true"
    end

    [attribute["key"], attr_value]
  }.to_h
end
parse_value(value_node) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 25
def parse_value(value_node)
  return unless value_node["tensor"]

  evaluate_tensor_node(value_node["tensor"])
end

Protected Instance Methods

evaluate_lines(lines = []) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 102
def evaluate_lines(lines = [])
  block = []
  node = {}
  node_attr = {}
  state = :top

  lines.each do |str|
    case state
    when :top
      node["type"] = parse_node_name(str)
      state = :node_context
      next
    when :node_context
      if str == "attr {"
        state = :attr_context
        node_attr = {}
        node["attributes"] ||= []
        node["attributes"] << node_attr
        next
      elsif str == "}"
        state = :top
        block << node
        node = {}
        next
      else
        key, value = str.split(":", 2)
        if key == "input"
          node["input"] ||= []
          node["input"] << process_value(value.strip)
        else
          node[key] = process_value(value.strip)
        end
      end
    when :attr_context
      if str == "value {"
        state = :value_context
        node_attr["value"] = {}
        next
      elsif str == "}"
        state = :node_context
        next
      else
        key, value = str.split(":", 2)
        node_attr[key] = process_value(value.strip)
      end
    when :value_context
      if str == "list {"
        state = :list_context
        node_attr["value"] = []
        next
      elsif str == "shape {"
        state = :shape_context
        node_attr["value"]["shape"] = []
        next
      elsif str == "tensor {"
        state = :tensor_context
        node_attr["value"]["tensor"] = {}
        next
      elsif str == "}"
        state = :attr_context
        next
      else
        key, value = str.split(":", 2)
        if key == "dtype"
          node_attr["value"]["dtype"] = value.strip
        elsif key == "type"
          node_attr["value"]["type"] = value.strip
        else
          node_attr["value"][key] = process_value(value.strip)
        end
      end
    when :list_context
      if str == "}"
        state = :value_context
        next
      else
        key, value = str.split(":", 2)
        node_attr["value"] << {key => value}
      end
    when :tensor_context
      if str == "tensor_shape {"
        state = :tensor_shape_context
        node_attr["value"]["tensor"]["shape"] = []
        next
      elsif str == "}"
        state = :value_context
        next
      else
        key, value = str.split(":", 2)
        if node_attr["value"]["tensor"][key] && !node_attr["value"]["tensor"][key].is_a?(Array)
          node_attr["value"]["tensor"][key] = [node_attr["value"]["tensor"][key]]
          node_attr["value"]["tensor"][key] << process_value(value.strip)
        elsif node_attr["value"]["tensor"][key]
          node_attr["value"]["tensor"][key] << process_value(value.strip)
        else
          node_attr["value"]["tensor"][key] = process_value(value.strip)
        end
      end
    when :tensor_shape_context
      if str == "dim {"
        state = :tensor_shape_dim_context
        next
      elsif str == "}"
        state = :tensor_context
        next
      end
    when :shape_context
      if str == "}"
        state = :value_context
        next
      elsif str == "dim {"
        state = :shape_dim_context
        next
      end
    when :shape_dim_context
      if str == "}"
        state = :shape_context
        next
      else
        _key, value = str.split(":", 2)
        node_attr["value"]["shape"] << value.strip.to_i
      end
    when :tensor_shape_dim_context
      if str == "}"
        state = :tensor_shape_context
        next
      else
        _key, value = str.split(":", 2)
        node_attr["value"]["tensor"]["shape"] << value.strip.to_i
      end
    end
  end

  block
end
parse_node_name(str) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 238
def parse_node_name(str)
  str.split(" ")[0]
end
process_value(value) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 242
def process_value(value)
  if value.start_with?('"')
    unescape(value.gsub!(/\A"|"\Z/, ""))
  else
    unescape(value)
  end
end
unescape(str) click to toggle source
# File lib/tensor_stream/graph_deserializers/protobuf.rb, line 257
def unescape(str)
  # Escape all the things
  str.gsub(/\\(?:([#{UNESCAPES.keys.join}])|u([\da-fA-F]{4}))|\\0?x([\da-fA-F]{2})/) do
    if $1
      $1 == '\\' ? '\\' : UNESCAPES[$1]
    elsif $2 # escape \u0000 unicode
      [$2.to_s.hex].pack("U*")
    elsif $3 # escape \0xff or \xff
      [$3].pack("H2")
    end
  end
end