class Spark::Mllib::SVMModel

SVMModel

A support vector machine.

Examples:

Spark::Mllib.import

# Dense vectors
data = [
    LabeledPoint.new(0.0, [0.0]),
    LabeledPoint.new(1.0, [1.0]),
    LabeledPoint.new(1.0, [2.0]),
    LabeledPoint.new(1.0, [3.0])
]
svm = SVMWithSGD.train($sc.parallelize(data))

svm.predict([1.0])
# => 1
svm.clear_threshold
svm.predict([1.0])
# => 1.25...

# Sparse vectors
data = [
    LabeledPoint.new(0.0, SparseVector.new(2, {0 => -1.0})),
    LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})),
    LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})),
    LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0}))
]
svm = SVMWithSGD.train($sc.parallelize(data))

svm.predict(SparseVector.new(2, {1 => 1.0}))
# => 1
svm.predict(SparseVector.new(2, {0 => -1.0}))
# => 0

Public Class Methods

new(*args) click to toggle source
Calls superclass method Spark::Mllib::ClassificationModel::new
# File lib/spark/mllib/classification/svm.rb, line 44
def initialize(*args)
  super
  @threshold = 0.0
end

Public Instance Methods

predict(vector) click to toggle source

Predict values for a single data point or an RDD of points using the model trained.

# File lib/spark/mllib/classification/svm.rb, line 51
def predict(vector)
  vector = Spark::Mllib::Vectors.to_vector(vector)
  margin = weights.dot(vector) + intercept

  if threshold.nil?
    return margin
  end

  if margin > threshold
    1
  else
    0
  end
end