class Spark::Mllib::LogisticRegressionModel
A linear binary classification model derived from logistic regression.
Examples:¶ ↑
Spark::Mllib.import # Dense vectors data = [ LabeledPoint.new(0.0, [0.0, 1.0]), LabeledPoint.new(1.0, [1.0, 0.0]), ] lrm = LogisticRegressionWithSGD.train($sc.parallelize(data)) lrm.predict([1.0, 0.0]) # => 1 lrm.predict([0.0, 1.0]) # => 0 lrm.clear_threshold lrm.predict([0.0, 1.0]) # => 0.123... # Sparse vectors data = [ LabeledPoint.new(0.0, SparseVector.new(2, {0 => 0.0})), LabeledPoint.new(1.0, SparseVector.new(2, {1 => 1.0})), LabeledPoint.new(0.0, SparseVector.new(2, {0 => 1.0})), LabeledPoint.new(1.0, SparseVector.new(2, {1 => 2.0})) ] lrm = LogisticRegressionWithSGD.train($sc.parallelize(data)) lrm.predict([0.0, 1.0]) # => 1 lrm.predict([1.0, 0.0]) # => 0 lrm.predict(SparseVector.new(2, {1 => 1.0})) # => 1 lrm.predict(SparseVector.new(2, {0 => 1.0})) # => 0 # LogisticRegressionWithLBFGS data = [ LabeledPoint.new(0.0, [0.0, 1.0]), LabeledPoint.new(1.0, [1.0, 0.0]), ] lrm = LogisticRegressionWithLBFGS.train($sc.parallelize(data)) lrm.predict([1.0, 0.0]) # => 1 lrm.predict([0.0, 1.0]) # => 0
Public Class Methods
new(*args)
click to toggle source
Calls superclass method
Spark::Mllib::ClassificationModel::new
# File lib/spark/mllib/classification/logistic_regression.rb, line 62 def initialize(*args) super @threshold = 0.5 end
Public Instance Methods
predict(vector)
click to toggle source
Predict values for a single data point or an RDD
of points using the model trained.
# File lib/spark/mllib/classification/logistic_regression.rb, line 69 def predict(vector) vector = Spark::Mllib::Vectors.to_vector(vector) margin = weights.dot(vector) + intercept score = 1.0 / (1.0 + Math.exp(-margin)) if threshold.nil? return score end if score > threshold 1 else 0 end end