spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-14565][ML] RandomForest should use parseInt and parseDouble for feature subset size instead of regexes
Date Fri, 15 Apr 2016 00:23:19 GMT
Repository: spark
Updated Branches:
  refs/heads/master d7e124edf -> 01dd1f5c0


[SPARK-14565][ML] RandomForest should use parseInt and parseDouble for feature subset size
instead of regexes

## What changes were proposed in this pull request?

This fix tries to change RandomForest's supported strategies from using regexes to using parseInt
and
parseDouble, for the purpose of robustness and maintainability.

## How was this patch tested?

Existing tests passed.

Author: Yong Tang <yong.tang.github@outlook.com>

Closes #12360 from yongtang/SPARK-14565.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/01dd1f5c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/01dd1f5c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/01dd1f5c

Branch: refs/heads/master
Commit: 01dd1f5c07f5c9ba91389c1556f911b028475cd3
Parents: d7e124e
Author: Yong Tang <yong.tang.github@outlook.com>
Authored: Thu Apr 14 17:23:16 2016 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Thu Apr 14 17:23:16 2016 -0700

----------------------------------------------------------------------
 .../spark/ml/tree/impl/DecisionTreeMetadata.scala  | 17 +++++++++++++----
 .../org/apache/spark/ml/tree/treeParams.scala      | 11 ++++++-----
 .../org/apache/spark/mllib/tree/RandomForest.scala |  6 ++++--
 .../spark/ml/tree/impl/RandomForestSuite.scala     |  4 ++--
 4 files changed, 25 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/01dd1f5c/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index c7cde15..5f7c40f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -18,8 +18,10 @@
 package org.apache.spark.ml.tree.impl
 
 import scala.collection.mutable
+import scala.util.Try
 
 import org.apache.spark.internal.Logging
+import org.apache.spark.ml.tree.RandomForestParams
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
@@ -184,15 +186,22 @@ private[spark] object DecisionTreeMetadata extends Logging {
       case _ => featureSubsetStrategy
     }
 
-    val isIntRegex = "^([1-9]\\d*)$".r
-    val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r
     val numFeaturesPerNode: Int = _featureSubsetStrategy match {
       case "all" => numFeatures
       case "sqrt" => math.sqrt(numFeatures).ceil.toInt
       case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
       case "onethird" => (numFeatures / 3.0).ceil.toInt
-      case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else
number.toInt
-      case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt
+      case _ =>
+        Try(_featureSubsetStrategy.toInt).filter(_ > 0).toOption match {
+          case Some(value) => math.min(value, numFeatures)
+          case None =>
+            Try(_featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).toOption
match {
+              case Some(value) => math.ceil(value * numFeatures).toInt
+              case _ => throw new IllegalArgumentException(s"Supported values:" +
+                s" ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")},"
+
+                s" (0.0-1.0], [1-n].")
+            }
+        }
     }
 
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,

http://git-wip-us.apache.org/repos/asf/spark/blob/01dd1f5c/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index b678391..d7559f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.tree
 
+import scala.util.Try
+
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
@@ -346,10 +348,12 @@ private[ml] trait HasFeatureSubsetStrategy extends Params {
    */
   final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
     "The number of features to consider for splits at each tree node." +
-      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(",
")}",
+      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(",
")}" +
+      s", (0.0-1.0], [1-n].",
     (value: String) =>
       RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase)
-      || value.matches(RandomForestParams.supportedFeatureSubsetStrategiesRegex))
+      || Try(value.toInt).filter(_ > 0).isSuccess
+      || Try(value.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess)
 
   setDefault(featureSubsetStrategy -> "auto")
 
@@ -396,9 +400,6 @@ private[spark] object RandomForestParams {
   // These options should be lowercase.
   final val supportedFeatureSubsetStrategies: Array[String] =
     Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
-
-  // The regex to capture "(0.0-1.0]", and "n" for integer 0 < n <= (number of features)
-  final val supportedFeatureSubsetStrategiesRegex = "^(?:[1-9]\\d*|0?\\.\\d*[1-9]\\d*|1\\.0+)$"
 }
 
 private[ml] trait RandomForestClassifierParams

http://git-wip-us.apache.org/repos/asf/spark/blob/01dd1f5c/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 2675584..ca7fb7f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.mllib.tree
 
 import scala.collection.JavaConverters._
+import scala.util.Try
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.JavaRDD
@@ -76,9 +77,10 @@ private class RandomForest (
   strategy.assertValid()
   require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees
= $numTrees.")
   require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy)
-    || featureSubsetStrategy.matches(NewRFParams.supportedFeatureSubsetStrategiesRegex),
+    || Try(featureSubsetStrategy.toInt).filter(_ > 0).isSuccess
+    || Try(featureSubsetStrategy.toDouble).filter(_ > 0).filter(_ <= 1.0).isSuccess,
     s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." +
-    s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")},"
+
+    s" Supported values: ${NewRFParams.supportedFeatureSubsetStrategies.mkString(", ")},"
+
     s" (0.0-1.0], [1-n].")
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/01dd1f5c/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 6db9ce1..1719f9f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -440,7 +440,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext
{
 
     val invalidStrategies = Array("-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0")
     for (invalidStrategy <- invalidStrategies) {
-      intercept[MatchError]{
+      intercept[IllegalArgumentException]{
         val metadata =
           DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 1, invalidStrategy)
       }
@@ -463,7 +463,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext
{
       checkFeatureSubsetStrategy(numTrees = 2, strategy, expected)
     }
     for (invalidStrategy <- invalidStrategies) {
-      intercept[MatchError]{
+      intercept[IllegalArgumentException]{
         val metadata =
           DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees = 2, invalidStrategy)
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message