class Tensorflow::Train::Optimizer

Attributes

name[R]

Public Class Methods

new(name: nil, use_locking: false) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 7
def initialize(name: nil, use_locking: false)
  @name = name
  @use_locking = use_locking
  raise(Error::InvalidArgumentError, "Must specify the optimizer name") unless name

  @slots = {}
  @non_slots = {}
end

Public Instance Methods

apply_gradients(grads_and_vars, global_step: nil, name: nil) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 28
def apply_gradients(grads_and_vars, global_step: nil, name: nil)
  varlist = grads_and_vars.map { |_grad, var| var }
  #create_slots(varlist)
  #TensorStream.name_scope(name, default: @name) do
    prepare
    apply_ops = grads_and_vars.map do |grad, var|
      #TensorStream.name_scope("update_" + var.op.name) do
        apply_dense(grad, var)
      #end
    end

    if global_step.nil?
      finish(apply_ops, name)
    else
      global_step.handle.graph.control_dependencies([finish(apply_ops, "update")]) do
        global_step.assign_add(Tensorflow.constant(1, dtype:global_step.dtype))
      end
    end
  #end
end
compute_gradients(loss, var_list: nil, grad_loss: nil) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 49
def compute_gradients(loss, var_list: nil, grad_loss: nil)
  trainable_vars = var_list || self.graph.get_collection_ref(Tensorflow::Graph::GraphKeys::TRAINABLE_VARIABLES)

  if trainable_vars.nil? || trainable_vars.empty?
    raise(Error::InvalidArgumentError, 'There are no variables to train for the loss function')
  end
  gradients = Graph::Gradients.new(graph)
  grads = gradients.gradients(loss, trainable_vars, grad_ys: grad_loss)

  grads.zip(trainable_vars)
end
get_slot(var, name) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 61
def get_slot(var, name)
  named_slots = @slots.fetch(name, nil)
  return nil if named_slots.nil?

  named_slots.fetch(var_key(var), nil)
end
get_slot_names() click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 68
def get_slot_names
  @slots.keys.sort
end
graph() click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 16
def graph
  ExecutionContext.current
end
minimize(loss, var_list: nil, grad_loss: nil, global_step: nil, name: nil) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 20
def minimize(loss, var_list: nil, grad_loss: nil, global_step: nil, name: nil)
  grads_and_vars = compute_gradients(loss, var_list: var_list, grad_loss: grad_loss)
  if grads_and_vars.empty?
    raise(Error::InvalidArgumentError, "No gradients provided for any variable, check your graph for ops that do not support gradients")
  end
  apply_gradients(grads_and_vars, global_step: global_step, name: name)
end

Protected Instance Methods

apply_dense(_grad, _var) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 86
def apply_dense(_grad, _var)
  raise(Error::UnimplementedError, "Not implemented")
end
call_if_callable(param) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 130
def call_if_callable(param)
  param.is_a?(Proc) ? param.call : param
end
create_non_slot_variable(initial_value, name, colocate_with) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 134
def create_non_slot_variable(initial_value, name, colocate_with)
  graph = colocate_with.graph

  key = [name, graph]
  v = @non_slots.fetch(key, nil)
  if v.nil?
    v = TensorStream.variable(initial_value, name: name, trainable: false)
    @non_slots[key] = v
  end
  v
end
create_slots(var_list) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 78
def create_slots(var_list)
  # no implementation
end
finish(update_ops, name_scope) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 74
def finish(update_ops, name_scope)
  Control.group(update_ops, name: name_scope)
end
get_non_slot_variable(name, graph: nil) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 125
def get_non_slot_variable(name, graph: nil)
  non_slot = @non_slots.fetch([name, graph], nil)
  non_slot
end
get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name) click to toggle source

Find or create a slot for a variable, using an Initializer.

# File lib/tensorflow/train/optimizer.rb, line 148
def get_or_make_slot_with_initializer(var, initializer, shape, dtype, slot_name, op_name)
  named_slots = slot_dict(slot_name)
  unless named_slots.key?(var_key(var))
    new_slot_variable = create_slot_with_initializer(var, initializer, shape, dtype, op_name)
    named_slots[var_key(var)] = new_slot_variable
  end
  named_slots[var_key(var)]
end
prepare() click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 82
def prepare
  # no implementation
end
slot_dict(slot_name) click to toggle source

Returns a dict for caching slots created under the given name.

Args: slot_name string Name for the slot

Returns: A dict that maps primary 'Variable' objects to the slot created

# File lib/tensorflow/train/optimizer.rb, line 112
def slot_dict(slot_name)
  named_slots = @slots.fetch(slot_name, nil)
  if named_slots.nil?
    named_slots = {}
    @slots[slot_name] = named_slots
  end
  named_slots
end
var_key(var) click to toggle source
# File lib/tensorflow/train/optimizer.rb, line 121
def var_key(var)
  [var.op.graph, var.op.name]
end
zeros_slot(var, slot_name, op_name) click to toggle source

Find or create a slot initialized with 0.0.

Args:

var: Variable - A Variable object
slot_name: string - Name for the slot
op_name: string - Name to use when scoping the Variable that needs to be created
# File lib/tensorflow/train/optimizer.rb, line 97
def zeros_slot(var, slot_name, op_name)
  named_slots = slot_dict(slot_name)
  unless named_slots.key?(var_key(var))
    named_slots[var_key(var)] = create_zeros_slot(var, op_name)
  end
  named_slots[var_key(var)]
end