class TensorStream::Train::Saver

High level class used for loading and saving variables

Public Class Methods

new(var_list = nil) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 10
def initialize(var_list = nil)
  graph = TensorStream::Graph.get_default_graph
  vars = var_list || graph.get_collection(GraphKeys::GLOBAL_VARIABLES)

  @filename = graph["ts_filename"] || TensorStream.placeholder(:string, name: "ts_filename", shape: [])

  @save_op = _op(:save_ts, @filename, *vars)
  @restore_op = _op(:restore_ts, @filename, *vars.map(&:name))
end

Public Instance Methods

restore(session, modelpath) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 46
def restore(session, modelpath)
  meta_file = File.join(modelpath, "model.meta")
  return unless File.exist?(meta_file)

  meta_data = JSON.parse(File.read(meta_file))
  gs = meta_data["gs"]
  filename = File.join(modelpath, ["model", gs, ".ckpt"].compact.join("-"))
  session.run(@restore_op, feed_dict: {@filename => filename})
end
save(session, outputdir, global_step: nil, latest_filename: nil, meta_graph_suffix: "meta", write_meta_graph: true, write_state: true, strip_default_attrs: false) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 20
def save(session, outputdir, global_step: nil,
  latest_filename: nil,
  meta_graph_suffix: "meta",
  write_meta_graph: true,
  write_state: true,
  strip_default_attrs: false)
  graph = TensorStream::Graph.get_default_graph
  vars = graph.get_collection(GraphKeys::GLOBAL_VARIABLES)

  variables = {}

  gs = eval_global_step(session, global_step)

  FileUtils.mkdir_p(outputdir)
  basename = "model"
  File.write(File.join(outputdir, "#{basename}.meta"), {"gs" => gs}.to_json)
  new_filename = File.join(outputdir, [basename, gs, ".ckpt"].compact.join("-"))
  session.run(@save_op, feed_dict: {@filename => new_filename})

  if write_meta_graph
    graph_filename = "#{basename}.yaml"
    TensorStream.train.write_graph(graph, outputdir, graph_filename, serializer: :yaml)
  end
  outputdir
end

Private Instance Methods

_add_saveable(saveables, seen_ops, saveable) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 80
def _add_saveable(saveables, seen_ops, saveable)
  raise TensorStream::ValueError, "The same saveable will be restored with two names: #{saveable.name}" if seen_ops.include?(saveable.op)

  saveables << saveable
  seen_ops << saveable.op
end
_validate_and_slice_inputs(names_to_saveables) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 68
def _validate_and_slice_inputs(names_to_saveables)
  saveables = []
  seen_ops = []

  names_to_saveables.values.sort_by { |item| item[0] }.each do |name, op|
    _saveable_objects_for_op(op, name).each do |converted_saveable_object|
      _add_saveable(saveables, seen_ops, converted_saveable_object)
    end
  end
  saveables
end
build_internal(names_to_saveables, reshape: false, sharded: false, max_to_keep: 5, keep_checkpoint_every_n_hours: 10000.0, name: nil, restore_sequentially: false, filename: "model", build_save: true, build_restore: true) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 58
def build_internal(names_to_saveables, reshape: false, sharded: false, max_to_keep: 5,
  keep_checkpoint_every_n_hours: 10000.0,
  name: nil,
  restore_sequentially: false,
  filename: "model",
  build_save: true,
  build_restore: true)
  saveables = _validate_and_slice_inputs(names_to_saveables)
end
eval_global_step(session, global_step) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 102
def eval_global_step(session, global_step)
  return nil if global_step.nil?

  if global_step.is_a?(Tensor)
    session.last_session_context(global_step.name)
  elsif global_step.is_a?(String) || global_step.is_a?(Symbol)
    session.last_session_context(global_step)
  else
    global_step.to_i
  end
end
save_op(filename_tensor, saveables) click to toggle source
# File lib/tensor_stream/train/saver.rb, line 87
def save_op(filename_tensor, saveables)
  tensor_names = []
  tensors = []
  tensor_slices = []

  saveables.each do |saveable|
    saveable.specs.each do |spec|
      tensor_names << spec.name
      tensors << spec.tensor
      tensor_slices << spec.slice_spec
    end
  end
  i_op(:save_ts, filename_tensor, *tensors)
end