spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
Date Wed, 23 Sep 2015 22:00:57 GMT
Repository: spark
Updated Branches:
  refs/heads/master a18208047 -> 098be27ad


[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types

All prediction models should store `numFeatures` indicating the number of features the model
was trained on. Default value of -1 added for backwards compatibility.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #8675 from sethah/SPARK-9715.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/098be27a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/098be27a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/098be27a

Branch: refs/heads/master
Commit: 098be27ad53c485ee2fc7f5871c47f899020e87b
Parents: a182080
Author: sethah <seth.hendrickson16@gmail.com>
Authored: Wed Sep 23 15:00:52 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Wed Sep 23 15:00:52 2015 -0700

----------------------------------------------------------------------
 .../examples/ml/JavaDeveloperApiExample.java    |  5 ++++
 .../spark/examples/ml/DeveloperApiExample.scala |  3 +++
 .../scala/org/apache/spark/ml/Predictor.scala   |  6 ++++-
 .../classification/DecisionTreeClassifier.scala | 13 ++++++----
 .../spark/ml/classification/GBTClassifier.scala | 26 ++++++++++++++------
 .../ml/classification/LogisticRegression.scala  |  2 ++
 .../MultilayerPerceptronClassifier.scala        |  2 ++
 .../spark/ml/classification/NaiveBayes.scala    |  2 ++
 .../classification/RandomForestClassifier.scala |  8 +++---
 .../ml/regression/DecisionTreeRegressor.scala   | 13 ++++++----
 .../spark/ml/regression/GBTRegressor.scala      | 24 +++++++++++++-----
 .../spark/ml/regression/LinearRegression.scala  |  2 ++
 .../ml/regression/RandomForestRegressor.scala   |  7 +++---
 .../spark/ml/tree/impl/RandomForest.scala       | 14 ++++++++---
 .../DecisionTreeClassifierSuite.scala           |  4 ++-
 .../ml/classification/GBTClassifierSuite.scala  | 11 ++++++---
 .../LogisticRegressionSuite.scala               |  2 ++
 .../MultilayerPerceptronClassifierSuite.scala   |  4 ++-
 .../ProbabilisticClassifierSuite.scala          |  6 +++--
 .../RandomForestClassifierSuite.scala           |  8 +++---
 .../regression/DecisionTreeRegressorSuite.scala |  2 ++
 .../spark/ml/regression/GBTRegressorSuite.scala |  7 ++++--
 .../ml/regression/LinearRegressionSuite.scala   |  4 ++-
 .../regression/RandomForestRegressorSuite.scala |  2 ++
 .../spark/ml/tree/impl/RandomForestSuite.scala  |  3 ++-
 25 files changed, 130 insertions(+), 50 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index a377694..0b4c0d9 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -220,6 +220,11 @@ class MyJavaLogisticRegressionModel
   public int numClasses() { return 2; }
 
   /**
+   * Number of features the model was trained on.
+   */
+  public int numFeatures() { return weights_.size(); }
+
+  /**
    * Create a copy of the model.
    * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
    * <p>

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 340c355..3758edc 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -172,6 +172,9 @@ private class MyLogisticRegressionModel(
   /** Number of classes the label can take.  2 indicates binary classification. */
   override val numClasses: Int = 2
 
+  /** Number of features the model was trained on. */
+  override val numFeatures: Int = weights.size
+
   /**
    * Create a copy of the model.
    * The copy is shallow, except for the embedded paramMap, which gets a deep copy.

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 19fe039..e0dcd42 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.ml
 
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
@@ -145,6 +145,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
   /** @group setParam */
   def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
 
+  /** Returns the number of features the model was trained on. If unknown, returns -1 */
+  @Since("1.6.0")
+  def numFeatures: Int = -1
+
   /**
    * Returns the SQL DataType corresponding to the FeaturesType type parameter.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index b8eb49f..a6f6d46 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -107,6 +107,7 @@ object DecisionTreeClassifier {
 final class DecisionTreeClassificationModel private[ml] (
     override val uid: String,
     override val rootNode: Node,
+    override val numFeatures: Int,
     override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
   with DecisionTreeModel with Serializable {
@@ -118,8 +119,8 @@ final class DecisionTreeClassificationModel private[ml] (
    * Construct a decision tree classification model.
    * @param rootNode  Root node of tree, with other nodes attached.
    */
-  private[ml] def this(rootNode: Node, numClasses: Int) =
-    this(Identifiable.randomUID("dtc"), rootNode, numClasses)
+  private[ml] def this(rootNode: Node, numFeatures: Int, numClasses: Int) =
+    this(Identifiable.randomUID("dtc"), rootNode, numFeatures, numClasses)
 
   override protected def predict(features: Vector): Double = {
     rootNode.predictImpl(features).prediction
@@ -141,7 +142,7 @@ final class DecisionTreeClassificationModel private[ml] (
   }
 
   override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
-    copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
+    copyValues(new DecisionTreeClassificationModel(uid, rootNode, numFeatures, numClasses),
extra)
       .setParent(parent)
   }
 
@@ -161,12 +162,14 @@ private[ml] object DecisionTreeClassificationModel {
   def fromOld(
       oldModel: OldDecisionTreeModel,
       parent: DecisionTreeClassifier,
-      categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+      categoricalFeatures: Map[Int, Int],
+      numFeatures: Int = -1): DecisionTreeClassificationModel = {
     require(oldModel.algo == OldAlgo.Classification,
       s"Cannot convert non-classification DecisionTreeModel (old API) to" +
         s" DecisionTreeClassificationModel (new API).  Algo is: ${oldModel.algo}")
     val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
-    new DecisionTreeClassificationModel(uid, rootNode, -1)
+    // Can't infer number of features from old model, so default to -1
+    new DecisionTreeClassificationModel(uid, rootNode, numFeatures, -1)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index ad86836..74aef94 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss}
 import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{Row, DataFrame}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.DoubleType
 
@@ -138,10 +138,11 @@ final class GBTClassifier(override val uid: String)
     require(numClasses == 2,
       s"GBTClassifier only supports binary classification but was given numClasses = $numClasses")
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+    val numFeatures = oldDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
     val oldGBT = new OldGBT(boostingStrategy)
     val oldModel = oldGBT.run(oldDataset)
-    GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
+    GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
   }
 
   override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
@@ -164,10 +165,11 @@ object GBTClassifier {
  * @param _treeWeights  Weights for the decision trees in the ensemble.
  */
 @Experimental
-final class GBTClassificationModel(
+final class GBTClassificationModel private[ml](
     override val uid: String,
     private val _trees: Array[DecisionTreeRegressionModel],
-    private val _treeWeights: Array[Double])
+    private val _treeWeights: Array[Double],
+    override val numFeatures: Int)
   extends PredictionModel[Vector, GBTClassificationModel]
   with TreeEnsembleModel with Serializable {
 
@@ -175,6 +177,14 @@ final class GBTClassificationModel(
   require(_trees.length == _treeWeights.length, "GBTClassificationModel given trees, treeWeights"
+
     s" of non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
 
+  /**
+   * Construct a GBTClassificationModel
+   * @param _trees  Decision trees in the ensemble.
+   * @param _treeWeights  Weights for the decision trees in the ensemble.
+   */
+  def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double])
=
+    this(uid, _trees, _treeWeights, -1)
+
   override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
 
   override def treeWeights: Array[Double] = _treeWeights
@@ -196,7 +206,8 @@ final class GBTClassificationModel(
   }
 
   override def copy(extra: ParamMap): GBTClassificationModel = {
-    copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent)
+    copyValues(new GBTClassificationModel(uid, _trees, _treeWeights, numFeatures),
+      extra).setParent(parent)
   }
 
   override def toString: String = {
@@ -215,7 +226,8 @@ private[ml] object GBTClassificationModel {
   def fromOld(
       oldModel: OldGBTModel,
       parent: GBTClassifier,
-      categoricalFeatures: Map[Int, Int]): GBTClassificationModel = {
+      categoricalFeatures: Map[Int, Int],
+      numFeatures: Int = -1): GBTClassificationModel = {
     require(oldModel.algo == OldAlgo.Classification, "Cannot convert GradientBoostedTreesModel"
+
       s" with algo=${oldModel.algo} (old API) to GBTClassificationModel (new API).")
     val newTrees = oldModel.trees.map { tree =>
@@ -223,6 +235,6 @@ private[ml] object GBTClassificationModel {
       DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
     }
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtc")
-    new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights)
+    new GBTClassificationModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index bd96e8d..c17a7b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -426,6 +426,8 @@ class LogisticRegressionModel private[ml] (
     1.0 / (1.0 + math.exp(-m))
   }
 
+  override val numFeatures: Int = weights.size
+
   override val numClasses: Int = 2
 
   private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 5f60dea..cd74625 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -181,6 +181,8 @@ class MultilayerPerceptronClassificationModel private[ml] (
   extends PredictionModel[Vector, MultilayerPerceptronClassificationModel]
   with Serializable {
 
+  override val numFeatures: Int = layers.head
+
   private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 082ea1f..a14dcec 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -137,6 +137,8 @@ class NaiveBayesModel private[ml] (
       throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
   }
 
+  override val numFeatures: Int = theta.numCols
+
   override val numClasses: Int = pi.size
 
   private def multinomialCalculation(features: Vector) = {

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index a6ebee1..bae3296 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -119,13 +119,12 @@ object RandomForestClassifier {
  * features.
  * @param _trees  Decision trees in the ensemble.
  *               Warning: These have null parents.
- * @param numFeatures  Number of features used by this model
  */
 @Experimental
 final class RandomForestClassificationModel private[ml] (
     override val uid: String,
     private val _trees: Array[DecisionTreeClassificationModel],
-    val numFeatures: Int,
+    override val numFeatures: Int,
     override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel]
   with TreeEnsembleModel with Serializable {
@@ -226,7 +225,8 @@ private[ml] object RandomForestClassificationModel {
       oldModel: OldRandomForestModel,
       parent: RandomForestClassifier,
       categoricalFeatures: Map[Int, Int],
-      numClasses: Int): RandomForestClassificationModel = {
+      numClasses: Int,
+      numFeatures: Int = -1): RandomForestClassificationModel = {
     require(oldModel.algo == OldAlgo.Classification, "Cannot convert RandomForestModel" +
       s" with algo=${oldModel.algo} (old API) to RandomForestClassificationModel (new API).")
     val newTrees = oldModel.trees.map { tree =>
@@ -234,6 +234,6 @@ private[ml] object RandomForestClassificationModel {
       DecisionTreeClassificationModel.fromOld(tree, null, categoricalFeatures)
     }
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("rfc")
-    new RandomForestClassificationModel(uid, newTrees, -1, numClasses)
+    new RandomForestClassificationModel(uid, newTrees, numFeatures, numClasses)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index d9a244b..88b79a4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -96,7 +96,8 @@ object DecisionTreeRegressor {
 @Experimental
 final class DecisionTreeRegressionModel private[ml] (
     override val uid: String,
-    override val rootNode: Node)
+    override val rootNode: Node,
+    override val numFeatures: Int)
   extends PredictionModel[Vector, DecisionTreeRegressionModel]
   with DecisionTreeModel with Serializable {
 
@@ -107,14 +108,15 @@ final class DecisionTreeRegressionModel private[ml] (
    * Construct a decision tree regression model.
    * @param rootNode  Root node of tree, with other nodes attached.
    */
-  private[ml] def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
+  private[ml] def this(rootNode: Node, numFeatures: Int) =
+    this(Identifiable.randomUID("dtr"), rootNode, numFeatures)
 
   override protected def predict(features: Vector): Double = {
     rootNode.predictImpl(features).prediction
   }
 
   override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
-    copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent)
+    copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent)
   }
 
   override def toString: String = {
@@ -133,12 +135,13 @@ private[ml] object DecisionTreeRegressionModel {
   def fromOld(
       oldModel: OldDecisionTreeModel,
       parent: DecisionTreeRegressor,
-      categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
+      categoricalFeatures: Map[Int, Int],
+      numFeatures: Int = -1): DecisionTreeRegressionModel = {
     require(oldModel.algo == OldAlgo.Regression,
       s"Cannot convert non-regression DecisionTreeModel (old API) to" +
         s" DecisionTreeRegressionModel (new API).  Algo is: ${oldModel.algo}")
     val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtr")
-    new DecisionTreeRegressionModel(uid, rootNode)
+    new DecisionTreeRegressionModel(uid, rootNode, numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index d841ecb..65b5b3e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -128,10 +128,11 @@ final class GBTRegressor(override val uid: String)
     val categoricalFeatures: Map[Int, Int] =
       MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
     val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
+    val numFeatures = oldDataset.first().features.size
     val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
     val oldGBT = new OldGBT(boostingStrategy)
     val oldModel = oldGBT.run(oldDataset)
-    GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures)
+    GBTRegressionModel.fromOld(oldModel, this, categoricalFeatures, numFeatures)
   }
 
   override def copy(extra: ParamMap): GBTRegressor = defaultCopy(extra)
@@ -154,10 +155,11 @@ object GBTRegressor {
  * @param _treeWeights  Weights for the decision trees in the ensemble.
  */
 @Experimental
-final class GBTRegressionModel(
+final class GBTRegressionModel private[ml](
     override val uid: String,
     private val _trees: Array[DecisionTreeRegressionModel],
-    private val _treeWeights: Array[Double])
+    private val _treeWeights: Array[Double],
+    override val numFeatures: Int)
   extends PredictionModel[Vector, GBTRegressionModel]
   with TreeEnsembleModel with Serializable {
 
@@ -165,6 +167,14 @@ final class GBTRegressionModel(
   require(_trees.length == _treeWeights.length, "GBTRegressionModel given trees, treeWeights
of" +
     s" non-matching lengths (${_trees.length}, ${_treeWeights.length}, respectively).")
 
+  /**
+   * Construct a GBTRegressionModel
+   * @param _trees  Decision trees in the ensemble.
+   * @param _treeWeights  Weights for the decision trees in the ensemble.
+   */
+  def this(uid: String, _trees: Array[DecisionTreeRegressionModel], _treeWeights: Array[Double])
=
+    this(uid, _trees, _treeWeights, -1)
+
   override def trees: Array[DecisionTreeModel] = _trees.asInstanceOf[Array[DecisionTreeModel]]
 
   override def treeWeights: Array[Double] = _treeWeights
@@ -185,7 +195,8 @@ final class GBTRegressionModel(
   }
 
   override def copy(extra: ParamMap): GBTRegressionModel = {
-    copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent)
+    copyValues(new GBTRegressionModel(uid, _trees, _treeWeights, numFeatures),
+      extra).setParent(parent)
   }
 
   override def toString: String = {
@@ -204,7 +215,8 @@ private[ml] object GBTRegressionModel {
   def fromOld(
       oldModel: OldGBTModel,
       parent: GBTRegressor,
-      categoricalFeatures: Map[Int, Int]): GBTRegressionModel = {
+      categoricalFeatures: Map[Int, Int],
+      numFeatures: Int = -1): GBTRegressionModel = {
     require(oldModel.algo == OldAlgo.Regression, "Cannot convert GradientBoostedTreesModel"
+
       s" with algo=${oldModel.algo} (old API) to GBTRegressionModel (new API).")
     val newTrees = oldModel.trees.map { tree =>
@@ -212,6 +224,6 @@ private[ml] object GBTRegressionModel {
       DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
     }
     val uid = if (parent != null) parent.uid else Identifiable.randomUID("gbtr")
-    new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights)
+    new GBTRegressionModel(parent.uid, newTrees, oldModel.treeWeights, numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 78a67c5..a77e702 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -293,6 +293,8 @@ class LinearRegressionModel private[ml] (
 
   private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
 
+  override val numFeatures: Int = weights.size
+
   /**
    * Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception
is
    * thrown if `trainingSummary == None`.

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index ddb7214..64fc172 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -115,7 +115,7 @@ object RandomForestRegressor {
 final class RandomForestRegressionModel private[ml] (
     override val uid: String,
     private val _trees: Array[DecisionTreeRegressionModel],
-    val numFeatures: Int)
+    override val numFeatures: Int)
   extends PredictionModel[Vector, RandomForestRegressionModel]
   with TreeEnsembleModel with Serializable {
 
@@ -187,13 +187,14 @@ private[ml] object RandomForestRegressionModel {
   def fromOld(
       oldModel: OldRandomForestModel,
       parent: RandomForestRegressor,
-      categoricalFeatures: Map[Int, Int]): RandomForestRegressionModel = {
+      categoricalFeatures: Map[Int, Int],
+      numFeatures: Int = -1): RandomForestRegressionModel = {
     require(oldModel.algo == OldAlgo.Regression, "Cannot convert RandomForestModel" +
       s" with algo=${oldModel.algo} (old API) to RandomForestRegressionModel (new API).")
     val newTrees = oldModel.trees.map { tree =>
       // parent for each tree is null since there is no good way to set this.
       DecisionTreeRegressionModel.fromOld(tree, null, categoricalFeatures)
     }
-    new RandomForestRegressionModel(parent.uid, newTrees, -1)
+    new RandomForestRegressionModel(parent.uid, newTrees, numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 4ac51a4..c494556 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -179,22 +179,28 @@ private[ml] object RandomForest extends Logging {
       }
     }
 
+    val numFeatures = metadata.numFeatures
+
     parentUID match {
       case Some(uid) =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
+            new DecisionTreeClassificationModel(uid, rootNode.toNode, numFeatures,
+              strategy.getNumClasses)
           }
         } else {
-          topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
+          topNodes.map { rootNode =>
+            new DecisionTreeRegressionModel(uid, rootNode.toNode, numFeatures)
+          }
         }
       case None =>
         if (strategy.algo == OldAlgo.Classification) {
           topNodes.map { rootNode =>
-            new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
+            new DecisionTreeClassificationModel(rootNode.toNode, numFeatures,
+              strategy.getNumClasses)
           }
         } else {
-          topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
+          topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode, numFeatures))
         }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index f680d8d..815f6fd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -59,7 +59,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
 
   test("params") {
     ParamsSuite.checkParams(new DecisionTreeClassifier)
-    val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null),
2)
+    val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null),
1, 2)
     ParamsSuite.checkParams(model)
   }
 
@@ -310,6 +310,7 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
       dt: DecisionTreeClassifier,
       categoricalFeatures: Map[Int, Int],
       numClasses: Int): Unit = {
+    val numFeatures = data.first().features.size
     val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
     val oldTree = OldDecisionTree.train(data, oldStrategy)
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
@@ -318,5 +319,6 @@ private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
     val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(
       oldTree, newTree.parent.asInstanceOf[DecisionTreeClassifier], categoricalFeatures)
     TreeTests.checkEqual(oldTreeAsNew, newTree)
+    assert(newTree.numFeatures === numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index e3909bc..039141a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -59,8 +59,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
{
   test("params") {
     ParamsSuite.checkParams(new GBTClassifier)
     val model = new GBTClassificationModel("gbtc",
-      Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
-      Array(1.0))
+      Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null), 1)),
+      Array(1.0), 1)
     ParamsSuite.checkParams(model)
   }
 
@@ -145,7 +145,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext
{
   */
 }
 
-private object GBTClassifierSuite {
+private object GBTClassifierSuite extends SparkFunSuite {
 
   /**
    * Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -156,6 +156,7 @@ private object GBTClassifierSuite {
       validationData: Option[RDD[LabeledPoint]],
       gbt: GBTClassifier,
       categoricalFeatures: Map[Int, Int]): Unit = {
+    val numFeatures = data.first().features.size
     val oldBoostingStrategy =
       gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
     val oldGBT = new OldGBT(oldBoostingStrategy)
@@ -164,7 +165,9 @@ private object GBTClassifierSuite {
     val newModel = gbt.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = GBTClassificationModel.fromOld(
-      oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures)
+      oldModel, newModel.parent.asInstanceOf[GBTClassifier], categoricalFeatures, numFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
+    assert(newModel.numFeatures === numFeatures)
+    assert(oldModelAsNew.numFeatures === numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index f5219f9..ec01998 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -194,6 +194,8 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext
{
 
     val model = lr.fit(dataset)
     assert(model.numClasses === 2)
+    val numFeatures = dataset.select("features").first().getAs[Vector](0).size
+    assert(model.numFeatures === numFeatures)
 
     val threshold = model.getThreshold
     val results = model.transform(dataset)

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index ddc948f..2d1df9b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.classification.LogisticRegressionSuite._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
 import org.apache.spark.mllib.evaluation.MulticlassMetrics
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.Row
@@ -73,6 +73,8 @@ class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSp
       .setSeed(11L)
       .setMaxIter(numIterations)
     val model = trainer.fit(dataFrame)
+    val numFeatures = dataFrame.select("features").first().getAs[Vector](0).size
+    assert(model.numFeatures === numFeatures)
     val mlpPredictionAndLabels = model.transform(dataFrame).select("prediction", "label")
       .map { case Row(p: Double, l: Double) => (p, l) }
     // train multinomial logistic regression

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
index 8f50cb9..fb5f00e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
 
 final class TestProbabilisticClassificationModel(
     override val uid: String,
+    override val numFeatures: Int,
     override val numClasses: Int)
   extends ProbabilisticClassificationModel[Vector, TestProbabilisticClassificationModel]
{
 
@@ -45,13 +46,14 @@ class ProbabilisticClassifierSuite extends SparkFunSuite {
 
   test("test thresholding") {
     val thresholds = Array(0.5, 0.2)
-    val testModel = new TestProbabilisticClassificationModel("myuid", 2).setThresholds(thresholds)
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
+      .setThresholds(thresholds)
     assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 1.0))) === 1.0)
     assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 0.2))) === 0.0)
   }
 
   test("test thresholding not required") {
-    val testModel = new TestProbabilisticClassificationModel("myuid", 2)
+    val testModel = new TestProbabilisticClassificationModel("myuid", 2, 2)
     assert(testModel.friendlyPredict(Vectors.dense(Array(1.0, 2.0))) === 1.0)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index b4403ec..deb8ec7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -68,7 +68,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
   test("params") {
     ParamsSuite.checkParams(new RandomForestClassifier)
     val model = new RandomForestClassificationModel("rfc",
-      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)),
2, 2)
+      Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 1, 2)),
2, 2)
     ParamsSuite.checkParams(model)
   }
 
@@ -209,7 +209,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
   */
 }
 
-private object RandomForestClassifierSuite {
+private object RandomForestClassifierSuite extends SparkFunSuite {
 
   /**
    * Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -220,6 +220,7 @@ private object RandomForestClassifierSuite {
       rf: RandomForestClassifier,
       categoricalFeatures: Map[Int, Int],
       numClasses: Int): Unit = {
+    val numFeatures = data.first().features.size
     val oldStrategy =
       rf.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, rf.getOldImpurity)
     val oldModel = OldRandomForest.trainClassifier(
@@ -233,6 +234,7 @@ private object RandomForestClassifierSuite {
     TreeTests.checkEqual(oldModelAsNew, newModel)
     assert(newModel.hasParent)
     assert(!newModel.trees.head.asInstanceOf[DecisionTreeClassificationModel].hasParent)
-    assert(newModel.numClasses == numClasses)
+    assert(newModel.numClasses === numClasses)
+    assert(newModel.numFeatures === numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index b092bcd..868fb8e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -89,6 +89,7 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
       data: RDD[LabeledPoint],
       dt: DecisionTreeRegressor,
       categoricalFeatures: Map[Int, Int]): Unit = {
+    val numFeatures = data.first().features.size
     val oldStrategy = dt.getOldStrategy(categoricalFeatures)
     val oldTree = OldDecisionTree.train(data, oldStrategy)
     val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses
= 0)
@@ -97,5 +98,6 @@ private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
     val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(
       oldTree, newTree.parent.asInstanceOf[DecisionTreeRegressor], categoricalFeatures)
     TreeTests.checkEqual(oldTreeAsNew, newTree)
+    assert(newTree.numFeatures === numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index a68197b..0932660 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -156,7 +156,7 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext
{
   */
 }
 
-private object GBTRegressorSuite {
+private object GBTRegressorSuite extends SparkFunSuite {
 
   /**
    * Train 2 models on the given dataset, one using the old API and one using the new API.
@@ -167,6 +167,7 @@ private object GBTRegressorSuite {
       validationData: Option[RDD[LabeledPoint]],
       gbt: GBTRegressor,
       categoricalFeatures: Map[Int, Int]): Unit = {
+    val numFeatures = data.first().features.size
     val oldBoostingStrategy = gbt.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Regression)
     val oldGBT = new OldGBT(oldBoostingStrategy)
     val oldModel = oldGBT.run(data)
@@ -174,7 +175,9 @@ private object GBTRegressorSuite {
     val newModel = gbt.fit(newData)
     // Use parent from newTree since this is not checked anyways.
     val oldModelAsNew = GBTRegressionModel.fromOld(
-      oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures)
+      oldModel, newModel.parent.asInstanceOf[GBTRegressor], categoricalFeatures, numFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
+    assert(newModel.numFeatures === numFeatures)
+    assert(oldModelAsNew.numFeatures === numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 8428f4f..7cb9471 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -22,8 +22,8 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors}
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.mllib.util.TestingUtils._
 import org.apache.spark.sql.{DataFrame, Row}
@@ -87,6 +87,8 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext
{
     assert(model.getPredictionCol === "prediction")
     assert(model.intercept !== 0.0)
     assert(model.hasParent)
+    val numFeatures = dataset.select("features").first().getAs[Vector](0).size
+    assert(model.numFeatures === numFeatures)
   }
 
   test("linear regression with intercept without regularization") {

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
index 7b1b3f1..7e751e4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala
@@ -137,6 +137,7 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
       data: RDD[LabeledPoint],
       rf: RandomForestRegressor,
       categoricalFeatures: Map[Int, Int]): Unit = {
+    val numFeatures = data.first().features.size
     val oldStrategy =
       rf.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, rf.getOldImpurity)
     val oldModel = OldRandomForest.trainRegressor(
@@ -147,5 +148,6 @@ private object RandomForestRegressorSuite extends SparkFunSuite {
     val oldModelAsNew = RandomForestRegressionModel.fromOld(
       oldModel, newModel.parent.asInstanceOf[RandomForestRegressor], categoricalFeatures)
     TreeTests.checkEqual(oldModelAsNew, newModel)
+    assert(newModel.numFeatures === numFeatures)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/098be27a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index dc85279..d5c238e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -77,7 +77,8 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext
{
 
     // Forest consisting of (full tree) + (internal node with 2 leafs)
     val trees = Array(parent, grandParent).map { root =>
-      new DecisionTreeClassificationModel(root, numClasses = 3).asInstanceOf[DecisionTreeModel]
+      new DecisionTreeClassificationModel(root, numFeatures = 2, numClasses = 3)
+        .asInstanceOf[DecisionTreeModel]
     }
     val importances: Vector = RandomForest.featureImportances(trees, 2)
     val tree2norm = feature0importance + feature1importance


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message