spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-7425][ML] spark.ml Predictor should support other numeric types for label
Date Sat, 02 Apr 2016 01:25:50 GMT
Repository: spark
Updated Branches:
  refs/heads/master abc6c42c2 -> 36e8fb800


[SPARK-7425][ML] spark.ml Predictor should support other numeric types for label

Currently, the Predictor abstraction expects the input labelCol type to be DoubleType, but
we should support other numeric types. This will involve updating the PredictorParams.validateAndTransformSchema
method.

Author: BenFradet <benjamin.fradet@gmail.com>

Closes #10355 from BenFradet/SPARK-7425.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/36e8fb80
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/36e8fb80
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/36e8fb80

Branch: refs/heads/master
Commit: 36e8fb8005eccea67a9dea8cf68ec3105aa43351
Parents: abc6c42
Author: BenFradet <benjamin.fradet@gmail.com>
Authored: Fri Apr 1 18:25:43 2016 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Fri Apr 1 18:25:43 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Predictor.scala   |  9 +-
 .../ml/classification/LogisticRegression.scala  |  7 +-
 .../spark/ml/classification/OneVsRest.scala     |  4 +-
 .../ml/regression/AFTSurvivalRegression.scala   | 11 +--
 .../GeneralizedLinearRegression.scala           | 13 +--
 .../ml/regression/IsotonicRegression.scala      |  4 +-
 .../spark/ml/regression/LinearRegression.scala  | 11 +--
 .../org/apache/spark/ml/util/SchemaUtils.scala  | 24 ++++--
 .../DecisionTreeClassifierSuite.scala           | 15 +++-
 .../ml/classification/GBTClassifierSuite.scala  |  9 +-
 .../LogisticRegressionSuite.scala               | 11 ++-
 .../MultilayerPerceptronClassifierSuite.scala   | 12 +++
 .../ml/classification/NaiveBayesSuite.scala     | 14 +++-
 .../ml/classification/OneVsRestSuite.scala      | 16 +++-
 .../RandomForestClassifierSuite.scala           |  8 ++
 .../regression/AFTSurvivalRegressionSuite.scala |  9 ++
 .../regression/DecisionTreeRegressorSuite.scala |  8 ++
 .../spark/ml/regression/GBTRegressorSuite.scala |  8 +-
 .../GeneralizedLinearRegressionSuite.scala      | 12 ++-
 .../ml/regression/IsotonicRegressionSuite.scala |  9 ++
 .../ml/regression/LinearRegressionSuite.scala   | 17 +++-
 .../regression/RandomForestRegressorSuite.scala |  8 ++
 .../apache/spark/ml/tree/impl/TreeTests.scala   | 18 ++++
 .../apache/spark/ml/util/MLTestingUtils.scala   | 86 +++++++++++++++++++-
 24 files changed, 294 insertions(+), 49 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index ebe4870..d23ae6f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -36,6 +36,7 @@ private[ml] trait PredictorParams extends Params
 
   /**
    * Validates and transforms the input schema with the provided param map.
+   *
    * @param schema input schema
    * @param fitting whether this is in fitting
    * @param featuresDataType  SQL DataType for FeaturesType.
@@ -49,8 +50,7 @@ private[ml] trait PredictorParams extends Params
     // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType =
Vector
     SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
     if (fitting) {
-      // TODO: Allow other numeric types
-      SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+      SchemaUtils.checkNumericType(schema, $(labelCol))
     }
     SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
   }
@@ -121,9 +121,8 @@ abstract class Predictor[
    * and put it in an RDD with strong types.
    */
   protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
