class Spark::Mllib::KMeansModel

KMeansModel

A clustering model derived from the k-means method.

Examples:

Spark::Mllib.import

# Dense vectors
data = [
  DenseVector.new([0.0,0.0]),
  DenseVector.new([1.0,1.0]),
  DenseVector.new([9.0,8.0]),
  DenseVector.new([8.0,9.0])
]

model = KMeans.train($sc.parallelize(data), 2, max_iterations: 10,
                     runs: 30, initialization_mode: "random")

model.predict([0.0, 0.0]) == model.predict([1.0, 1.0])
# => true
model.predict([8.0, 9.0]) == model.predict([9.0, 8.0])
# => true

# Sparse vectors
data = [
    SparseVector.new(3, {1 => 1.0}),
    SparseVector.new(3, {1 => 1.1}),
    SparseVector.new(3, {2 => 1.0}),
    SparseVector.new(3, {2 => 1.1})
]
model = KMeans.train($sc.parallelize(data), 2, initialization_mode: "k-means||")

model.predict([0.0, 1.0, 0.0]) == model.predict([0, 1.1, 0.0])
# => true
model.predict([0.0, 0.0, 1.0]) == model.predict([0, 0, 1.1])
# => true
model.predict(data[0]) == model.predict(data[1])
# => true
model.predict(data[2]) == model.predict(data[3])
# => true

Attributes

centers[R]

Public Class Methods

from_java(object) click to toggle source
# File lib/spark/mllib/clustering/kmeans.rb, line 72
def self.from_java(object)
  centers = object.clusterCenters
  centers.map! do |center|
    Spark.jb.java_to_ruby(center)
  end

  KMeansModel.new(centers)
end
new(centers) click to toggle source
# File lib/spark/mllib/clustering/kmeans.rb, line 51
def initialize(centers)
  @centers = centers
end

Public Instance Methods

predict(vector) click to toggle source

Find the cluster to which x belongs in this model.

# File lib/spark/mllib/clustering/kmeans.rb, line 56
def predict(vector)
  vector = Spark::Mllib::Vectors.to_vector(vector)
  best = 0
  best_distance = Float::INFINITY

  @centers.each_with_index do |center, index|
    distance = vector.squared_distance(center)
    if distance < best_distance
      best = index
      best_distance = distance
    end
  end

  best
end