class MachineLearningWorkbench::Optimizer::NaturalEvolutionStrategies::SNES
Separable Natural Evolution Strategies
Attributes
variances[R]
Public Instance Methods
convergence()
click to toggle source
Estimate algorithm convergence as total variance
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/snes.rb, line 41 def convergence variances.sum end
initialize_distribution(mu_init: 0, sigma_init: 1)
click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/snes.rb, line 9 def initialize_distribution mu_init: 0, sigma_init: 1 @mu = case mu_init when Array raise ArgumentError unless mu_init.size == ndims NArray[mu_init] when Numeric NArray.new([1,ndims]).fill mu_init else raise ArgumentError, "Something is wrong with mu_init: #{mu_init}" end @variances = case sigma_init when Array raise ArgumentError unless sigma_init.size == ndims NArray[*sigma_init] when Numeric NArray.new([ndims]).fill(sigma_init) else raise ArgumentError, "Something is wrong with sigma_init: #{sigma_init}" \ "(did you remember to copy the other cases from XNES?)" end @sigma = @variances.diag end
load(data)
click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/snes.rb, line 49 def load data raise ArgumentError unless data.size == 2 @mu, @variances = data.map &:to_na @sigma = variances.diag end
save()
click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/snes.rb, line 45 def save [mu.to_a, variances.to_a] end
train(picks: sorted_inds)
click to toggle source
# File lib/machine_learning_workbench/optimizer/natural_evolution_strategies/snes.rb, line 32 def train picks: sorted_inds g_mu = utils.dot(picks) g_sigma = utils.dot(picks**2 - 1) @mu += sigma.dot(g_mu.transpose).transpose * lrate @variances *= (g_sigma * lrate / 2).exponential.flatten @sigma = @variances.diag end