Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 1C8B610A86 for ; Sun, 3 May 2015 18:42:10 +0000 (UTC) Received: (qmail 96511 invoked by uid 500); 3 May 2015 18:42:10 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 96486 invoked by uid 500); 3 May 2015 18:42:09 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 96477 invoked by uid 99); 3 May 2015 18:42:09 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 03 May 2015 18:42:09 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id C91D8DFF07; Sun, 3 May 2015 18:42:09 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: meng@apache.org To: commits@spark.apache.org Message-Id: <555bbb0a42694351aad985a29a738023@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-7022] [PYSPARK] [ML] Add ML.Tuning.ParamGridBuilder to PySpark Date: Sun, 3 May 2015 18:42:09 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master 49549d5a1 -> f4af92550 [SPARK-7022] [PYSPARK] [ML] Add ML.Tuning.ParamGridBuilder to PySpark Author: Omede Firouz Author: Omede Closes #5601 from oefirouz/paramgrid and squashes the following commits: c9e2481 [Omede Firouz] Make test a doctest 9a8ce22 [Omede] Fix linter issues 8b8a6d2 [Omede Firouz] [SPARK-7022][PySpark][ML] Add ML.Tuning.ParamGridBuilder to PySpark Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f4af9255 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f4af9255 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f4af9255 Branch: refs/heads/master Commit: f4af92550cb90e47a12d4625fa615dd2b1587d42 Parents: 49549d5 Author: Omede Firouz Authored: Sun May 3 11:42:02 2015 -0700 Committer: Xiangrui Meng Committed: Sun May 3 11:42:02 2015 -0700 ---------------------------------------------------------------------- python/pyspark/ml/tuning.py | 94 ++++++++++++++++++++++++++++++++++++++++ python/run-tests | 1 + 2 files changed, 95 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f4af9255/python/pyspark/ml/tuning.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py new file mode 100644 index 0000000..a383bd0 --- /dev/null +++ b/python/pyspark/ml/tuning.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ['ParamGridBuilder'] + + +class ParamGridBuilder(object): + """ + Builder for a param grid used in grid search-based model selection. + + >>> from classification import LogisticRegression + >>> lr = LogisticRegression() + >>> output = ParamGridBuilder().baseOn({lr.labelCol: 'l'}) \ + .baseOn([lr.predictionCol, 'p']) \ + .addGrid(lr.regParam, [1.0, 2.0, 3.0]) \ + .addGrid(lr.maxIter, [1, 5]) \ + .addGrid(lr.featuresCol, ['f']) \ + .build() + >>> expected = [ \ +{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 1, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 1.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 2.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}, \ +{lr.regParam: 3.0, lr.featuresCol: 'f', lr.maxIter: 5, lr.labelCol: 'l', lr.predictionCol: 'p'}] + >>> fail_count = 0 + >>> for e in expected: + ... if e not in output: + ... fail_count += 1 + >>> if len(expected) != len(output): + ... fail_count += 1 + >>> fail_count + 0 + """ + + def __init__(self): + self._param_grid = {} + + def addGrid(self, param, values): + """ + Sets the given parameters in this grid to fixed values. + """ + self._param_grid[param] = values + + return self + + def baseOn(self, *args): + """ + Sets the given parameters in this grid to fixed values. + Accepts either a parameter dictionary or a list of (parameter, value) pairs. + """ + if isinstance(args[0], dict): + self.baseOn(*args[0].items()) + else: + for (param, value) in args: + self.addGrid(param, [value]) + + return self + + def build(self): + """ + Builds and returns all combinations of parameters specified + by the param grid. + """ + param_maps = [{}] + for (param, values) in self._param_grid.items(): + new_param_maps = [] + for value in values: + for old_map in param_maps: + copied_map = old_map.copy() + copied_map[param] = value + new_param_maps.append(copied_map) + param_maps = new_param_maps + + return param_maps + + +if __name__ == "__main__": + import doctest + doctest.testmod() http://git-wip-us.apache.org/repos/asf/spark/blob/f4af9255/python/run-tests ---------------------------------------------------------------------- diff --git a/python/run-tests b/python/run-tests index 88b63b8..0e0eee3 100755 --- a/python/run-tests +++ b/python/run-tests @@ -98,6 +98,7 @@ function run_ml_tests() { echo "Run ml tests ..." run_test "pyspark/ml/feature.py" run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tuning.py" run_test "pyspark/ml/tests.py" } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org