class Tensorflow::Train::Optimizer
Attributes
name[R]
Public Class Methods
new(name: nil, use_locking: false)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 7 def initialize(name: nil, use_locking: false) @name = name @use_locking = use_locking raise(Error::InvalidArgumentError, "Must specify the optimizer name") unless name @slots = {} @non_slots = {} end
Public Instance Methods
apply_gradients(grads_and_vars, global_step: nil, name: nil)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 28 def apply_gradients(grads_and_vars, global_step: nil, name: nil) varlist = grads_and_vars.map { |_grad, var| var } #create_slots(varlist) #TensorStream.name_scope(name, default: @name) do prepare apply_ops = grads_and_vars.map do |grad, var| #TensorStream.name_scope("update_" + var.op.name) do apply_dense(grad, var) #end end if global_step.nil? finish(apply_ops, name) else global_step.handle.graph.control_dependencies([finish(apply_ops, "update")]) do global_step.assign_add(Tensorflow.constant(1, dtype:global_step.dtype)) end end #end end
compute_gradients(loss, var_list: nil, grad_loss: nil)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 49 def compute_gradients(loss, var_list: nil, grad_loss: nil) trainable_vars = var_list || self.graph.get_collection_ref(Tensorflow::Graph::GraphKeys::TRAINABLE_VARIABLES) if trainable_vars.nil? || trainable_vars.empty? raise(Error::InvalidArgumentError, 'There are no variables to train for the loss function') end gradients = Graph::Gradients.new(graph) grads = gradients.gradients(loss, trainable_vars, grad_ys: grad_loss) grads.zip(trainable_vars) end
get_slot(var, name)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 61 def get_slot(var, name) named_slots = @slots.fetch(name, nil) return nil if named_slots.nil? named_slots.fetch(var_key(var), nil) end
get_slot_names()
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 68 def get_slot_names @slots.keys.sort end
graph()
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 16 def graph ExecutionContext.current end
minimize(loss, var_list: nil, grad_loss: nil, global_step: nil, name: nil)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 20 def minimize(loss, var_list: nil, grad_loss: nil, global_step: nil, name: nil) grads_and_vars = compute_gradients(loss, var_list: var_list, grad_loss: grad_loss) if grads_and_vars.empty? raise(Error::InvalidArgumentError, "No gradients provided for any variable, check your graph for ops that do not support gradients") end apply_gradients(grads_and_vars, global_step: global_step, name: name) end
Protected Instance Methods
apply_dense(_grad, _var)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 86 def apply_dense(_grad, _var) raise(Error::UnimplementedError, "Not implemented") end
call_if_callable(param)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 130 def call_if_callable(param) param.is_a?(Proc) ? param.call : param end
create_non_slot_variable(initial_value, name, colocate_with)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 134 def create_non_slot_variable(initial_value, name, colocate_with) graph = colocate_with.graph key = [name, graph] v = @non_slots.fetch(key, nil) if v.nil? v = TensorStream.variable(initial_value, name: name, trainable: false) @non_slots[key] = v end v end
create_slots(var_list)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 78 def create_slots(var_list) # no implementation end
finish(update_ops, name_scope)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 74 def finish(update_ops, name_scope) Control.group(update_ops, name: name_scope) end
get_non_slot_variable(name, graph: nil)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 125 def get_non_slot_variable(name, graph: nil) non_slot = @non_slots.fetch([name, graph], nil) non_slot end
get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name)
click to toggle source
Find or create a slot for a variable, using an Initializer.
# File lib/tensorflow/train/optimizer.rb, line 148 def get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name) named_slots = slot_dict(slot_name) unless named_slots.key?(var_key(var)) new_slot_variable = create_slot_with_initializer(var, initializer, shape, dtype, op_name) named_slots[var_key(var)] = new_slot_variable end named_slots[var_key(var)] end
prepare()
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 82 def prepare # no implementation end
slot_dict(slot_name)
click to toggle source
Returns a dict for caching slots created under the given name.
Args: slot_name string Name for the slot
Returns: A dict that maps primary 'Variable' objects to the slot created
# File lib/tensorflow/train/optimizer.rb, line 112 def slot_dict(slot_name) named_slots = @slots.fetch(slot_name, nil) if named_slots.nil? named_slots = {} @slots[slot_name] = named_slots end named_slots end
var_key(var)
click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 121 def var_key(var) [var.op.graph, var.op.name] end
zeros_slot(var, slot_name, op_name)
click to toggle source
Find or create a slot initialized with 0.0.
Args:
var: Variable - A Variable object slot_name: string - Name for the slot op_name: string - Name to use when scoping the Variable that needs to be created
# File lib/tensorflow/train/optimizer.rb, line 97 def zeros_slot(var, slot_name, op_name) named_slots = slot_dict(slot_name) unless named_slots.key?(var_key(var)) named_slots[var_key(var)] = create_zeros_slot(var, op_name) end named_slots[var_key(var)] end