class MachineLearningWorkbench::Optimizer::NaturalEvolutionStrategies::RNES

Radial Natural Evolution Strategies

Attributes

variance[R]

Public Instance Methods

convergence() click to toggle source

Estimate algorithm convergence based on variance

# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/rnes.rb, line 40
def convergence
  variance
end
initialize_distribution(mu_init: 0, sigma_init: 1) click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/rnes.rb, line 8
def initialize_distribution mu_init: 0, sigma_init: 1
  @mu = case mu_init
    when Array
      raise ArgumentError unless mu_init.size == ndims
      NArray[mu_init]
    when Numeric
      NArray.new([1,ndims]).fill mu_init
    else
      raise ArgumentError, "Something is wrong with mu_init: #{mu_init}"
  end
  @variance = sigma_init
  @sigma = case sigma_init
  when Array
    raise ArgumentError "RNES uses single global variance"
  when Numeric
    NArray.new([ndims]).fill(variance).diag
  else
    raise ArgumentError, "Something is wrong with sigma_init: #{sigma_init}"
  end
end
load(data) click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/rnes.rb, line 48
def load data
  raise ArgumentError unless data.size == 2
  mu_ary, @variance = data
  @mu = mu_ary.to_na
  @sigma = eye * variance
end
save() click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/rnes.rb, line 44
def save
  [mu.to_a, variance]
end
train(picks: sorted_inds) click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/rnes.rb, line 29
def train picks: sorted_inds
  g_mu = utils.dot(picks)
  # g_sigma = utils.dot(picks.row_norms**2 - ndims).first # back to scalar
  row_norms = NLinalg.norm picks, 2, axis:1
  g_sigma = utils.dot(row_norms**2 - ndims)[0] # back to scalar
  @mu += sigma.dot(g_mu.transpose).transpose * lrate
  @variance *= Math.exp(g_sigma * lrate / 2)
  @sigma = NArray.new([ndims]).fill(variance).diag
end