-    dataset.select($(labelCol), $(featuresCol)).rdd.map {
-      case Row(label: Double, features: Vector) =>
-        LabeledPoint(label, features)
+    dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
+      case Row(label: Double, features: Vector) => LabeledPoint(label, features)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 3d1d5b6..aeb94a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -38,6 +38,7 @@ import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.DoubleType
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -265,7 +266,7 @@ class LogisticRegression @Since("1.2.0") (
       LogisticRegressionModel = {
     val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
     val instances: RDD[Instance] =
-      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
+      dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
         case Row(label: Double, weight: Double, features: Vector) =>
           Instance(label, weight, features)
       }
@@ -361,7 +362,7 @@ class LogisticRegression @Since("1.2.0") (
         if (optInitialModel.isDefined && optInitialModel.get.coefficients.size !=
numFeatures) {
           val vec = optInitialModel.get.coefficients
           logWarning(
-            s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}")
+            s"Initial coefficients provided $vec did not match the expected size $numFeatures")
         }
 
         if (optInitialModel.isDefined && optInitialModel.get.coefficients.size ==
numFeatures) {
@@ -522,7 +523,7 @@ class LogisticRegressionModel private[spark] (
       (LogisticRegressionModel, String) = {
     $(probabilityCol) match {
       case "" =>
-        val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString()
+        val probabilityColName = "probability_" + java.util.UUID.randomUUID.toString
         (copy(ParamMap.empty).setProbabilityCol(probabilityColName), probabilityColName)
       case p => (this, p)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 98b99a3..263d54c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -295,10 +295,12 @@ final class OneVsRest @Since("1.4.0") (
 
   @Since("1.4.0")
   override def fit(dataset: DataFrame): OneVsRestModel = {
+    transformSchema(dataset.schema)
+
     // determine number of classes either from metadata if provided, or via computation.
     val labelSchema = dataset.schema($(labelCol))
     val computeNumClasses: () => Int = () => {
-      val Row(maxLabelIndex: Double) = dataset.agg(max($(labelCol))).head()
+      val Row(maxLabelIndex: Double) = dataset.agg(max(col($(labelCol)).cast(DoubleType))).head()
       // classes are assumed to be numbered from 0,...,maxLabelIndex
       maxLabelIndex.toInt + 1
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index ba5708a..3278974 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -103,7 +103,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
     SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     if (fitting) {
       SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
-      SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+      SchemaUtils.checkNumericType(schema, $(labelCol))
     }
     if (hasQuantilesCol) {
       SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
@@ -184,10 +184,11 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override
val uid: S
    * and put it in an RDD with strong types.
    */
   protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
-    dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map {
-      case Row(features: Vector, label: Double, censor: Double) =>
-        AFTPoint(features, label, censor)
-    }
+    dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
+      .rdd.map {
+        case Row(features: Vector, label: Double, censor: Double) =>
+          AFTPoint(features, label, censor)
+      }
   }
 
   @Since("1.6.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 0e71e8d..a40d373 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
 /**
  * Params for Generalized Linear Regression.
@@ -47,6 +47,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
    * to be used in the model.
    * Supported options: "gaussian", "binomial", "poisson" and "gamma".
    * Default is "gaussian".
+   *
    * @group param
    */
   @Since("2.0.0")
@@ -63,6 +64,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
    * Param for the name of link function which provides the relationship
    * between the linear predictor and the mean of the distribution function.
    * Supported options: "identity", "log", "inverse", "logit", "probit", "cloglog" and "sqrt".
+   *
    * @group param
    */
   @Since("2.0.0")
@@ -210,9 +212,10 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override
val
     }
 
     val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
-    val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd
-      .map { case Row(label: Double, weight: Double, features: Vector) =>
-        Instance(label, weight, features)
+    val instances: RDD[Instance] =
+      dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
+        case Row(label: Double, weight: Double, features: Vector) =>
+          Instance(label, weight, features)
       }
 
     if (familyObj == Gaussian && linkObj == Identity) {
@@ -698,7 +701,7 @@ class GeneralizedLinearRegressionModel private[ml] (
     : (GeneralizedLinearRegressionModel, String) = {
     $(predictionCol) match {
       case "" =>
-        val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+        val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
         (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
       case p => (this, p)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index fb733f9..bd0b631 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -90,7 +90,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
     } else {
       lit(1.0)
     }
-    dataset.select(col($(labelCol)), f, w).rdd.map {
+    dataset.select(col($(labelCol)).cast(DoubleType), f, w).rdd.map {
       case Row(label: Double, feature: Double, weight: Double) =>
         (label, feature, weight)
     }
@@ -106,7 +106,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
       schema: StructType,
       fitting: Boolean): StructType = {
     if (fitting) {
-      SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+      SchemaUtils.checkNumericType(schema, $(labelCol))
       if (hasWeightCol) {
         SchemaUtils.checkColumnType(schema, $(weightCol), DoubleType)
       } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 5ec0213..ba5ad4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -40,6 +40,7 @@ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.DoubleType
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -171,7 +172,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid:
String
       // For low dimensional data, WeightedLeastSquares is more efficiently since the
       // training algorithm only requires one pass through the data. (SPARK-10668)
       val instances: RDD[Instance] = dataset.select(
-        col($(labelCol)), w, col($(featuresCol))).rdd.map {
+        col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
           case Row(label: Double, weight: Double, features: Vector) =>
             Instance(label, weight, features)
       }
@@ -431,7 +432,7 @@ class LinearRegressionModel private[ml] (
   private[regression] def findSummaryModelAndPredictionCol(): (LinearRegressionModel, String)
= {
     $(predictionCol) match {
       case "" =>
-        val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString()
+        val predictionColName = "prediction_" + java.util.UUID.randomUUID.toString
         (copy(ParamMap.empty).setPredictionCol(predictionColName), predictionColName)
       case p => (this, p)
     }
@@ -550,7 +551,7 @@ class LinearRegressionSummary private[regression] (
 
   @transient private val metrics = new RegressionMetrics(
     predictions
-      .select(predictionCol, labelCol)
+      .select(col(predictionCol), col(labelCol).cast(DoubleType))
       .rdd
       .map { case Row(pred: Double, label: Double) => (pred, label) },
     !model.getFitIntercept)
@@ -653,7 +654,7 @@ class LinearRegressionSummary private[regression] (
           col(model.getWeightCol)).as("wse")).agg(sum(col("wse"))).first().getDouble(0)
       }
       val sigma2 = rss / degreesOfFreedom
-      diagInvAtWA.map(_ * sigma2).map(math.sqrt(_))
+      diagInvAtWA.map(_ * sigma2).map(math.sqrt)
     }
   }
 
@@ -826,7 +827,7 @@ private class LeastSquaresAggregator(
     instance match { case Instance(label, weight, features) =>
       require(dim == features.size, s"Dimensions mismatch when adding new sample." +
         s" Expecting $dim but got ${features.size}.")
-      require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+      require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0")
 
       if (weight == 0.0) return this
 

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
index 76021ad..334410c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.ml.util
 
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types.{DataType, NumericType, StructField, StructType}
 
 
 /**
@@ -44,10 +44,10 @@ private[spark] object SchemaUtils {
   }
 
   /**
-    * Check whether the given schema contains a column of one of the require data types.
-    * @param colName  column name
-    * @param dataTypes  required column data types
-    */
+   * Check whether the given schema contains a column of one of the require data types.
+   * @param colName  column name
+   * @param dataTypes  required column data types
+   */
   def checkColumnTypes(
       schema: StructType,
       colName: String,
@@ -61,6 +61,20 @@ private[spark] object SchemaUtils {
   }
 
   /**
+   * Check whether the given schema contains a column of the numeric data type.
+   * @param colName  column name
+   */
+  def checkNumericType(
+      schema: StructType,
+      colName: String,
+      msg: String = ""): Unit = {
+    val actualDataType = schema(colName).dataType
+    val message = if (msg != null && msg.trim.length > 0) " " + msg else ""
+    require(actualDataType.isInstanceOf[NumericType], s"Column $colName must be of type "
+
+      s"NumericType but was actually of type $actualDataType.$message")
+  }
+
+  /**
    * Appends a new column to the input schema. This fails if the given output column already
exists.
    * @param schema input schema
    * @param colName new column name. If this column name is an empty string "", this method
returns

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 2b07524..fe839e1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -27,8 +27,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite
=> OldDecisionTreeSuite}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{DataFrame, Row}
 
 class DecisionTreeClassifierSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -176,7 +175,7 @@ class DecisionTreeClassifierSuite
   }
 
   test("Multiclass classification tree with 10-ary (ordered) categorical features," +
-      " with just enough bins") {
+    " with just enough bins") {
     val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
     val dt = new DecisionTreeClassifier()
       .setImpurity("Gini")
@@ -273,7 +272,7 @@ class DecisionTreeClassifierSuite
     ))
     val df = TreeTests.setMetadata(data, Map(0 -> 1), 2)
     val dt = new DecisionTreeClassifier().setMaxDepth(3)
-    val model = dt.fit(df)
+    dt.fit(df)
   }
 
   test("Use soft prediction for binary classification with ordered categorical features")
{
@@ -335,6 +334,14 @@ class DecisionTreeClassifierSuite
     assert(importances.toArray.forall(_ >= 0.0))
   }
 
+  test("should support all NumericType labels and not support other types") {
+    val dt = new DecisionTreeClassifier().setMaxDepth(1)
+    MLTestingUtils.checkNumericTypes[DecisionTreeClassificationModel, DecisionTreeClassifier](
+      dt, isClassification = true, sqlContext) { (expected, actual) =>
+        TreeTests.checkEqual(expected, actual)
+      }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index bf7481e..76d8c93 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -31,7 +31,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.util.Utils
 
-
 /**
  * Test suite for [[GBTClassifier]].
  */
@@ -102,6 +101,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
{
     Utils.deleteRecursively(tempDir)
   }
 
+  test("should support all NumericType labels and not support other types") {
+    val gbt = new GBTClassifier().setMaxDepth(1)
+    MLTestingUtils.checkNumericTypes[GBTClassificationModel, GBTClassifier](
+      gbt, isClassification = true, sqlContext) { (expected, actual) =>
+        TreeTests.checkEqual(expected, actual)
+      }
+  }
+
   // TODO: Reinstate test once runWithValidation is implemented   SPARK-7132
   /*
   test("runWithValidation stops early and performs better on a validation dataset") {

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index afeeaf7..7eefaf2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -103,7 +103,7 @@ class LogisticRegressionSuite
     assert(model.hasSummary)
     // Validate that we re-insert a probability column for evaluation
     val fieldNames = model.summary.predictions.schema.fieldNames
-    assert((dataset.schema.fieldNames.toSet).subsetOf(
+    assert(dataset.schema.fieldNames.toSet.subsetOf(
       fieldNames.toSet))
     assert(fieldNames.exists(s => s.startsWith("probability_")))
   }
@@ -934,6 +934,15 @@ class LogisticRegressionSuite
     testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings,
       checkModelData)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val lr = new LogisticRegression().setMaxIter(1)
+    MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
+      lr, isClassification = true, sqlContext) { (expected, actual) =>
+        assert(expected.intercept === actual.intercept)
+        assert(expected.coefficients.toArray === actual.coefficients.toArray)
+      }
+  }
 }
 
 object LogisticRegressionSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 4378138..06ff049 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.MLTestingUtils
 import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
@@ -162,4 +163,15 @@ class MultilayerPerceptronClassifierSuite
     assert(newMlpModel.layers === mlpModel.layers)
     assert(newMlpModel.weights === mlpModel.weights)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val layers = Array(3, 2)
+    val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
+    MLTestingUtils.checkNumericTypes[
+        MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
+      mpc, isClassification = true, sqlContext) { (expected, actual) =>
+        assert(expected.layers === actual.layers)
+        assert(expected.weights === actual.weights)
+      }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 082a6bc..4727cd4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -21,7 +21,7 @@ import breeze.linalg.{Vector => BV}
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
 import org.apache.spark.mllib.classification.NaiveBayesSuite._
 import org.apache.spark.mllib.linalg._
@@ -86,7 +86,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with
Defa
       model: NaiveBayesModel,
       modelType: String): Unit = {
     featureAndProbabilities.collect().foreach {
-      case Row(features: Vector, probability: Vector) => {
+      case Row(features: Vector, probability: Vector) =>
         assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
         val expected = modelType match {
           case Multinomial =>
@@ -97,7 +97,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with
Defa
             throw new UnknownError(s"Invalid modelType: $modelType.")
         }
         assert(probability ~== expected relTol 1.0e-10)
-      }
     }
   }
 
@@ -185,6 +184,15 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext
with Defa
     val nb = new NaiveBayes()
     testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val nb = new NaiveBayes()
+    MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
+      nb, isClassification = true, sqlContext) { (expected, actual) =>
+        assert(expected.pi === actual.pi)
+        assert(expected.theta === actual.theta)
+      }
+  }
 }
 
 object NaiveBayesSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 51c1baf..4131396 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -74,7 +74,7 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with
Defau
     // copied model must have the same parent.
     MLTestingUtils.checkCopy(ovaModel)
 
-    assert(ovaModel.models.size === numClasses)
+    assert(ovaModel.models.length === numClasses)
     val transformedDataset = ovaModel.transform(dataset)
 
     // check for label metadata in prediction col
@@ -224,6 +224,20 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext
with Defau
     val newOvaModel = testDefaultReadWrite(ovaModel, testParams = false)
     checkModelData(ovaModel, newOvaModel)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val ovr = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(1))
+    MLTestingUtils.checkNumericTypes[OneVsRestModel, OneVsRest](
+      ovr, isClassification = true, sqlContext) { (expected, actual) =>
+        val expectedModels = expected.models.map(m => m.asInstanceOf[LogisticRegressionModel])
+        val actualModels = actual.models.map(m => m.asInstanceOf[LogisticRegressionModel])
+        assert(expectedModels.length === actualModels.length)
+        expectedModels.zip(actualModels).foreach { case (e, a) =>
+          assert(e.intercept === a.intercept)
+          assert(e.coefficients.toArray === a.coefficients.toArray)
+        }
+      }
+  }
 }
 
 private class MockLogisticRegression(uid: String) extends LogisticRegression(uid) {

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index b896099..052bc83 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -178,6 +178,14 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
     assert(importances.toArray.forall(_ >= 0.0))
   }
 
+  test("should support all NumericType labels and not support other types") {
+    val rf = new RandomForestClassifier().setMaxDepth(1)
+    MLTestingUtils.checkNumericTypes[RandomForestClassificationModel, RandomForestClassifier](
+      rf, isClassification = true, sqlContext) { (expected, actual) =>
+        TreeTests.checkEqual(expected, actual)
+      }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index dbd752d..f4844cc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -347,6 +347,15 @@ class AFTSurvivalRegressionSuite
     }
   }
 
+  test("should support all NumericType labels") {
+    val aft = new AFTSurvivalRegression().setMaxIter(1)
+    MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
+      aft, isClassification = false, sqlContext) { (expected, actual) =>
+        assert(expected.intercept === actual.intercept)
+        assert(expected.coefficients === actual.coefficients)
+      }
+  }
+
   test("read/write") {
     def checkModelData(
         model: AFTSurvivalRegressionModel,

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 662e3fc..e9fb267 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -117,6 +117,14 @@ class DecisionTreeRegressorSuite
     assert(importances.toArray.forall(_ >= 0.0))
   }
 
+  test("should support all NumericType labels and not support other types") {
+    val dt = new DecisionTreeRegressor().setMaxDepth(1)
+    MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](
+      dt, isClassification = false, sqlContext) { (expected, actual) =>
+        TreeTests.checkEqual(expected, actual)
+      }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index dfb8418..914818f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -29,7 +29,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.util.Utils
 
-
 /**
  * Test suite for [[GBTRegressor]].
  */
@@ -110,7 +109,14 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
{
 
     sc.checkpointDir = None
     Utils.deleteRecursively(tempDir)
+  }
 
+  test("should support all NumericType labels and not support other types") {
+    val gbt = new GBTRegressor().setMaxDepth(1)
+    MLTestingUtils.checkNumericTypes[GBTRegressionModel, GBTRegressor](
+      gbt, isClassification = false, sqlContext) { (expected, actual) =>
+        TreeTests.checkEqual(expected, actual)
+      }
   }
 
   // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 4ebdbf2..2265464 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -982,6 +982,16 @@ class GeneralizedLinearRegressionSuite
     testEstimatorAndModelReadWrite(glr, datasetPoissonLog,
       GeneralizedLinearRegressionSuite.allParamSettings, checkModelData)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val glr = new GeneralizedLinearRegression().setMaxIter(1)
+    MLTestingUtils.checkNumericTypes[
+        GeneralizedLinearRegressionModel, GeneralizedLinearRegression](
+      glr, isClassification = false, sqlContext) { (expected, actual) =>
+        assert(expected.intercept === actual.intercept)
+        assert(expected.coefficients === actual.coefficients)
+      }
+  }
 }
 
 object GeneralizedLinearRegressionSuite {
@@ -1023,7 +1033,7 @@ object GeneralizedLinearRegressionSuite {
     generator.setSeed(seed)
 
     (0 until nPoints).map { _ =>
-      val features = Vectors.dense(coefficients.indices.map { rndElement(_) }.toArray)
+      val features = Vectors.dense(coefficients.indices.map(rndElement).toArray)
       val eta = BLAS.dot(Vectors.dense(coefficients), features) + intercept
       val mu = link match {
         case "identity" => eta

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
index b8874b4..3a10ad7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala
@@ -180,6 +180,15 @@ class IsotonicRegressionSuite
     testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings,
       checkModelData)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val ir = new IsotonicRegression()
+    MLTestingUtils.checkNumericTypes[IsotonicRegressionModel, IsotonicRegression](
+      ir, isClassification = false, sqlContext) { (expected, actual) =>
+        assert(expected.boundaries === actual.boundaries)
+        assert(expected.predictions === actual.predictions)
+      }
+  }
 }
 
 object IsotonicRegressionSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index bd45d21..cccb7f8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -61,9 +61,9 @@ class LinearRegressionSuite
     val featureSize = 4100
     datasetWithSparseFeature = sqlContext.createDataFrame(
       sc.parallelize(LinearDataGenerator.generateLinearInput(
-        intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble).toArray,
-        xMean = Seq.fill(featureSize)(r.nextDouble).toArray,
-        xVariance = Seq.fill(featureSize)(r.nextDouble).toArray, nPoints = 200,
+        intercept = 0.0, weights = Seq.fill(featureSize)(r.nextDouble()).toArray,
+        xMean = Seq.fill(featureSize)(r.nextDouble()).toArray,
+        xVariance = Seq.fill(featureSize)(r.nextDouble()).toArray, nPoints = 200,
         seed, eps = 0.1, sparsity = 0.7), 2))
 
     /*
@@ -687,7 +687,7 @@ class LinearRegressionSuite
       // Validate that we re-insert a prediction column for evaluation
       val modelNoPredictionColFieldNames
       = modelNoPredictionCol.summary.predictions.schema.fieldNames
-      assert((datasetWithDenseFeature.schema.fieldNames.toSet).subsetOf(
+      assert(datasetWithDenseFeature.schema.fieldNames.toSet.subsetOf(
         modelNoPredictionColFieldNames.toSet))
       assert(modelNoPredictionColFieldNames.exists(s => s.startsWith("prediction_")))
 
@@ -1006,6 +1006,15 @@ class LinearRegressionSuite
     testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings,
       checkModelData)
   }
+
+  test("should support all NumericType labels and not support other types") {
+    val lr = new LinearRegression().setMaxIter(1)
+    MLTestingUtils.checkNumericTypes[LinearRegressionModel, LinearRegression](
+      lr, isClassification = false, sqlContext) { (expected, actual) =>
+        assert(expected.intercept === actual.intercept)
+        assert(expected.coefficients === actual.coefficients)
+      }
+  }
 }
 
 object LinearRegressionSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 6be0c8b..2ab4f1b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -94,6 +94,14 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
     assert(importances.toArray.forall(_ >= 0.0))
   }
 
+  test("should support all NumericType labels and not support other types") {
+    val rf = new RandomForestRegressor().setMaxDepth(1)
+    MLTestingUtils.checkNumericTypes[RandomForestRegressionModel, RandomForestRegressor](
+      rf, isClassification = false, sqlContext) { (expected, actual) =>
+        TreeTests.checkEqual(expected, actual)
+      }
+  }
+
   /////////////////////////////////////////////////////////////////////////////
   // Tests of model save/load
   /////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
index 12808b0..bd5bd17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/TreeTests.scala
@@ -74,6 +74,24 @@ private[ml] object TreeTests extends SparkFunSuite {
   }
 
   /**
+   * Set label metadata (particularly the number of classes) on a DataFrame.
+   * @param data  Dataset.  Categorical features and labels must already have 0-based indices.
+   *              This must be non-empty.
+   * @param numClasses  Number of classes label can take. If 0, mark as continuous.
+   * @param labelColName  Name of the label column on which to set the metadata.
+   * @return DataFrame with metadata
+   */
+  def setMetadata(data: DataFrame, numClasses: Int, labelColName: String): DataFrame = {
+    val labelAttribute = if (numClasses == 0) {
+      NumericAttribute.defaultAttr.withName(labelColName)
+    } else {
+      NominalAttribute.defaultAttr.withName(labelColName).withNumValues(numClasses)
+    }
+    val labelMetadata = labelAttribute.toMetadata()
+    data.select(data("features"), data(labelColName).as(labelColName, labelMetadata))
+  }
+
+  /**
    * Check if the two trees are exactly the same.
    * Note: I hesitate to override Node.equals since it could cause problems if users
    *       make mistakes such as creating loops of Nodes.

http://git-wip-us.apache.org/repos/asf/spark/blob/36e8fb80/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
index d290cc9..8108460 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala
@@ -17,14 +17,96 @@
 
 package org.apache.spark.ml.util
 
-import org.apache.spark.ml.Model
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.tree.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
 
-object MLTestingUtils {
+object MLTestingUtils extends SparkFunSuite {
   def checkCopy(model: Model[_]): Unit = {
     val copied = model.copy(ParamMap.empty)
       .asInstanceOf[Model[_]]
     assert(copied.parent.uid == model.parent.uid)
     assert(copied.parent == model.parent)
   }
+
+  def checkNumericTypes[M <: Model[M], T <: Estimator[M]](
+      estimator: T,
+      isClassification: Boolean,
+      sqlContext: SQLContext)(check: (M, M) => Unit): Unit = {
+    val dfs = if (isClassification) {
+      genClassifDFWithNumericLabelCol(sqlContext)
+    } else {
+      genRegressionDFWithNumericLabelCol(sqlContext)
+    }
+    val expected = estimator.fit(dfs(DoubleType))
+    val actuals = dfs.keys.filter(_ != DoubleType).map(t => estimator.fit(dfs(t)))
+    actuals.foreach(actual => check(expected, actual))
+
+    val dfWithStringLabels = generateDFWithStringLabelCol(sqlContext)
+    val thrown = intercept[IllegalArgumentException] {
+      estimator.fit(dfWithStringLabels)
+    }
+    assert(thrown.getMessage contains
+      "Column label must be of type NumericType but was actually of type StringType")
+  }
+
+  def genClassifDFWithNumericLabelCol(
+      sqlContext: SQLContext,
+      labelColName: String = "label",
+      featuresColName: String = "features"): Map[NumericType, DataFrame] = {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, Vectors.dense(0, 2, 3)),
+      (1, Vectors.dense(0, 3, 1)),
+      (0, Vectors.dense(0, 2, 2)),
+      (1, Vectors.dense(0, 3, 9)),
+      (0, Vectors.dense(0, 2, 6))
+    )).toDF(labelColName, featuresColName)
+
+    val types =
+      Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10,
0))
+    types.map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
+      .map { case (t, d) => t -> TreeTests.setMetadata(d, 2, labelColName) }
+      .toMap
+  }
+
+  def genRegressionDFWithNumericLabelCol(
+      sqlContext: SQLContext,
+      labelColName: String = "label",
+      featuresColName: String = "features",
+      censorColName: String = "censor"): Map[NumericType, DataFrame] = {
+    val df = sqlContext.createDataFrame(Seq(
+      (0, Vectors.dense(0)),
+      (1, Vectors.dense(1)),
+      (2, Vectors.dense(2)),
+      (3, Vectors.dense(3)),
+      (4, Vectors.dense(4))
+    )).toDF(labelColName, featuresColName)
+
+    val types =
+      Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DoubleType, DecimalType(10,
0))
+    types
+      .map(t => t -> df.select(col(labelColName).cast(t), col(featuresColName)))
+      .map { case (t, d) =>
+        t -> TreeTests.setMetadata(d, 0, labelColName).withColumn(censorColName, lit(0.0))
+      }
+      .toMap
+  }
+
+  def generateDFWithStringLabelCol(
+      sqlContext: SQLContext,
+      labelColName: String = "label",
+      featuresColName: String = "features",
+      censorColName: String = "censor"): DataFrame =
+    sqlContext.createDataFrame(Seq(
+      ("0", Vectors.dense(0, 2, 3), 0.0),
+      ("1", Vectors.dense(0, 3, 1), 1.0),
+      ("0", Vectors.dense(0, 2, 2), 0.0),
+      ("1", Vectors.dense(0, 3, 9), 1.0),
+      ("0", Vectors.dense(0, 2, 6), 0.0)
+    )).toDF(labelColName, featuresColName, censorColName)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message