spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject [2/2] spark git commit: [SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs
Date Fri, 06 Feb 2015 07:44:15 GMT
[SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs

This is part (1a) of the updates from the design doc in [https://docs.google.com/document/d/1BH9el33kBX8JiDdgUJXdLW14CA2qhTCWIG46eXZVoJs]

**UPDATE**: Most of the APIs are being kept private[spark] to allow further discussion.  Here is a list of changes which are public:
* new output columns: rawPrediction, probabilities
  * The “score” column is now called “rawPrediction”
* Classifiers now provide numClasses
* Params.get and .set are now protected instead of private[ml].
* ParamMap now has a size method.
* new classes: LinearRegression, LinearRegressionModel
* LogisticRegression now has an intercept.

### Sketch of APIs (most of which are private[spark] for now)

Abstract classes for learning algorithms (+ corresponding Model abstractions):
* Classifier (+ ClassificationModel)
* ProbabilisticClassifier (+ ProbabilisticClassificationModel)
* Regressor (+ RegressionModel)
* Predictor (+ PredictionModel)
* *For all of these*:
 * There is no strongly typed training-time API.
 * There is a strongly typed test-time (prediction) API which helps developers implement new algorithms.

Concrete classes: learning algorithms
* LinearRegression
* LogisticRegression (updated to use new abstract classes)
 * Also, removed "score" in favor of "probability" output column.  Changed BinaryClassificationEvaluator to match. (SPARK-5031)

Other updates:
* params.scala: Changed Params.set/get to be protected instead of private[ml]
 * This was needed for the example of defining a class from outside of the MLlib namespace.
* VectorUDT: Will later change from private[spark] to public.
 * This is needed for outside users to write their own validateAndTransformSchema() methods using vectors.
 * Also, added equals() method.f
* SPARK-4942 : ML Transformers should allow output cols to be turned on,off
 * Update validateAndTransformSchema
 * Update transform
* (Updated examples, test suites according to other changes)

New examples:
* DeveloperApiExample.scala (example of defining algorithm from outside of the MLlib namespace)
 * Added Java version too

Test Suites:
* LinearRegressionSuite
* LogisticRegressionSuite
* + Java versions of above suites

CC: mengxr  etrain  shivaram

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #3637 from jkbradley/ml-api-part1 and squashes the following commits:

405bfb8 [Joseph K. Bradley] Last edits based on code review.  Small cleanups
fec348a [Joseph K. Bradley] Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types.
8316d5e [Joseph K. Bradley] fixes after rebasing on master
fc62406 [Joseph K. Bradley] fixed test suites after last commit
bcb9549 [Joseph K. Bradley] Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame)
9872424 [Joseph K. Bradley] fixed JavaLinearRegressionSuite.java Java sql api
f542997 [Joseph K. Bradley] Added MIMA excludes for VectorUDT (now public), and added DeveloperApi annotation to it
216d199 [Joseph K. Bradley] fixed after sql datatypes PR got merged
f549e34 [Joseph K. Bradley] Updates based on code review.  Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT.   * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value.
343e7bd [Joseph K. Bradley] added blanket mima exclude for ml package
82f340b [Joseph K. Bradley] Fixed bug in LogisticRegression (introduced in this PR).  Fixed Java suites
0a16da9 [Joseph K. Bradley] Fixed Linear/Logistic RegressionSuites
c3c8da5 [Joseph K. Bradley] small cleanup
934f97b [Joseph K. Bradley] Fixed bugs from previous commit.
1c61723 [Joseph K. Bradley] * Made ProbabilisticClassificationModel into a subclass of ClassificationModel.  Also introduced ProbabilisticClassifier.  * This was to support output column “probabilityCol” in transform().
4e2f711 [Joseph K. Bradley] rat fix
bc654e1 [Joseph K. Bradley] Added spark.ml LinearRegressionSuite
8d13233 [Joseph K. Bradley] Added methods: * Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others
1680905 [Joseph K. Bradley] Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0
adbe50a [Joseph K. Bradley] * fixed LinearRegression train() to use embedded paramMap * added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites
58802e3 [Joseph K. Bradley] added train() to Predictor subclasses which does not take a ParamMap.
57d54ab [Joseph K. Bradley] * Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap. * remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite
e433872 [Joseph K. Bradley] Updated docs.  Added LabeledPointSuite to spark.ml
54b7b31 [Joseph K. Bradley] Fixed issue with logreg threshold being set correctly
0617d61 [Joseph K. Bradley] Fixed bug from last commit (sorting paramMap by parameter names in toString).  Fixed bug in persisting logreg data.  Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup).
601e792 [Joseph K. Bradley] Modified ParamMap to sort parameters in toString.  Cleaned up classes in class hierarchy, before implementing tests and examples.
d705e87 [Joseph K. Bradley] Added LinearRegression and Regressor back from ml-api branch
52f4fde [Joseph K. Bradley] removing everything except for simple class hierarchy for classification
d35bb5d [Joseph K. Bradley] fixed compilation issues, but have not added tests yet
bfade12 [Joseph K. Bradley] Added lots of classes for new ML API:

(cherry picked from commit dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/45b95e7d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/45b95e7d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/45b95e7d

Branch: refs/heads/branch-1.3
Commit: 45b95e7d23bdfcbe55473c44b1b056e4005d45b0
Parents: c35a11e
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Thu Feb 5 23:43:47 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Thu Feb 5 23:44:02 2015 -0800

----------------------------------------------------------------------
 .../examples/ml/JavaCrossValidatorExample.java  |   6 +-
 .../examples/ml/JavaDeveloperApiExample.java    | 217 +++++++++++++++++
 .../examples/ml/JavaSimpleParamsExample.java    |  10 +-
 .../JavaSimpleTextClassificationPipeline.java   |   4 +-
 .../examples/ml/CrossValidatorExample.scala     |   7 +-
 .../spark/examples/ml/DeveloperApiExample.scala | 184 +++++++++++++++
 .../spark/examples/ml/SimpleParamsExample.scala |  16 +-
 .../ml/SimpleTextClassificationPipeline.scala   |   7 +-
 .../scala/org/apache/spark/ml/Estimator.scala   |   9 +-
 .../spark/ml/classification/Classifier.scala    | 206 ++++++++++++++++
 .../ml/classification/LogisticRegression.scala  | 212 ++++++++++-------
 .../ProbabilisticClassifier.scala               | 147 ++++++++++++
 .../BinaryClassificationEvaluator.scala         |  24 +-
 .../org/apache/spark/ml/feature/Tokenizer.scala |   4 +-
 .../spark/ml/impl/estimator/Predictor.scala     | 234 +++++++++++++++++++
 .../org/apache/spark/ml/param/params.scala      |  68 +++++-
 .../apache/spark/ml/param/sharedParams.scala    |  28 ++-
 .../spark/ml/regression/LinearRegression.scala  |  96 ++++++++
 .../apache/spark/ml/regression/Regressor.scala  |  78 +++++++
 .../org/apache/spark/mllib/linalg/Vectors.scala |  13 ++
 .../org/apache/spark/ml/JavaPipelineSuite.java  |   2 +-
 .../JavaLogisticRegressionSuite.java            |  91 +++++++-
 .../regression/JavaLinearRegressionSuite.java   |  89 +++++++
 .../LogisticRegressionSuite.scala               |  86 ++++++-
 .../ml/regression/LinearRegressionSuite.scala   |  65 ++++++
 project/MimaExcludes.scala                      |   6 +
 26 files changed, 1753 insertions(+), 156 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 0fbee6e..5041e0b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -116,10 +116,12 @@ public class JavaCrossValidatorExample {
 
     // Make predictions on test documents. cvModel uses the best model found (lrModel).
     cvModel.transform(test).registerTempTable("prediction");
-    DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+    DataFrame predictions = jsql.sql("SELECT id, text, probability, prediction FROM prediction");
     for (Row r: predictions.collect()) {
-      System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+      System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
           + ", prediction=" + r.get(3));
     }
+
+    jsc.stop();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
new file mode 100644
index 0000000..42d4d7d
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.Classifier;
+import org.apache.spark.ml.classification.ClassificationModel;
+import org.apache.spark.ml.param.IntParam;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.param.Params;
+import org.apache.spark.ml.param.Params$;
+import org.apache.spark.mllib.linalg.BLAS;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+
+/**
+ * A simple example demonstrating how to write your own learning algorithm using Estimator,
+ * Transformer, and other abstractions.
+ * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}.
+ *
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaDeveloperApiExample
+ * </pre>
+ */
+public class JavaDeveloperApiExample {
+
+  public static void main(String[] args) throws Exception {
+    SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample");
+    JavaSparkContext jsc = new JavaSparkContext(conf);
+    SQLContext jsql = new SQLContext(jsc);
+
+    // Prepare training data.
+    List<LabeledPoint> localTraining = Lists.newArrayList(
+        new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+        new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+        new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+        new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
+    DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+
+    // Create a LogisticRegression instance.  This instance is an Estimator.
+    MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
+    // Print out the parameters, documentation, and any default values.
+    System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n");
+
+    // We may set parameters using setter methods.
+    lr.setMaxIter(10);
+
+    // Learn a LogisticRegression model.  This uses the parameters stored in lr.
+    MyJavaLogisticRegressionModel model = lr.fit(training);
+
+    // Prepare test data.
+    List<LabeledPoint> localTest = Lists.newArrayList(
+        new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+        new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+        new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
+    DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+
+    // Make predictions on test documents. cvModel uses the best model found (lrModel).
+    DataFrame results = model.transform(test);
+    double sumPredictions = 0;
+    for (Row r : results.select("features", "label", "prediction").collect()) {
+      sumPredictions += r.getDouble(2);
+    }
+    if (sumPredictions != 0.0) {
+      throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
+          " even though all weights are 0!");
+    }
+
+    jsc.stop();
+  }
+}
+
+/**
+ * Example of defining a type of {@link Classifier}.
+ *
+ * NOTE: This is private since it is an example.  In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegression
+    extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel>
+    implements Params {
+
+  /**
+   * Param for max number of iterations
+   * <p/>
+   * NOTE: The usual way to add a parameter to a model or algorithm is to include:
+   * - val myParamName: ParamType
+   * - def getMyParamName
+   * - def setMyParamName
+   */
+  IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
+
+  int getMaxIter() { return (int)get(maxIter); }
+
+  public MyJavaLogisticRegression() {
+    setMaxIter(100);
+  }
+
+  // The parameter setter is in this class since it should return type MyJavaLogisticRegression.
+  MyJavaLogisticRegression setMaxIter(int value) {
+    return (MyJavaLogisticRegression)set(maxIter, value);
+  }
+
+  // This method is used by fit().
+  // In Java, we have to make it public since Java does not understand Scala's protected modifier.
+  public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+    // Extract columns from data using helper method.
+    JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+
+    // Do learning to estimate the weight vector.
+    int numFeatures = oldDataset.take(1).get(0).features().size();
+    Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
+
+    // Create a model, and return it.
+    return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+  }
+}
+
+/**
+ * Example of defining a type of {@link ClassificationModel}.
+ *
+ * NOTE: This is private since it is an example.  In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegressionModel
+    extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
+
+  private MyJavaLogisticRegression parent_;
+  public MyJavaLogisticRegression parent() { return parent_; }
+
+  private ParamMap fittingParamMap_;
+  public ParamMap fittingParamMap() { return fittingParamMap_; }
+
+  private Vector weights_;
+  public Vector weights() { return weights_; }
+
+  public MyJavaLogisticRegressionModel(
+      MyJavaLogisticRegression parent_,
+      ParamMap fittingParamMap_,
+      Vector weights_) {
+    this.parent_ = parent_;
+    this.fittingParamMap_ = fittingParamMap_;
+    this.weights_ = weights_;
+  }
+
+  // This uses the default implementation of transform(), which reads column "features" and outputs
+  // columns "prediction" and "rawPrediction."
+
+  // This uses the default implementation of predict(), which chooses the label corresponding to
+  // the maximum value returned by [[predictRaw()]].
+
+  /**
+   * Raw prediction for each possible label.
+   * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+   * a measure of confidence in each possible label (where larger = more confident).
+   * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+   *
+   * @return vector where element i is the raw prediction for label i.
+   * This raw prediction may be any real number, where a larger value indicates greater
+   * confidence for that label.
+   *
+   * In Java, we have to make this method public since Java does not understand Scala's protected
+   * modifier.
+   */
+  public Vector predictRaw(Vector features) {
+    double margin = BLAS.dot(features, weights_);
+    // There are 2 classes (binary classification), so we return a length-2 vector,
+    // where index i corresponds to class i (i = 0, 1).
+    return Vectors.dense(-margin, margin);
+  }
+
+  /**
+   * Number of classes the label can take.  2 indicates binary classification.
+   */
+  public int numClasses() { return 2; }
+
+  /**
+   * Create a copy of the model.
+   * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+   * <p/>
+   * This is used for the defaul implementation of [[transform()]].
+   *
+   * In Java, we have to make this method public since Java does not understand Scala's protected
+   * modifier.
+   */
+  public MyJavaLogisticRegressionModel copy() {
+    MyJavaLogisticRegressionModel m =
+        new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
+    Params$.MODULE$.inheritValues(this.paramMap(), this, m);
+    return m;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index eaaa344..cc69e63 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -81,7 +81,7 @@ public class JavaSimpleParamsExample {
 
     // One can also combine ParamMaps.
     ParamMap paramMap2 = new ParamMap();
-    paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
+    paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name
     ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
 
     // Now learn a new model using the paramMapCombined parameters.
@@ -98,14 +98,16 @@ public class JavaSimpleParamsExample {
 
     // Make predictions on test documents using the Transformer.transform() method.
     // LogisticRegression.transform will only use the 'features' column.
-    // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
-    // column since we renamed the lr.scoreCol parameter previously.
+    // Note that model2.transform() outputs a 'myProbability' column instead of the usual
+    // 'probability' column since we renamed the lr.probabilityCol parameter previously.
     model2.transform(test).registerTempTable("results");
     DataFrame results =
-        jsql.sql("SELECT features, label, probability, prediction FROM results");
+        jsql.sql("SELECT features, label, myProbability, prediction FROM results");
     for (Row r: results.collect()) {
       System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
           + ", prediction=" + r.get(3));
     }
+
+    jsc.stop();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index 82d665a..d929f1a 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -85,8 +85,10 @@ public class JavaSimpleTextClassificationPipeline {
     model.transform(test).registerTempTable("prediction");
     DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
     for (Row r: predictions.collect()) {
-      System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+      System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
           + ", prediction=" + r.get(3));
     }
+
+    jsc.stop();
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
index b6c30a0..a2893f7 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.classification.LogisticRegression
 import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
 import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
 import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.sql.{Row, SQLContext}
 
 /**
@@ -100,10 +101,10 @@ object CrossValidatorExample {
 
     // Make predictions on test documents. cvModel uses the best model found (lrModel).
     cvModel.transform(test)
-      .select("id", "text", "score", "prediction")
+      .select("id", "text", "probability", "prediction")
       .collect()
-      .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
-      println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+      .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+      println(s"($id, $text) --> prob=$prob, prediction=$prediction")
     }
 
     sc.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
new file mode 100644
index 0000000..aed4423
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
+import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+/**
+ * A simple example demonstrating how to write your own learning algorithm using Estimator,
+ * Transformer, and other abstractions.
+ * This mimics [[org.apache.spark.ml.classification.LogisticRegression]].
+ * Run with
+ * {{{
+ * bin/run-example ml.DeveloperApiExample
+ * }}}
+ */
+object DeveloperApiExample {
+
+  def main(args: Array[String]) {
+    val conf = new SparkConf().setAppName("DeveloperApiExample")
+    val sc = new SparkContext(conf)
+    val sqlContext = new SQLContext(sc)
+    import sqlContext.implicits._
+
+    // Prepare training data.
+    val training = sc.parallelize(Seq(
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+      LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+      LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))
+
+    // Create a LogisticRegression instance.  This instance is an Estimator.
+    val lr = new MyLogisticRegression()
+    // Print out the parameters, documentation, and any default values.
+    println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+    // We may set parameters using setter methods.
+    lr.setMaxIter(10)
+
+    // Learn a LogisticRegression model.  This uses the parameters stored in lr.
+    val model = lr.fit(training)
+
+    // Prepare test data.
+    val test = sc.parallelize(Seq(
+      LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+      LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+      LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
+
+    // Make predictions on test data.
+    val sumPredictions: Double = model.transform(test)
+      .select("features", "label", "prediction")
+      .collect()
+      .map { case Row(features: Vector, label: Double, prediction: Double) =>
+        prediction
+      }.sum
+    assert(sumPredictions == 0.0,
+      "MyLogisticRegression predicted something other than 0, even though all weights are 0!")
+
+    sc.stop()
+  }
+}
+
+/**
+ * Example of defining a parameter trait for a user-defined type of [[Classifier]].
+ *
+ * NOTE: This is private since it is an example.  In practice, you may not want it to be private.
+ */
+private trait MyLogisticRegressionParams extends ClassifierParams {
+
+  /**
+   * Param for max number of iterations
+   *
+   * NOTE: The usual way to add a parameter to a model or algorithm is to include:
+   *   - val myParamName: ParamType
+   *   - def getMyParamName
+   *   - def setMyParamName
+   * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression
+   * and MyLogisticRegressionModel).  We place the setter (setMaxIter) method in the Estimator
+   * class since the maxIter parameter is only used during training (not in the Model).
+   */
+  val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+  def getMaxIter: Int = get(maxIter)
+}
+
+/**
+ * Example of defining a type of [[Classifier]].
+ *
+ * NOTE: This is private since it is an example.  In practice, you may not want it to be private.
+ */
+private class MyLogisticRegression
+  extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel]
+  with MyLogisticRegressionParams {
+
+  setMaxIter(100) // Initialize
+
+  // The parameter setter is in this class since it should return type MyLogisticRegression.
+  def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+  // This method is used by fit()
+  override protected def train(
+      dataset: DataFrame,
+      paramMap: ParamMap): MyLogisticRegressionModel = {
+    // Extract columns from data using helper method.
+    val oldDataset = extractLabeledPoints(dataset, paramMap)
+
+    // Do learning to estimate the weight vector.
+    val numFeatures = oldDataset.take(1)(0).features.size
+    val weights = Vectors.zeros(numFeatures) // Learning would happen here.
+
+    // Create a model, and return it.
+    new MyLogisticRegressionModel(this, paramMap, weights)
+  }
+}
+
+/**
+ * Example of defining a type of [[ClassificationModel]].
+ *
+ * NOTE: This is private since it is an example.  In practice, you may not want it to be private.
+ */
+private class MyLogisticRegressionModel(
+    override val parent: MyLogisticRegression,
+    override val fittingParamMap: ParamMap,
+    val weights: Vector)
+  extends ClassificationModel[Vector, MyLogisticRegressionModel]
+  with MyLogisticRegressionParams {
+
+  // This uses the default implementation of transform(), which reads column "features" and outputs
+  // columns "prediction" and "rawPrediction."
+
+  // This uses the default implementation of predict(), which chooses the label corresponding to
+  // the maximum value returned by [[predictRaw()]].
+
+  /**
+   * Raw prediction for each possible label.
+   * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+   * a measure of confidence in each possible label (where larger = more confident).
+   * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+   *
+   * @return  vector where element i is the raw prediction for label i.
+   *          This raw prediction may be any real number, where a larger value indicates greater
+   *          confidence for that label.
+   */
+  override protected def predictRaw(features: Vector): Vector = {
+    val margin = BLAS.dot(features, weights)
+    // There are 2 classes (binary classification), so we return a length-2 vector,
+    // where index i corresponds to class i (i = 0, 1).
+    Vectors.dense(-margin, margin)
+  }
+
+  /** Number of classes the label can take.  2 indicates binary classification. */
+  override val numClasses: Int = 2
+
+  /**
+   * Create a copy of the model.
+   * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+   *
+   * This is used for the defaul implementation of [[transform()]].
+   */
+  override protected def copy(): MyLogisticRegressionModel = {
+    val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
+    Params.inheritValues(this.paramMap, this, m)
+    m
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 4d1530c..80c9f5f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -72,7 +72,7 @@ object SimpleParamsExample {
     paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
 
     // One can also combine ParamMaps.
-    val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name
+    val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
     val paramMapCombined = paramMap ++ paramMap2
 
     // Now learn a new model using the paramMapCombined parameters.
@@ -80,21 +80,21 @@ object SimpleParamsExample {
     val model2 = lr.fit(training, paramMapCombined)
     println("Model 2 was fit using parameters: " + model2.fittingParamMap)
 
-    // Prepare test documents.
+    // Prepare test data.
     val test = sc.parallelize(Seq(
       LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
       LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
       LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
 
-    // Make predictions on test documents using the Transformer.transform() method.
+    // Make predictions on test data using the Transformer.transform() method.
     // LogisticRegression.transform will only use the 'features' column.
-    // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
-    // column since we renamed the lr.scoreCol parameter previously.
+    // Note that model2.transform() outputs a 'myProbability' column instead of the usual
+    // 'probability' column since we renamed the lr.probabilityCol parameter previously.
     model2.transform(test)
-      .select("features", "label", "probability", "prediction")
+      .select("features", "label", "myProbability", "prediction")
       .collect()
-      .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
-        println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
+      .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
+        println("($features, $label) -> prob=$prob, prediction=$prediction")
       }
 
     sc.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
----------------------------------------------------------------------
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index dbbe01d..968cb29 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{SparkConf, SparkContext}
 import org.apache.spark.ml.Pipeline
 import org.apache.spark.ml.classification.LogisticRegression
 import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.mllib.linalg.Vector
 import org.apache.spark.sql.{Row, SQLContext}
 
 @BeanInfo
@@ -79,10 +80,10 @@ object SimpleTextClassificationPipeline {
 
     // Make predictions on test documents.
     model.transform(test)
-      .select("id", "text", "score", "prediction")
+      .select("id", "text", "probability", "prediction")
       .collect()
-      .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
-        println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+      .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+        println("($id, $text) --> prob=$prob, prediction=$prediction")
       }
 
     sc.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index bc3defe..eff7ef9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -34,7 +34,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
    * Fits a single model to the input data with optional parameters.
    *
    * @param dataset input dataset
-   * @param paramPairs optional list of param pairs (overwrite embedded params)
+   * @param paramPairs Optional list of param pairs.
+   *                   These values override any specified in this Estimator's embedded ParamMap.
    * @return fitted model
    */
   @varargs
@@ -47,7 +48,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
    * Fits a single model to the input data with provided parameter map.
    *
    * @param dataset input dataset
-   * @param paramMap parameter map
+   * @param paramMap Parameter map.
+   *                 These values override any specified in this Estimator's embedded ParamMap.
    * @return fitted model
    */
   def fit(dataset: DataFrame, paramMap: ParamMap): M
@@ -58,7 +60,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
    * Subclasses could overwrite this to optimize multi-model training.
    *
    * @param dataset input dataset
-   * @param paramMaps an array of parameter maps
+   * @param paramMaps An array of parameter maps.
+   *                  These values override any specified in this Estimator's embedded ParamMap.
    * @return fitted models, matching the input parameter maps
    */
   def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
new file mode 100644
index 0000000..1bf8eb4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+
+/**
+ * :: DeveloperApi ::
+ * Params for classification.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait ClassifierParams extends PredictorParams
+  with HasRawPredictionCol {
+
+  override protected def validateAndTransformSchema(
+      schema: StructType,
+      paramMap: ParamMap,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
+    val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
+    val map = this.paramMap ++ paramMap
+    addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
+  }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Single-label binary or multiclass classification.
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
+ * @tparam E  Concrete Estimator type
+ * @tparam M  Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Classifier[
+    FeaturesType,
+    E <: Classifier[FeaturesType, E, M],
+    M <: ClassificationModel[FeaturesType, M]]
+  extends Predictor[FeaturesType, E, M]
+  with ClassifierParams {
+
+  def setRawPredictionCol(value: String): E =
+    set(rawPredictionCol, value).asInstanceOf[E]
+
+  // TODO: defaultEvaluator (follow-up PR)
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model produced by a [[Classifier]].
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
+ * @tparam M  Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark]
+abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
+  extends PredictionModel[FeaturesType, M] with ClassifierParams {
+
+  def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]
+
+  /** Number of classes (values which the label can take). */
+  def numClasses: Int
+
+  /**
+   * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+   * parameters:
+   *  - predicted labels as [[predictionCol]] of type [[Double]]
+   *  - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]].
+   *
+   * @param dataset input dataset
+   * @param paramMap additional parameters, overwrite embedded params
+   * @return transformed dataset
+   */
+  override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+    // This default implementation should be overridden as needed.
+
+    // Check schema
+    transformSchema(dataset.schema, paramMap, logging = true)
+    val map = this.paramMap ++ paramMap
+
+    // Prepare model
+    val tmpModel = if (paramMap.size != 0) {
+      val tmpModel = this.copy()
+      Params.inheritValues(paramMap, parent, tmpModel)
+      tmpModel
+    } else {
+      this
+    }
+
+    val (numColsOutput, outputData) =
+      ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+    if (numColsOutput == 0) {
+      logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
+        " since no output columns were set.")
+    }
+    outputData
+  }
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Predict label for the given features.
+   * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+   *
+   * This default implementation for classification predicts the index of the maximum value
+   * from [[predictRaw()]].
+   */
+  @DeveloperApi
+  override protected def predict(features: FeaturesType): Double = {
+    predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2
+  }
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Raw prediction for each possible label.
+   * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+   * a measure of confidence in each possible label (where larger = more confident).
+   * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+   *
+   * @return  vector where element i is the raw prediction for label i.
+   *          This raw prediction may be any real number, where a larger value indicates greater
+   *          confidence for that label.
+   */
+  @DeveloperApi
+  protected def predictRaw(features: FeaturesType): Vector
+
+}
+
+private[ml] object ClassificationModel {
+
+  /**
+   * Added prediction column(s).  This is separated from [[ClassificationModel.transform()]]
+   * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
+   * @param dataset  Input dataset
+   * @param map  Parameter map.  This will NOT be merged with the embedded paramMap; the merge
+   *             should already be done.
+   * @return (number of columns added, transformed dataset)
+   */
+  def transformColumnsImpl[FeaturesType](
+      dataset: DataFrame,
+      model: ClassificationModel[FeaturesType, _],
+      map: ParamMap): (Int, DataFrame) = {
+
+    // Output selected columns only.
+    // This is a bit complicated since it tries to avoid repeated computation.
+    var tmpData = dataset
+    var numColsOutput = 0
+    if (map(model.rawPredictionCol) != "") {
+      // output raw prediction
+      val features2raw: FeaturesType => Vector = model.predictRaw
+      tmpData = tmpData.select($"*",
+        callUDF(features2raw, new VectorUDT,
+          col(map(model.featuresCol))).as(map(model.rawPredictionCol)))
+      numColsOutput += 1
+      if (map(model.predictionCol) != "") {
+        val raw2pred: Vector => Double = (rawPred) => {
+          rawPred.toArray.zipWithIndex.maxBy(_._1)._2
+        }
+        tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType,
+          col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
+        numColsOutput += 1
+      }
+    } else if (map(model.predictionCol) != "") {
+      // output prediction
+      val features2pred: FeaturesType => Double = model.predict
+      tmpData = tmpData.select($"*",
+        callUDF(features2pred, DoubleType,
+          col(map(model.featuresCol))).as(map(model.predictionCol)))
+      numColsOutput += 1
+    }
+    (numColsOutput, tmpData)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index b46a5cd..c146fe2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -18,61 +18,32 @@
 package org.apache.spark.ml.classification
 
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml._
 import org.apache.spark.ml.param._
 import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql._
+import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
+import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.Dsl._
-import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+import org.apache.spark.sql.types.DoubleType
 import org.apache.spark.storage.StorageLevel
 
+
 /**
- * :: AlphaComponent ::
  * Params for logistic regression.
  */
-@AlphaComponent
-private[classification] trait LogisticRegressionParams extends Params
-  with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
-  with HasScoreCol with HasPredictionCol {
+private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
+  with HasRegParam with HasMaxIter with HasThreshold
 
-  /**
-   * Validates and transforms the input schema with the provided param map.
-   * @param schema input schema
-   * @param paramMap additional parameters
-   * @param fitting whether this is in fitting
-   * @return output schema
-   */
-  protected def validateAndTransformSchema(
-      schema: StructType,
-      paramMap: ParamMap,
-      fitting: Boolean): StructType = {
-    val map = this.paramMap ++ paramMap
-    val featuresType = schema(map(featuresCol)).dataType
-    // TODO: Support casting Array[Double] and Array[Float] to Vector.
-    require(featuresType.isInstanceOf[VectorUDT],
-      s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
-    if (fitting) {
-      val labelType = schema(map(labelCol)).dataType
-      require(labelType == DoubleType,
-        s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
-    }
-    val fieldNames = schema.fieldNames
-    require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
-    require(!fieldNames.contains(map(predictionCol)),
-      s"Prediction column ${map(predictionCol)} already exists.")
-    val outputFields = schema.fields ++ Seq(
-      StructField(map(scoreCol), DoubleType, false),
-      StructField(map(predictionCol), DoubleType, false))
-    StructType(outputFields)
-  }
-}
 
 /**
+ * :: AlphaComponent ::
+ *
  * Logistic regression.
+ * Currently, this class only supports binary classification.
  */
-class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams {
+@AlphaComponent
+class LogisticRegression
+  extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
+  with LogisticRegressionParams {
 
   setRegParam(0.1)
   setMaxIter(100)
@@ -80,68 +51,151 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
 
   def setRegParam(value: Double): this.type = set(regParam, value)
   def setMaxIter(value: Int): this.type = set(maxIter, value)
-  def setLabelCol(value: String): this.type = set(labelCol, value)
   def setThreshold(value: Double): this.type = set(threshold, value)
-  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
-  def setScoreCol(value: String): this.type = set(scoreCol, value)
-  def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
-  override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
-    transformSchema(dataset.schema, paramMap, logging = true)
-    val map = this.paramMap ++ paramMap
-    val instances = dataset.select(map(labelCol), map(featuresCol))
-      .map { case Row(label: Double, features: Vector) =>
-        LabeledPoint(label, features)
-      }.persist(StorageLevel.MEMORY_AND_DISK)
+  override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
+    // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
+    val oldDataset = extractLabeledPoints(dataset, paramMap)
+    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+    if (handlePersistence) {
+      oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+    }
+
+    // Train model
     val lr = new LogisticRegressionWithLBFGS
     lr.optimizer
-      .setRegParam(map(regParam))
-      .setNumIterations(map(maxIter))
-    val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
-    instances.unpersist()
-    // copy model params
-    Params.inheritValues(map, this, lrm)
-    lrm
-  }
+      .setRegParam(paramMap(regParam))
+      .setNumIterations(paramMap(maxIter))
+    val oldModel = lr.run(oldDataset)
+    val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
-    validateAndTransformSchema(schema, paramMap, fitting = true)
+    if (handlePersistence) {
+      oldDataset.unpersist()
+    }
+    lrm
   }
 }
 
+
 /**
  * :: AlphaComponent ::
+ *
  * Model produced by [[LogisticRegression]].
  */
 @AlphaComponent
 class LogisticRegressionModel private[ml] (
     override val parent: LogisticRegression,
     override val fittingParamMap: ParamMap,
-    weights: Vector)
-  extends Model[LogisticRegressionModel] with LogisticRegressionParams {
+    val weights: Vector,
+    val intercept: Double)
+  extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
+  with LogisticRegressionParams {
+
+  setThreshold(0.5)
 
   def setThreshold(value: Double): this.type = set(threshold, value)
-  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
-  def setScoreCol(value: String): this.type = set(scoreCol, value)
-  def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
-    validateAndTransformSchema(schema, paramMap, fitting = false)
+  private val margin: Vector => Double = (features) => {
+    BLAS.dot(features, weights) + intercept
+  }
+
+  private val score: Vector => Double = (features) => {
+    val m = margin(features)
+    1.0 / (1.0 + math.exp(-m))
   }
 
   override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+    // This is overridden (a) to be more efficient (avoiding re-computing values when creating
+    // multiple output columns) and (b) to handle threshold, which the abstractions do not use.
+    // TODO: We should abstract away the steps defined by UDFs below so that the abstractions
+    // can call whichever UDFs are needed to create the output columns.
+
+    // Check schema
     transformSchema(dataset.schema, paramMap, logging = true)
+
     val map = this.paramMap ++ paramMap
-    val scoreFunction = udf { v: Vector =>
-      val margin = BLAS.dot(v, weights)
-      1.0 / (1.0 + math.exp(-margin))
+
+    // Output selected columns only.
+    // This is a bit complicated since it tries to avoid repeated computation.
+    //   rawPrediction (-margin, margin)
+    //   probability (1.0-score, score)
+    //   prediction (max margin)
+    var tmpData = dataset
+    var numColsOutput = 0
+    if (map(rawPredictionCol) != "") {
+      val features2raw: Vector => Vector = (features) => predictRaw(features)
+      tmpData = tmpData.select($"*",
+        callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
+      numColsOutput += 1
+    }
+    if (map(probabilityCol) != "") {
+      if (map(rawPredictionCol) != "") {
+        val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
+          val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+          Vectors.dense(1.0 - prob1, prob1)
+        }
+        tmpData = tmpData.select($"*",
+          callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
+      } else {
+        val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
+        tmpData = tmpData.select($"*",
+          callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+      }
+      numColsOutput += 1
     }
-    val t = map(threshold)
-    val predictFunction = udf { score: Double =>
-      if (score > t) 1.0 else 0.0
+    if (map(predictionCol) != "") {
+      val t = map(threshold)
+      if (map(probabilityCol) != "") {
+        val predict: Vector => Double = { probs: Vector =>
+          if (probs(1) > t) 1.0 else 0.0
+        }
+        tmpData = tmpData.select($"*",
+          callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
+      } else if (map(rawPredictionCol) != "") {
+        val predict: Vector => Double = { rawPreds: Vector =>
+          val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+          if (prob1 > t) 1.0 else 0.0
+        }
+        tmpData = tmpData.select($"*",
+          callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
+      } else {
+        val predict: Vector => Double = (features: Vector) => this.predict(features)
+        tmpData = tmpData.select($"*",
+          callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+      }
+      numColsOutput += 1
     }
-    dataset
-      .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
-      .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))
+    if (numColsOutput == 0) {
+      this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
+        " since no output columns were set.")
+    }
+    tmpData
+  }
+
+  override val numClasses: Int = 2
+
+  /**
+   * Predict label for the given feature vector.
+   * The behavior of this can be adjusted using [[threshold]].
+   */
+  override protected def predict(features: Vector): Double = {
+    println(s"LR.predict with threshold: ${paramMap(threshold)}")
+    if (score(features) > paramMap(threshold)) 1 else 0
+  }
+
+  override protected def predictProbabilities(features: Vector): Vector = {
+    val s = score(features)
+    Vectors.dense(1.0 - s, s)
+  }
+
+  override protected def predictRaw(features: Vector): Vector = {
+    val m = margin(features)
+    Vectors.dense(0.0, m)
+  }
+
+  override protected def copy(): LogisticRegressionModel = {
+    val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
+    Params.inheritValues(this.paramMap, this, m)
+    m
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
new file mode 100644
index 0000000..1202528
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.types.{DataType, StructType}
+
+
+/**
+ * Params for probabilistic classification.
+ */
+private[classification] trait ProbabilisticClassifierParams
+  extends ClassifierParams with HasProbabilityCol {
+
+  override protected def validateAndTransformSchema(
+      schema: StructType,
+      paramMap: ParamMap,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
+    val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
+    val map = this.paramMap ++ paramMap
+    addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT)
+  }
+}
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Single-label binary or multiclass classifier which can output class conditional probabilities.
+ *
+ * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
+ * @tparam E  Concrete Estimator type
+ * @tparam M  Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class ProbabilisticClassifier[
+    FeaturesType,
+    E <: ProbabilisticClassifier[FeaturesType, E, M],
+    M <: ProbabilisticClassificationModel[FeaturesType, M]]
+  extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams {
+
+  def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
+}
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by a [[ProbabilisticClassifier]].
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType  Type of input features.  E.g., [[Vector]]
+ * @tparam M  Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class ProbabilisticClassificationModel[
+    FeaturesType,
+    M <: ProbabilisticClassificationModel[FeaturesType, M]]
+  extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
+
+  def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
+
+  /**
+   * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+   * parameters:
+   *  - predicted labels as [[predictionCol]] of type [[Double]]
+   *  - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]
+   *  - probability of each class as [[probabilityCol]] of type [[Vector]].
+   *
+   * @param dataset input dataset
+   * @param paramMap additional parameters, overwrite embedded params
+   * @return transformed dataset
+   */
+  override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+    // This default implementation should be overridden as needed.
+
+    // Check schema
+    transformSchema(dataset.schema, paramMap, logging = true)
+    val map = this.paramMap ++ paramMap
+
+    // Prepare model
+    val tmpModel = if (paramMap.size != 0) {
+      val tmpModel = this.copy()
+      Params.inheritValues(paramMap, parent, tmpModel)
+      tmpModel
+    } else {
+      this
+    }
+
+    val (numColsOutput, outputData) =
+      ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+
+    // Output selected columns only.
+    if (map(probabilityCol) != "") {
+      // output probabilities
+      val features2probs: FeaturesType => Vector = (features) => {
+        tmpModel.predictProbabilities(features)
+      }
+      outputData.select($"*",
+        callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+    } else {
+      if (numColsOutput == 0) {
+        this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
+          " since no output columns were set.")
+      }
+      outputData
+    }
+  }
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Predict the probability of each class given the features.
+   * These predictions are also called class conditional probabilities.
+   *
+   * WARNING: Not all models output well-calibrated probability estimates!  These probabilities
+   *          should be treated as confidences, not precise probabilities.
+   *
+   * This internal method is used to implement [[transform()]] and output [[probabilityCol]].
+   */
+  @DeveloperApi
+  protected def predictProbabilities(features: FeaturesType): Vector
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 1979ab9..f21a306 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -18,19 +18,22 @@
 package org.apache.spark.ml.evaluation
 
 import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml._
+import org.apache.spark.ml.Evaluator
 import org.apache.spark.ml.param._
 import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql.{DataFrame, Row}
 import org.apache.spark.sql.types.DoubleType
 
+
 /**
  * :: AlphaComponent ::
+ *
  * Evaluator for binary classification, which expects two input columns: score and label.
  */
 @AlphaComponent
 class BinaryClassificationEvaluator extends Evaluator with Params
-  with HasScoreCol with HasLabelCol {
+  with HasRawPredictionCol with HasLabelCol {
 
   /** param for metric name in evaluation */
   val metricName: Param[String] = new Param(this, "metricName",
@@ -38,23 +41,20 @@ class BinaryClassificationEvaluator extends Evaluator with Params
   def getMetricName: String = get(metricName)
   def setMetricName(value: String): this.type = set(metricName, value)
 
-  def setScoreCol(value: String): this.type = set(scoreCol, value)
+  def setScoreCol(value: String): this.type = set(rawPredictionCol, value)
   def setLabelCol(value: String): this.type = set(labelCol, value)
 
   override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
     val map = this.paramMap ++ paramMap
 
     val schema = dataset.schema
-    val scoreType = schema(map(scoreCol)).dataType
-    require(scoreType == DoubleType,
-      s"Score column ${map(scoreCol)} must be double type but found $scoreType")
-    val labelType = schema(map(labelCol)).dataType
-    require(labelType == DoubleType,
-      s"Label column ${map(labelCol)} must be double type but found $labelType")
+    checkInputColumn(schema, map(rawPredictionCol), new VectorUDT)
+    checkInputColumn(schema, map(labelCol), DoubleType)
 
-    val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol))
-      .map { case Row(score: Double, label: Double) =>
-        (score, label)
+    // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
+    val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
+      .map { case Row(rawPrediction: Vector, label: Double) =>
+        (rawPrediction(1), label)
       }
     val metrics = new BinaryClassificationMetrics(scoreAndLabels)
     val metric = map(metricName) match {

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index e622a5c..0b1f90d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
 @AlphaComponent
 class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
 
-  protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
+  override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
     _.toLowerCase.split("\\s")
   }
 
-  protected override def validateInputType(inputType: DataType): Unit = {
+  override protected def validateInputType(inputType: DataType): Unit = {
     require(inputType == StringType, s"Input type must be string type but got $inputType.")
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
new file mode 100644
index 0000000..89b53f3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.estimator
+
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for parameters for prediction (regression and classification).
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait PredictorParams extends Params
+  with HasLabelCol with HasFeaturesCol with HasPredictionCol {
+
+  /**
+   * Validates and transforms the input schema with the provided param map.
+   * @param schema input schema
+   * @param paramMap additional parameters
+   * @param fitting whether this is in fitting
+   * @param featuresDataType  SQL DataType for FeaturesType.
+   *                          E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+   * @return output schema
+   */
+  protected def validateAndTransformSchema(
+      schema: StructType,
+      paramMap: ParamMap,
+      fitting: Boolean,
+      featuresDataType: DataType): StructType = {
+    val map = this.paramMap ++ paramMap
+    // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
+    checkInputColumn(schema, map(featuresCol), featuresDataType)
+    if (fitting) {
+      // TODO: Allow other numeric types
+      checkInputColumn(schema, map(labelCol), DoubleType)
+    }
+    addOutputColumn(schema, map(predictionCol), DoubleType)
+  }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for prediction problems (regression and classification).
+ *
+ * @tparam FeaturesType  Type of features.
+ *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam Learner  Specialization of this class.  If you subclass this type, use this type
+ *                  parameter to specify the concrete type.
+ * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
+ *            parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Predictor[
+    FeaturesType,
+    Learner <: Predictor[FeaturesType, Learner, M],
+    M <: PredictionModel[FeaturesType, M]]
+  extends Estimator[M] with PredictorParams {
+
+  def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+  def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+  def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
+
+  override def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+    // This handles a few items such as schema validation.
+    // Developers only need to implement train().
+    transformSchema(dataset.schema, paramMap, logging = true)
+    val map = this.paramMap ++ paramMap
+    val model = train(dataset, map)
+    Params.inheritValues(map, this, model) // copy params to model
+    model
+  }
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Train a model using the given dataset and parameters.
+   * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
+   * and copying parameters into the model.
+   *
+   * @param dataset  Training dataset
+   * @param paramMap  Parameter map.  Unlike [[fit()]]'s paramMap, this paramMap has already
+   *                  been combined with the embedded ParamMap.
+   * @return  Fitted model
+   */
+  @DeveloperApi
+  protected def train(dataset: DataFrame, paramMap: ParamMap): M
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+   *
+   * This is used by [[validateAndTransformSchema()]].
+   * This workaround is needed since SQL has different APIs for Scala and Java.
+   *
+   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+   */
+  @DeveloperApi
+  protected def featuresDataType: DataType = new VectorUDT
+
+  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+    validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
+  }
+
+  /**
+   * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+   * and put it in an RDD with strong types.
+   */
+  protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = {
+    val map = this.paramMap ++ paramMap
+    dataset.select(map(labelCol), map(featuresCol))
+      .map { case Row(label: Double, features: Vector) =>
+      LabeledPoint(label, features)
+    }
+  }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for a model for prediction tasks (regression and classification).
+ *
+ * @tparam FeaturesType  Type of features.
+ *                       E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam M  Specialization of [[PredictionModel]].  If you subclass this type, use this type
+ *            parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
+  extends Model[M] with PredictorParams {
+
+  def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+
+  def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+   *
+   * This is used by [[validateAndTransformSchema()]].
+   * This workaround is needed since SQL has different APIs for Scala and Java.
+   *
+   * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+   */
+  @DeveloperApi
+  protected def featuresDataType: DataType = new VectorUDT
+
+  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+    validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
+  }
+
+  /**
+   * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing
+   * the predictions as a new column [[predictionCol]].
+   *
+   * @param dataset input dataset
+   * @param paramMap additional parameters, overwrite embedded params
+   * @return transformed dataset with [[predictionCol]] of type [[Double]]
+   */
+  override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+    // This default implementation should be overridden as needed.
+
+    // Check schema
+    transformSchema(dataset.schema, paramMap, logging = true)
+    val map = this.paramMap ++ paramMap
+
+    // Prepare model
+    val tmpModel = if (paramMap.size != 0) {
+      val tmpModel = this.copy()
+      Params.inheritValues(paramMap, parent, tmpModel)
+      tmpModel
+    } else {
+      this
+    }
+
+    if (map(predictionCol) != "") {
+      val pred: FeaturesType => Double = (features) => {
+        tmpModel.predict(features)
+      }
+      dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+    } else {
+      this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
+        " since no output columns were set.")
+      dataset
+    }
+  }
+
+  /**
+   * :: DeveloperApi ::
+   *
+   * Predict label for the given features.
+   * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+   */
+  @DeveloperApi
+  protected def predict(features: FeaturesType): Double
+
+  /**
+   * Create a copy of the model.
+   * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+   */
+  protected def copy(): M
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 5fb4379..17ece89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -22,8 +22,10 @@ import scala.collection.mutable
 
 import java.lang.reflect.Modifier
 
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
 import org.apache.spark.ml.Identifiable
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+
 
 /**
  * :: AlphaComponent ::
@@ -65,37 +67,47 @@ class Param[T] (
 // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
 
 /** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None)
+class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double])
   extends Param[Double](parent, name, doc, defaultValue) {
 
+  def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
   override def w(value: Double): ParamPair[Double] = super.w(value)
 }
 
 /** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None)
+class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int])
   extends Param[Int](parent, name, doc, defaultValue) {
 
+  def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
   override def w(value: Int): ParamPair[Int] = super.w(value)
 }
 
 /** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None)
+class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float])
   extends Param[Float](parent, name, doc, defaultValue) {
 
+  def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
   override def w(value: Float): ParamPair[Float] = super.w(value)
 }
 
 /** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None)
+class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long])
   extends Param[Long](parent, name, doc, defaultValue) {
 
+  def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
   override def w(value: Long): ParamPair[Long] = super.w(value)
 }
 
 /** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None)
+class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean])
   extends Param[Boolean](parent, name, doc, defaultValue) {
 
+  def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
   override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
 }
 
@@ -158,7 +170,7 @@ trait Params extends Identifiable with Serializable {
   /**
    * Sets a parameter in the embedded param map.
    */
-  private[ml] def set[T](param: Param[T], value: T): this.type = {
+  protected def set[T](param: Param[T], value: T): this.type = {
     require(param.parent.eq(this))
     paramMap.put(param.asInstanceOf[Param[Any]], value)
     this
@@ -174,7 +186,7 @@ trait Params extends Identifiable with Serializable {
   /**
    * Gets the value of a parameter in the embedded param map.
    */
-  private[ml] def get[T](param: Param[T]): T = {
+  protected def get[T](param: Param[T]): T = {
     require(param.parent.eq(this))
     paramMap(param)
   }
@@ -183,9 +195,40 @@ trait Params extends Identifiable with Serializable {
    * Internal param map.
    */
   protected val paramMap: ParamMap = ParamMap.empty
+
+  /**
+   * Check whether the given schema contains an input column.
+   * @param colName  Parameter name for the input column.
+   * @param dataType  SQL DataType of the input column.
+   */
+  protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = {
+    val actualDataType = schema(colName).dataType
+    require(actualDataType.equals(dataType),
+      s"Input column $colName must be of type $dataType" +
+        s" but was actually $actualDataType.  Column param description: ${getParam(colName)}")
+  }
+
+  protected def addOutputColumn(
+      schema: StructType,
+      colName: String,
+      dataType: DataType): StructType = {
+    if (colName.length == 0) return schema
+    val fieldNames = schema.fieldNames
+    require(!fieldNames.contains(colName), s"Prediction column $colName already exists.")
+    val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false))
+    StructType(outputFields)
+  }
 }
 
-private[ml] object Params {
+/**
+ * :: DeveloperApi ::
+ *
+ * Helper functionality for developers.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] object Params {
 
   /**
    * Copies parameter values from the parent estimator to the child model it produced.
@@ -279,7 +322,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
   def copy: ParamMap = new ParamMap(map.clone())
 
   override def toString: String = {
-    map.map { case (param, value) =>
+    map.toSeq.sortBy(_._1.name).map { case (param, value) =>
       s"\t${param.parent.uid}-${param.name}: $value"
     }.mkString("{\n", ",\n", "\n}")
   }
@@ -310,6 +353,11 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
       ParamPair(param, value)
     }
   }
+
+  /**
+   * Number of param pairs in this set.
+   */
+  def size: Int = map.size
 }
 
 object ParamMap {

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index ef141d3..32fc744 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -17,6 +17,12 @@
 
 package org.apache.spark.ml.param
 
+/* NOTE TO DEVELOPERS:
+ * If you mix these parameter traits into your algorithm, please add a setter method as well
+ * so that users may use a builder pattern:
+ *  val myLearner = new MyLearner().setParam1(x).setParam2(y)...
+ */
+
 private[ml] trait HasRegParam extends Params {
   /** param for regularization parameter */
   val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
@@ -42,12 +48,6 @@ private[ml] trait HasLabelCol extends Params {
   def getLabelCol: String = get(labelCol)
 }
 
-private[ml] trait HasScoreCol extends Params {
-  /** param for score column name */
-  val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score"))
-  def getScoreCol: String = get(scoreCol)
-}
-
 private[ml] trait HasPredictionCol extends Params {
   /** param for prediction column name */
   val predictionCol: Param[String] =
@@ -55,6 +55,22 @@ private[ml] trait HasPredictionCol extends Params {
   def getPredictionCol: String = get(predictionCol)
 }
 
+private[ml] trait HasRawPredictionCol extends Params {
+  /** param for raw prediction column name */
+  val rawPredictionCol: Param[String] =
+    new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
+      Some("rawPrediction"))
+  def getRawPredictionCol: String = get(rawPredictionCol)
+}
+
+private[ml] trait HasProbabilityCol extends Params {
+  /** param for predicted class conditional probabilities column name */
+  val probabilityCol: Param[String] =
+    new Param(this, "probabilityCol", "column name for predicted class conditional probabilities",
+      Some("probability"))
+  def getProbabilityCol: String = get(probabilityCol)
+}
+
 private[ml] trait HasThreshold extends Params {
   /** param for threshold in (binary) prediction */
   val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")

http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
new file mode 100644
index 0000000..d5a7bda
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
+import org.apache.spark.mllib.linalg.{BLAS, Vector}
+import org.apache.spark.mllib.regression.LinearRegressionWithSGD
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * Params for linear regression.
+ */
+private[regression] trait LinearRegressionParams extends RegressorParams
+  with HasRegParam with HasMaxIter
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Linear regression.
+ */
+@AlphaComponent
+class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
+  with LinearRegressionParams {
+
+  setRegParam(0.1)
+  setMaxIter(100)
+
+  def setRegParam(value: Double): this.type = set(regParam, value)
+  def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+  override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
+    // Extract columns from data.  If dataset is persisted, do not persist oldDataset.
+    val oldDataset = extractLabeledPoints(dataset, paramMap)
+    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+    if (handlePersistence) {
+      oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+    }
+
+    // Train model
+    val lr = new LinearRegressionWithSGD()
+    lr.optimizer
+      .setRegParam(paramMap(regParam))
+      .setNumIterations(paramMap(maxIter))
+    val model = lr.run(oldDataset)
+    val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
+
+    if (handlePersistence) {
+      oldDataset.unpersist()
+    }
+    lrm
+  }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by [[LinearRegression]].
+ */
+@AlphaComponent
+class LinearRegressionModel private[ml] (
+    override val parent: LinearRegression,
+    override val fittingParamMap: ParamMap,
+    val weights: Vector,
+    val intercept: Double)
+  extends RegressionModel[Vector, LinearRegressionModel]
+  with LinearRegressionParams {
+
+  override protected def predict(features: Vector): Double = {
+    BLAS.dot(features, weights) + intercept
+  }
+
+  override protected def copy(): LinearRegressionModel = {
+    val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
+    Params.inheritValues(this.paramMap, this, m)
+    m
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message