module TensorStream::VariableOps

Collection of machine learning related ops

Public Class Methods

included(klass) click to toggle source
# File lib/tensor_stream/evaluator/ruby/variable_ops.rb, line 4
def self.included(klass)
  klass.class_eval do
    register_op :variable_v2 do |_context, tensor, _inputs|
      value = var_read_value(tensor)
      raise "variable #{tensor.options[:var_name]} not initalized" if value.nil?

      value
    end

    register_op :assign do |context, tensor, inputs|
      var_assign_value(tensor, inputs[0])
    end

    register_op :assign_add, no_eval: true do |context, tensor, inputs|
      current_val = var_read_value(tensor)

      raise "variable #{tensor.options[:var_name]} not initialized" if current_val.nil?
      eval_a, eval_b = broadcast(current_val, inputs[0])
      result = multi_array_op(->(var, val) { var + val }, eval_a, eval_b)
      var_assign_value(tensor, result)
    end

    register_op :assign_sub do |context, tensor, inputs|
      current_val = var_read_value(tensor)
      raise "variable #{tensor.options[:var_name]} not initialized" if current_val.nil?
      eval_a, eval_b = broadcast(current_val, inputs[0])
      result = multi_array_op(->(var, val) { var - val }, eval_a, eval_b)
      var_assign_value(tensor, result)
    end

    register_op :save_ts do |_context, tensor, inputs|
      outputfile = inputs[0]
      inputs = tensor.inputs.dup

      inputs.shift
      variables = {}
      inputs.each do |savable|
        val = var_read_value(savable)

        packed_data = Zlib::Deflate.deflate(TensorStream::Packer.pack(val, savable.data_type))
        variables[savable.options[:var_name]] = {
          "shape" => shape_eval(val),
          "data" => Base64.strict_encode64(packed_data),
        }
      end

      File.write(outputfile, {"variables" => variables}.to_yaml)
      nil
    end

    register_op :restore_ts do |_context, tensor, inputs|
      inputs = inputs.dup
      filename = inputs.shift
      tensor_names = inputs

      input_dump = YAML.safe_load(File.read(filename), [Symbol])
      vars = tensor.graph.get_collection(GraphKeys::GLOBAL_VARIABLES)
      vars.select! { |v| input_dump["variables"].key?(v.name) && tensor_names.include?(v.name) }
      vars.each do |variable|
        data = TensorStream::Packer.unpack(Zlib::Inflate.inflate(Base64.decode64(input_dump["variables"][variable.name]["data"])), variable.data_type)
        shape = input_dump["variables"][variable.name]["shape"]
        variable.buffer = nil
        var_assign_value(variable, TensorShape.reshape(data, shape))
      end

      nil
    end
  end
end