class TensorStream::Pbtext
Parses pbtext files and loads it as a graph
Public Instance Methods
get_string(tensor_or_graph, session = nil, graph_keys = nil) { |graph, k| ... }
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 7 def get_string(tensor_or_graph, session = nil, graph_keys = nil) graph = tensor_or_graph.is_a?(Tensor) ? tensor_or_graph.graph : tensor_or_graph @lines = [] node_keys = graph_keys.nil? ? graph.node_keys : graph.node_keys.select { |k| graph_keys.include?(k) } node_keys.each do |k| node = if block_given? yield graph, k else graph.get_tensor_by_name(k) end @lines << "node {" @lines << " name: #{node.name.to_json}" if node.is_a?(TensorStream::Operation) @lines << " op: #{camelize(node.operation.to_s).to_json}" node.inputs.each do |input| next unless input @lines << " input: #{input.name.to_json}" end # type pb_attr("T", "type: #{sym_to_protobuf_type(node.data_type)}") case node.operation.to_s when "const" pb_attr("value", tensor_value(node)) when "variable_v2" pb_attr("shape", shape_buf(node, "shape")) end process_options(node) end @lines << "}" end @lines << "versions {" @lines << " producer: 26" @lines << "}" @lines.flatten.join("\n") end
Private Instance Methods
attr_value(val, indent = 0)
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 63 def attr_value(val, indent = 0) spaces = " " * indent case val.class.to_s when "TrueClass", "FalseClass" @lines << "#{spaces}b: #{val}" when "Integer" @lines << "#{spaces}i: #{val}" when "String", @lines << "#{spaces}s: #{val}" when "Float" @lines << "#{spaces}f: #{val}" when "Symbol" @lines << "#{spaces}sym: #{val}" when "Array" @lines << "#{spaces}list {" val.each do |v_item| attr_value(v_item, indent + 2) end @lines << "#{spaces}}" when "TensorStream::TensorShape" @lines << "#{spaces}shape {" val.shape&.each do |dim| @lines << "#{spaces} dim {" @lines << "#{spaces} size: #{dim}" @lines << "#{spaces} }" end @lines << "#{spaces}}" when "TensorStream::Variable" else raise "unknown type #{val.class}" end end
pack_arr_float(float_arr)
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 96 def pack_arr_float(float_arr) float_arr.flatten.pack("f*").bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join end
pack_arr_int(int_arr)
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 100 def pack_arr_int(int_arr) int_arr.flatten.pack("l*").bytes.map { |b| /[^[:print:]]/.match?(b.chr) ? "\\#{sprintf("%o", b).rjust(3, "0")}" : b.chr }.join end
pb_attr(key, value)
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 170 def pb_attr(key, value) @lines << " attr {" @lines << " key: \"#{key}\"" @lines << " value {" if value.is_a?(Array) value.each do |v| @lines << " #{v}" end else @lines << " #{value}" end @lines << " }" @lines << " }" end
process_options(node)
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 50 def process_options(node) return if node.options.nil? node.options.reject { |_k, v| v.nil? }.each do |k, v| next if %w[name internal_name data_type].include?(k.to_s) || k.to_s.start_with?("__") @lines << " attr {" @lines << " key: \"#{k}\"" @lines << " value {" attr_value(v, 6) @lines << " }" @lines << " }" end end
shape_buf(tensor, shape_type = "tensor_shape")
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 104 def shape_buf(tensor, shape_type = "tensor_shape") arr = [] arr << " #{shape_type} {" tensor.shape.shape&.each do |dim| arr << " dim {" arr << " size: #{dim}" arr << " }" end arr << " }" arr end
sym_to_protobuf_type(type)
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 153 def sym_to_protobuf_type(type) case type when :int32, :int "DT_INT32" when :int16 "DT_INT16" when :float, :float32 "DT_FLOAT" when :float64 "DT_FLOAT64" when :string "DT_STRING" else "UKNOWN" end end
tensor_value(tensor)
click to toggle source
# File lib/tensor_stream/graph_serializers/pbtext.rb, line 116 def tensor_value(tensor) arr = [] arr << "tensor {" arr << " dtype: #{sym_to_protobuf_type(tensor.data_type)}" arr += shape_buf(tensor) if tensor.rank > 0 if TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type) packed = pack_arr_float(tensor.const_value) arr << " tensor_content: \"#{packed}\"" elsif TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type) packed = pack_arr_int(tensor.const_value) arr << " tensor_content: \"#{packed}\"" elsif tensor.data_type == :string tensor.const_value.each do |v| arr << " string_val: #{v.to_json}" end else arr << " tensor_content: #{tensor.const_value.flatten}" end else val_type = if TensorStream::Ops::INTEGER_TYPES.include?(tensor.data_type) "int_val" elsif TensorStream::Ops::FLOATING_POINT_TYPES.include?(tensor.data_type) "float_val" elsif tensor.data_type == :string "string_val" else "val" end arr << " #{val_type}: #{tensor.const_value.to_json}" end arr << "}" arr end