spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [SPARK-2934][MLlib] Adding LogisticRegressionWithLBFGS Interface
Date Tue, 12 Aug 2014 02:49:47 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.1 7e31f7c27 -> 8f6e2e9df


[SPARK-2934][MLlib] Adding LogisticRegressionWithLBFGS Interface

for training with LBFGS Optimizer which will converge faster than SGD.

Author: DB Tsai <dbtsai@alpinenow.com>

Closes #1862 from dbtsai/dbtsai-lbfgs-lor and squashes the following commits:

aa84b81 [DB Tsai] small change
f852bcd [DB Tsai] Remove duplicate method
f119fdc [DB Tsai] Formatting
97776aa [DB Tsai] address more feedback
85b4a91 [DB Tsai] address feedback
3cf50c2 [DB Tsai] LogisticRegressionWithLBFGS interface

(cherry picked from commit 6fab941b65f0cb6c9b32e0f8290d76889cda6a87)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f6e2e9d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f6e2e9d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f6e2e9d

Branch: refs/heads/branch-1.1
Commit: 8f6e2e9df41e7de22b1d1cbd524e20881f861dd0
Parents: 7e31f7c
Author: DB Tsai <dbtsai@alpinenow.com>
Authored: Mon Aug 11 19:49:29 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Aug 11 19:49:43 2014 -0700

----------------------------------------------------------------------
 .../classification/LogisticRegression.scala     | 51 ++++++++++-
 .../LogisticRegressionSuite.scala               | 89 +++++++++++++++++++-
 2 files changed, 136 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8f6e2e9d/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 2242329..31d474a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -101,7 +101,7 @@ class LogisticRegressionWithSGD private (
 }
 
 /**
- * Top-level methods for calling Logistic Regression.
+ * Top-level methods for calling Logistic Regression using Stochastic Gradient Descent.
  * NOTE: Labels used in Logistic Regression should be {0, 1}
  */
 object LogisticRegressionWithSGD {
@@ -188,3 +188,52 @@ object LogisticRegressionWithSGD {
     train(input, numIterations, 1.0, 1.0)
   }
 }
+
+/**
+ * Train a classification model for Logistic Regression using Limited-memory BFGS.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
+ */
+class LogisticRegressionWithLBFGS private (
+    private var convergenceTol: Double,
+    private var maxNumIterations: Int,
+    private var regParam: Double)
+  extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
+
+  /**
+   * Construct a LogisticRegression object with default parameters
+   */
+  def this() = this(1E-4, 100, 0.0)
+
+  private val gradient = new LogisticGradient()
+  private val updater = new SimpleUpdater()
+  // Have to return new LBFGS object every time since users can reset the parameters anytime.
+  override def optimizer = new LBFGS(gradient, updater)
+    .setNumCorrections(10)
+    .setConvergenceTol(convergenceTol)
+    .setMaxNumIterations(maxNumIterations)
+    .setRegParam(regParam)
+
+  override protected val validators = List(DataValidators.binaryLabelValidator)
+
+  /**
+   * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4.
+   * Smaller value will lead to higher accuracy with the cost of more iterations.
+   */
+  def setConvergenceTol(convergenceTol: Double): this.type = {
+    this.convergenceTol = convergenceTol
+    this
+  }
+
+  /**
+   * Set the maximal number of iterations for L-BFGS. Default 100.
+   */
+  def setNumIterations(numIterations: Int): this.type = {
+    this.maxNumIterations = numIterations
+    this
+  }
+
+  override protected def createModel(weights: Vector, intercept: Double) = {
+    new LogisticRegressionModel(weights, intercept)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/8f6e2e9d/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index da7c633..2289c6c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -67,7 +67,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with
Match
   }
 
   // Test if we can correctly learn A, B where Y = logistic(A + B*X)
-  test("logistic regression") {
+  test("logistic regression with SGD") {
     val nPoints = 10000
     val A = 2.0
     val B = -1.5
@@ -94,7 +94,36 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with
Match
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
 
-  test("logistic regression with initial weights") {
+  // Test if we can correctly learn A, B where Y = logistic(A + B*X)
+  test("logistic regression with LBFGS") {
+    val nPoints = 10000
+    val A = 2.0
+    val B = -1.5
+
+    val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
+
+    val testRDD = sc.parallelize(testData, 2)
+    testRDD.cache()
+    val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
+
+    val model = lr.run(testRDD)
+
+    // Test the weights
+    assert(model.weights(0) ~== -1.52 relTol 0.01)
+    assert(model.intercept ~== 2.00 relTol 0.01)
+    assert(model.weights(0) ~== model.weights(0) relTol 0.01)
+    assert(model.intercept ~== model.intercept relTol 0.01)
+
+    val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
+    val validationRDD = sc.parallelize(validationData, 2)
+    // Test prediction on RDD.
+    validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+    // Test prediction on Array.
+    validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+  }
+
+  test("logistic regression with initial weights with SGD") {
     val nPoints = 10000
     val A = 2.0
     val B = -1.5
@@ -125,11 +154,42 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext
with Match
     // Test prediction on Array.
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
+
+  test("logistic regression with initial weights with LBFGS") {
+    val nPoints = 10000
+    val A = 2.0
+    val B = -1.5
+
+    val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
+
+    val initialB = -1.0
+    val initialWeights = Vectors.dense(initialB)
+
+    val testRDD = sc.parallelize(testData, 2)
+    testRDD.cache()
+
+    // Use half as many iterations as the previous test.
+    val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
+
+    val model = lr.run(testRDD, initialWeights)
+
+    // Test the weights
+    assert(model.weights(0) ~== -1.50 relTol 0.02)
+    assert(model.intercept ~== 1.97 relTol 0.02)
+
+    val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
+    val validationRDD = sc.parallelize(validationData, 2)
+    // Test prediction on RDD.
+    validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+    // Test prediction on Array.
+    validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+  }
 }
 
 class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {
 
-  test("task size should be small in both training and prediction") {
+  test("task size should be small in both training and prediction using SGD optimizer") {
     val m = 4
     val n = 200000
     val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
@@ -139,6 +199,29 @@ class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkCont
     // If we serialize data directly in the task closure, the size of the serialized task
would be
     // greater than 1MB and hence Spark would throw an error.
     val model = LogisticRegressionWithSGD.train(points, 2)
+
     val predictions = model.predict(points.map(_.features))
+
+    // Materialize the RDDs
+    predictions.count()
   }
+
+  test("task size should be small in both training and prediction using LBFGS optimizer")
{
+    val m = 4
+    val n = 200000
+    val points = sc.parallelize(0 until m, 2).mapPartitionsWithIndex { (idx, iter) =>
+      val random = new Random(idx)
+      iter.map(i => LabeledPoint(1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
+    }.cache()
+    // If we serialize data directly in the task closure, the size of the serialized task
would be
+    // greater than 1MB and hence Spark would throw an error.
+    val model =
+      (new LogisticRegressionWithLBFGS().setIntercept(true).setNumIterations(2)).run(points)
+
+    val predictions = model.predict(points.map(_.features))
+
+    // Materialize the RDDs
+    predictions.count()
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message