class Spark::Mllib::SVMModel
A support vector machine.
Examples:¶ ↑
Spark::Mllib.import # Dense vectors data = [ LabeledPoint.new(0.0, [0.0]), LabeledPoint.new(1.0, [1.0]), LabeledPoint.new(1.0, [2.0]), LabeledPoint.new(1.0, [3.0]) ] svm = SVMWithSGD.train($sc.parallelize(data)) svm.predict([1.0]) # => 1 svm.clear_threshold svm.predict([1.0]) # => 1.25... # Sparse vectors data = [ LabeledPoint.new(0.0, SparseVector.new(2, {0 => -1.0})), LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})), LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})), LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0})) ] svm = SVMWithSGD.train($sc.parallelize(data)) svm.predict(SparseVector.new(2, {1 => 1.0})) # => 1 svm.predict(SparseVector.new(2, {0 => -1.0})) # => 0
Public Class Methods
new(*args)
click to toggle source
Calls superclass method
Spark::Mllib::ClassificationModel::new
# File lib/spark/mllib/classification/svm.rb, line 44 def initialize(*args) super @threshold = 0.0 end
Public Instance Methods
predict(vector)
click to toggle source
Predict values for a single data point or an RDD
of points using the model trained.
# File lib/spark/mllib/classification/svm.rb, line 51 def predict(vector) vector = Spark::Mllib::Vectors.to_vector(vector) margin = weights.dot(vector) + intercept if threshold.nil? return margin end if margin > threshold 1 else 0 end end