spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-10097] Adds `shouldMaximize` flag to `ml.evaluation.Evaluator`
Date Wed, 19 Aug 2015 18:35:22 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 a8e880818 -> f25c32475


[SPARK-10097] Adds `shouldMaximize` flag to `ml.evaluation.Evaluator`

Previously, users of evaluator (`CrossValidator` and `TrainValidationSplit`) would only maximize
the metric in evaluator, leading to a hacky solution which negated metrics to be minimized
and caused erroneous negative values to be reported to the user.

This PR adds a `isLargerBetter` attribute to the `Evaluator` base class, instructing users
of `Evaluator` on whether the chosen metric should be maximized or minimized.

CC jkbradley

Author: Feynman Liang <fliang@databricks.com>
Author: Joseph K. Bradley <joseph@databricks.com>

Closes #8290 from feynmanliang/SPARK-10097.

(cherry picked from commit 28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f25c3247
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f25c3247
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f25c3247

Branch: refs/heads/branch-1.5
Commit: f25c324758095ddf572157e64c8f1a93843f79c7
Parents: a8e8808
Author: Feynman Liang <fliang@databricks.com>
Authored: Wed Aug 19 11:35:05 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Wed Aug 19 11:35:17 2015 -0700

----------------------------------------------------------------------
 .../BinaryClassificationEvaluator.scala         | 20 ++++++++++++--------
 .../apache/spark/ml/evaluation/Evaluator.scala  |  7 +++++++
 .../MulticlassClassificationEvaluator.scala     |  8 ++++++++
 .../ml/evaluation/RegressionEvaluator.scala     | 19 +++++++++++--------
 .../apache/spark/ml/tuning/CrossValidator.scala |  4 +++-
 .../spark/ml/tuning/TrainValidationSplit.scala  |  4 +++-
 .../evaluation/RegressionEvaluatorSuite.scala   |  4 ++--
 .../spark/ml/tuning/CrossValidatorSuite.scala   |  2 ++
 .../ml/tuning/TrainValidationSplitSuite.scala   |  2 ++
 python/pyspark/ml/evaluation.py                 |  4 ++--
 10 files changed, 52 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 5d5cb7e..56419a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -40,8 +40,11 @@ class BinaryClassificationEvaluator(override val uid: String)
    * param for metric name in evaluation
    * @group param
    */
-  val metricName: Param[String] = new Param(this, "metricName",
-    "metric name in evaluation (areaUnderROC|areaUnderPR)")
+  val metricName: Param[String] = {
+    val allowedParams = ParamValidators.inArray(Array("areaUnderROC", "areaUnderPR"))
+    new Param(
+      this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", allowedParams)
+  }
 
   /** @group getParam */
   def getMetricName: String = $(metricName)
@@ -76,16 +79,17 @@ class BinaryClassificationEvaluator(override val uid: String)
       }
     val metrics = new BinaryClassificationMetrics(scoreAndLabels)
     val metric = $(metricName) match {
-      case "areaUnderROC" =>
-        metrics.areaUnderROC()
-      case "areaUnderPR" =>
-        metrics.areaUnderPR()
-      case other =>
-        throw new IllegalArgumentException(s"Does not support metric $other.")
+      case "areaUnderROC" => metrics.areaUnderROC()
+      case "areaUnderPR" => metrics.areaUnderPR()
     }
     metrics.unpersist()
     metric
   }
 
+  override def isLargerBetter: Boolean = $(metricName) match {
+    case "areaUnderROC" => true
+    case "areaUnderPR" => true
+  }
+
   override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index e56c946..13bd330 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -46,5 +46,12 @@ abstract class Evaluator extends Params {
    */
   def evaluate(dataset: DataFrame): Double
 
+  /**
+   * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)
+   * or minimized (false).
+   * A given evaluator may support multiple metrics which may be maximized or minimized.
+   */
+  def isLargerBetter: Boolean = true
+
   override def copy(extra: ParamMap): Evaluator
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 44f779c..f73d234 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -81,5 +81,13 @@ class MulticlassClassificationEvaluator (override val uid: String)
     metric
   }
 
+  override def isLargerBetter: Boolean = $(metricName) match {
+    case "f1" => true
+    case "precision" => true
+    case "recall" => true
+    case "weightedPrecision" => true
+    case "weightedRecall" => true
+  }
+
   override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 01c000b..d21c88a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -73,17 +73,20 @@ final class RegressionEvaluator(override val uid: String)
       }
     val metrics = new RegressionMetrics(predictionAndLabels)
     val metric = $(metricName) match {
-      case "rmse" =>
-        -metrics.rootMeanSquaredError
-      case "mse" =>
-        -metrics.meanSquaredError
-      case "r2" =>
-        metrics.r2
-      case "mae" =>
-        -metrics.meanAbsoluteError
+      case "rmse" => metrics.rootMeanSquaredError
+      case "mse" => metrics.meanSquaredError
+      case "r2" => metrics.r2
+      case "mae" => metrics.meanAbsoluteError
     }
     metric
   }
 
+  override def isLargerBetter: Boolean = $(metricName) match {
+    case "rmse" => false
+    case "mse" => false
+    case "r2" => true
+    case "mae" => false
+  }
+
   override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 4792eb0..0679bfd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -100,7 +100,9 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
     }
     f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
     logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
-    val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+    val (bestMetric, bestIndex) =
+      if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
+      else metrics.zipWithIndex.minBy(_._1)
     logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
     logInfo(s"Best cross-validation metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index c0edc73..73a14b8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -99,7 +99,9 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
     validationDataset.unpersist()
 
     logInfo(s"Train validation split metrics: ${metrics.toSeq}")
-    val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+    val (bestMetric, bestIndex) =
+      if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
+      else metrics.zipWithIndex.minBy(_._1)
     logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
     logInfo(s"Best train validation split metric: $bestMetric.")
     val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 5b20378..aa722da 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -63,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
 
     // default = rmse
     val evaluator = new RegressionEvaluator()
-    assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001)
+    assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
 
     // r2 score
     evaluator.setMetricName("r2")
@@ -71,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
 
     // mae
     evaluator.setMetricName("mae")
-    assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001)
+    assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index aaca08b..fde02e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -143,6 +143,8 @@ object CrossValidatorSuite {
       throw new UnsupportedOperationException
     }
 
+    override def isLargerBetter: Boolean = true
+
     override val uid: String = "eval"
 
     override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index c8e58f2..ef24e6f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -132,6 +132,8 @@ object TrainValidationSplitSuite {
       throw new UnsupportedOperationException
     }
 
+    override def isLargerBetter: Boolean = true
+
     override val uid: String = "eval"
 
     override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)

http://git-wip-us.apache.org/repos/asf/spark/blob/f25c3247/python/pyspark/ml/evaluation.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index e23ce05..6b0a9ff 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -163,11 +163,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
     ...
     >>> evaluator = RegressionEvaluator(predictionCol="raw")
     >>> evaluator.evaluate(dataset)
-    -2.842...
+    2.842...
     >>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
     0.993...
     >>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
-    -2.649...
+    2.649...
     """
     # Because we will maximize evaluation value (ref: `CrossValidator`),
     # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message