spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-10884][ML] Support prediction on single instance for regression and classification related models
Date Wed, 21 Mar 2018 16:39:17 GMT
Repository: spark
Updated Branches:
  refs/heads/master 500b21c3d -> bf09f2f71


[SPARK-10884][ML] Support prediction on single instance for regression and classification
related models

## What changes were proposed in this pull request?

Support prediction on single instance for regression and classification related models (i.e.,
PredictionModel, ClassificationModel and their sub classes).
Add corresponding test cases.

## How was this patch tested?

Test cases added.

Author: WeichenXu <weichen.xu@databricks.com>

Closes #19381 from WeichenXu123/single_prediction.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bf09f2f7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bf09f2f7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bf09f2f7

Branch: refs/heads/master
Commit: bf09f2f71276d3b3a84a8f89109bd785a066c3e6
Parents: 500b21c
Author: WeichenXu <weichen.xu@databricks.com>
Authored: Wed Mar 21 09:39:14 2018 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Wed Mar 21 09:39:14 2018 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Predictor.scala   |  5 ++--
 .../spark/ml/classification/Classifier.scala    |  6 ++---
 .../classification/DecisionTreeClassifier.scala |  2 +-
 .../spark/ml/classification/GBTClassifier.scala |  2 +-
 .../spark/ml/classification/LinearSVC.scala     |  2 +-
 .../ml/classification/LogisticRegression.scala  |  2 +-
 .../MultilayerPerceptronClassifier.scala        |  2 +-
 .../ml/regression/DecisionTreeRegressor.scala   |  2 +-
 .../spark/ml/regression/GBTRegressor.scala      |  2 +-
 .../GeneralizedLinearRegression.scala           |  2 +-
 .../spark/ml/regression/LinearRegression.scala  |  2 +-
 .../ml/regression/RandomForestRegressor.scala   |  2 +-
 .../DecisionTreeClassifierSuite.scala           | 17 +++++++++++++-
 .../ml/classification/GBTClassifierSuite.scala  |  9 ++++++++
 .../ml/classification/LinearSVCSuite.scala      |  6 +++++
 .../LogisticRegressionSuite.scala               |  9 ++++++++
 .../MultilayerPerceptronClassifierSuite.scala   | 12 ++++++++++
 .../ml/classification/NaiveBayesSuite.scala     | 22 ++++++++++++++++++
 .../RandomForestClassifierSuite.scala           | 16 +++++++++++++
 .../regression/DecisionTreeRegressorSuite.scala | 15 ++++++++++++
 .../spark/ml/regression/GBTRegressorSuite.scala |  8 +++++++
 .../GeneralizedLinearRegressionSuite.scala      |  8 +++++++
 .../ml/regression/LinearRegressionSuite.scala   |  7 ++++++
 .../regression/RandomForestRegressorSuite.scala | 24 ++++++++++++++++----
 .../scala/org/apache/spark/ml/util/MLTest.scala | 15 ++++++++++--
 25 files changed, 176 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 08b0cb9..d8f3dfa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -219,7 +219,8 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
 
   /**
    * Predict label for the given features.
-   * This internal method is used to implement `transform()` and output [[predictionCol]].
+   * This method is used to implement `transform()` and output [[predictionCol]].
    */
-  protected def predict(features: FeaturesType): Double
+  @Since("2.4.0")
+  def predict(features: FeaturesType): Double
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 9d1d5aa..7e5790a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.SparkException
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, VectorUDT}
@@ -192,12 +192,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
 
   /**
    * Predict label for the given features.
-   * This internal method is used to implement `transform()` and output [[predictionCol]].
+   * This method is used to implement `transform()` and output [[predictionCol]].
    *
    * This default implementation for classification predicts the index of the maximum value
    * from `predictRaw()`.
    */
