spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject [1/2] spark git commit: [SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs
Date Mon, 11 Apr 2016 16:28:32 GMT
Repository: spark
Updated Branches:
  refs/heads/master e82d95bf6 -> 1c751fcf4


http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 4a3ad66..36dce01 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -40,7 +40,7 @@ import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.CholeskyDecomposition
 import org.apache.spark.mllib.optimization.NNLS
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -200,8 +200,8 @@ class ALSModel private[ml] (
   @Since("1.3.0")
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
-  @Since("1.3.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     // Register a UDF for DataFrame, and then
     // create a new column named map(predictionCol) by running the predict UDF.
     val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
@@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
     this
   }
 
-  @Since("1.3.0")
-  override def fit(dataset: DataFrame): ALSModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): ALSModel = {
     import dataset.sqlContext.implicits._
     val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
     val ratings = dataset

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 3278974..afed1f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -32,7 +32,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DoubleType, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -183,7 +183,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override
val uid: S
    * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
    * and put it in an RDD with strong types.
    */
-  protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
+  protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
     dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
       .rdd.map {
         case Row(features: Vector, label: Double, censor: Double) =>
@@ -191,8 +191,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override
val uid: S
       }
   }
 
-  @Since("1.6.0")
-  override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
     validateAndTransformSchema(dataset.schema, fitting = true)
     val instances = extractAFTPoints(dataset)
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -299,8 +299,8 @@ class AFTSurvivalRegressionModel private[ml] (
     math.exp(BLAS.dot(coefficients, features) + intercept)
   }
 
-  @Since("1.6.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema)
     val predictUDF = udf { features: Vector => predict(features) }
     val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 1289a31..c04c416 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
 import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 
 
@@ -83,7 +83,7 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override
val
   /** @group setParam */
   def setVarianceCol(value: String): this.type = set(varianceCol, value)
 
-  override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
+  override protected def train(dataset: Dataset[_]): DecisionTreeRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -158,15 +158,16 @@ final class DecisionTreeRegressionModel private[ml] (
     rootNode.predictImpl(features).impurityStats.calculate()
   }
 
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     transformImpl(dataset)
   }
 
-  override protected def transformImpl(dataset: DataFrame): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val predictUDF = udf { (features: Vector) => predict(features) }
     val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
-    var output = dataset
+    var output = dataset.toDF
     if ($(predictionCol).nonEmpty) {
       output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 8eb2984..0b52fe2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.loss.{AbsoluteError => OldAbsoluteError,
Loss
   SquaredError => OldSquaredError}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 
 /**
@@ -147,7 +147,7 @@ final class GBTRegressor @Since("1.4.0") (@Since("1.4.0") override val
uid: Stri
     }
   }
 
-  override protected def train(dataset: DataFrame): GBTRegressionModel = {
+  override protected def train(dataset: Dataset[_]): GBTRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -209,7 +209,7 @@ final class GBTRegressionModel private[ml](
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights
 
-  override protected def transformImpl(dataset: DataFrame): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
     val predictUDF = udf { (features: Any) =>
       bcastModel.value.predict(features.asInstanceOf[Vector])

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 05bf645..00cf25d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.{BLAS, Vector}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
@@ -196,7 +196,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override
val
   def setSolver(value: String): this.type = set(solver, value)
   setDefault(solver -> "irls")
 
-  override protected def train(dataset: DataFrame): GeneralizedLinearRegressionModel = {
+  override protected def train(dataset: Dataset[_]): GeneralizedLinearRegressionModel = {
     val familyObj = Family.fromName($(family))
     val linkObj = if (isDefined(link)) {
       Link.fromName($(link))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index bd0b631..7a78ecb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
 import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression}
 import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, lit, udf}
 import org.apache.spark.sql.types.{DoubleType, StructType}
 import org.apache.spark.storage.StorageLevel
@@ -77,7 +77,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
    * Extracts (label, feature, weight) from input dataset.
    */
   protected[ml] def extractWeightedLabeledPoints(
-      dataset: DataFrame): RDD[(Double, Double, Double)] = {
+      dataset: Dataset[_]): RDD[(Double, Double, Double)] = {
     val f = if (dataset.schema($(featuresCol)).dataType.isInstanceOf[VectorUDT]) {
       val idx = $(featureIndex)
       val extract = udf { v: Vector => v(idx) }
@@ -164,8 +164,8 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") override val
uid: Stri
   @Since("1.5.0")
   override def copy(extra: ParamMap): IsotonicRegression = defaultCopy(extra)
 
-  @Since("1.5.0")
-  override def fit(dataset: DataFrame): IsotonicRegressionModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
     validateAndTransformSchema(dataset.schema, fitting = true)
     // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
     val instances = extractWeightedLabeledPoints(dataset)
@@ -236,8 +236,8 @@ class IsotonicRegressionModel private[ml] (
     copyValues(new IsotonicRegressionModel(uid, oldModel), extra).setParent(parent)
   }
 
-  @Since("1.5.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     val predict = dataset.schema($(featuresCol)).dataType match {
       case DoubleType =>
         udf { feature: Double => oldModel.predict(feature) }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index aacff4e..71e0273 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -38,7 +38,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.linalg.BLAS._
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.DoubleType
 import org.apache.spark.storage.StorageLevel
@@ -158,7 +158,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid:
String
   def setSolver(value: String): this.type = set(solver, value)
   setDefault(solver -> "auto")
 
-  override protected def train(dataset: DataFrame): LinearRegressionModel = {
+  override protected def train(dataset: Dataset[_]): LinearRegressionModel = {
     // Extract the number of features before deciding optimization solver.
     val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map {
       case Row(features: Vector) => features.size
@@ -417,7 +417,7 @@ class LinearRegressionModel private[ml] (
    * @param dataset Test dataset to evaluate model on.
    */
   @Since("2.0.0")
-  def evaluate(dataset: DataFrame): LinearRegressionSummary = {
+  def evaluate(dataset: Dataset[_]): LinearRegressionSummary = {
     // Handle possible missing or invalid prediction columns
     val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol()
     new LinearRegressionSummary(summaryModel.transform(dataset), predictionColName,

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 736cd9f..bee13c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -32,7 +32,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions._
 
 
@@ -93,7 +93,7 @@ final class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override
val
   override def setFeatureSubsetStrategy(value: String): this.type =
     super.setFeatureSubsetStrategy(value)
 
-  override protected def train(dataset: DataFrame): RandomForestRegressionModel = {
+  override protected def train(dataset: Dataset[_]): RandomForestRegressionModel = {
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
@@ -164,7 +164,7 @@ final class RandomForestRegressionModel private[ml] (
   @Since("1.4.0")
   override def treeWeights: Array[Double] = _treeWeights
 
-  override protected def transformImpl(dataset: DataFrame): DataFrame = {
+  override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
     val bcastModel = dataset.sqlContext.sparkContext.broadcast(this)
     val predictUDF = udf { (features: Any) =>
       bcastModel.value.predict(features.asInstanceOf[Vector])

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 4d9d4d4..de563d4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -33,7 +33,7 @@ import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared.HasSeed
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -90,8 +90,8 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid:
String)
   @Since("2.0.0")
   def setSeed(value: Long): this.type = set(seed, value)
 
-  @Since("1.4.0")
-  override def fit(dataset: DataFrame): CrossValidatorModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): CrossValidatorModel = {
     val schema = dataset.schema
     transformSchema(schema, logging = true)
     val sqlCtx = dataset.sqlContext
@@ -100,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid:
String)
     val epm = $(estimatorParamMaps)
     val numModels = epm.length
     val metrics = new Array[Double](epm.length)
-    val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
+    val splits = MLUtils.kFold(dataset.toDF.rdd, $(numFolds), $(seed))
     splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
       val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
       val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
@@ -209,8 +209,8 @@ class CrossValidatorModel private[ml] (
     this(uid, bestModel, avgMetrics.asScala.toArray)
   }
 
-  @Since("1.4.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     bestModel.transform(dataset)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 0f2179c..12d6905 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.tuning
 import java.util.{List => JList}
 
 import scala.collection.JavaConverters._
+import scala.language.existentials
 
 import org.apache.hadoop.fs.Path
 import org.json4s.DefaultFormats
@@ -31,7 +32,7 @@ import org.apache.spark.ml.evaluation.Evaluator
 import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.HasSeed
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -89,8 +90,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val
uid: St
   @Since("2.0.0")
   def setSeed(value: Long): this.type = set(seed, value)
 
-  @Since("1.5.0")
-  override def fit(dataset: DataFrame): TrainValidationSplitModel = {
+  @Since("2.0.0")
+  override def fit(dataset: Dataset[_]): TrainValidationSplitModel = {
     val schema = dataset.schema
     transformSchema(schema, logging = true)
     val sqlCtx = dataset.sqlContext
@@ -207,8 +208,8 @@ class TrainValidationSplitModel private[ml] (
     this(uid, bestModel, validationMetrics.asScala.toArray)
   }
 
-  @Since("1.5.0")
-  override def transform(dataset: DataFrame): DataFrame = {
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
     bestModel.transform(dataset)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 0f0c3a2..5812cdd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -186,7 +186,7 @@ sealed trait Vector extends Serializable {
  * :: AlphaComponent ::
  *
  * User-defined type for [[Vector]] which allows easy interaction with SQL
- * via [[org.apache.spark.sql.DataFrame]].
+ * via [[org.apache.spark.sql.Dataset]].
  */
 @AlphaComponent
 class VectorUDT extends UserDefinedType[Vector] {

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index f3321fb..a8c4ac6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -31,7 +31,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap}
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
 
 class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
@@ -51,6 +51,12 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
     val dataset3 = mock[DataFrame]
     val dataset4 = mock[DataFrame]
 
+    when(dataset0.toDF).thenReturn(dataset0)
+    when(dataset1.toDF).thenReturn(dataset1)
+    when(dataset2.toDF).thenReturn(dataset2)
+    when(dataset3.toDF).thenReturn(dataset3)
+    when(dataset4.toDF).thenReturn(dataset4)
+
     when(estimator0.copy(any[ParamMap])).thenReturn(estimator0)
     when(model0.copy(any[ParamMap])).thenReturn(model0)
     when(transformer1.copy(any[ParamMap])).thenReturn(transformer1)
@@ -213,7 +219,7 @@ class WritableStage(override val uid: String) extends Transformer with
MLWritabl
 
   override def write: MLWriter = new DefaultParamsWriter(this)
 
-  override def transform(dataset: DataFrame): DataFrame = dataset
+  override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
 
   override def transformSchema(schema: StructType): StructType = schema
 }
@@ -234,7 +240,7 @@ class UnWritableStage(override val uid: String) extends Transformer {
 
   override def copy(extra: ParamMap): UnWritableStage = defaultCopy(extra)
 
-  override def transform(dataset: DataFrame): DataFrame = dataset
+  override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF
 
   override def transformSchema(schema: StructType): StructType = schema
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 7eefaf2..48db428 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.lit
 
 class LogisticRegressionSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
   @transient var binaryDataset: DataFrame = _
   private val eps: Double = 1e-5
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 06ff049..80547fa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -26,12 +26,12 @@ import org.apache.spark.mllib.evaluation.MulticlassMetrics
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 
 class MultilayerPerceptronClassifierSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 4727cd4..80a46fc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -27,11 +27,11 @@ import org.apache.spark.mllib.classification.NaiveBayesSuite._
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 
 class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
 
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index 4131396..f3e8fd1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -30,12 +30,12 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.Metadata
 
 class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
 
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
   @transient var rdd: RDD[LabeledPoint] = _
 
   override def beforeAll(): Unit = {
@@ -246,7 +246,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid
 
   setMaxIter(1)
 
-  override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+  override protected[spark] def train(dataset: Dataset[_]): LogisticRegressionModel = {
     val labelSchema = dataset.schema($(labelCol))
     // check for label attribute propagation.
     assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 18f2c99..e641d79 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -20,13 +20,13 @@ package org.apache.spark.ml.clustering
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 class BisectingKMeansSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
   final val k = 5
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 8edd44e..1a274ae 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -20,14 +20,14 @@ package org.apache.spark.ml.clustering
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 
 class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
   with DefaultReadWriteTest {
 
   final val k = 5
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c684bc1..2076c74 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
 
 private[clustering] case class TestRow(features: Vector)
 
 class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest
{
 
   final val k = 5
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index a1c9389..ee8eae8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -24,7 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
 
 
 object LDASuite {
@@ -64,7 +64,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
 
   val k: Int = 5
   val vocabSize: Int = 30
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
index 58fda29..e4e15f4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
@@ -22,7 +22,7 @@ import scala.beans.BeanInfo
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 
 @BeanInfo
 case class NGramTestData(inputTokens: Array[String], wantedNGrams: Array[String])
@@ -92,7 +92,7 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRe
 
 object NGramSuite extends SparkFunSuite {
 
-  def testNGram(t: NGram, dataset: DataFrame): Unit = {
+  def testNGram(t: NGram, dataset: Dataset[_]): Unit = {
     t.transform(dataset)
       .select("nGrams", "wantedNGrams")
       .collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
index a5b24c1..3505bef 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala
@@ -20,10 +20,10 @@ package org.apache.spark.ml.feature
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 
 object StopWordsRemoverSuite extends SparkFunSuite {
-  def testStopWordsRemover(t: StopWordsRemover, dataset: DataFrame): Unit = {
+  def testStopWordsRemover(t: StopWordsRemover, dataset: Dataset[_]): Unit = {
     t.transform(dataset)
       .select("filtered", "expected")
       .collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 2c3255e..d0f3cdc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -115,7 +115,7 @@ class StringIndexerSuite
       .setInputCol("label")
       .setOutputCol("labelIndex")
     val df = sqlContext.range(0L, 10L).toDF()
-    assert(indexerModel.transform(df).eq(df))
+    assert(indexerModel.transform(df).collect().toSet === df.collect().toSet)
   }
 
   test("StringIndexerModel can't overwrite output column") {

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
index 36e8e5d..299f622 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/TokenizerSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
 
 @BeanInfo
 case class TokenizerTestData(rawText: String, wantedTokens: Array[String])
@@ -106,7 +106,7 @@ class RegexTokenizerSuite
 
 object RegexTokenizerSuite extends SparkFunSuite {
 
-  def testRegexTokenizer(t: RegexTokenizer, dataset: DataFrame): Unit = {
+  def testRegexTokenizer(t: RegexTokenizer, dataset: Dataset[_]): Unit = {
     t.transform(dataset)
       .select("tokens", "wantedTokens")
       .collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 2265464..4905f3e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -992,6 +992,14 @@ class GeneralizedLinearRegressionSuite
         assert(expected.coefficients === actual.coefficients)
       }
   }
+
+  test("glm accepts Dataset[LabeledPoint]") {
+    val context = sqlContext
+    import context.implicits._
+    new GeneralizedLinearRegression()
+      .setFamily("gaussian")
+      .fit(datasetGaussianIdentity.as[LabeledPoint])
+  }
 }
 
 object GeneralizedLinearRegressionSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 7af3c6d..3e734aa 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -29,13 +29,13 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.{StructField, StructType}
 
 class CrossValidatorSuite
   extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
 
-  @transient var dataset: DataFrame = _
+  @transient var dataset: Dataset[_] = _
 
   override def beforeAll(): Unit = {
     super.beforeAll()
@@ -311,7 +311,7 @@ object CrossValidatorSuite extends SparkFunSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol
{
 
-    override def fit(dataset: DataFrame): MyModel = {
+    override def fit(dataset: Dataset[_]): MyModel = {
       throw new UnsupportedOperationException
     }
 
@@ -325,7 +325,7 @@ object CrossValidatorSuite extends SparkFunSuite {
 
   class MyEvaluator extends Evaluator {
 
-    override def evaluate(dataset: DataFrame): Double = {
+    override def evaluate(dataset: Dataset[_]): Double = {
       throw new UnsupportedOperationException
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 4030956..dbee47c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
 import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.types.StructType
 
 class TrainValidationSplitSuite
@@ -158,7 +158,7 @@ object TrainValidationSplitSuite {
 
   class MyEstimator(override val uid: String) extends Estimator[MyModel] with HasInputCol
{
 
-    override def fit(dataset: DataFrame): MyModel = {
+    override def fit(dataset: Dataset[_]): MyModel = {
       throw new UnsupportedOperationException
     }
 
@@ -172,7 +172,7 @@ object TrainValidationSplitSuite {
 
   class MyEvaluator extends Evaluator {
 
-    override def evaluate(dataset: DataFrame): Double = {
+    override def evaluate(dataset: Dataset[_]): Double = {
       throw new UnsupportedOperationException
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c751fcf/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
index 1628047..7ebd7eb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala
@@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
 import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
 
 trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
 
@@ -98,7 +98,7 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite =>
   def testEstimatorAndModelReadWrite[
     E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable](
       estimator: E,
-      dataset: DataFrame,
+      dataset: Dataset[_],
       testParams: Map[String, Any],
       checkModelData: (M, M) => Unit): Unit = {
     // Set some Params to make sure set Params are serialized.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message