spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5858][MLLIB] Remove unnecessary first() call in GLM
Date Tue, 17 Feb 2015 18:17:47 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3ce46e94f -> c76da36c2


[SPARK-5858][MLLIB] Remove unnecessary first() call in GLM

`numFeatures` is only used by multinomial logistic regression. Calling `.first()` for every
GLM causes performance regression, especially in Python.

Author: Xiangrui Meng <meng@databricks.com>

Closes #4647 from mengxr/SPARK-5858 and squashes the following commits:

036dc7f [Xiangrui Meng] remove unnecessary first() call
12c5548 [Xiangrui Meng] check numFeatures only once


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c76da36c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c76da36c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c76da36c

Branch: refs/heads/master
Commit: c76da36c2163276b5c34e59fbb139eeb34ed0faa
Parents: 3ce46e9
Author: Xiangrui Meng <meng@databricks.com>
Authored: Tue Feb 17 10:17:45 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Feb 17 10:17:45 2015 -0800

----------------------------------------------------------------------
 .../spark/mllib/classification/LogisticRegression.scala       | 6 +++++-
 .../spark/mllib/regression/GeneralizedLinearAlgorithm.scala   | 7 ++++---
 2 files changed, 9 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c76da36c/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 420d6e2..b787667 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -355,6 +355,10 @@ class LogisticRegressionWithLBFGS
   }
 
   override protected def createModel(weights: Vector, intercept: Double) = {
-    new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
+    if (numOfLinearPredictor == 1) {
+      new LogisticRegressionModel(weights, intercept)
+    } else {
+      new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor +
1)
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c76da36c/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 2b71453..7c66e8c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -126,7 +126,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
   /**
    * The dimension of training features.
    */
-  protected var numFeatures: Int = 0
+  protected var numFeatures: Int = -1
 
   /**
    * Set if the algorithm should use feature scaling to improve the convergence during optimization.
@@ -163,7 +163,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
    * RDD of LabeledPoint entries.
    */
   def run(input: RDD[LabeledPoint]): M = {
-    numFeatures = input.first().features.size
+    if (numFeatures < 0) {
+      numFeatures = input.map(_.features.size).first()
+    }
 
     /**
      * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights,
@@ -193,7 +195,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
    * of LabeledPoint entries starting from the initial weights provided.
    */
   def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
-    numFeatures = input.first().features.size
 
     if (input.getStorageLevel == StorageLevel.NONE) {
       logWarning("The input data is not directly cached, which may hurt performance if its"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message