spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-7022] [PYSPARK] [ML] Add ML.Tuning.ParamGridBuilder to PySpark
Date Sun, 03 May 2015 18:42:09 GMT
Repository: spark
Updated Branches:
  refs/heads/master 49549d5a1 -> f4af92550


[SPARK-7022] [PYSPARK] [ML] Add ML.Tuning.ParamGridBuilder to PySpark

Author: Omede Firouz <ofirouz@palantir.com>
Author: Omede <omedefirouz@gmail.com>

Closes #5601 from oefirouz/paramgrid and squashes the following commits:

c9e2481 [Omede Firouz] Make test a doctest
9a8ce22 [Omede] Fix linter issues
8b8a6d2 [Omede Firouz] [SPARK-7022][PySpark][ML] Add ML.Tuning.ParamGridBuilder to PySpark


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f4af9255
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f4af9255
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f4af9255

Branch: refs/heads/master
Commit: f4af92550cb90e47a12d4625fa615dd2b1587d42
Parents: 49549d5
Author: Omede Firouz <ofirouz@palantir.com>
Authored: Sun May 3 11:42:02 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Sun May 3 11:42:02 2015 -0700

----------------------------------------------------------------------
 python/pyspark/ml/tuning.py | 94 ++++++++++++++++++++++++++++++++++++++++
 python/run-tests            |  1 +
 2 files changed, 95 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f4af9255/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
new file mode 100644
index 0000000..a383bd0
--- /dev/null
+++ b/python/pyspark/ml/tuning.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+__all__ = ['ParamGridBuilder']
+
+
+class ParamGridBuilder(object):
+    """
+    Builder for a param grid used in grid search-based model selection.
+
+    >>> from classification import LogisticRegression
+    >>> lr = LogisticRegression()
+    >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \
+            .baseOn([lr.predictionCol, 'p']) \
+            .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \
+            .addGrid(lr.maxIter, [1, 5]) \
+            .addGrid(lr.featuresCol, ['f']) \
+            .build()
+    >>> expected = [ \
+{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol:
'p'}, \
+{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol:
'p'}, \
+{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol:
'p'}, \
+{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol:
'p'}, \
+{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol:
'p'}, \
+{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol:
'p'}]
+    >>> fail_count = 0
+    >>> for e in expected:
+    ...     if e not in output:
+    ...         fail_count += 1
+    >>> if len(expected) != len(output):
+    ...     fail_count += 1
+    >>> fail_count
+    0
+    """
+
+    def __init__(self):
+        self._param_grid = {}
+
+    def addGrid(self, param, values):
+        """
+        Sets the given parameters in this grid to fixed values.
+        """
+        self._param_grid[param] = values
+
+        return self
+
+    def baseOn(self, *args):
+        """
+        Sets the given parameters in this grid to fixed values.
+        Accepts either a parameter dictionary or a list of (parameter, value) pairs.
+        """
+        if isinstance(args[0], dict):
+            self.baseOn(*args[0].items())
+        else:
+            for (param, value) in args:
+                self.addGrid(param, [value])
+
+        return self
+
+    def build(self):
+        """
+        Builds and returns all combinations of parameters specified
+        by the param grid.
+        """
+        param_maps = [{}]
+        for (param, values) in self._param_grid.items():
+            new_param_maps = []
+            for value in values:
+                for old_map in param_maps:
+                    copied_map = old_map.copy()
+                    copied_map[param] = value
+                    new_param_maps.append(copied_map)
+            param_maps = new_param_maps
+
+        return param_maps
+
+
+if __name__ == "__main__":
+    import doctest
+    doctest.testmod()

http://git-wip-us.apache.org/repos/asf/spark/blob/f4af9255/python/run-tests
----------------------------------------------------------------------
diff --git a/python/run-tests b/python/run-tests
index 88b63b8..0e0eee3 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -98,6 +98,7 @@ function run_ml_tests() {
     echo "Run ml tests ..."
     run_test "pyspark/ml/feature.py"
     run_test "pyspark/ml/classification.py"
+    run_test "pyspark/ml/tuning.py"
     run_test "pyspark/ml/tests.py"
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message