spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error
Date Fri, 12 Feb 2016 00:42:56 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.6 91a5ca5e8 -> 9d45ec466


[SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error

Pyspark Params class has a method `hasParam(paramName)` which returns `True` if the class
has a parameter by that name, but throws an `AttributeError` otherwise. There is not currently
a way of getting a Boolean to indicate if a class has a parameter. With Spark 2.0 we could
modify the existing behavior of `hasParam` or add an additional method with this functionality.

In Python:
```python
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()
print nb.hasParam("smoothing")
print nb.hasParam("notAParam")
```
produces:
> True
> AttributeError: 'NaiveBayes' object has no attribute 'notAParam'

However, in Scala:
```scala
import org.apache.spark.ml.classification.NaiveBayes
val nb  = new NaiveBayes()
nb.hasParam("smoothing")
nb.hasParam("notAParam")
```
produces:
> true
> false

cc holdenk

Author: sethah <seth.hendrickson16@gmail.com>

Closes #10962 from sethah/SPARK-13047.

(cherry picked from commit b35467388612167f0bc3d17142c21a406f6c620d)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d45ec46
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d45ec46
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d45ec46

Branch: refs/heads/branch-1.6
Commit: 9d45ec466a4067bb2d0b59ff1174bec630daa7b1
Parents: 91a5ca5
Author: sethah <seth.hendrickson16@gmail.com>
Authored: Thu Feb 11 16:42:44 2016 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Thu Feb 11 16:42:52 2016 -0800

----------------------------------------------------------------------
 python/pyspark/ml/param/__init__.py | 7 +++++--
 python/pyspark/ml/tests.py          | 9 +++++++--
 2 files changed, 12 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9d45ec46/python/pyspark/ml/param/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 35c9b77..666689c 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -156,8 +156,11 @@ class Params(Identifiable):
         Tests whether this instance contains a param with a given
         (string) name.
         """
-        param = self._resolveParam(paramName)
-        return param in self.params
+        if isinstance(paramName, str):
+            p = getattr(self, paramName, None)
+            return isinstance(p, Param)
+        else:
+            raise TypeError("hasParam(): paramName must be a string")
 
     @since("1.4.0")
     def getOrDefault(self, param):

http://git-wip-us.apache.org/repos/asf/spark/blob/9d45ec46/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 7a16cf5..674dbe9 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -170,6 +170,11 @@ class ParamTests(PySparkTestCase):
         self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
         self.assertTrue(maxIter.parent == testParams.uid)
 
+    def test_hasparam(self):
+        testParams = TestParams()
+        self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params]))
+        self.assertFalse(testParams.hasParam("notAParameter"))
+
     def test_params(self):
         testParams = TestParams()
         maxIter = testParams.maxIter
@@ -179,7 +184,7 @@ class ParamTests(PySparkTestCase):
         params = testParams.params
         self.assertEqual(params, [inputCol, maxIter, seed])
 
-        self.assertTrue(testParams.hasParam(maxIter))
+        self.assertTrue(testParams.hasParam(maxIter.name))
         self.assertTrue(testParams.hasDefault(maxIter))
         self.assertFalse(testParams.isSet(maxIter))
         self.assertTrue(testParams.isDefined(maxIter))
@@ -188,7 +193,7 @@ class ParamTests(PySparkTestCase):
         self.assertTrue(testParams.isSet(maxIter))
         self.assertEqual(testParams.getMaxIter(), 100)
 
-        self.assertTrue(testParams.hasParam(inputCol))
+        self.assertTrue(testParams.hasParam(inputCol.name))
         self.assertFalse(testParams.hasDefault(inputCol))
         self.assertFalse(testParams.isSet(inputCol))
         self.assertFalse(testParams.isDefined(inputCol))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message