class Spark::Mllib::GaussianMixture

Public Class Methods

train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil) click to toggle source
# File lib/spark/mllib/clustering/gaussian_mixture.rb, line 66
def self.train(rdd, k, convergence_tol: 0.001, max_iterations: 100, seed: nil)
  weights, means, sigmas = Spark.jb.call(RubyMLLibAPI.new, 'trainGaussianMixtureModel', rdd,
                                         k, convergence_tol, max_iterations, Spark.jb.to_long(seed))

  means.map! {|mu|    Spark.jb.java_to_ruby(mu)}
  sigmas.map!{|sigma| Spark.jb.java_to_ruby(sigma)}

  mvgs = Array.new(k) do |i|
    MultivariateGaussian.new(means[i], sigmas[i])
  end

  GaussianMixtureModel.new(weights, mvgs)
end