spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-4306] [MLlib] Python API for LogisticRegressionWithLBFGS
Date Tue, 18 Nov 2014 23:58:13 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.2 a93d64c8c -> 4ae78abe6


[SPARK-4306] [MLlib] Python API for LogisticRegressionWithLBFGS

```
class LogisticRegressionWithLBFGS
 |  train(cls, data, iterations=100, initialWeights=None, corrections=10, tolerance=0.0001,
regParam=0.01, intercept=False)
 |      Train a logistic regression model on the given data.
 |
 |      :param data:           The training data, an RDD of LabeledPoint.
 |      :param iterations:     The number of iterations (default: 100).
 |      :param initialWeights: The initial weights (default: None).
 |      :param regParam:       The regularizer parameter (default: 0.01).
 |      :param regType:        The type of regularizer used for training
 |                             our model.
 |                             :Allowed values:
 |                               - "l1" for using L1 regularization
 |                               - "l2" for using L2 regularization
 |                               - None for no regularization
 |                               (default: "l2")
 |      :param intercept:      Boolean parameter which indicates the use
 |                             or not of the augmented representation for
 |                             training data (i.e. whether bias features
 |                             are activated or not).
 |      :param corrections:    The number of corrections used in the LBFGS update (default:
10).
 |      :param tolerance:      The convergence tolerance of iterations for L-BFGS (default:
1e-4).
 |
 |      >>> data = [
 |      ...     LabeledPoint(0.0, [0.0, 1.0]),
 |      ...     LabeledPoint(1.0, [1.0, 0.0]),
 |      ... ]
 |      >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
 |      >>> lrm.predict([1.0, 0.0])
 |      1
 |      >>> lrm.predict([0.0, 1.0])
 |      0
 |      >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
 |      [1, 0]
```

Author: Davies Liu <davies@databricks.com>

Closes #3307 from davies/lbfgs and squashes the following commits:

34bd986 [Davies Liu] Merge branch 'master' of http://git-wip-us.apache.org/repos/asf/spark
into lbfgs
5a945a6 [Davies Liu] address comments
941061b [Davies Liu] Merge branch 'master' of github.com:apache/spark into lbfgs
03e5543 [Davies Liu] add it to docs
ed2f9a8 [Davies Liu] add regType
76cd1b6 [Davies Liu] reorder arguments
4429a74 [Davies Liu] Update classification.py
9252783 [Davies Liu] python api for LogisticRegressionWithLBFGS

(cherry picked from commit d2e29516f2064f93f3a9070c91fc7460706e0b0a)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4ae78abe
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4ae78abe
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4ae78abe

Branch: refs/heads/branch-1.2
Commit: 4ae78abe66e593ac8bf9de37eca80413730c431b
Parents: a93d64c
Author: Davies Liu <davies@databricks.com>
Authored: Tue Nov 18 15:57:33 2014 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Nov 18 15:57:48 2014 -0800

----------------------------------------------------------------------
 .../spark/mllib/api/python/PythonMLLibAPI.scala | 35 ++++++++++++
 python/pyspark/mllib/classification.py          | 57 ++++++++++++++++++--
 2 files changed, 88 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4ae78abe/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index c8476a5..6f94b7f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -230,6 +230,41 @@ class PythonMLLibAPI extends Serializable {
   }
 
   /**
+   * Java stub for Python mllib LogisticRegressionWithLBFGS.train()
+   */
+  def trainLogisticRegressionModelWithLBFGS(
+      data: JavaRDD[LabeledPoint],
+      numIterations: Int,
+      initialWeights: Vector,
+      regParam: Double,
+      regType: String,
+      intercept: Boolean,
+      corrections: Int,
+      tolerance: Double): JList[Object] = {
+    val LogRegAlg = new LogisticRegressionWithLBFGS()
+    LogRegAlg.setIntercept(intercept)
+    LogRegAlg.optimizer
+      .setNumIterations(numIterations)
+      .setRegParam(regParam)
+      .setNumCorrections(corrections)
+      .setConvergenceTol(tolerance)
+    if (regType == "l2") {
+      LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
+    } else if (regType == "l1") {
+      LogRegAlg.optimizer.setUpdater(new L1Updater)
+    } else if (regType == null) {
+      LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
+    } else {
+      throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+        + " Can only be initialized using the following string values: ['l1', 'l2', None].")
+    }
+    trainRegressionModel(
+      LogRegAlg,
+      data,
+      initialWeights)
+  }
+
+  /**
    * Java stub for NaiveBayes.train()
    */
   def trainNaiveBayes(

http://git-wip-us.apache.org/repos/asf/spark/blob/4ae78abe/python/pyspark/mllib/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index ee0729b..f14d0ed 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -26,8 +26,8 @@ from pyspark.mllib.linalg import SparseVector, _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
 
 
-__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
-           'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
+__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
+           'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
 
 
 class LinearBinaryClassificationModel(LinearModel):
@@ -151,7 +151,7 @@ class LogisticRegressionWithSGD(object):
 
                                      (default: "l2")
 
-        @param intercept:         Boolean parameter which indicates the use
+        :param intercept:         Boolean parameter which indicates the use
                                   or not of the augmented representation for
                                   training data (i.e. whether bias features
                                   are activated or not).
@@ -164,6 +164,55 @@ class LogisticRegressionWithSGD(object):
         return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
 
 
+class LogisticRegressionWithLBFGS(object):
+
+    @classmethod
+    def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
+              intercept=False, corrections=10, tolerance=1e-4):
+        """
+        Train a logistic regression model on the given data.
+
+        :param data:           The training data, an RDD of LabeledPoint.
+        :param iterations:     The number of iterations (default: 100).
+        :param initialWeights: The initial weights (default: None).
+        :param regParam:       The regularizer parameter (default: 0.01).
+        :param regType:        The type of regularizer used for training
+                               our model.
+
+                               :Allowed values:
+                                 - "l1" for using L1 regularization
+                                 - "l2" for using L2 regularization
+                                 - None for no regularization
+
+                                 (default: "l2")
+
+        :param intercept:      Boolean parameter which indicates the use
+                               or not of the augmented representation for
+                               training data (i.e. whether bias features
+                               are activated or not).
+        :param corrections:    The number of corrections used in the LBFGS
+                               update (default: 10).
+        :param tolerance:      The convergence tolerance of iterations for
+                               L-BFGS (default: 1e-4).
+
+        >>> data = [
+        ...     LabeledPoint(0.0, [0.0, 1.0]),
+        ...     LabeledPoint(1.0, [1.0, 0.0]),
+        ... ]
+        >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
+        >>> lrm.predict([1.0, 0.0])
+        1
+        >>> lrm.predict([0.0, 1.0])
+        0
+        """
+        def train(rdd, i):
+            return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations),
i,
+                                 float(regParam), str(regType), bool(intercept), int(corrections),
+                                 float(tolerance))
+
+        return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
+
+
 class SVMModel(LinearBinaryClassificationModel):
 
     """A support vector machine.
@@ -241,7 +290,7 @@ class SVMWithSGD(object):
 
                                      (default: "l2")
 
-        @param intercept:         Boolean parameter which indicates the use
+        :param intercept:         Boolean parameter which indicates the use
                                   or not of the augmented representation for
                                   training data (i.e. whether bias features
                                   are activated or not).


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message