spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-9112] [ML] Implement Stats for LogisticRegression
Date Thu, 06 Aug 2015 17:08:39 GMT
Repository: spark
Updated Branches:
  refs/heads/master 9f94c85ff -> c5c6aded6


[SPARK-9112] [ML] Implement Stats for LogisticRegression

I have added support for stats in LogisticRegression. The API is similar to that of LinearRegression
with LogisticRegressionTrainingSummary and LogisticRegressionSummary

I have some queries and asked them inline.

Author: MechCoder <manojkumarsivaraj334@gmail.com>

Closes #7538 from MechCoder/log_reg_stats and squashes the following commits:

2e9f7c7 [MechCoder] Change defs into lazy vals
d775371 [MechCoder] Clean up class inheritance
9586125 [MechCoder] Add abstraction to handle Multiclass Metrics
40ad8ef [MechCoder] minor
640376a [MechCoder] remove unnecessary dataframe stuff and add docs
80d9954 [MechCoder] Added tests
fbed861 [MechCoder] DataFrame support for metrics
70a0fc4 [MechCoder] [SPARK-9112] [ML] Implement Stats for LogisticRegression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c5c6aded
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c5c6aded
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c5c6aded

Branch: refs/heads/master
Commit: c5c6aded641048a3e66ac79d9e84d34e4b1abae7
Parents: 9f94c85
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Authored: Thu Aug 6 10:08:33 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu Aug 6 10:08:33 2015 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  | 166 ++++++++++++++++++-
 .../JavaLogisticRegressionSuite.java            |   9 +
 .../LogisticRegressionSuite.scala               |  37 ++++-
 3 files changed, 209 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c5c6aded/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 0d07383..f55134d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -30,10 +30,12 @@ import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.linalg._
 import org.apache.spark.mllib.linalg.BLAS._
 import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
 import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
 import org.apache.spark.mllib.util.MLUtils
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.functions.{col, udf}
 import org.apache.spark.storage.StorageLevel
 
 /**
@@ -284,7 +286,13 @@ class LogisticRegression(override val uid: String)
 
     if (handlePersistence) instances.unpersist()
 
-    copyValues(new LogisticRegressionModel(uid, weights, intercept))
+    val model = copyValues(new LogisticRegressionModel(uid, weights, intercept))
+    val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+      model.transform(dataset),
+      $(probabilityCol),
+      $(labelCol),
+      objectiveHistory)
+    model.setSummary(logRegSummary)
   }
 
   override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
@@ -319,6 +327,38 @@ class LogisticRegressionModel private[ml] (
 
   override val numClasses: Int = 2
 
+  private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
+
+  /**
+   * Gets summary of model on training set. An exception is
+   * thrown if `trainingSummary == None`.
+   */
+  def summary: LogisticRegressionTrainingSummary = trainingSummary match {
+    case Some(summ) => summ
+    case None =>
+      throw new SparkException(
+        "No training summary available for this LogisticRegressionModel",
+        new NullPointerException())
+  }
+
+  private[classification] def setSummary(
+      summary: LogisticRegressionTrainingSummary): this.type = {
+    this.trainingSummary = Some(summary)
+    this
+  }
+
+  /** Indicates whether a training summary exists for this model instance. */
+  def hasSummary: Boolean = trainingSummary.isDefined
+
+  /**
+   * Evaluates the model on a testset.
+   * @param dataset Test dataset to evaluate model on.
+   */
+  // TODO: decide on a good name before exposing to public API
+  private[classification] def evaluate(dataset: DataFrame): LogisticRegressionSummary = {
+    new BinaryLogisticRegressionSummary(this.transform(dataset), $(probabilityCol), $(labelCol))
+  }
+
   /**
    * Predict label for the given feature vector.
    * The behavior of this can be adjusted using [[thresholds]].
@@ -441,6 +481,128 @@ private[classification] class MultiClassSummarizer extends Serializable
{
 }
 
 /**
+ * Abstraction for multinomial Logistic Regression Training results.
+ */
+sealed trait LogisticRegressionTrainingSummary extends LogisticRegressionSummary {
+
+  /** objective function (scaled loss + regularization) at each iteration. */
+  def objectiveHistory: Array[Double]
+
+  /** Number of training iterations until termination */
+  def totalIterations: Int = objectiveHistory.length
+
+}
+
+/**
+ * Abstraction for Logistic Regression Results for a given model.
+ */
+sealed trait LogisticRegressionSummary extends Serializable {
+
+  /** Dataframe outputted by the model's `transform` method. */
+  def predictions: DataFrame
+
+  /** Field in "predictions" which gives the calibrated probability of each sample as a vector.
*/
+  def probabilityCol: String
+
+  /** Field in "predictions" which gives the the true label of each sample. */
+  def labelCol: String
+
+}
+
+/**
+ * :: Experimental ::
+ * Logistic regression training results.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ *                       each sample as a vector.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ * @param objectiveHistory objective function (scaled loss + regularization) at each iteration.
+ */
+@Experimental
+class BinaryLogisticRegressionTrainingSummary private[classification] (
+    predictions: DataFrame,
+    probabilityCol: String,
+    labelCol: String,
+    val objectiveHistory: Array[Double])
+  extends BinaryLogisticRegressionSummary(predictions, probabilityCol, labelCol)
+  with LogisticRegressionTrainingSummary {
+
+}
+
+/**
+ * :: Experimental ::
+ * Binary Logistic regression results for a given model.
+ * @param predictions dataframe outputted by the model's `transform` method.
+ * @param probabilityCol field in "predictions" which gives the calibrated probability of
+ *                       each sample.
+ * @param labelCol field in "predictions" which gives the true label of each sample.
+ */
+@Experimental
+class BinaryLogisticRegressionSummary private[classification] (
+    @transient override val predictions: DataFrame,
+    override val probabilityCol: String,
+    override val labelCol: String) extends LogisticRegressionSummary {
+
+  private val sqlContext = predictions.sqlContext
+  import sqlContext.implicits._
+
+  /**
+   * Returns a BinaryClassificationMetrics object.
+   */
+  // TODO: Allow the user to vary the number of bins using a setBins method in
+  // BinaryClassificationMetrics. For now the default is set to 100.
+  @transient private val binaryMetrics = new BinaryClassificationMetrics(
+    predictions.select(probabilityCol, labelCol).map {
+      case Row(score: Vector, label: Double) => (score(1), label)
+    }, 100
+  )
+
+  /**
+   * Returns the receiver operating characteristic (ROC) curve,
+   * which is an Dataframe having two fields (FPR, TPR)
+   * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it.
+   * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic
+   */
+  @transient lazy val roc: DataFrame = binaryMetrics.roc().toDF("FPR", "TPR")
+
+  /**
+   * Computes the area under the receiver operating characteristic (ROC) curve.
+   */
+  lazy val areaUnderROC: Double = binaryMetrics.areaUnderROC()
+
+  /**
+   * Returns the precision-recall curve, which is an Dataframe containing
+   * two fields recall, precision with (0.0, 1.0) prepended to it.
+   */
+  @transient lazy val pr: DataFrame = binaryMetrics.pr().toDF("recall", "precision")
+
+  /**
+   * Returns a dataframe with two fields (threshold, F-Measure) curve with beta = 1.0.
+   */
+  @transient lazy val fMeasureByThreshold: DataFrame = {
+    binaryMetrics.fMeasureByThreshold().toDF("threshold", "F-Measure")
+  }
+
+  /**
+   * Returns a dataframe with two fields (threshold, precision) curve.
+   * Every possible probability obtained in transforming the dataset are used
+   * as thresholds used in calculating the precision.
+   */
+  @transient lazy val precisionByThreshold: DataFrame = {
+    binaryMetrics.precisionByThreshold().toDF("threshold", "precision")
+  }
+
+  /**
+   * Returns a dataframe with two fields (threshold, recall) curve.
+   * Every possible probability obtained in transforming the dataset are used
+   * as thresholds used in calculating the recall.
+   */
+  @transient lazy val recallByThreshold: DataFrame = {
+    binaryMetrics.recallByThreshold().toDF("threshold", "recall")
+  }
+}
+
+/**
  * LogisticAggregator computes the gradient and loss for binary logistic loss function, as
used
  * in binary classification for samples in sparse or dense vector in a online fashion.
  *

http://git-wip-us.apache.org/repos/asf/spark/blob/c5c6aded/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index fb1de51..7e9aa38 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -152,4 +152,13 @@ public class JavaLogisticRegressionSuite implements Serializable {
       }
     }
   }
+
+  @Test
+  public void logisticRegressionTrainingSummary() {
+    LogisticRegression lr = new LogisticRegression();
+    LogisticRegressionModel model = lr.fit(dataset);
+
+    LogisticRegressionTrainingSummary summary = model.summary();
+    assert(summary.totalIterations() == summary.objectiveHistory().length);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c5c6aded/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index da13dcb..8c3d459 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -723,6 +723,41 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext
{
     val weightsR = Vectors.dense(0.0, 0.0, 0.0, 0.0)
 
     assert(model1.intercept ~== interceptR relTol 1E-5)
-    assert(model1.weights ~= weightsR absTol 1E-6)
+    assert(model1.weights ~== weightsR absTol 1E-6)
+  }
+
+  test("evaluate on test set") {
+    // Evaluate on test set should be same as that of the transformed training data.
+    val lr = new LogisticRegression()
+      .setMaxIter(10)
+      .setRegParam(1.0)
+      .setThreshold(0.6)
+    val model = lr.fit(dataset)
+    val summary = model.summary.asInstanceOf[BinaryLogisticRegressionSummary]
+
+    val sameSummary = model.evaluate(dataset).asInstanceOf[BinaryLogisticRegressionSummary]
+    assert(summary.areaUnderROC === sameSummary.areaUnderROC)
+    assert(summary.roc.collect() === sameSummary.roc.collect())
+    assert(summary.pr.collect === sameSummary.pr.collect())
+    assert(
+      summary.fMeasureByThreshold.collect() === sameSummary.fMeasureByThreshold.collect())
+    assert(summary.recallByThreshold.collect() === sameSummary.recallByThreshold.collect())
+    assert(
+      summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
+  }
+
+  test("statistics on training data") {
+    // Test that loss is monotonically decreasing.
+    val lr = new LogisticRegression()
+      .setMaxIter(10)
+      .setRegParam(1.0)
+      .setThreshold(0.6)
+    val model = lr.fit(dataset)
+    assert(
+      model.summary
+        .objectiveHistory
+        .sliding(2)
+        .forall(x => x(0) >= x(1)))
+
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message