spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From hol...@apache.org
Subject spark git commit: [SPARK-14712][ML] LogisticRegressionModel.toString should summarize model
Date Thu, 28 Jun 2018 19:41:05 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5b0596648 -> 524827f06


[SPARK-14712][ML] LogisticRegressionModel.toString should summarize model

## What changes were proposed in this pull request?

[SPARK-14712](https://issues.apache.org/jira/browse/SPARK-14712)
spark.mllib LogisticRegressionModel overrides toString to print a little model info. We should
do the same in spark.ml and override repr in pyspark.

## How was this patch tested?

LogisticRegressionSuite.scala
Python doctest in pyspark.ml.classification.py

Author: bravo-zhang <mzhang1230@gmail.com>

Closes #18826 from bravo-zhang/spark-14712.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/524827f0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/524827f0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/524827f0

Branch: refs/heads/master
Commit: 524827f0626281847582ec3056982db7eb83f8b1
Parents: 5b05966
Author: bravo-zhang <mzhang1230@gmail.com>
Authored: Thu Jun 28 12:40:39 2018 -0700
Committer: Holden Karau <holden@pigscanfly.ca>
Committed: Thu Jun 28 12:40:39 2018 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/classification/LogisticRegression.scala    | 5 +++++
 .../spark/ml/classification/LogisticRegressionSuite.scala      | 6 ++++++
 python/pyspark/ml/classification.py                            | 5 +++++
 python/pyspark/mllib/classification.py                         | 3 +++
 4 files changed, 19 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 06ca37b..92e342e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1202,6 +1202,11 @@ class LogisticRegressionModel private[spark] (
    */
   @Since("1.6.0")
   override def write: MLWriter = new LogisticRegressionModel.LogisticRegressionModelWriter(this)
+
+  override def toString: String = {
+    s"LogisticRegressionModel: " +
+    s"uid = ${super.toString}, numClasses = $numClasses, numFeatures = $numFeatures"
+  }
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 36b7e51..75c2aeb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -2751,6 +2751,12 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest
{
       assert(model.getFamily === family)
     }
   }
+
+  test("toString") {
+    val model = new LogisticRegressionModel("logReg", Vectors.dense(0.1, 0.2, 0.3), 0.0)
+    val expected = "LogisticRegressionModel: uid = logReg, numClasses = 2, numFeatures =
3"
+    assert(model.toString === expected)
+  }
 }
 
 object LogisticRegressionSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 1754c48..d5963f4 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -239,6 +239,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
     True
     >>> blorModel.intercept == model2.intercept
     True
+    >>> model2
+    LogisticRegressionModel: uid = ..., numClasses = 2, numFeatures = 2
 
     .. versionadded:: 1.3.0
     """
@@ -562,6 +564,9 @@ class LogisticRegressionModel(JavaModel, JavaClassificationModel, JavaMLWritable
         java_blr_summary = self._call_java("evaluate", dataset)
         return BinaryLogisticRegressionSummary(java_blr_summary)
 
+    def __repr__(self):
+        return self._call_java("toString")
+
 
 class LogisticRegressionSummary(JavaWrapper):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/524827f0/python/pyspark/mllib/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index bb28198..e00ed95 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -258,6 +258,9 @@ class LogisticRegressionModel(LinearClassificationModel):
         model.setThreshold(threshold)
         return model
 
+    def __repr__(self):
+        return self._call_java("toString")
+
 
 class LogisticRegressionWithSGD(object):
     """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message