spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yli...@apache.org
Subject spark git commit: [SPARK-20631][PYTHON][ML] LogisticRegression._checkThresholdConsistency should use values not Params
Date Wed, 10 May 2017 08:58:16 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.2 ef50a9548 -> 3ed2f4d51


[SPARK-20631][PYTHON][ML] LogisticRegression._checkThresholdConsistency should use values
not Params

## What changes were proposed in this pull request?

- Replace `getParam` calls with `getOrDefault` calls.
- Fix exception message to avoid unintended `TypeError`.
- Add unit tests

## How was this patch tested?

New unit tests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17891 from zero323/SPARK-20631.

(cherry picked from commit 804949c6bf00b8e26c39d48bbcc4d0470ee84e47)
Signed-off-by: Yanbo Liang <ybliang8@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3ed2f4d5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3ed2f4d5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3ed2f4d5

Branch: refs/heads/branch-2.2
Commit: 3ed2f4d516ce02dfef929195778f8214703913d8
Parents: ef50a95
Author: zero323 <zero323@users.noreply.github.com>
Authored: Wed May 10 16:57:52 2017 +0800
Committer: Yanbo Liang <ybliang8@gmail.com>
Committed: Wed May 10 16:58:08 2017 +0800

----------------------------------------------------------------------
 python/pyspark/ml/classification.py |  6 +++---
 python/pyspark/ml/tests.py          | 12 ++++++++++++
 2 files changed, 15 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3ed2f4d5/python/pyspark/ml/classification.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index a9756ea..dcc12d9 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -349,13 +349,13 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol,
HasPredicti
 
     def _checkThresholdConsistency(self):
         if self.isSet(self.threshold) and self.isSet(self.thresholds):
-            ts = self.getParam(self.thresholds)
+            ts = self.getOrDefault(self.thresholds)
             if len(ts) != 2:
                 raise ValueError("Logistic Regression getThreshold only applies to" +
                                  " binary classification, but thresholds has length != 2."
+
-                                 " thresholds: " + ",".join(ts))
+                                 " thresholds: {0}".format(str(ts)))
             t = 1.0/(1.0 + ts[0]/ts[1])
-            t2 = self.getParam(self.threshold)
+            t2 = self.getOrDefault(self.threshold)
             if abs(t2 - t) >= 1E-5:
                 raise ValueError("Logistic Regression getThreshold found inconsistent values
for" +
                                  " threshold (%g) and thresholds (equivalent to %g)" % (t2,
t))

http://git-wip-us.apache.org/repos/asf/spark/blob/3ed2f4d5/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 571ac4b..51a3e8e 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -807,6 +807,18 @@ class PersistenceTest(SparkSessionTestCase):
         except OSError:
             pass
 
+    def logistic_regression_check_thresholds(self):
+        self.assertIsInstance(
+            LogisticRegression(threshold=0.5, thresholds=[0.5, 0.5]),
+            LogisticRegressionModel
+        )
+
+        self.assertRaisesRegexp(
+            ValueError,
+            "Logistic Regression getThreshold found inconsistent.*$",
+            LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]
+        )
+
     def _compare_params(self, m1, m2, param):
         """
         Compare 2 ML Params instances for the given param, and assert both have the same
param value


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message