spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject [2/2] spark git commit: [SPARK-5995] [ML] Make Prediction dev API public
Date Wed, 06 May 2015 23:15:56 GMT
[SPARK-5995] [ML] Make Prediction dev API public

Changes:
* Update protected prediction methods, following design doc. **<--most interesting change**
* Changed abstract classes for Estimator and Model to be public.  Added DeveloperApi tag.  (I kept the traits for Estimator/Model Params private.)
* Changed ProbabilisticClassificationModel method names to use probability instead of probabilities.

CC: mengxr shivaram etrain

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5913 from jkbradley/public-dev-api and squashes the following commits:

e9aa0ea [Joseph K. Bradley] moved findMax to DenseVector and renamed to argmax. fixed bug for vector of length 0
15b9957 [Joseph K. Bradley] renamed probabilities to probability in method names
5cda84d [Joseph K. Bradley] regenerated sharedParams
7d1877a [Joseph K. Bradley] Made spark.ml prediction abstractions public.  Organized their prediction methods for efficient computation of multiple output columns.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1ad04dae
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1ad04dae
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1ad04dae

Branch: refs/heads/master
Commit: 1ad04dae038673a448f529c39b17817b78d6acd0
Parents: 7740996
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Wed May 6 16:15:51 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed May 6 16:15:51 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Predictor.scala   | 191 ++++++++
 .../spark/ml/classification/Classifier.scala    | 110 ++---
 .../classification/DecisionTreeClassifier.scala |   5 +-
 .../spark/ml/classification/GBTClassifier.scala |   5 +-
 .../ml/classification/LogisticRegression.scala  | 100 ++---
 .../ProbabilisticClassifier.scala               | 100 +++--
 .../classification/RandomForestClassifier.scala |   5 +-
 .../spark/ml/impl/estimator/Predictor.scala     | 217 ----------
 .../apache/spark/ml/impl/tree/treeParams.scala  | 431 -------------------
 .../ml/param/shared/SharedParamsCodeGen.scala   |   6 +-
 .../spark/ml/param/shared/sharedParams.scala    |   4 +-
 .../ml/regression/DecisionTreeRegressor.scala   |   5 +-
 .../spark/ml/regression/GBTRegressor.scala      |   5 +-
 .../spark/ml/regression/LinearRegression.scala  |   5 +-
 .../ml/regression/RandomForestRegressor.scala   |   5 +-
 .../apache/spark/ml/regression/Regressor.scala  |  42 +-
 .../org/apache/spark/ml/tree/treeParams.scala   | 431 +++++++++++++++++++
 .../org/apache/spark/mllib/linalg/Vectors.scala |  22 +
 18 files changed, 814 insertions(+), 875 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
new file mode 100644
index 0000000..0e53877
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+import org.apache.spark.sql.{DataFrame, Row}
+
+/**
+ * (private[ml])  Trait for parameters for prediction (regression and classification).
+ */
+private[ml] trait PredictorParams extends Params
+  with HasLabelCol with HasFeaturesCol with HasPredictionCol {
+
+  /**
+   * Validates and transforms the input schema with the provided param map.
+   * @param schema input schema
+   * @param fitting whether this is in fitting
+   * @param featuresDataType  SQL DataType for FeaturesType.
+   *                          E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+   * @return output schema
+   */
+  protected def validateAndTransformSchema(
+      schema: StructType,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
+    // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
+    SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
+    if (fitting) {
+      // TODO: Allow other numeric types
+      SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+    }
+    SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Abstraction for prediction problems (regression and classification).
+ *
+ * @tparam FeaturesType  Type of features.
+ *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam Learner  Specialization of this class.  If you subclass this type, use this type
+ *                  parameter to specify the concrete type.
+ * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
+ *            parameter to specify the concrete type for the corresponding model.
+ */
+@DeveloperApi
+abstract class Predictor[
+    FeaturesType,
+    Learner <: Predictor[FeaturesType, Learner, M],
+    M <: PredictionModel[FeaturesType, M]]
+  extends Estimator[M] with PredictorParams {
+
+  /** @group setParam */
+  def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+
+  /** @group setParam */
+  def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+
+  /** @group setParam */
+  def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
+
+  override def fit(dataset: DataFrame): M = {
+    // This handles a few items such as schema validation.
+    // Developers only need to implement train().
+    transformSchema(dataset.schema, logging = true)
+    copyValues(train(dataset))
+  }
+
+  override def copy(extra: ParamMap): Learner = {
+    super.copy(extra).asInstanceOf[Learner]
+  }
+
+  /**
+   * Train a model using the given dataset and parameters.
+   * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
+   * and copying parameters into the model.
+   *
+   * @param dataset  Training dataset
+   * @return  Fitted model
+   */
+  protected def train(dataset: DataFrame): M
+
+  /**
+   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+   *
+   * This is used by [[validateAndTransformSchema()]].
+   * This workaround is needed since SQL has different APIs for Scala and Java.
+   *
+   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+   */
+  protected def featuresDataType: DataType = new VectorUDT
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema, fitting = true, featuresDataType)
+  }
+
+  /**
+   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+   * and put it in an RDD with strong types.
+   */
+  protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
+    dataset.select($(labelCol), $(featuresCol))
+      .map { case Row(label: Double, features: Vector) =>
+      LabeledPoint(label, features)
+    }
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Abstraction for a model for prediction tasks (regression and classification).
+ *
+ * @tparam FeaturesType  Type of features.
+ *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
+ *            parameter to specify the concrete type for the corresponding model.
+ */
+@DeveloperApi
+abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
+  extends Model[M] with PredictorParams {
+
+  /** @group setParam */
+  def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+
+  /** @group setParam */
+  def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
+
+  /**
+   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+   *
+   * This is used by [[validateAndTransformSchema()]].
+   * This workaround is needed since SQL has different APIs for Scala and Java.
+   *
+   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+   */
+  protected def featuresDataType: DataType = new VectorUDT
+
+  override def transformSchema(schema: StructType): StructType = {
+    validateAndTransformSchema(schema, fitting = false, featuresDataType)
+  }
+
+  /**
+   * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing
+   * the predictions as a new column [[predictionCol]].
+   *
+   * @param dataset input dataset
+   * @return transformed dataset with [[predictionCol]] of type [[Double]]
+   */
+  override def transform(dataset: DataFrame): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
+    if ($(predictionCol).nonEmpty) {
+      dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
+    } else {
+      this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
+        " since no output columns were set.")
+      dataset
+    }
+  }
+
+  /**
+   * Predict label for the given features.
+   * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+   */
+  protected def predict(features: FeaturesType): Double
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index d3361e2..263d580 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -17,8 +17,8 @@
 
 package org.apache.spark.ml.classification
 
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
 import org.apache.spark.ml.param.shared.HasRawPredictionCol
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
@@ -26,15 +26,12 @@ import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
 
+
 /**
- * :: DeveloperApi ::
- * Params for classification.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ * (private[spark]) Params for classification.
  */
