class TensorStream::Train::Saver
High level class used for loading and saving variables
Public Class Methods
new(var_list = nil)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 10 def initialize(var_list = nil) graph = TensorStream::Graph.get_default_graph vars = var_list || graph.get_collection(GraphKeys::GLOBAL_VARIABLES) @filename = graph["ts_filename"] || TensorStream.placeholder(:string, name: "ts_filename", shape: []) @save_op = _op(:save_ts, @filename, *vars) @restore_op = _op(:restore_ts, @filename, *vars.map(&:name)) end
Public Instance Methods
restore(session, modelpath)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 46 def restore(session, modelpath) meta_file = File.join(modelpath, "model.meta") return unless File.exist?(meta_file) meta_data = JSON.parse(File.read(meta_file)) gs = meta_data["gs"] filename = File.join(modelpath, ["model", gs, ".ckpt"].compact.join("-")) session.run(@restore_op, feed_dict: {@filename => filename}) end
save(session, outputdir, global_step: nil, latest_filename: nil, meta_graph_suffix: "meta", write_meta_graph: true, write_state: true, strip_default_attrs: false)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 20 def save(session, outputdir, global_step: nil, latest_filename: nil, meta_graph_suffix: "meta", write_meta_graph: true, write_state: true, strip_default_attrs: false) graph = TensorStream::Graph.get_default_graph vars = graph.get_collection(GraphKeys::GLOBAL_VARIABLES) variables = {} gs = eval_global_step(session, global_step) FileUtils.mkdir_p(outputdir) basename = "model" File.write(File.join(outputdir, "#{basename}.meta"), {"gs" => gs}.to_json) new_filename = File.join(outputdir, [basename, gs, ".ckpt"].compact.join("-")) session.run(@save_op, feed_dict: {@filename => new_filename}) if write_meta_graph graph_filename = "#{basename}.yaml" TensorStream.train.write_graph(graph, outputdir, graph_filename, serializer: :yaml) end outputdir end
Private Instance Methods
_add_saveable(saveables, seen_ops, saveable)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 80 def _add_saveable(saveables, seen_ops, saveable) raise TensorStream::ValueError, "The same saveable will be restored with two names: #{saveable.name}" if seen_ops.include?(saveable.op) saveables << saveable seen_ops << saveable.op end
_validate_and_slice_inputs(names_to_saveables)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 68 def _validate_and_slice_inputs(names_to_saveables) saveables = [] seen_ops = [] names_to_saveables.values.sort_by { |item| item[0] }.each do |name, op| _saveable_objects_for_op(op, name).each do |converted_saveable_object| _add_saveable(saveables, seen_ops, converted_saveable_object) end end saveables end
build_internal(names_to_saveables, reshape: false, sharded: false, max_to_keep: 5, keep_checkpoint_every_n_hours: 10000.0, name: nil, restore_sequentially: false, filename: "model", build_save: true, build_restore: true)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 58 def build_internal(names_to_saveables, reshape: false, sharded: false, max_to_keep: 5, keep_checkpoint_every_n_hours: 10000.0, name: nil, restore_sequentially: false, filename: "model", build_save: true, build_restore: true) saveables = _validate_and_slice_inputs(names_to_saveables) end
eval_global_step(session, global_step)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 102 def eval_global_step(session, global_step) return nil if global_step.nil? if global_step.is_a?(Tensor) session.last_session_context(global_step.name) elsif global_step.is_a?(String) || global_step.is_a?(Symbol) session.last_session_context(global_step) else global_step.to_i end end
save_op(filename_tensor, saveables)
click to toggle source
# File lib/tensor_stream/train/saver.rb, line 87 def save_op(filename_tensor, saveables) tensor_names = [] tensors = [] tensor_slices = [] saveables.each do |saveable| saveable.specs.each do |spec| tensor_names << spec.name tensors << spec.tensor tensor_slices << spec.slice_spec end end i_op(:save_ts, filename_tensor, *tensors) end