class TensorStream::YamlLoader

Class for deserialization from a YAML file

Public Class Methods

new(graph = nil) click to toggle source
# File lib/tensor_stream/graph_deserializers/yaml_loader.rb, line 5
def initialize(graph = nil)
  @graph = graph || TensorStream.get_default_graph
end

Public Instance Methods

load_from_file(filename) click to toggle source

Loads a model Yaml file and builds the model from it

Args: filename: String - Location of Yaml file

Returns: Graph where model is restored to

# File lib/tensor_stream/graph_deserializers/yaml_loader.rb, line 16
def load_from_file(filename)
  load_from_string(File.read(filename))
end
load_from_string(buffer) click to toggle source

Loads a model Yaml file and builds the model from it

Args: buffer: String - String in Yaml format of the model

Returns: Graph where model is restored to

# File lib/tensor_stream/graph_deserializers/yaml_loader.rb, line 27
def load_from_string(buffer)
  serialized_ops = YAML.safe_load(buffer, [Symbol], [], true)
  serialized_ops.each do |op_def|
    inputs = op_def[:inputs].map { |i| @graph.get_tensor_by_name(i) }
    options = {}

    new_var = nil
    if op_def[:op].to_sym == :variable_v2
      new_var = Variable.new(op_def.dig(:attrs, :data_type))

      var_options = {}
      var_options[:name] = op_def.dig(:attrs, :var_name)

      new_var.prepare(nil, nil, TensorStream.get_variable_scope, var_options)
      @graph.add_variable(new_var, var_options)
    end

    new_op = Operation.new(@graph, inputs: inputs, options: op_def[:attrs].merge(options))
    new_op.operation = op_def[:op].to_sym
    new_op.name = op_def[:name]
    new_op.shape = TensorShape.new(TensorStream::InferShape.infer_shape(new_op))
    new_op.rank = new_op.shape.rank
    new_op.data_type = new_op.set_data_type(op_def.dig(:attrs, :data_type))
    new_op.is_const = new_op.infer_const
    new_op.given_name = new_op.name
    new_var.op = new_op if new_var

    @graph.add_node(new_op)
  end
  @graph
end