class TensorStream::Train::MomentumOptimizer

Optimizer that implements the Momentum algorithm. loosely based on the tensorflow implementation.

Public Class Methods

new(learning_rate, momentum, name: "momentum", use_nesterov: false, use_locking: false) click to toggle source

Construct a new Momentum optimizer.

Args:

learning_rate: A Tensor or a floating point value that indicates the learning rate
momentum: A Tensor or a floating point value for the momentum
name: Optional name prefix
use_nesterov: boolean - Flag that indicates if nesterov momentum is to be used. http://jmlr.org/proceedings/papers/v28/sutskever13.pdf
use_locking: boolean - filler argument for compatibility, not used at the moment
Calls superclass method
# File lib/tensor_stream/train/momentum_optimizer.rb, line 16
def initialize(learning_rate, momentum, name: "momentum", use_nesterov: false, use_locking: false)
  @learning_rate = learning_rate
  @momentum = momentum
  @use_nesterov = use_nesterov
  super(name: name, use_locking: use_locking)
end

Protected Instance Methods

apply_dense(grad, var) click to toggle source
# File lib/tensor_stream/train/momentum_optimizer.rb, line 36
def apply_dense(grad, var)
  mom = get_slot(var, "momentum")

  _op(:apply_momentum, var, mom,
    TensorStream.cast(@learning_rate_tensor, var.data_type),
    grad,
    TensorStream.cast(@momentum_tensor, var.data_type),
    use_locking: @use_locking,
    use_nesterov: @use_nesterov)
end
create_slots(var_list) click to toggle source
# File lib/tensor_stream/train/momentum_optimizer.rb, line 30
def create_slots(var_list)
  var_list.each do |v|
    zeros_slot(v, "momentum", @name)
  end
end
prepare() click to toggle source
# File lib/tensor_stream/train/momentum_optimizer.rb, line 25
def prepare
  @learning_rate_tensor = TensorStream.convert_to_tensor(@learning_rate, name: "learning_rate")
  @momentum_tensor = TensorStream.convert_to_tensor(@momentum, name: "momentum")
end