Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 7074E176CF for ; Fri, 6 Feb 2015 07:44:15 +0000 (UTC) Received: (qmail 37851 invoked by uid 500); 6 Feb 2015 07:44:15 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 37822 invoked by uid 500); 6 Feb 2015 07:44:15 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 37807 invoked by uid 99); 6 Feb 2015 07:44:15 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 06 Feb 2015 07:44:15 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 012E2E0664; Fri, 6 Feb 2015 07:44:14 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 8bit From: meng@apache.org To: commits@spark.apache.org Date: Fri, 06 Feb 2015 07:44:15 -0000 Message-Id: In-Reply-To: <51a992413db749eda921af68df6c6d3b@git.apache.org> References: <51a992413db749eda921af68df6c6d3b@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [2/2] spark git commit: [SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs [SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs This is part (1a) of the updates from the design doc in [https://docs.google.com/document/d/1BH9el33kBX8JiDdgUJXdLW14CA2qhTCWIG46eXZVoJs] **UPDATE**: Most of the APIs are being kept private[spark] to allow further discussion. Here is a list of changes which are public: * new output columns: rawPrediction, probabilities * The “score” column is now called “rawPrediction” * Classifiers now provide numClasses * Params.get and .set are now protected instead of private[ml]. * ParamMap now has a size method. * new classes: LinearRegression, LinearRegressionModel * LogisticRegression now has an intercept. ### Sketch of APIs (most of which are private[spark] for now) Abstract classes for learning algorithms (+ corresponding Model abstractions): * Classifier (+ ClassificationModel) * ProbabilisticClassifier (+ ProbabilisticClassificationModel) * Regressor (+ RegressionModel) * Predictor (+ PredictionModel) * *For all of these*: * There is no strongly typed training-time API. * There is a strongly typed test-time (prediction) API which helps developers implement new algorithms. Concrete classes: learning algorithms * LinearRegression * LogisticRegression (updated to use new abstract classes) * Also, removed "score" in favor of "probability" output column. Changed BinaryClassificationEvaluator to match. (SPARK-5031) Other updates: * params.scala: Changed Params.set/get to be protected instead of private[ml] * This was needed for the example of defining a class from outside of the MLlib namespace. * VectorUDT: Will later change from private[spark] to public. * This is needed for outside users to write their own validateAndTransformSchema() methods using vectors. * Also, added equals() method.f * SPARK-4942 : ML Transformers should allow output cols to be turned on,off * Update validateAndTransformSchema * Update transform * (Updated examples, test suites according to other changes) New examples: * DeveloperApiExample.scala (example of defining algorithm from outside of the MLlib namespace) * Added Java version too Test Suites: * LinearRegressionSuite * LogisticRegressionSuite * + Java versions of above suites CC: mengxr etrain shivaram Author: Joseph K. Bradley Closes #3637 from jkbradley/ml-api-part1 and squashes the following commits: 405bfb8 [Joseph K. Bradley] Last edits based on code review. Small cleanups fec348a [Joseph K. Bradley] Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types. 8316d5e [Joseph K. Bradley] fixes after rebasing on master fc62406 [Joseph K. Bradley] fixed test suites after last commit bcb9549 [Joseph K. Bradley] Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame) 9872424 [Joseph K. Bradley] fixed JavaLinearRegressionSuite.java Java sql api f542997 [Joseph K. Bradley] Added MIMA excludes for VectorUDT (now public), and added DeveloperApi annotation to it 216d199 [Joseph K. Bradley] fixed after sql datatypes PR got merged f549e34 [Joseph K. Bradley] Updates based on code review. Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value. 343e7bd [Joseph K. Bradley] added blanket mima exclude for ml package 82f340b [Joseph K. Bradley] Fixed bug in LogisticRegression (introduced in this PR). Fixed Java suites 0a16da9 [Joseph K. Bradley] Fixed Linear/Logistic RegressionSuites c3c8da5 [Joseph K. Bradley] small cleanup 934f97b [Joseph K. Bradley] Fixed bugs from previous commit. 1c61723 [Joseph K. Bradley] * Made ProbabilisticClassificationModel into a subclass of ClassificationModel. Also introduced ProbabilisticClassifier. * This was to support output column “probabilityCol” in transform(). 4e2f711 [Joseph K. Bradley] rat fix bc654e1 [Joseph K. Bradley] Added spark.ml LinearRegressionSuite 8d13233 [Joseph K. Bradley] Added methods: * Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others 1680905 [Joseph K. Bradley] Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0 adbe50a [Joseph K. Bradley] * fixed LinearRegression train() to use embedded paramMap * added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites 58802e3 [Joseph K. Bradley] added train() to Predictor subclasses which does not take a ParamMap. 57d54ab [Joseph K. Bradley] * Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap. * remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite e433872 [Joseph K. Bradley] Updated docs. Added LabeledPointSuite to spark.ml 54b7b31 [Joseph K. Bradley] Fixed issue with logreg threshold being set correctly 0617d61 [Joseph K. Bradley] Fixed bug from last commit (sorting paramMap by parameter names in toString). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup). 601e792 [Joseph K. Bradley] Modified ParamMap to sort parameters in toString. Cleaned up classes in class hierarchy, before implementing tests and examples. d705e87 [Joseph K. Bradley] Added LinearRegression and Regressor back from ml-api branch 52f4fde [Joseph K. Bradley] removing everything except for simple class hierarchy for classification d35bb5d [Joseph K. Bradley] fixed compilation issues, but have not added tests yet bfade12 [Joseph K. Bradley] Added lots of classes for new ML API: (cherry picked from commit dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f) Signed-off-by: Xiangrui Meng Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/45b95e7d Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/45b95e7d Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/45b95e7d Branch: refs/heads/branch-1.3 Commit: 45b95e7d23bdfcbe55473c44b1b056e4005d45b0 Parents: c35a11e Author: Joseph K. Bradley Authored: Thu Feb 5 23:43:47 2015 -0800 Committer: Xiangrui Meng Committed: Thu Feb 5 23:44:02 2015 -0800 ---------------------------------------------------------------------- .../examples/ml/JavaCrossValidatorExample.java | 6 +- .../examples/ml/JavaDeveloperApiExample.java | 217 +++++++++++++++++ .../examples/ml/JavaSimpleParamsExample.java | 10 +- .../JavaSimpleTextClassificationPipeline.java | 4 +- .../examples/ml/CrossValidatorExample.scala | 7 +- .../spark/examples/ml/DeveloperApiExample.scala | 184 +++++++++++++++ .../spark/examples/ml/SimpleParamsExample.scala | 16 +- .../ml/SimpleTextClassificationPipeline.scala | 7 +- .../scala/org/apache/spark/ml/Estimator.scala | 9 +- .../spark/ml/classification/Classifier.scala | 206 ++++++++++++++++ .../ml/classification/LogisticRegression.scala | 212 ++++++++++------- .../ProbabilisticClassifier.scala | 147 ++++++++++++ .../BinaryClassificationEvaluator.scala | 24 +- .../org/apache/spark/ml/feature/Tokenizer.scala | 4 +- .../spark/ml/impl/estimator/Predictor.scala | 234 +++++++++++++++++++ .../org/apache/spark/ml/param/params.scala | 68 +++++- .../apache/spark/ml/param/sharedParams.scala | 28 ++- .../spark/ml/regression/LinearRegression.scala | 96 ++++++++ .../apache/spark/ml/regression/Regressor.scala | 78 +++++++ .../org/apache/spark/mllib/linalg/Vectors.scala | 13 ++ .../org/apache/spark/ml/JavaPipelineSuite.java | 2 +- .../JavaLogisticRegressionSuite.java | 91 +++++++- .../regression/JavaLinearRegressionSuite.java | 89 +++++++ .../LogisticRegressionSuite.scala | 86 ++++++- .../ml/regression/LinearRegressionSuite.scala | 65 ++++++ project/MimaExcludes.scala | 6 + 26 files changed, 1753 insertions(+), 156 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 0fbee6e..5041e0b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -116,10 +116,12 @@ public class JavaCrossValidatorExample { // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, probability, prediction FROM prediction"); for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java new file mode 100644 index 0000000..42d4d7d --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.Classifier; +import org.apache.spark.ml.classification.ClassificationModel; +import org.apache.spark.ml.param.IntParam; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.ml.param.Params; +import org.apache.spark.ml.param.Params$; +import org.apache.spark.mllib.linalg.BLAS; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}. + * + * Run with + *
+ * bin/run-example ml.JavaDeveloperApiExample
+ * 
+ */ +public class JavaDeveloperApiExample { + + public static void main(String[] args) throws Exception { + SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample"); + JavaSparkContext jsc = new JavaSparkContext(conf); + SQLContext jsql = new SQLContext(jsc); + + // Prepare training data. + List localTraining = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + + // Create a LogisticRegression instance. This instance is an Estimator. + MyJavaLogisticRegression lr = new MyJavaLogisticRegression(); + // Print out the parameters, documentation, and any default values. + System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n"); + + // We may set parameters using setter methods. + lr.setMaxIter(10); + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + MyJavaLogisticRegressionModel model = lr.fit(training); + + // Prepare test data. + List localTest = Lists.newArrayList( + new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + + // Make predictions on test documents. cvModel uses the best model found (lrModel). + DataFrame results = model.transform(test); + double sumPredictions = 0; + for (Row r : results.select("features", "label", "prediction").collect()) { + sumPredictions += r.getDouble(2); + } + if (sumPredictions != 0.0) { + throw new Exception("MyJavaLogisticRegression predicted something other than 0," + + " even though all weights are 0!"); + } + + jsc.stop(); + } +} + +/** + * Example of defining a type of {@link Classifier}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegression + extends Classifier + implements Params { + + /** + * Param for max number of iterations + *

+ * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + */ + IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations"); + + int getMaxIter() { return (int)get(maxIter); } + + public MyJavaLogisticRegression() { + setMaxIter(100); + } + + // The parameter setter is in this class since it should return type MyJavaLogisticRegression. + MyJavaLogisticRegression setMaxIter(int value) { + return (MyJavaLogisticRegression)set(maxIter, value); + } + + // This method is used by fit(). + // In Java, we have to make it public since Java does not understand Scala's protected modifier. + public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) { + // Extract columns from data using helper method. + JavaRDD oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD(); + + // Do learning to estimate the weight vector. + int numFeatures = oldDataset.take(1).get(0).features().size(); + Vector weights = Vectors.zeros(numFeatures); // Learning would happen here. + + // Create a model, and return it. + return new MyJavaLogisticRegressionModel(this, paramMap, weights); + } +} + +/** + * Example of defining a type of {@link ClassificationModel}. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +class MyJavaLogisticRegressionModel + extends ClassificationModel implements Params { + + private MyJavaLogisticRegression parent_; + public MyJavaLogisticRegression parent() { return parent_; } + + private ParamMap fittingParamMap_; + public ParamMap fittingParamMap() { return fittingParamMap_; } + + private Vector weights_; + public Vector weights() { return weights_; } + + public MyJavaLogisticRegressionModel( + MyJavaLogisticRegression parent_, + ParamMap fittingParamMap_, + Vector weights_) { + this.parent_ = parent_; + this.fittingParamMap_ = fittingParamMap_; + this.weights_ = weights_; + } + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public Vector predictRaw(Vector features) { + double margin = BLAS.dot(features, weights_); + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + return Vectors.dense(-margin, margin); + } + + /** + * Number of classes the label can take. 2 indicates binary classification. + */ + public int numClasses() { return 2; } + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + *

+ * This is used for the defaul implementation of [[transform()]]. + * + * In Java, we have to make this method public since Java does not understand Scala's protected + * modifier. + */ + public MyJavaLogisticRegressionModel copy() { + MyJavaLogisticRegressionModel m = + new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_); + Params$.MODULE$.inheritValues(this.paramMap(), this, m); + return m; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index eaaa344..cc69e63 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -81,7 +81,7 @@ public class JavaSimpleParamsExample { // One can also combine ParamMaps. ParamMap paramMap2 = new ParamMap(); - paramMap2.put(lr.scoreCol().w("probability")); // Change output column name + paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2); // Now learn a new model using the paramMapCombined parameters. @@ -98,14 +98,16 @@ public class JavaSimpleParamsExample { // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test).registerTempTable("results"); DataFrame results = - jsql.sql("SELECT features, label, probability, prediction FROM results"); + jsql.sql("SELECT features, label, myProbability, prediction FROM results"); for (Row r: results.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 82d665a..d929f1a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -85,8 +85,10 @@ public class JavaSimpleTextClassificationPipeline { model.transform(test).registerTempTable("prediction"); DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { - System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2) + ", prediction=" + r.get(3)); } + + jsc.stop(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index b6c30a0..a2893f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -23,6 +23,7 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} /** @@ -100,10 +101,10 @@ object CrossValidatorExample { // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test) - .select("id", "text", "score", "prediction") + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println(s"($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala new file mode 100644 index 0000000..aed4423 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel} +import org.apache.spark.ml.param.{Params, IntParam, ParamMap} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql.{DataFrame, Row, SQLContext} + + +/** + * A simple example demonstrating how to write your own learning algorithm using Estimator, + * Transformer, and other abstractions. + * This mimics [[org.apache.spark.ml.classification.LogisticRegression]]. + * Run with + * {{{ + * bin/run-example ml.DeveloperApiExample + * }}} + */ +object DeveloperApiExample { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("DeveloperApiExample") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + + // Prepare training data. + val training = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), + LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), + LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)))) + + // Create a LogisticRegression instance. This instance is an Estimator. + val lr = new MyLogisticRegression() + // Print out the parameters, documentation, and any default values. + println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + + // We may set parameters using setter methods. + lr.setMaxIter(10) + + // Learn a LogisticRegression model. This uses the parameters stored in lr. + val model = lr.fit(training) + + // Prepare test data. + val test = sc.parallelize(Seq( + LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), + LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), + LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) + + // Make predictions on test data. + val sumPredictions: Double = model.transform(test) + .select("features", "label", "prediction") + .collect() + .map { case Row(features: Vector, label: Double, prediction: Double) => + prediction + }.sum + assert(sumPredictions == 0.0, + "MyLogisticRegression predicted something other than 0, even though all weights are 0!") + + sc.stop() + } +} + +/** + * Example of defining a parameter trait for a user-defined type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private trait MyLogisticRegressionParams extends ClassifierParams { + + /** + * Param for max number of iterations + * + * NOTE: The usual way to add a parameter to a model or algorithm is to include: + * - val myParamName: ParamType + * - def getMyParamName + * - def setMyParamName + * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression + * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator + * class since the maxIter parameter is only used during training (not in the Model). + */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = get(maxIter) +} + +/** + * Example of defining a type of [[Classifier]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegression + extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + setMaxIter(100) // Initialize + + // The parameter setter is in this class since it should return type MyLogisticRegression. + def setMaxIter(value: Int): this.type = set(maxIter, value) + + // This method is used by fit() + override protected def train( + dataset: DataFrame, + paramMap: ParamMap): MyLogisticRegressionModel = { + // Extract columns from data using helper method. + val oldDataset = extractLabeledPoints(dataset, paramMap) + + // Do learning to estimate the weight vector. + val numFeatures = oldDataset.take(1)(0).features.size + val weights = Vectors.zeros(numFeatures) // Learning would happen here. + + // Create a model, and return it. + new MyLogisticRegressionModel(this, paramMap, weights) + } +} + +/** + * Example of defining a type of [[ClassificationModel]]. + * + * NOTE: This is private since it is an example. In practice, you may not want it to be private. + */ +private class MyLogisticRegressionModel( + override val parent: MyLogisticRegression, + override val fittingParamMap: ParamMap, + val weights: Vector) + extends ClassificationModel[Vector, MyLogisticRegressionModel] + with MyLogisticRegressionParams { + + // This uses the default implementation of transform(), which reads column "features" and outputs + // columns "prediction" and "rawPrediction." + + // This uses the default implementation of predict(), which chooses the label corresponding to + // the maximum value returned by [[predictRaw()]]. + + /** + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + override protected def predictRaw(features: Vector): Vector = { + val margin = BLAS.dot(features, weights) + // There are 2 classes (binary classification), so we return a length-2 vector, + // where index i corresponds to class i (i = 0, 1). + Vectors.dense(-margin, margin) + } + + /** Number of classes the label can take. 2 indicates binary classification. */ + override val numClasses: Int = 2 + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + * + * This is used for the defaul implementation of [[transform()]]. + */ + override protected def copy(): MyLogisticRegressionModel = { + val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights) + Params.inheritValues(this.paramMap, this, m) + m + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index 4d1530c..80c9f5f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -72,7 +72,7 @@ object SimpleParamsExample { paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params. // One can also combine ParamMaps. - val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name + val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name val paramMapCombined = paramMap ++ paramMap2 // Now learn a new model using the paramMapCombined parameters. @@ -80,21 +80,21 @@ object SimpleParamsExample { val model2 = lr.fit(training, paramMapCombined) println("Model 2 was fit using parameters: " + model2.fittingParamMap) - // Prepare test documents. + // Prepare test data. val test = sc.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)))) - // Make predictions on test documents using the Transformer.transform() method. + // Make predictions on test data using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. - // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' - // column since we renamed the lr.scoreCol parameter previously. + // Note that model2.transform() outputs a 'myProbability' column instead of the usual + // 'probability' column since we renamed the lr.probabilityCol parameter previously. model2.transform(test) - .select("features", "label", "probability", "prediction") + .select("features", "label", "myProbability", "prediction") .collect() - .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => - println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) + .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) => + println("($features, $label) -> prob=$prob, prediction=$prediction") } sc.stop() http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index dbbe01d..968cb29 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -23,6 +23,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} @BeanInfo @@ -79,10 +80,10 @@ object SimpleTextClassificationPipeline { // Make predictions on test documents. model.transform(test) - .select("id", "text", "score", "prediction") + .select("id", "text", "probability", "prediction") .collect() - .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => - println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) + .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => + println("($id, $text) --> prob=$prob, prediction=$prediction") } sc.stop() http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index bc3defe..eff7ef9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -34,7 +34,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with optional parameters. * * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) + * @param paramPairs Optional list of param pairs. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ @varargs @@ -47,7 +48,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Fits a single model to the input data with provided parameter map. * * @param dataset input dataset - * @param paramMap parameter map + * @param paramMap Parameter map. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted model */ def fit(dataset: DataFrame, paramMap: ParamMap): M @@ -58,7 +60,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * Subclasses could overwrite this to optimize multi-model training. * * @param dataset input dataset - * @param paramMaps an array of parameter maps + * @param paramMaps An array of parameter maps. + * These values override any specified in this Estimator's embedded ParamMap. * @return fitted models, matching the input parameter maps */ def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala new file mode 100644 index 0000000..1bf8eb4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{DeveloperApi, AlphaComponent} +import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * Params for classification. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait ClassifierParams extends PredictorParams + with HasRawPredictionCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = this.paramMap ++ paramMap + addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT) + } +} + +/** + * :: AlphaComponent :: + * Single-label binary or multiclass classification. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Classifier[ + FeaturesType, + E <: Classifier[FeaturesType, E, M], + M <: ClassificationModel[FeaturesType, M]] + extends Predictor[FeaturesType, E, M] + with ClassifierParams { + + def setRawPredictionCol(value: String): E = + set(rawPredictionCol, value).asInstanceOf[E] + + // TODO: defaultEvaluator (follow-up PR) +} + +/** + * :: AlphaComponent :: + * Model produced by a [[Classifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] +abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]] + extends PredictionModel[FeaturesType, M] with ClassifierParams { + + def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M] + + /** Number of classes (values which the label can take). */ + def numClasses: Int + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + if (numColsOutput == 0) { + logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + * + * This default implementation for classification predicts the index of the maximum value + * from [[predictRaw()]]. + */ + @DeveloperApi + override protected def predict(features: FeaturesType): Double = { + predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2 + } + + /** + * :: DeveloperApi :: + * + * Raw prediction for each possible label. + * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives + * a measure of confidence in each possible label (where larger = more confident). + * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]]. + * + * @return vector where element i is the raw prediction for label i. + * This raw prediction may be any real number, where a larger value indicates greater + * confidence for that label. + */ + @DeveloperApi + protected def predictRaw(features: FeaturesType): Vector + +} + +private[ml] object ClassificationModel { + + /** + * Added prediction column(s). This is separated from [[ClassificationModel.transform()]] + * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]]. + * @param dataset Input dataset + * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge + * should already be done. + * @return (number of columns added, transformed dataset) + */ + def transformColumnsImpl[FeaturesType]( + dataset: DataFrame, + model: ClassificationModel[FeaturesType, _], + map: ParamMap): (Int, DataFrame) = { + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + var tmpData = dataset + var numColsOutput = 0 + if (map(model.rawPredictionCol) != "") { + // output raw prediction + val features2raw: FeaturesType => Vector = model.predictRaw + tmpData = tmpData.select($"*", + callUDF(features2raw, new VectorUDT, + col(map(model.featuresCol))).as(map(model.rawPredictionCol))) + numColsOutput += 1 + if (map(model.predictionCol) != "") { + val raw2pred: Vector => Double = (rawPred) => { + rawPred.toArray.zipWithIndex.maxBy(_._1)._2 + } + tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType, + col(map(model.rawPredictionCol))).as(map(model.predictionCol))) + numColsOutput += 1 + } + } else if (map(model.predictionCol) != "") { + // output prediction + val features2pred: FeaturesType => Double = model.predict + tmpData = tmpData.select($"*", + callUDF(features2pred, DoubleType, + col(map(model.featuresCol))).as(map(model.predictionCol))) + numColsOutput += 1 + } + (numColsOutput, tmpData) + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index b46a5cd..c146fe2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -18,61 +18,32 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql._ +import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Dsl._ -import org.apache.spark.sql.types.{DoubleType, StructField, StructType} +import org.apache.spark.sql.types.DoubleType import org.apache.spark.storage.StorageLevel + /** - * :: AlphaComponent :: * Params for logistic regression. */ -@AlphaComponent -private[classification] trait LogisticRegressionParams extends Params - with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol - with HasScoreCol with HasPredictionCol { +private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams + with HasRegParam with HasMaxIter with HasThreshold - /** - * Validates and transforms the input schema with the provided param map. - * @param schema input schema - * @param paramMap additional parameters - * @param fitting whether this is in fitting - * @return output schema - */ - protected def validateAndTransformSchema( - schema: StructType, - paramMap: ParamMap, - fitting: Boolean): StructType = { - val map = this.paramMap ++ paramMap - val featuresType = schema(map(featuresCol)).dataType - // TODO: Support casting Array[Double] and Array[Float] to Vector. - require(featuresType.isInstanceOf[VectorUDT], - s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") - if (fitting) { - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") - } - val fieldNames = schema.fieldNames - require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") - require(!fieldNames.contains(map(predictionCol)), - s"Prediction column ${map(predictionCol)} already exists.") - val outputFields = schema.fields ++ Seq( - StructField(map(scoreCol), DoubleType, false), - StructField(map(predictionCol), DoubleType, false)) - StructType(outputFields) - } -} /** + * :: AlphaComponent :: + * * Logistic regression. + * Currently, this class only supports binary classification. */ -class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { +@AlphaComponent +class LogisticRegression + extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel] + with LogisticRegressionParams { setRegParam(0.1) setMaxIter(100) @@ -80,68 +51,151 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti def setRegParam(value: Double): this.type = set(regParam, value) def setMaxIter(value: Int): this.type = set(maxIter, value) - def setLabelCol(value: String): this.type = set(labelCol, value) def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { - transformSchema(dataset.schema, paramMap, logging = true) - val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol), map(featuresCol)) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - }.persist(StorageLevel.MEMORY_AND_DISK) + override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model val lr = new LogisticRegressionWithLBFGS lr.optimizer - .setRegParam(map(regParam)) - .setNumIterations(map(maxIter)) - val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) - instances.unpersist() - // copy model params - Params.inheritValues(map, this, lrm) - lrm - } + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val oldModel = lr.run(oldDataset) + val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = true) + if (handlePersistence) { + oldDataset.unpersist() + } + lrm } } + /** * :: AlphaComponent :: + * * Model produced by [[LogisticRegression]]. */ @AlphaComponent class LogisticRegressionModel private[ml] ( override val parent: LogisticRegression, override val fittingParamMap: ParamMap, - weights: Vector) - extends Model[LogisticRegressionModel] with LogisticRegressionParams { + val weights: Vector, + val intercept: Double) + extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] + with LogisticRegressionParams { + + setThreshold(0.5) def setThreshold(value: Double): this.type = set(threshold, value) - def setFeaturesCol(value: String): this.type = set(featuresCol, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) - def setPredictionCol(value: String): this.type = set(predictionCol, value) - private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { - validateAndTransformSchema(schema, paramMap, fitting = false) + private val margin: Vector => Double = (features) => { + BLAS.dot(features, weights) + intercept + } + + private val score: Vector => Double = (features) => { + val m = margin(features) + 1.0 / (1.0 + math.exp(-m)) } override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This is overridden (a) to be more efficient (avoiding re-computing values when creating + // multiple output columns) and (b) to handle threshold, which the abstractions do not use. + // TODO: We should abstract away the steps defined by UDFs below so that the abstractions + // can call whichever UDFs are needed to create the output columns. + + // Check schema transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap - val scoreFunction = udf { v: Vector => - val margin = BLAS.dot(v, weights) - 1.0 / (1.0 + math.exp(-margin)) + + // Output selected columns only. + // This is a bit complicated since it tries to avoid repeated computation. + // rawPrediction (-margin, margin) + // probability (1.0-score, score) + // prediction (max margin) + var tmpData = dataset + var numColsOutput = 0 + if (map(rawPredictionCol) != "") { + val features2raw: Vector => Vector = (features) => predictRaw(features) + tmpData = tmpData.select($"*", + callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol))) + numColsOutput += 1 + } + if (map(probabilityCol) != "") { + if (map(rawPredictionCol) != "") { + val raw2prob: Vector => Vector = { (rawPreds: Vector) => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + Vectors.dense(1.0 - prob1, prob1) + } + tmpData = tmpData.select($"*", + callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol))) + } else { + val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features) + tmpData = tmpData.select($"*", + callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol))) + } + numColsOutput += 1 } - val t = map(threshold) - val predictFunction = udf { score: Double => - if (score > t) 1.0 else 0.0 + if (map(predictionCol) != "") { + val t = map(threshold) + if (map(probabilityCol) != "") { + val predict: Vector => Double = { probs: Vector => + if (probs(1) > t) 1.0 else 0.0 + } + tmpData = tmpData.select($"*", + callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol))) + } else if (map(rawPredictionCol) != "") { + val predict: Vector => Double = { rawPreds: Vector => + val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1))) + if (prob1 > t) 1.0 else 0.0 + } + tmpData = tmpData.select($"*", + callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol))) + } else { + val predict: Vector => Double = (features: Vector) => this.predict(features) + tmpData = tmpData.select($"*", + callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol))) + } + numColsOutput += 1 } - dataset - .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol))) - .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol))) + if (numColsOutput == 0) { + this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" + + " since no output columns were set.") + } + tmpData + } + + override val numClasses: Int = 2 + + /** + * Predict label for the given feature vector. + * The behavior of this can be adjusted using [[threshold]]. + */ + override protected def predict(features: Vector): Double = { + println(s"LR.predict with threshold: ${paramMap(threshold)}") + if (score(features) > paramMap(threshold)) 1 else 0 + } + + override protected def predictProbabilities(features: Vector): Vector = { + val s = score(features) + Vectors.dense(1.0 - s, s) + } + + override protected def predictRaw(features: Vector): Vector = { + val m = margin(features) + Vectors.dense(0.0, m) + } + + override protected def copy(): LogisticRegressionModel = { + val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(this.paramMap, this, m) + m } } http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala new file mode 100644 index 0000000..1202528 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.types.{DataType, StructType} + + +/** + * Params for probabilistic classification. + */ +private[classification] trait ProbabilisticClassifierParams + extends ClassifierParams with HasProbabilityCol { + + override protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType) + val map = this.paramMap ++ paramMap + addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT) + } +} + + +/** + * :: AlphaComponent :: + * + * Single-label binary or multiclass classifier which can output class conditional probabilities. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam E Concrete Estimator type + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassifier[ + FeaturesType, + E <: ProbabilisticClassifier[FeaturesType, E, M], + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams { + + def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E] +} + + +/** + * :: AlphaComponent :: + * + * Model produced by a [[ProbabilisticClassifier]]. + * Classes are indexed {0, 1, ..., numClasses - 1}. + * + * @tparam FeaturesType Type of input features. E.g., [[Vector]] + * @tparam M Concrete Model type + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class ProbabilisticClassificationModel[ + FeaturesType, + M <: ProbabilisticClassificationModel[FeaturesType, M]] + extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams { + + def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M] + + /** + * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by + * parameters: + * - predicted labels as [[predictionCol]] of type [[Double]] + * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]] + * - probability of each class as [[probabilityCol]] of type [[Vector]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + val (numColsOutput, outputData) = + ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map) + + // Output selected columns only. + if (map(probabilityCol) != "") { + // output probabilities + val features2probs: FeaturesType => Vector = (features) => { + tmpModel.predictProbabilities(features) + } + outputData.select($"*", + callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol))) + } else { + if (numColsOutput == 0) { + this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" + + " since no output columns were set.") + } + outputData + } + } + + /** + * :: DeveloperApi :: + * + * Predict the probability of each class given the features. + * These predictions are also called class conditional probabilities. + * + * WARNING: Not all models output well-calibrated probability estimates! These probabilities + * should be treated as confidences, not precise probabilities. + * + * This internal method is used to implement [[transform()]] and output [[probabilityCol]]. + */ + @DeveloperApi + protected def predictProbabilities(features: FeaturesType): Vector +} http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 1979ab9..f21a306 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -18,19 +18,22 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml._ +import org.apache.spark.ml.Evaluator import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType + /** * :: AlphaComponent :: + * * Evaluator for binary classification, which expects two input columns: score and label. */ @AlphaComponent class BinaryClassificationEvaluator extends Evaluator with Params - with HasScoreCol with HasLabelCol { + with HasRawPredictionCol with HasLabelCol { /** param for metric name in evaluation */ val metricName: Param[String] = new Param(this, "metricName", @@ -38,23 +41,20 @@ class BinaryClassificationEvaluator extends Evaluator with Params def getMetricName: String = get(metricName) def setMetricName(value: String): this.type = set(metricName, value) - def setScoreCol(value: String): this.type = set(scoreCol, value) + def setScoreCol(value: String): this.type = set(rawPredictionCol, value) def setLabelCol(value: String): this.type = set(labelCol, value) override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema - val scoreType = schema(map(scoreCol)).dataType - require(scoreType == DoubleType, - s"Score column ${map(scoreCol)} must be double type but found $scoreType") - val labelType = schema(map(labelCol)).dataType - require(labelType == DoubleType, - s"Label column ${map(labelCol)} must be double type but found $labelType") + checkInputColumn(schema, map(rawPredictionCol), new VectorUDT) + checkInputColumn(schema, map(labelCol), DoubleType) - val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) - .map { case Row(score: Double, label: Double) => - (score, label) + // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2. + val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol)) + .map { case Row(rawPrediction: Vector, label: Double) => + (rawPrediction(1), label) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = map(metricName) match { http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index e622a5c..0b1f90d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType} @AlphaComponent class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { - protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { _.toLowerCase.split("\\s") } - protected override def validateInputType(inputType: DataType): Unit = { + override protected def validateInputType(inputType: DataType): Unit = { require(inputType == StringType, s"Input type must be string type but got $inputType.") } http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala new file mode 100644 index 0000000..89b53f3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.impl.estimator + +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.linalg.{VectorUDT, Vector} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} + + +/** + * :: DeveloperApi :: + * + * Trait for parameters for prediction (regression and classification). + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] trait PredictorParams extends Params + with HasLabelCol with HasFeaturesCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @param featuresDataType SQL DataType for FeaturesType. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val map = this.paramMap ++ paramMap + // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector + checkInputColumn(schema, map(featuresCol), featuresDataType) + if (fitting) { + // TODO: Allow other numeric types + checkInputColumn(schema, map(labelCol), DoubleType) + } + addOutputColumn(schema, map(predictionCol), DoubleType) + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for prediction problems (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam Learner Specialization of this class. If you subclass this type, use this type + * parameter to specify the concrete type. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class Predictor[ + FeaturesType, + Learner <: Predictor[FeaturesType, Learner, M], + M <: PredictionModel[FeaturesType, M]] + extends Estimator[M] with PredictorParams { + + def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner] + def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner] + def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner] + + override def fit(dataset: DataFrame, paramMap: ParamMap): M = { + // This handles a few items such as schema validation. + // Developers only need to implement train(). + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val model = train(dataset, map) + Params.inheritValues(map, this, model) // copy params to model + model + } + + /** + * :: DeveloperApi :: + * + * Train a model using the given dataset and parameters. + * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation + * and copying parameters into the model. + * + * @param dataset Training dataset + * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already + * been combined with the embedded ParamMap. + * @return Fitted model + */ + @DeveloperApi + protected def train(dataset: DataFrame, paramMap: ParamMap): M + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType) + } + + /** + * Extract [[labelCol]] and [[featuresCol]] from the given dataset, + * and put it in an RDD with strong types. + */ + protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = { + val map = this.paramMap ++ paramMap + dataset.select(map(labelCol), map(featuresCol)) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + } + } +} + +/** + * :: AlphaComponent :: + * + * Abstraction for a model for prediction tasks (regression and classification). + * + * @tparam FeaturesType Type of features. + * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features. + * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type + * parameter to specify the concrete type for the corresponding model. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@AlphaComponent +private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]] + extends Model[M] with PredictorParams { + + def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M] + + def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M] + + /** + * :: DeveloperApi :: + * + * Returns the SQL DataType corresponding to the FeaturesType type parameter. + * + * This is used by [[validateAndTransformSchema()]]. + * This workaround is needed since SQL has different APIs for Scala and Java. + * + * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector. + */ + @DeveloperApi + protected def featuresDataType: DataType = new VectorUDT + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType) + } + + /** + * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing + * the predictions as a new column [[predictionCol]]. + * + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset with [[predictionCol]] of type [[Double]] + */ + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + // This default implementation should be overridden as needed. + + // Check schema + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + + // Prepare model + val tmpModel = if (paramMap.size != 0) { + val tmpModel = this.copy() + Params.inheritValues(paramMap, parent, tmpModel) + tmpModel + } else { + this + } + + if (map(predictionCol) != "") { + val pred: FeaturesType => Double = (features) => { + tmpModel.predict(features) + } + dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol))) + } else { + this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + + " since no output columns were set.") + dataset + } + } + + /** + * :: DeveloperApi :: + * + * Predict label for the given features. + * This internal method is used to implement [[transform()]] and output [[predictionCol]]. + */ + @DeveloperApi + protected def predict(features: FeaturesType): Double + + /** + * Create a copy of the model. + * The copy is shallow, except for the embedded paramMap, which gets a deep copy. + */ + protected def copy(): M +} http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/param/params.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 5fb4379..17ece89 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -22,8 +22,10 @@ import scala.collection.mutable import java.lang.reflect.Modifier -import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.annotation.{AlphaComponent, DeveloperApi} import org.apache.spark.ml.Identifiable +import org.apache.spark.sql.types.{DataType, StructField, StructType} + /** * :: AlphaComponent :: @@ -65,37 +67,47 @@ class Param[T] ( // specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... /** Specialized version of [[Param[Double]]] for Java. */ -class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) +class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double]) extends Param[Double](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Double): ParamPair[Double] = super.w(value) } /** Specialized version of [[Param[Int]]] for Java. */ -class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) +class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int]) extends Param[Int](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Int): ParamPair[Int] = super.w(value) } /** Specialized version of [[Param[Float]]] for Java. */ -class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) +class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float]) extends Param[Float](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Float): ParamPair[Float] = super.w(value) } /** Specialized version of [[Param[Long]]] for Java. */ -class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) +class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long]) extends Param[Long](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Long): ParamPair[Long] = super.w(value) } /** Specialized version of [[Param[Boolean]]] for Java. */ -class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) +class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean]) extends Param[Boolean](parent, name, doc, defaultValue) { + def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None) + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) } @@ -158,7 +170,7 @@ trait Params extends Identifiable with Serializable { /** * Sets a parameter in the embedded param map. */ - private[ml] def set[T](param: Param[T], value: T): this.type = { + protected def set[T](param: Param[T], value: T): this.type = { require(param.parent.eq(this)) paramMap.put(param.asInstanceOf[Param[Any]], value) this @@ -174,7 +186,7 @@ trait Params extends Identifiable with Serializable { /** * Gets the value of a parameter in the embedded param map. */ - private[ml] def get[T](param: Param[T]): T = { + protected def get[T](param: Param[T]): T = { require(param.parent.eq(this)) paramMap(param) } @@ -183,9 +195,40 @@ trait Params extends Identifiable with Serializable { * Internal param map. */ protected val paramMap: ParamMap = ParamMap.empty + + /** + * Check whether the given schema contains an input column. + * @param colName Parameter name for the input column. + * @param dataType SQL DataType of the input column. + */ + protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + require(actualDataType.equals(dataType), + s"Input column $colName must be of type $dataType" + + s" but was actually $actualDataType. Column param description: ${getParam(colName)}") + } + + protected def addOutputColumn( + schema: StructType, + colName: String, + dataType: DataType): StructType = { + if (colName.length == 0) return schema + val fieldNames = schema.fieldNames + require(!fieldNames.contains(colName), s"Prediction column $colName already exists.") + val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false)) + StructType(outputFields) + } } -private[ml] object Params { +/** + * :: DeveloperApi :: + * + * Helper functionality for developers. + * + * NOTE: This is currently private[spark] but will be made public later once it is stabilized. + */ +@DeveloperApi +private[spark] object Params { /** * Copies parameter values from the parent estimator to the child model it produced. @@ -279,7 +322,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten def copy: ParamMap = new ParamMap(map.clone()) override def toString: String = { - map.map { case (param, value) => + map.toSeq.sortBy(_._1.name).map { case (param, value) => s"\t${param.parent.uid}-${param.name}: $value" }.mkString("{\n", ",\n", "\n}") } @@ -310,6 +353,11 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten ParamPair(param, value) } } + + /** + * Number of param pairs in this set. + */ + def size: Int = map.size } object ParamMap { http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala index ef141d3..32fc744 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -17,6 +17,12 @@ package org.apache.spark.ml.param +/* NOTE TO DEVELOPERS: + * If you mix these parameter traits into your algorithm, please add a setter method as well + * so that users may use a builder pattern: + * val myLearner = new MyLearner().setParam1(x).setParam2(y)... + */ + private[ml] trait HasRegParam extends Params { /** param for regularization parameter */ val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") @@ -42,12 +48,6 @@ private[ml] trait HasLabelCol extends Params { def getLabelCol: String = get(labelCol) } -private[ml] trait HasScoreCol extends Params { - /** param for score column name */ - val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) - def getScoreCol: String = get(scoreCol) -} - private[ml] trait HasPredictionCol extends Params { /** param for prediction column name */ val predictionCol: Param[String] = @@ -55,6 +55,22 @@ private[ml] trait HasPredictionCol extends Params { def getPredictionCol: String = get(predictionCol) } +private[ml] trait HasRawPredictionCol extends Params { + /** param for raw prediction column name */ + val rawPredictionCol: Param[String] = + new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name", + Some("rawPrediction")) + def getRawPredictionCol: String = get(rawPredictionCol) +} + +private[ml] trait HasProbabilityCol extends Params { + /** param for predicted class conditional probabilities column name */ + val probabilityCol: Param[String] = + new Param(this, "probabilityCol", "column name for predicted class conditional probabilities", + Some("probability")) + def getProbabilityCol: String = get(probabilityCol) +} + private[ml] trait HasThreshold extends Params { /** param for threshold in (binary) prediction */ val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") http://git-wip-us.apache.org/repos/asf/spark/blob/45b95e7d/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala new file mode 100644 index 0000000..d5a7bda --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam} +import org.apache.spark.mllib.linalg.{BLAS, Vector} +import org.apache.spark.mllib.regression.LinearRegressionWithSGD +import org.apache.spark.sql.DataFrame +import org.apache.spark.storage.StorageLevel + + +/** + * Params for linear regression. + */ +private[regression] trait LinearRegressionParams extends RegressorParams + with HasRegParam with HasMaxIter + + +/** + * :: AlphaComponent :: + * + * Linear regression. + */ +@AlphaComponent +class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel] + with LinearRegressionParams { + + setRegParam(0.1) + setMaxIter(100) + + def setRegParam(value: Double): this.type = set(regParam, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + + override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = { + // Extract columns from data. If dataset is persisted, do not persist oldDataset. + val oldDataset = extractLabeledPoints(dataset, paramMap) + val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE + if (handlePersistence) { + oldDataset.persist(StorageLevel.MEMORY_AND_DISK) + } + + // Train model + val lr = new LinearRegressionWithSGD() + lr.optimizer + .setRegParam(paramMap(regParam)) + .setNumIterations(paramMap(maxIter)) + val model = lr.run(oldDataset) + val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept) + + if (handlePersistence) { + oldDataset.unpersist() + } + lrm + } +} + +/** + * :: AlphaComponent :: + * + * Model produced by [[LinearRegression]]. + */ +@AlphaComponent +class LinearRegressionModel private[ml] ( + override val parent: LinearRegression, + override val fittingParamMap: ParamMap, + val weights: Vector, + val intercept: Double) + extends RegressionModel[Vector, LinearRegressionModel] + with LinearRegressionParams { + + override protected def predict(features: Vector): Double = { + BLAS.dot(features, weights) + intercept + } + + override protected def copy(): LinearRegressionModel = { + val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept) + Params.inheritValues(this.paramMap, this, m) + m + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org