-@DeveloperApi
-private[spark] trait ClassifierParams extends PredictorParams
-  with HasRawPredictionCol {
+private[spark] trait ClassifierParams
+  extends PredictorParams with HasRawPredictionCol {
 
   override protected def validateAndTransformSchema(
       schema: StructType,
@@ -46,23 +43,21 @@ private[spark] trait ClassifierParams extends PredictorParams
 }
 
 /**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
+ *
  * Single-label binary or multiclass classification.
  * Classes are indexed {0, 1, ..., numClasses - 1}.
  *
  * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
  * @tparam E  Concrete Estimator type
  * @tparam M  Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
  */
-@AlphaComponent
-private[spark] abstract class Classifier[
+@DeveloperApi
+abstract class Classifier[
     FeaturesType,
     E <: Classifier[FeaturesType, E, M],
     M <: ClassificationModel[FeaturesType, M]]
-  extends Predictor[FeaturesType, E, M]
-  with ClassifierParams {
+  extends Predictor[FeaturesType, E, M] with ClassifierParams {
 
   /** @group setParam */
   def setRawPredictionCol(value: String): E = set(rawPredictionCol, value).asInstanceOf[E]
@@ -71,17 +66,15 @@ private[spark] abstract class Classifier[
 }
 
 /**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
+ *
  * Model produced by a [[Classifier]].
  * Classes are indexed {0, 1, ..., numClasses - 1}.
  *
  * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
  * @tparam M  Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
  */
-@AlphaComponent
-private[spark]
+@DeveloperApi
 abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
   extends PredictionModel[FeaturesType, M] with ClassifierParams {
 
@@ -101,13 +94,27 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
    * @return transformed dataset
    */
   override def transform(dataset: DataFrame): DataFrame = {
-    // This default implementation should be overridden as needed.
-
-    // Check schema
     transformSchema(dataset.schema, logging = true)
 
-    val (numColsOutput, outputData) =
-      ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
+    // Output selected columns only.
+    // This is a bit complicated since it tries to avoid repeated computation.
+    var outputData = dataset
+    var numColsOutput = 0
+    if (getRawPredictionCol != "") {
+      outputData = outputData.withColumn(getRawPredictionCol,
+        callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
+      numColsOutput += 1
+    }
+    if (getPredictionCol != "") {
+      val predUDF = if (getRawPredictionCol != "") {
+        callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol))
+      } else {
+        callUDF(predict _, DoubleType, col(getFeaturesCol))
+      }
+      outputData = outputData.withColumn(getPredictionCol, predUDF)
+      numColsOutput += 1
+    }
+
     if (numColsOutput == 0) {
       logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
         " since no output columns were set.")
@@ -116,22 +123,17 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
   }
 
   /**
-   * :: DeveloperApi ::
-   *
    * Predict label for the given features.
    * This internal method is used to implement [[transform()]] and output [[predictionCol]].
    *
    * This default implementation for classification predicts the index of the maximum value
    * from [[predictRaw()]].
    */
-  @DeveloperApi
   override protected def predict(features: FeaturesType): Double = {
-    predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2
+    raw2prediction(predictRaw(features))
   }
 
   /**
-   * :: DeveloperApi ::
-   *
    * Raw prediction for each possible label.
    * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
    * a measure of confidence in each possible label (where larger = more confident).
@@ -141,48 +143,12 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
    *          This raw prediction may be any real number, where a larger value indicates greater
    *          confidence for that label.
    */
-  @DeveloperApi
   protected def predictRaw(features: FeaturesType): Vector
-}
-
-private[ml] object ClassificationModel {
 
   /**
-   * Added prediction column(s).  This is separated from [[ClassificationModel.transform()]]
-   * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
-   * @param dataset  Input dataset
-   * @return (number of columns added, transformed dataset)
+   * Given a vector of raw predictions, select the predicted label.
+   * This may be overridden to support thresholds which favor particular labels.
+   * @return  predicted label
    */
-  def transformColumnsImpl[FeaturesType](
-      dataset: DataFrame,
-      model: ClassificationModel[FeaturesType, _]): (Int, DataFrame) = {
-
-    // Output selected columns only.
-    // This is a bit complicated since it tries to avoid repeated computation.
-    var tmpData = dataset
-    var numColsOutput = 0
-    if (model.getRawPredictionCol != "") {
-      // output raw prediction
-      val features2raw: FeaturesType => Vector = model.predictRaw
-      tmpData = tmpData.withColumn(model.getRawPredictionCol,
-        callUDF(features2raw, new VectorUDT, col(model.getFeaturesCol)))
-      numColsOutput += 1
-      if (model.getPredictionCol != "") {
-        val raw2pred: Vector => Double = (rawPred) => {
-          rawPred.toArray.zipWithIndex.maxBy(_._1)._2
-        }
-        tmpData = tmpData.withColumn(model.getPredictionCol,
-          callUDF(raw2pred, DoubleType, col(model.getRawPredictionCol)))
-        numColsOutput += 1
-      }
-    } else if (model.getPredictionCol != "") {
-      // output prediction
-      val features2pred: FeaturesType => Double = model.predict
-      tmpData = tmpData.withColumn(model.getPredictionCol,
-        callUDF(features2pred, DoubleType, col(model.getFeaturesCol)))
-      numColsOutput += 1
-    }
-    (numColsOutput, tmpData)
-  }
-
+  protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 419e5ba..dcebea1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.tree.{TreeClassifierParams, DecisionTreeParams, DecisionTreeModel, Node}
 import org.apache.spark.ml.util.MetadataUtils
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index 534ea95..ae51b05 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -21,11 +21,10 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.ml.regression.DecisionTreeRegressionModel
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{GBTParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
 import org.apache.spark.ml.util.MetadataUtils
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index b73be03..550369d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -21,9 +21,8 @@ import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
+import org.apache.spark.mllib.linalg._
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.functions._
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -99,76 +98,17 @@ class LogisticRegressionModel private[ml] (
   /** @group setParam */
   def setThreshold(value: Double): this.type = set(threshold, value)
 
+  /** Margin (rawPrediction) for class label 1.  For binary classification only. */
   private val margin: Vector => Double = (features) => {
     BLAS.dot(features, weights) + intercept
   }
 
+  /** Score (probability) for class label 1.  For binary classification only. */
   private val score: Vector => Double = (features) => {
     val m = margin(features)
     1.0 / (1.0 + math.exp(-m))
   }
 
-  override def transform(dataset: DataFrame): DataFrame = {
-    // This is overridden (a) to be more efficient (avoiding re-computing values when creating
-    // multiple output columns) and (b) to handle threshold, which the abstractions do not use.
-    // TODO: We should abstract away the steps defined by UDFs below so that the abstractions
-    // can call whichever UDFs are needed to create the output columns.
-
-    // Check schema
-    transformSchema(dataset.schema, logging = true)
-
-    // Output selected columns only.
-    // This is a bit complicated since it tries to avoid repeated computation.
-    //   rawPrediction (-margin, margin)
-    //   probability (1.0-score, score)
-    //   prediction (max margin)
-    var tmpData = dataset
-    var numColsOutput = 0
-    if ($(rawPredictionCol) != "") {
-      val features2raw: Vector => Vector = (features) => predictRaw(features)
-      tmpData = tmpData.withColumn($(rawPredictionCol),
-        callUDF(features2raw, new VectorUDT, col($(featuresCol))))
-      numColsOutput += 1
-    }
-    if ($(probabilityCol) != "") {
-      if ($(rawPredictionCol) != "") {
-        val raw2prob = udf { (rawPreds: Vector) =>
-          val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
-          Vectors.dense(1.0 - prob1, prob1): Vector
-        }
-        tmpData = tmpData.withColumn($(probabilityCol), raw2prob(col($(rawPredictionCol))))
-      } else {
-        val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
-        tmpData = tmpData.withColumn($(probabilityCol), features2prob(col($(featuresCol))))
-      }
-      numColsOutput += 1
-    }
-    if ($(predictionCol) != "") {
-      val t = $(threshold)
-      if ($(probabilityCol) != "") {
-        val predict = udf { probs: Vector =>
-          if (probs(1) > t) 1.0 else 0.0
-        }
-        tmpData = tmpData.withColumn($(predictionCol), predict(col($(probabilityCol))))
-      } else if ($(rawPredictionCol) != "") {
-        val predict = udf { rawPreds: Vector =>
-          val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
-          if (prob1 > t) 1.0 else 0.0
-        }
-        tmpData = tmpData.withColumn($(predictionCol), predict(col($(rawPredictionCol))))
-      } else {
-        val predict = udf { features: Vector => this.predict(features) }
-        tmpData = tmpData.withColumn($(predictionCol), predict(col($(featuresCol))))
-      }
-      numColsOutput += 1
-    }
-    if (numColsOutput == 0) {
-      this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
-        " since no output columns were set.")
-    }
-    tmpData
-  }
-
   override val numClasses: Int = 2
 
   /**
@@ -179,17 +119,43 @@ class LogisticRegressionModel private[ml] (
     if (score(features) > getThreshold) 1 else 0
   }
 
-  override protected def predictProbabilities(features: Vector): Vector = {
-    val s = score(features)
-    Vectors.dense(1.0 - s, s)
+  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+    rawPrediction match {
+      case dv: DenseVector =>
+        var i = 0
+        while (i < dv.size) {
+          dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
+          i += 1
+        }
+        dv
+      case sv: SparseVector =>
+        throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
+          " raw2probabilitiesInPlace encountered SparseVector")
+    }
   }
 
   override protected def predictRaw(features: Vector): Vector = {
     val m = margin(features)
-    Vectors.dense(0.0, m)
+    Vectors.dense(-m, m)
   }
 
   override def copy(extra: ParamMap): LogisticRegressionModel = {
     copyValues(new LogisticRegressionModel(parent, weights, intercept), extra)
   }
+
+  override protected def raw2prediction(rawPrediction: Vector): Double = {
+    val t = getThreshold
+    val rawThreshold = if (t == 0.0) {
+      Double.NegativeInfinity
+    } else if (t == 1.0) {
+      Double.PositiveInfinity
+    } else {
+      Math.log(t / (1.0 - t))
+    }
+    if (rawPrediction(1) > rawThreshold) 1 else 0
+  }
+
+  override protected def probability2prediction(probability: Vector): Double = {
+    if (probability(1) > getThreshold) 1 else 0
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 8519841..330ae29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -17,16 +17,16 @@
 
 package org.apache.spark.ml.classification
 
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DoubleType, DataType, StructType}
 
 /**
- * Params for probabilistic classification.
+ * (private[classification])  Params for probabilistic classification.
  */
 private[classification] trait ProbabilisticClassifierParams
   extends ClassifierParams with HasProbabilityCol {
@@ -42,17 +42,15 @@ private[classification] trait ProbabilisticClassifierParams
 
 
 /**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
  *
  * Single-label binary or multiclass classifier which can output class conditional probabilities.
  *
  * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
  * @tparam E  Concrete Estimator type
  * @tparam M  Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
  */
-@AlphaComponent
+@DeveloperApi
 private[spark] abstract class ProbabilisticClassifier[
     FeaturesType,
     E <: ProbabilisticClassifier[FeaturesType, E, M],
@@ -65,17 +63,15 @@ private[spark] abstract class ProbabilisticClassifier[
 
 
 /**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
  *
  * Model produced by a [[ProbabilisticClassifier]].
  * Classes are indexed {0, 1, ..., numClasses - 1}.
  *
  * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
  * @tparam M  Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
  */
-@AlphaComponent
+@DeveloperApi
 private[spark] abstract class ProbabilisticClassificationModel[
     FeaturesType,
     M <: ProbabilisticClassificationModel[FeaturesType, M]]
@@ -95,39 +91,79 @@ private[spark] abstract class ProbabilisticClassificationModel[
    * @return transformed dataset
    */
   override def transform(dataset: DataFrame): DataFrame = {
-    // This default implementation should be overridden as needed.
-
-    // Check schema
     transformSchema(dataset.schema, logging = true)
 
-    val (numColsOutput, outputData) =
-      ClassificationModel.transformColumnsImpl[FeaturesType](dataset, this)
-
     // Output selected columns only.
-    if ($(probabilityCol) != "") {
-      // output probabilities
-      outputData.withColumn($(probabilityCol),
-        callUDF(predictProbabilities _, new VectorUDT, col($(featuresCol))))
-    } else {
-      if (numColsOutput == 0) {
-        this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
-          " since no output columns were set.")
+    // This is a bit complicated since it tries to avoid repeated computation.
+    var outputData = dataset
+    var numColsOutput = 0
+    if ($(rawPredictionCol).nonEmpty) {
+      outputData = outputData.withColumn(getRawPredictionCol,
+        callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol)))
+      numColsOutput += 1
+    }
+    if ($(probabilityCol).nonEmpty) {
+      val probUDF = if ($(rawPredictionCol).nonEmpty) {
+        callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol)))
+      } else {
+        callUDF(predictProbability _, new VectorUDT, col($(featuresCol)))
+      }
+      outputData = outputData.withColumn($(probabilityCol), probUDF)
+      numColsOutput += 1
+    }
+    if ($(predictionCol).nonEmpty) {
+      val predUDF = if ($(rawPredictionCol).nonEmpty) {
+        callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol)))
+      } else if ($(probabilityCol).nonEmpty) {
+        callUDF(probability2prediction _, DoubleType, col($(probabilityCol)))
+      } else {
+        callUDF(predict _, DoubleType, col($(featuresCol)))
       }
-      outputData
+      outputData = outputData.withColumn($(predictionCol), predUDF)
+      numColsOutput += 1
+    }
+
+    if (numColsOutput == 0) {
+      this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
+        " since no output columns were set.")
     }
+    outputData
   }
 
   /**
-   * :: DeveloperApi ::
+   * Estimate the probability of each class given the raw prediction,
+   * doing the computation in-place.
+   * These predictions are also called class conditional probabilities.
+   *
+   * This internal method is used to implement [[transform()]] and output [[probabilityCol]].
    *
+   * @return Estimated class conditional probabilities (modified input vector)
+   */
+  protected def raw2probabilityInPlace(rawPrediction: Vector): Vector
+
+  /** Non-in-place version of [[raw2probabilityInPlace()]] */
+  protected def raw2probability(rawPrediction: Vector): Vector = {
+    val probs = rawPrediction.copy
+    raw2probabilityInPlace(probs)
+  }
+
+  /**
    * Predict the probability of each class given the features.
    * These predictions are also called class conditional probabilities.
    *
-   * WARNING: Not all models output well-calibrated probability estimates!  These probabilities
-   *          should be treated as confidences, not precise probabilities.
-   *
    * This internal method is used to implement [[transform()]] and output [[probabilityCol]].
+   *
+   * @return Estimated class conditional probabilities
+   */
+  protected def predictProbability(features: FeaturesType): Vector = {
+    val rawPreds = predictRaw(features)
+    raw2probabilityInPlace(rawPreds)
+  }
+
+  /**
+   * Given a vector of class conditional probabilities, select the predicted label.
+   * This may be overridden to support thresholds which favor particular labels.
+   * @return  predicted label
    */
-  @DeveloperApi
-  protected def predictProbabilities(features: FeaturesType): Vector
+  protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 17f59bb..9954893 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -20,10 +20,9 @@ package org.apache.spark.ml.classification
 import scala.collection.mutable
 
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{RandomForestParams, TreeClassifierParams, DecisionTreeModel, TreeEnsembleModel}
 import org.apache.spark.ml.util.MetadataUtils
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
deleted file mode 100644
index e8b3628..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ /dev/null
@@ -1,217 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.impl.estimator
-
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.{Estimator, Model}
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.SchemaUtils
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
-
-/**
- * :: DeveloperApi ::
- *
- * Trait for parameters for prediction (regression and classification).
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@DeveloperApi
-private[spark] trait PredictorParams extends Params
-  with HasLabelCol with HasFeaturesCol with HasPredictionCol {
-
-  /**
-   * Validates and transforms the input schema with the provided param map.
-   * @param schema input schema
-   * @param fitting whether this is in fitting
-   * @param featuresDataType  SQL DataType for FeaturesType.
-   *                          E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
-   * @return output schema
-   */
-  protected def validateAndTransformSchema(
-      schema: StructType,
-      fitting: Boolean,
-      featuresDataType: DataType): StructType = {
-    // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
-    SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
-    if (fitting) {
-      // TODO: Allow other numeric types
-      SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
-    }
-    SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
-  }
-}
-
-/**
- * :: AlphaComponent ::
- *
- * Abstraction for prediction problems (regression and classification).
- *
- * @tparam FeaturesType  Type of features.
- *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
- * @tparam Learner  Specialization of this class.  If you subclass this type, use this type
- *                  parameter to specify the concrete type.
- * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
- *            parameter to specify the concrete type for the corresponding model.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@AlphaComponent
-private[spark] abstract class Predictor[
-    FeaturesType,
-    Learner <: Predictor[FeaturesType, Learner, M],
-    M <: PredictionModel[FeaturesType, M]]
-  extends Estimator[M] with PredictorParams {
-
-  /** @group setParam */
-  def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
-
-  /** @group setParam */
-  def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
-
-  /** @group setParam */
-  def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
-
-  override def fit(dataset: DataFrame): M = {
-    // This handles a few items such as schema validation.
-    // Developers only need to implement train().
-    transformSchema(dataset.schema, logging = true)
-    copyValues(train(dataset))
-  }
-
-  override def copy(extra: ParamMap): Learner = {
-    super.copy(extra).asInstanceOf[Learner]
-  }
-
-  /**
-   * :: DeveloperApi ::
-   *
-   * Train a model using the given dataset and parameters.
-   * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
-   * and copying parameters into the model.
-   *
-   * @param dataset  Training dataset
-   * @return  Fitted model
-   */
-  @DeveloperApi
-  protected def train(dataset: DataFrame): M
-
-  /**
-   * :: DeveloperApi ::
-   *
-   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
-   *
-   * This is used by [[validateAndTransformSchema()]].
-   * This workaround is needed since SQL has different APIs for Scala and Java.
-   *
-   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
-   */
-  @DeveloperApi
-  protected def featuresDataType: DataType = new VectorUDT
-
-  override def transformSchema(schema: StructType): StructType = {
-    validateAndTransformSchema(schema, fitting = true, featuresDataType)
-  }
-
-  /**
-   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
-   * and put it in an RDD with strong types.
-   */
-  protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = {
-    dataset.select($(labelCol), $(featuresCol))
-      .map { case Row(label: Double, features: Vector) =>
-      LabeledPoint(label, features)
-    }
-  }
-}
-
-/**
- * :: AlphaComponent ::
- *
- * Abstraction for a model for prediction tasks (regression and classification).
- *
- * @tparam FeaturesType  Type of features.
- *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
- * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
- *            parameter to specify the concrete type for the corresponding model.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@AlphaComponent
-private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
-  extends Model[M] with PredictorParams {
-
-  /** @group setParam */
-  def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
-
-  /** @group setParam */
-  def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
-
-  /**
-   * :: DeveloperApi ::
-   *
-   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
-   *
-   * This is used by [[validateAndTransformSchema()]].
-   * This workaround is needed since SQL has different APIs for Scala and Java.
-   *
-   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
-   */
-  @DeveloperApi
-  protected def featuresDataType: DataType = new VectorUDT
-
-  override def transformSchema(schema: StructType): StructType = {
-    validateAndTransformSchema(schema, fitting = false, featuresDataType)
-  }
-
-  /**
-   * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing
-   * the predictions as a new column [[predictionCol]].
-   *
-   * @param dataset input dataset
-   * @return transformed dataset with [[predictionCol]] of type [[Double]]
-   */
-  override def transform(dataset: DataFrame): DataFrame = {
-    // This default implementation should be overridden as needed.
-
-    // Check schema
-    transformSchema(dataset.schema, logging = true)
-
-    if ($(predictionCol) != "") {
-      dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol))))
-    } else {
-      this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
-        " since no output columns were set.")
-      dataset
-    }
-  }
-
-  /**
-   * :: DeveloperApi ::
-   *
-   * Predict label for the given features.
-   * This internal method is used to implement [[transform()]] and output [[predictionCol]].
-   */
-  @DeveloperApi
-  protected def predict(features: FeaturesType): Double
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
deleted file mode 100644
index 0e22562..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
+++ /dev/null
@@ -1,431 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.impl.tree
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.ml.impl.estimator.PredictorParams
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
-import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Decision Tree-based algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait DecisionTreeParams extends PredictorParams {
-
-  /**
-   * Maximum depth of the tree (>= 0).
-   * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
-   * (default = 5)
-   * @group param
-   */
-  final val maxDepth: IntParam =
-    new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
-      " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
-      ParamValidators.gtEq(0))
-
-  /**
-   * Maximum number of bins used for discretizing continuous features and for choosing how to split
-   * on features at each node.  More bins give higher granularity.
-   * Must be >= 2 and >= number of categories in any categorical feature.
-   * (default = 32)
-   * @group param
-   */
-  final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
-    " discretizing continuous features.  Must be >=2 and >= number of categories for any" +
-    " categorical feature.", ParamValidators.gtEq(2))
-
-  /**
-   * Minimum number of instances each child must have after split.
-   * If a split causes the left or right child to have fewer than minInstancesPerNode,
-   * the split will be discarded as invalid.
-   * Should be >= 1.
-   * (default = 1)
-   * @group param
-   */
-  final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
-    " number of instances each child must have after split.  If a split causes the left or right" +
-    " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
-    " Should be >= 1.", ParamValidators.gtEq(1))
-
-  /**
-   * Minimum information gain for a split to be considered at a tree node.
-   * (default = 0.0)
-   * @group param
-   */
-  final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
-    "Minimum information gain for a split to be considered at a tree node.")
-
-  /**
-   * Maximum memory in MB allocated to histogram aggregation.
-   * (default = 256 MB)
-   * @group expertParam
-   */
-  final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
-    "Maximum memory in MB allocated to histogram aggregation.",
-    ParamValidators.gtEq(0))
-
-  /**
-   * If false, the algorithm will pass trees to executors to match instances with nodes.
-   * If true, the algorithm will cache node IDs for each instance.
-   * Caching can speed up training of deeper trees.
-   * (default = false)
-   * @group expertParam
-   */
-  final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
-    " algorithm will pass trees to executors to match instances with nodes. If true, the" +
-    " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
-    " trees.")
-
-  /**
-   * Specifies how often to checkpoint the cached node IDs.
-   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
-   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
-   * [[org.apache.spark.SparkContext]].
-   * Must be >= 1.
-   * (default = 10)
-   * @group expertParam
-   */
-  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
-    " how often to checkpoint the cached node IDs.  E.g. 10 means that the cache will get" +
-    " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
-    " checkpoint directory is set in the SparkContext. Must be >= 1.",
-    ParamValidators.gtEq(1))
-
-  setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
-    maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
-
-  /** @group setParam */
-  def setMaxDepth(value: Int): this.type = set(maxDepth, value)
-
-  /** @group getParam */
-  final def getMaxDepth: Int = $(maxDepth)
-
-  /** @group setParam */
-  def setMaxBins(value: Int): this.type = set(maxBins, value)
-
-  /** @group getParam */
-  final def getMaxBins: Int = $(maxBins)
-
-  /** @group setParam */
-  def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
-
-  /** @group getParam */
-  final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
-
-  /** @group setParam */
-  def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
-
-  /** @group getParam */
-  final def getMinInfoGain: Double = $(minInfoGain)
-
-  /** @group expertSetParam */
-  def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
-
-  /** @group expertGetParam */
-  final def getMaxMemoryInMB: Int = $(maxMemoryInMB)
-
-  /** @group expertSetParam */
-  def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
-
-  /** @group expertGetParam */
-  final def getCacheNodeIds: Boolean = $(cacheNodeIds)
-
-  /** @group expertSetParam */
-  def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
-
-  /** @group expertGetParam */
-  final def getCheckpointInterval: Int = $(checkpointInterval)
-
-  /** (private[ml]) Create a Strategy instance to use with the old API. */
-  private[ml] def getOldStrategy(
-      categoricalFeatures: Map[Int, Int],
-      numClasses: Int,
-      oldAlgo: OldAlgo.Algo,
-      oldImpurity: OldImpurity,
-      subsamplingRate: Double): OldStrategy = {
-    val strategy = OldStrategy.defaultStategy(oldAlgo)
-    strategy.impurity = oldImpurity
-    strategy.checkpointInterval = getCheckpointInterval
-    strategy.maxBins = getMaxBins
-    strategy.maxDepth = getMaxDepth
-    strategy.maxMemoryInMB = getMaxMemoryInMB
-    strategy.minInfoGain = getMinInfoGain
-    strategy.minInstancesPerNode = getMinInstancesPerNode
-    strategy.useNodeIdCache = getCacheNodeIds
-    strategy.numClasses = numClasses
-    strategy.categoricalFeaturesInfo = categoricalFeatures
-    strategy.subsamplingRate = subsamplingRate
-    strategy
-  }
-}
-
-/**
- * Parameters for Decision Tree-based classification algorithms.
- */
-private[ml] trait TreeClassifierParams extends Params {
-
-  /**
-   * Criterion used for information gain calculation (case-insensitive).
-   * Supported: "entropy" and "gini".
-   * (default = gini)
-   * @group param
-   */
-  final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
-    " information gain calculation (case-insensitive). Supported options:" +
-    s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
-    (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
-
-  setDefault(impurity -> "gini")
-
-  /** @group setParam */
-  def setImpurity(value: String): this.type = set(impurity, value)
-
-  /** @group getParam */
-  final def getImpurity: String = $(impurity).toLowerCase
-
-  /** Convert new impurity to old impurity. */
-  private[ml] def getOldImpurity: OldImpurity = {
-    getImpurity match {
-      case "entropy" => OldEntropy
-      case "gini" => OldGini
-      case _ =>
-        // Should never happen because of check in setter method.
-        throw new RuntimeException(
-          s"TreeClassifierParams was given unrecognized impurity: $impurity.")
-    }
-  }
-}
-
-private[ml] object TreeClassifierParams {
-  // These options should be lowercase.
-  final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
-}
-
-/**
- * Parameters for Decision Tree-based regression algorithms.
- */
-private[ml] trait TreeRegressorParams extends Params {
-
-  /**
-   * Criterion used for information gain calculation (case-insensitive).
-   * Supported: "variance".
-   * (default = variance)
-   * @group param
-   */
-  final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
-    " information gain calculation (case-insensitive). Supported options:" +
-    s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
-    (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
-
-  setDefault(impurity -> "variance")
-
-  /** @group setParam */
-  def setImpurity(value: String): this.type = set(impurity, value)
-
-  /** @group getParam */
-  final def getImpurity: String = $(impurity).toLowerCase
-
-  /** Convert new impurity to old impurity. */
-  private[ml] def getOldImpurity: OldImpurity = {
-    getImpurity match {
-      case "variance" => OldVariance
-      case _ =>
-        // Should never happen because of check in setter method.
-        throw new RuntimeException(
-          s"TreeRegressorParams was given unrecognized impurity: $impurity")
-    }
-  }
-}
-
-private[ml] object TreeRegressorParams {
-  // These options should be lowercase.
-  final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
-}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Decision Tree-based ensemble algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
-
-  /**
-   * Fraction of the training data used for learning each decision tree, in range (0, 1].
-   * (default = 1.0)
-   * @group param
-   */
-  final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
-    "Fraction of the training data used for learning each decision tree, in range (0, 1].",
-    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
-
-  setDefault(subsamplingRate -> 1.0)
-
-  /** @group setParam */
-  def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
-
-  /** @group getParam */
-  final def getSubsamplingRate: Double = $(subsamplingRate)
-
-  /** @group setParam */
-  def setSeed(value: Long): this.type = set(seed, value)
-
-  /**
-   * Create a Strategy instance to use with the old API.
-   * NOTE: The caller should set impurity and seed.
-   */
-  private[ml] def getOldStrategy(
-      categoricalFeatures: Map[Int, Int],
-      numClasses: Int,
-      oldAlgo: OldAlgo.Algo,
-      oldImpurity: OldImpurity): OldStrategy = {
-    super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
-  }
-}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Random Forest algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait RandomForestParams extends TreeEnsembleParams {
-
-  /**
-   * Number of trees to train (>= 1).
-   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
-   * TODO: Change to always do bootstrapping (simpler).  SPARK-7130
-   * (default = 20)
-   * @group param
-   */
-  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
-    ParamValidators.gtEq(1))
-
-  /**
-   * The number of features to consider for splits at each tree node.
-   * Supported options:
-   *  - "auto": Choose automatically for task:
-   *            If numTrees == 1, set to "all."
-   *            If numTrees > 1 (forest), set to "sqrt" for classification and
-   *              to "onethird" for regression.
-   *  - "all": use all features
-   *  - "onethird": use 1/3 of the features
-   *  - "sqrt": use sqrt(number of features)
-   *  - "log2": use log2(number of features)
-   * (default = "auto")
-   *
-   * These various settings are based on the following references:
-   *  - log2: tested in Breiman (2001)
-   *  - sqrt: recommended by Breiman manual for random forests
-   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
-   *    package.
-   * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf  Breiman (2001)]]
-   * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf  Breiman manual for
-   *     random forests]]
-   *
-   * @group param
-   */
-  final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
-    "The number of features to consider for splits at each tree node." +
-      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
-    (value: String) =>
-      RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
-
-  setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
-
-  /** @group setParam */
-  def setNumTrees(value: Int): this.type = set(numTrees, value)
-
-  /** @group getParam */
-  final def getNumTrees: Int = $(numTrees)
-
-  /** @group setParam */
-  def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
-
-  /** @group getParam */
-  final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
-}
-
-private[ml] object RandomForestParams {
-  // These options should be lowercase.
-  final val supportedFeatureSubsetStrategies: Array[String] =
-    Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
-}
-
-/**
- * :: DeveloperApi ::
- * Parameters for Gradient-Boosted Tree algorithms.
- *
- * Note: Marked as private and DeveloperApi since this may be made public in the future.
- */
-@DeveloperApi
-private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
-
-  /**
-   * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
-   * estimator.
-   * (default = 0.1)
-   * @group param
-   */
-  final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
-    " learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
-    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
-
-  /* TODO: Add this doc when we add this param.  SPARK-7132
-   * Threshold for stopping early when runWithValidation is used.
-   * If the error rate on the validation input changes by less than the validationTol,
-   * then learning will stop early (before [[numIterations]]).
-   * This parameter is ignored when run is used.
-   * (default = 1e-5)
-   * @group param
-   */
-  // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
-  // validationTol -> 1e-5
-
-  setDefault(maxIter -> 20, stepSize -> 0.1)
-
-  /** @group setParam */
-  def setMaxIter(value: Int): this.type = set(maxIter, value)
-
-  /** @group setParam */
-  def setStepSize(value: Double): this.type = set(stepSize, value)
-
-  /** @group getParam */
-  final def getStepSize: Double = $(stepSize)
-
-  /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
-  private[ml] def getOldBoostingStrategy(
-      categoricalFeatures: Map[Int, Int],
-      oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
-    val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
-    // NOTE: The old API does not support "seed" so we ignore it.
-    new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
-  }
-
-  /** Get old Gradient Boosting Loss type */
-  private[ml] def getOldLossType: OldLoss
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index d379172..0e1ff97 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -40,8 +40,10 @@ private[shared] object SharedParamsCodeGen {
       ParamDesc[String]("predictionCol", "prediction column name", Some("\"prediction\"")),
       ParamDesc[String]("rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
         Some("\"rawPrediction\"")),
-      ParamDesc[String]("probabilityCol",
-        "column name for predicted class conditional probabilities", Some("\"probability\"")),
+      ParamDesc[String]("probabilityCol", "Column name for predicted class conditional" +
+        " probabilities. Note: Not all models output well-calibrated probability estimates!" +
+        " These probabilities should be treated as confidences, not precise probabilities.",
+        Some("\"probability\"")),
       ParamDesc[Double]("threshold",
         "threshold in binary classification prediction, in range [0, 1]",
         isValid = "ParamValidators.inRange(0, 1)"),

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index fb1874c..87f8680 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -128,10 +128,10 @@ private[ml] trait HasRawPredictionCol extends Params {
 private[ml] trait HasProbabilityCol extends Params {
 
   /**
-   * Param for column name for predicted class conditional probabilities.
+   * Param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities..
    * @group param
    */
-  final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "column name for predicted class conditional probabilities")
+  final val probabilityCol: Param[String] = new Param[String](this, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
 
   setDefault(probabilityCol, "probability")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index b07c26f..f8f0b16 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.ml.regression
 
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.tree.{TreeRegressorParams, DecisionTreeParams, DecisionTreeModel, Node}
 import org.apache.spark.ml.util.MetadataUtils
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index bc79695..461905c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -21,10 +21,9 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{GBTParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
 import org.apache.spark.ml.util.MetadataUtils
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 66c475f..e63c9a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -25,6 +25,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
 
 import org.apache.spark.Logging
 import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
 import org.apache.spark.mllib.linalg.{Vector, Vectors}
@@ -39,7 +40,7 @@ import org.apache.spark.util.StatCounter
 /**
  * Params for linear regression.
  */
-private[regression] trait LinearRegressionParams extends RegressorParams
+private[regression] trait LinearRegressionParams extends PredictorParams
   with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
 
 /**
@@ -240,7 +241,7 @@ class LinearRegressionModel private[ml] (
  *     + \bar{y} / \hat{y}||^2
  *   = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2
  * }}}
- * where w_i^\prime is the effective weights defined by w_i/\hat{x_i}, offset is
+ * where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is
  * {{{
  * - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}.
  * }}}, and diff is

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 0468a1b..dbc6289 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -18,10 +18,9 @@
 package org.apache.spark.ml.regression
 
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
-import org.apache.spark.ml.impl.tree.{RandomForestParams, TreeRegressorParams}
+import org.apache.spark.ml.{PredictionModel, Predictor}
 import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.tree.{DecisionTreeModel, TreeEnsembleModel}
+import org.apache.spark.ml.tree.{RandomForestParams, TreeRegressorParams, DecisionTreeModel, TreeEnsembleModel}
 import org.apache.spark.ml.util.MetadataUtils
 import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.mllib.regression.LabeledPoint

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
index c6b3327..c72ef29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -17,62 +17,40 @@
 
 package org.apache.spark.ml.regression
 
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
-import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
 
-/**
- * :: DeveloperApi ::
- * Params for regression.
- * Currently empty, but may add functionality later.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
- */
-@DeveloperApi
-private[spark] trait RegressorParams extends PredictorParams
 
 /**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
  *
  * Single-label regression
  *
  * @tparam FeaturesType  Type of input features.  E.g., [[org.apache.spark.mllib.linalg.Vector]]
  * @tparam Learner  Concrete Estimator type
  * @tparam M  Concrete Model type
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
  */
-@AlphaComponent
+@DeveloperApi
 private[spark] abstract class Regressor[
     FeaturesType,
     Learner <: Regressor[FeaturesType, Learner, M],
     M <: RegressionModel[FeaturesType, M]]
-  extends Predictor[FeaturesType, Learner, M]
-  with RegressorParams {
+  extends Predictor[FeaturesType, Learner, M] with PredictorParams {
 
   // TODO: defaultEvaluator (follow-up PR)
 }
 
 /**
- * :: AlphaComponent ::
+ * :: DeveloperApi ::
  *
  * Model produced by a [[Regressor]].
  *
  * @tparam FeaturesType  Type of input features.  E.g., [[org.apache.spark.mllib.linalg.Vector]]
  * @tparam M  Concrete Model type.
- *
- * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
  */
-@AlphaComponent
-private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
-  extends PredictionModel[FeaturesType, M] with RegressorParams {
-
-  /**
-   * :: DeveloperApi ::
-   *
-   * Predict real-valued label for the given features.
-   * This internal method is used to implement [[transform()]] and output [[predictionCol]].
-   */
-  @DeveloperApi
-  protected def predict(features: FeaturesType): Double
+@DeveloperApi
+abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
+  extends PredictionModel[FeaturesType, M] with PredictorParams {
 
+  // TODO: defaultEvaluator (follow-up PR)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/1ad04dae/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
new file mode 100644
index 0000000..816fced
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -0,0 +1,431 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.PredictorParams
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
+import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait DecisionTreeParams extends PredictorParams {
+
+  /**
+   * Maximum depth of the tree (>= 0).
+   * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+   * (default = 5)
+   * @group param
+   */
+  final val maxDepth: IntParam =
+    new IntParam(this, "maxDepth", "Maximum depth of the tree. (>= 0)" +
+      " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.",
+      ParamValidators.gtEq(0))
+
+  /**
+   * Maximum number of bins used for discretizing continuous features and for choosing how to split
+   * on features at each node.  More bins give higher granularity.
+   * Must be >= 2 and >= number of categories in any categorical feature.
+   * (default = 32)
+   * @group param
+   */
+  final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+    " discretizing continuous features.  Must be >=2 and >= number of categories for any" +
+    " categorical feature.", ParamValidators.gtEq(2))
+
+  /**
+   * Minimum number of instances each child must have after split.
+   * If a split causes the left or right child to have fewer than minInstancesPerNode,
+   * the split will be discarded as invalid.
+   * Should be >= 1.
+   * (default = 1)
+   * @group param
+   */
+  final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+    " number of instances each child must have after split.  If a split causes the left or right" +
+    " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+    " Should be >= 1.", ParamValidators.gtEq(1))
+
+  /**
+   * Minimum information gain for a split to be considered at a tree node.
+   * (default = 0.0)
+   * @group param
+   */
+  final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+    "Minimum information gain for a split to be considered at a tree node.")
+
+  /**
+   * Maximum memory in MB allocated to histogram aggregation.
+   * (default = 256 MB)
+   * @group expertParam
+   */
+  final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+    "Maximum memory in MB allocated to histogram aggregation.",
+    ParamValidators.gtEq(0))
+
+  /**
+   * If false, the algorithm will pass trees to executors to match instances with nodes.
+   * If true, the algorithm will cache node IDs for each instance.
+   * Caching can speed up training of deeper trees.
+   * (default = false)
+   * @group expertParam
+   */
+  final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+    " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+    " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+    " trees.")
+
+  /**
+   * Specifies how often to checkpoint the cached node IDs.
+   * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+   * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+   * [[org.apache.spark.SparkContext]].
+   * Must be >= 1.
+   * (default = 10)
+   * @group expertParam
+   */
+  final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
+    " how often to checkpoint the cached node IDs.  E.g. 10 means that the cache will get" +
+    " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
+    " checkpoint directory is set in the SparkContext. Must be >= 1.",
+    ParamValidators.gtEq(1))
+
+  setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
+    maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+
+  /** @group setParam */
+  def setMaxDepth(value: Int): this.type = set(maxDepth, value)
+
+  /** @group getParam */
+  final def getMaxDepth: Int = $(maxDepth)
+
+  /** @group setParam */
+  def setMaxBins(value: Int): this.type = set(maxBins, value)
+
+  /** @group getParam */
+  final def getMaxBins: Int = $(maxBins)
+
+  /** @group setParam */
+  def setMinInstancesPerNode(value: Int): this.type = set(minInstancesPerNode, value)
+
+  /** @group getParam */
+  final def getMinInstancesPerNode: Int = $(minInstancesPerNode)
+
+  /** @group setParam */
+  def setMinInfoGain(value: Double): this.type = set(minInfoGain, value)
+
+  /** @group getParam */
+  final def getMinInfoGain: Double = $(minInfoGain)
+
+  /** @group expertSetParam */
+  def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
+
+  /** @group expertGetParam */
+  final def getMaxMemoryInMB: Int = $(maxMemoryInMB)
+
+  /** @group expertSetParam */
+  def setCacheNodeIds(value: Boolean): this.type = set(cacheNodeIds, value)
+
+  /** @group expertGetParam */
+  final def getCacheNodeIds: Boolean = $(cacheNodeIds)
+
+  /** @group expertSetParam */
+  def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+
+  /** @group expertGetParam */
+  final def getCheckpointInterval: Int = $(checkpointInterval)
+
+  /** (private[ml]) Create a Strategy instance to use with the old API. */
+  private[ml] def getOldStrategy(
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int,
+      oldAlgo: OldAlgo.Algo,
+      oldImpurity: OldImpurity,
+      subsamplingRate: Double): OldStrategy = {
+    val strategy = OldStrategy.defaultStategy(oldAlgo)
+    strategy.impurity = oldImpurity
+    strategy.checkpointInterval = getCheckpointInterval
+    strategy.maxBins = getMaxBins
+    strategy.maxDepth = getMaxDepth
+    strategy.maxMemoryInMB = getMaxMemoryInMB
+    strategy.minInfoGain = getMinInfoGain
+    strategy.minInstancesPerNode = getMinInstancesPerNode
+    strategy.useNodeIdCache = getCacheNodeIds
+    strategy.numClasses = numClasses
+    strategy.categoricalFeaturesInfo = categoricalFeatures
+    strategy.subsamplingRate = subsamplingRate
+    strategy
+  }
+}
+
+/**
+ * Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+  /**
+   * Criterion used for information gain calculation (case-insensitive).
+   * Supported: "entropy" and "gini".
+   * (default = gini)
+   * @group param
+   */
+  final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+    " information gain calculation (case-insensitive). Supported options:" +
+    s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}",
+    (value: String) => TreeClassifierParams.supportedImpurities.contains(value.toLowerCase))
+
+  setDefault(impurity -> "gini")
+
+  /** @group setParam */
+  def setImpurity(value: String): this.type = set(impurity, value)
+
+  /** @group getParam */
+  final def getImpurity: String = $(impurity).toLowerCase
+
+  /** Convert new impurity to old impurity. */
+  private[ml] def getOldImpurity: OldImpurity = {
+    getImpurity match {
+      case "entropy" => OldEntropy
+      case "gini" => OldGini
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(
+          s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+    }
+  }
+}
+
+private[ml] object TreeClassifierParams {
+  // These options should be lowercase.
+  final val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+}
+
+/**
+ * Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends Params {
+
+  /**
+   * Criterion used for information gain calculation (case-insensitive).
+   * Supported: "variance".
+   * (default = variance)
+   * @group param
+   */
+  final val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+    " information gain calculation (case-insensitive). Supported options:" +
+    s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}",
+    (value: String) => TreeRegressorParams.supportedImpurities.contains(value.toLowerCase))
+
+  setDefault(impurity -> "variance")
+
+  /** @group setParam */
+  def setImpurity(value: String): this.type = set(impurity, value)
+
+  /** @group getParam */
+  final def getImpurity: String = $(impurity).toLowerCase
+
+  /** Convert new impurity to old impurity. */
+  private[ml] def getOldImpurity: OldImpurity = {
+    getImpurity match {
+      case "variance" => OldVariance
+      case _ =>
+        // Should never happen because of check in setter method.
+        throw new RuntimeException(
+          s"TreeRegressorParams was given unrecognized impurity: $impurity")
+    }
+  }
+}
+
+private[ml] object TreeRegressorParams {
+  // These options should be lowercase.
+  final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based ensemble algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
+
+  /**
+   * Fraction of the training data used for learning each decision tree, in range (0, 1].
+   * (default = 1.0)
+   * @group param
+   */
+  final val subsamplingRate: DoubleParam = new DoubleParam(this, "subsamplingRate",
+    "Fraction of the training data used for learning each decision tree, in range (0, 1].",
+    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+  setDefault(subsamplingRate -> 1.0)
+
+  /** @group setParam */
+  def setSubsamplingRate(value: Double): this.type = set(subsamplingRate, value)
+
+  /** @group getParam */
+  final def getSubsamplingRate: Double = $(subsamplingRate)
+
+  /** @group setParam */
+  def setSeed(value: Long): this.type = set(seed, value)
+
+  /**
+   * Create a Strategy instance to use with the old API.
+   * NOTE: The caller should set impurity and seed.
+   */
+  private[ml] def getOldStrategy(
+      categoricalFeatures: Map[Int, Int],
+      numClasses: Int,
+      oldAlgo: OldAlgo.Algo,
+      oldImpurity: OldImpurity): OldStrategy = {
+    super.getOldStrategy(categoricalFeatures, numClasses, oldAlgo, oldImpurity, getSubsamplingRate)
+  }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Random Forest algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait RandomForestParams extends TreeEnsembleParams {
+
+  /**
+   * Number of trees to train (>= 1).
+   * If 1, then no bootstrapping is used.  If > 1, then bootstrapping is done.
+   * TODO: Change to always do bootstrapping (simpler).  SPARK-7130
+   * (default = 20)
+   * @group param
+   */
+  final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)",
+    ParamValidators.gtEq(1))
+
+  /**
+   * The number of features to consider for splits at each tree node.
+   * Supported options:
+   *  - "auto": Choose automatically for task:
+   *            If numTrees == 1, set to "all."
+   *            If numTrees > 1 (forest), set to "sqrt" for classification and
+   *              to "onethird" for regression.
+   *  - "all": use all features
+   *  - "onethird": use 1/3 of the features
+   *  - "sqrt": use sqrt(number of features)
+   *  - "log2": use log2(number of features)
+   * (default = "auto")
+   *
+   * These various settings are based on the following references:
+   *  - log2: tested in Breiman (2001)
+   *  - sqrt: recommended by Breiman manual for random forests
+   *  - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
+   *    package.
+   * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf  Breiman (2001)]]
+   * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf  Breiman manual for
+   *     random forests]]
+   *
+   * @group param
+   */
+  final val featureSubsetStrategy: Param[String] = new Param[String](this, "featureSubsetStrategy",
+    "The number of features to consider for splits at each tree node." +
+      s" Supported options: ${RandomForestParams.supportedFeatureSubsetStrategies.mkString(", ")}",
+    (value: String) =>
+      RandomForestParams.supportedFeatureSubsetStrategies.contains(value.toLowerCase))
+
+  setDefault(numTrees -> 20, featureSubsetStrategy -> "auto")
+
+  /** @group setParam */
+  def setNumTrees(value: Int): this.type = set(numTrees, value)
+
+  /** @group getParam */
+  final def getNumTrees: Int = $(numTrees)
+
+  /** @group setParam */
+  def setFeatureSubsetStrategy(value: String): this.type = set(featureSubsetStrategy, value)
+
+  /** @group getParam */
+  final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase
+}
+
+private[ml] object RandomForestParams {
+  // These options should be lowercase.
+  final val supportedFeatureSubsetStrategies: Array[String] =
+    Array("auto", "all", "onethird", "sqrt", "log2").map(_.toLowerCase)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Gradient-Boosted Tree algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter {
+
+  /**
+   * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each
+   * estimator.
+   * (default = 0.1)
+   * @group param
+   */
+  final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size (a.k.a." +
+    " learning rate) in interval (0, 1] for shrinking the contribution of each estimator",
+    ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true))
+
+  /* TODO: Add this doc when we add this param.  SPARK-7132
+   * Threshold for stopping early when runWithValidation is used.
+   * If the error rate on the validation input changes by less than the validationTol,
+   * then learning will stop early (before [[numIterations]]).
+   * This parameter is ignored when run is used.
+   * (default = 1e-5)
+   * @group param
+   */
+  // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "")
+  // validationTol -> 1e-5
+
+  setDefault(maxIter -> 20, stepSize -> 0.1)
+
+  /** @group setParam */
+  def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+  /** @group setParam */
+  def setStepSize(value: Double): this.type = set(stepSize, value)
+
+  /** @group getParam */
+  final def getStepSize: Double = $(stepSize)
+
+  /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */
+  private[ml] def getOldBoostingStrategy(
+      categoricalFeatures: Map[Int, Int],
+      oldAlgo: OldAlgo.Algo): OldBoostingStrategy = {
+    val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 2, oldAlgo, OldVariance)
+    // NOTE: The old API does not support "seed" so we ignore it.
+    new OldBoostingStrategy(strategy, getOldLossType, getMaxIter, getStepSize)
+  }
+
+  /** Get old Gradient Boosting Loss type */
+  private[ml] def getOldLossType: OldLoss
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message