-  override protected def predict(features: FeaturesType): Double = {
+  override def predict(features: FeaturesType): Double = {
     raw2prediction(predictRaw(features))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 9f60f08..65cce69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -181,7 +181,7 @@ class DecisionTreeClassificationModel private[ml] (
   private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
     this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     rootNode.predictImpl(features).prediction
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index f11bc1d..cd44489 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -267,7 +267,7 @@ class GBTClassificationModel private[ml](
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     // If thresholds defined, use predictRaw to get probabilities, otherwise use optimization
     if (isDefined(thresholds)) {
       super.predict(features)

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
index ce400f4..8f950cd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala
@@ -316,7 +316,7 @@ class LinearSVCModel private[classification] (
     BLAS.dot(features, coefficients) + intercept
   }
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     if (margin(features) > $(threshold)) 1.0 else 0.0
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index fa19160..3ae4db3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1090,7 +1090,7 @@ class LogisticRegressionModel private[spark] (
    * Predict label for the given feature vector.
    * The behavior of this can be adjusted using `thresholds`.
    */
-  override protected def predict(features: Vector): Double = if (isMultinomial) {
+  override def predict(features: Vector): Double = if (isMultinomial) {
     super.predict(features)
   } else {
     // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index fd4c98f..af2e469 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -322,7 +322,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
    * Predict label for the given features.
    * This internal method is used to implement `transform()` and output [[predictionCol]].
    */
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     LabelConverter.decodeLabel(mlpModel.predict(features))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 0291a57..ad154fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -178,7 +178,7 @@ class DecisionTreeRegressionModel private[ml] (
   private[ml] def this(rootNode: Node, numFeatures: Int) =
     this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     rootNode.predictImpl(features).prediction
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index f41d15b..6569ff2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -230,7 +230,7 @@ class GBTRegressionModel private[ml](
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     // TODO: When we add a generic Boosting class, handle transform there?  SPARK-7129
     // Classifies by thresholding sum of weighted tree predictions
     val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 917a4d2..9f1f240 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1010,7 +1010,7 @@ class GeneralizedLinearRegressionModel private[ml] (
 
   private lazy val familyAndLink = FamilyAndLink(this)
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     predict(features, 0.0)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 6d3fe7a..9251015 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -699,7 +699,7 @@ class LinearRegressionModel private[ml] (
   }
 
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     dot(features, coefficients) + intercept
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 200b234..2d59446 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -199,7 +199,7 @@ class RandomForestRegressionModel private[ml] (
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
 
-  override protected def predict(features: Vector): Double = {
+  override def predict(features: Vector): Double = {
     // TODO: When we add a generic Bagging class, handle transform there.  SPARK-7128
     // Predict average of tree predictions.
     // Ignore the weights since all are 1.0 for now.

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index eeb0324..2930f49 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.tree.{CategoricalSplit, InternalNode, LeafNode}
+import org.apache.spark.ml.tree.LeafNode
 import org.apache.spark.ml.tree.impl.TreeTests
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
@@ -264,6 +264,21 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest
{
       Vector, DecisionTreeClassificationModel](this, newTree, newData)
   }
 
+  test("prediction on single instance") {
+    val rdd = continuousDataPointsForMulticlassRDD
+    val dt = new DecisionTreeClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(4)
+      .setMaxBins(100)
+    val categoricalFeatures = Map(0 -> 3)
+    val numClasses = 3
+
+    val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+    val newTree = dt.fit(newData)
+
+    testPredictionModelSinglePrediction(newTree, newData)
+  }
+
   test("training with 1-category categorical feature") {
     val data = sc.parallelize(Seq(
       LabeledPoint(0, Vectors.dense(0, 2, 3)),

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 092b4a0..5779606 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -197,6 +197,15 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
       Vector, GBTClassificationModel](this, gbtModel, validationDataset)
   }
 
+  test("prediction on single instance") {
+
+    val gbt = new GBTClassifier().setSeed(123)
+    val trainingDataset = trainData.toDF("label", "features")
+    val gbtModel = gbt.fit(trainingDataset)
+
+    testPredictionModelSinglePrediction(gbtModel, trainingDataset)
+  }
+
   test("GBT parameter stepSize should be in interval (0, 1]") {
     withClue("GBT parameter stepSize should be in interval (0, 1]") {
       intercept[IllegalArgumentException] {

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index a93825b..c05c896 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -201,6 +201,12 @@ class LinearSVCSuite extends MLTest with DefaultReadWriteTest {
       dataset.as[LabeledPoint], estimator, modelEquals, 42L)
   }
 
+  test("prediction on single instance") {
+    val trainer = new LinearSVC()
+    val model = trainer.fit(smallBinaryDataset)
+    testPredictionModelSinglePrediction(model, smallBinaryDataset)
+  }
+
   test("linearSVC comparison with R e1071 and scikit-learn") {
     val trainer1 = new LinearSVC()
       .setRegParam(0.00002) // set regParam = 2.0 / datasize / c

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9987cbf..36b7e51 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -499,6 +499,15 @@ class LogisticRegressionSuite extends MLTest with DefaultReadWriteTest
{
       Vector, LogisticRegressionModel](this, model, smallBinaryDataset)
   }
 
+  test("prediction on single instance") {
+    val blor = new LogisticRegression().setFamily("binomial")
+    val blorModel = blor.fit(smallBinaryDataset)
+    testPredictionModelSinglePrediction(blorModel, smallBinaryDataset)
+    val mlor = new LogisticRegression().setFamily("multinomial")
+    val mlorModel = mlor.fit(smallMultinomialDataset)
+    testPredictionModelSinglePrediction(mlorModel, smallMultinomialDataset)
+  }
+
   test("coefficients and intercept methods") {
     val mlr = new LogisticRegression().setMaxIter(1).setFamily("multinomial")
     val mlrModel = mlr.fit(smallMultinomialDataset)

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index daa58a5..6b5fe6e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -76,6 +76,18 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe
     }
   }
 
+  test("prediction on single instance") {
+    val layers = Array[Int](2, 5, 2)
+    val trainer = new MultilayerPerceptronClassifier()
+      .setLayers(layers)
+      .setBlockSize(1)
+      .setSeed(123L)
+      .setMaxIter(100)
+      .setSolver("l-bfgs")
+    val model = trainer.fit(dataset)
+    testPredictionModelSinglePrediction(model, dataset)
+  }
+
   test("Predicted class probabilities: calibration on toy dataset") {
     val layers = Array[Int](4, 5, 2)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 49115c8..5f9ab98 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -167,6 +167,28 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
       Vector, NaiveBayesModel](this, model, testDataset)
   }
 
+  test("prediction on single instance") {
+    val nPoints = 1000
+    val piArray = Array(0.5, 0.1, 0.4).map(math.log)
+    val thetaArray = Array(
+      Array(0.70, 0.10, 0.10, 0.10), // label 0
+      Array(0.10, 0.70, 0.10, 0.10), // label 1
+      Array(0.10, 0.10, 0.70, 0.10)  // label 2
+    ).map(_.map(math.log))
+    val pi = Vectors.dense(piArray)
+    val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
+
+    val trainDataset =
+      generateNaiveBayesInput(piArray, thetaArray, nPoints, seed, "multinomial").toDF()
+    val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
+    val model = nb.fit(trainDataset)
+
+    val validationDataset =
+      generateNaiveBayesInput(piArray, thetaArray, nPoints, 17, "multinomial").toDF()
+
+    testPredictionModelSinglePrediction(model, validationDataset)
+  }
+
   test("Naive Bayes with weighted samples") {
     val numClasses = 3
     def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index 02a9d5c..ba4a9cf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -155,6 +155,22 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest
{
       Vector, RandomForestClassificationModel](this, model, df)
   }
 
+  test("prediction on single instance") {
+    val rdd = orderedLabeledPoints5_20
+    val rf = new RandomForestClassifier()
+      .setImpurity("Gini")
+      .setMaxDepth(3)
+      .setNumTrees(3)
+      .setSeed(123)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val numClasses = 2
+
+    val df: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+    val model = rf.fit(df)
+
+    testPredictionModelSinglePrediction(model, df)
+  }
+
   test("Fitting without numClasses in metadata") {
     val df: DataFrame = TreeTests.featureImportanceData(sc).toDF()
     val rf = new RandomForestClassifier().setMaxDepth(1).setNumTrees(1)

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 68a1218..29a4383 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -136,6 +136,21 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest
{
     assert(importances.toArray.forall(_ >= 0.0))
   }
 
+  test("prediction on single instance") {
+    val dt = new DecisionTreeRegressor()
+      .setImpurity("variance")
+      .setMaxDepth(3)
+      .setSeed(123)
+
+    // In this data, feature 1 is very important.
+    val data: RDD[LabeledPoint] = TreeTests.featureImportanceData(sc)
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0)
+
+    val model = dt.fit(df)
+    testPredictionModelSinglePrediction(model, df)
+  }
+
   test("should support all NumericType labels and not support other types") {
     val dt = new DecisionTreeRegressor().setMaxDepth(1)
     MLTestingUtils.checkNumericTypes[DecisionTreeRegressionModel, DecisionTreeRegressor](

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 11c593b..fad11d0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -99,6 +99,14 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
     }
   }
 
+  test("prediction on single instance") {
+    val gbt = new GBTRegressor()
+      .setMaxDepth(2)
+      .setMaxIter(2)
+    val model = gbt.fit(trainData.toDF())
+    testPredictionModelSinglePrediction(model, validationData.toDF)
+  }
+
   test("Checkpointing") {
     val tempDir = Utils.createTempDir()
     val path = tempDir.toURI.toString

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index ef2ff94..d5bcbb2 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -211,6 +211,14 @@ class GeneralizedLinearRegressionSuite extends MLTest with DefaultReadWriteTest
     assert(model.getLink === "identity")
   }
 
+  test("prediction on single instance") {
+    val glr = new GeneralizedLinearRegression
+    val model = glr.setFamily("gaussian").setLink("identity")
+      .fit(datasetGaussianIdentity)
+
+    testPredictionModelSinglePrediction(model, datasetGaussianIdentity)
+  }
+
   test("generalized linear regression: gaussian family against glm") {
     /*
        R code:

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index d42cb17..9b19f63 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -636,6 +636,13 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest
{
     }
   }
 
+  test("prediction on single instance") {
+    val trainer = new LinearRegression
+    val model = trainer.fit(datasetWithDenseFeature)
+
+    testPredictionModelSinglePrediction(model, datasetWithDenseFeature)
+  }
+
   test("linear regression model with constant label") {
     /*
        R code:

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 8b8e8a6..e83c49f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -19,22 +19,22 @@ package org.apache.spark.ml.regression
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.feature.LabeledPoint
+import org.apache.spark.ml.linalg.Vector
 import org.apache.spark.ml.tree.impl.TreeTests
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
 import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row}
 
 /**
  * Test suite for [[RandomForestRegressor]].
  */
-class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
-  with DefaultReadWriteTest{
+class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
 
   import RandomForestRegressorSuite.compareAPIs
+  import testImplicits._
 
   private var orderedLabeledPoints50_1000: RDD[LabeledPoint] = _
 
@@ -74,6 +74,20 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
     regressionTestWithContinuousFeatures(rf)
   }
 
+  test("prediction on single instance") {
+    val rf = new RandomForestRegressor()
+      .setImpurity("variance")
+      .setMaxDepth(2)
+      .setMaxBins(10)
+      .setNumTrees(1)
+      .setFeatureSubsetStrategy("auto")
+      .setSeed(123)
+
+    val df = orderedLabeledPoints50_1000.toDF()
+    val model = rf.fit(df)
+    testPredictionModelSinglePrediction(model, df)
+  }
+
   test("Feature importance with toy data") {
     val rf = new RandomForestRegressor()
       .setImpurity("variance")

http://git-wip-us.apache.org/repos/asf/spark/blob/bf09f2f7/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
index 795fd0e..76d41f9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -22,8 +22,9 @@ import java.io.File
 import org.scalatest.Suite
 
 import org.apache.spark.SparkContext
-import org.apache.spark.ml.Transformer
-import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.ml.{PredictionModel, Transformer}
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row}
 import org.apache.spark.sql.execution.streaming.MemoryStream
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.streaming.StreamTest
@@ -136,4 +137,14 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite =>
       assert(hasExpectedMessage(exceptionOnStreamData))
     }
   }
+
+  def testPredictionModelSinglePrediction(model: PredictionModel[Vector, _],
+    dataset: Dataset[_]): Unit = {
+
+    model.transform(dataset).select(model.getFeaturesCol, model.getPredictionCol)
+      .collect().foreach {
+      case Row(features: Vector, prediction: Double) =>
+        assert(prediction === model.predict(features))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message