spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dbt...@apache.org
Subject [3/3] spark git commit: [SPARK-17163][ML] Unified LogisticRegression interface
Date Tue, 20 Sep 2016 04:34:01 GMT
[SPARK-17163][ML] Unified LogisticRegression interface

## What changes were proposed in this pull request?

Merge `MultinomialLogisticRegression` into `LogisticRegression` and remove `MultinomialLogisticRegression`.

Marked as WIP because we should discuss the coefficients API in the model. See discussion below.

JIRA: [SPARK-17163](https://issues.apache.org/jira/browse/SPARK-17163)

## How was this patch tested?

Merged test suites and added some new unit tests.

## Design

### Switching between binomial and multinomial

We default to automatically detecting whether we should run binomial or multinomial lor. We expose a new parameter called `family` which defaults to auto. When "auto" is used, we run normal binomial lor with pivoting if there are 1 or 2 label classes. Otherwise, we run multinomial. If the user explicitly sets the family, then we abide by that setting. In the case where "binomial" is set but multiclass lor is detected, we throw an error.

### coefficients/intercept model API (TODO)

This is the biggest design point remaining, IMO. We need to decide how to store the coefficients and intercepts in the model, and in turn how to expose them via the API. Two important points:

* We must maintain compatibility with the old API, i.e. we must expose `def coefficients: Vector` and `def intercept: Double`
* There are two separate cases: binomial lr where we have a single set of coefficients and a single intercept and multinomial lr where we have `numClasses` sets of coefficients and `numClasses` intercepts.

Some options:

1. **Store the binomial coefficients as a `2 x numFeatures` matrix.** This means that we would center the model coefficients before storing them in the model. The BLOR algorithm gives `1 * numFeatures` coefficients, but we would convert them to `2 x numFeatures` coefficients before storing them, effectively doubling the storage in the model. This has the advantage that we can make the code cleaner (i.e. less `if (isMultinomial) ... else ...`) and we don't have to reason about the different cases as much. It has the disadvantage that we double the storage space and we could see small regressions at prediction time since there are 2x the number of operations in the prediction algorithms. Additionally, we still have to produce the uncentered coefficients/intercept via the API, so we will have to either ALSO store the uncentered version, or compute it in `def coefficients: Vector` every time.

2. **Store the binomial coefficients as a `1 x numFeatures` matrix.** We still store the coefficients as a matrix and the intercepts as a vector. When users call `coefficients` we return them a `Vector` that is backed by the same underlying array as the `coefficientMatrix`, so we don't duplicate any data. At prediction time, we use the old prediction methods that are specialized for binary LOR. The benefits here are that we don't store extra data, and we won't see any regressions in performance. The cost of this is that we have separate implementations for predict methods in the binary vs multiclass case. The duplicated code is really not very high, but it's still a bit messy.

If we do decide to store the 2x coefficients, we would likely want to see some performance tests to understand the potential regressions.

**Update:** We have chosen option 2

### Threshold/thresholds (TODO)

Currently, when `threshold` is set we clear whatever value is in `thresholds` and when `thresholds` is set we clear whatever value is in `threshold`. [SPARK-11543](https://issues.apache.org/jira/browse/SPARK-11543) was created to prefer thresholds over threshold. We should decide if we should implement this behavior now or if we want to do it in a separate JIRA.

**Update:** Let's leave it for a follow up PR

## Follow up

* Summary model for multiclass logistic regression [SPARK-17139](https://issues.apache.org/jira/browse/SPARK-17139)
* Thresholds vs threshold [SPARK-11543](https://issues.apache.org/jira/browse/SPARK-11543)

Author: sethah <seth.hendrickson16@gmail.com>

Closes #14834 from sethah/SPARK-17163.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/26145a5a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/26145a5a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/26145a5a

Branch: refs/heads/master
Commit: 26145a5af9a88053c0eaf280206ca2621c8919f6
Parents: e719b1c
Author: sethah <seth.hendrickson16@gmail.com>
Authored: Mon Sep 19 21:33:54 2016 -0700
Committer: DB Tsai <dbt@netflix.com>
Committed: Mon Sep 19 21:33:54 2016 -0700

----------------------------------------------------------------------
 .../ml/classification/LogisticRegression.scala  |  476 +++++--
 .../MultinomialLogisticRegression.scala         |  632 ---------
 .../ProbabilisticClassifier.scala               |   22 +-
 .../classification/LogisticRegression.scala     |    6 +-
 .../LogisticRegressionSuite.scala               | 1288 ++++++++++++++++--
 .../MultinomialLogisticRegressionSuite.scala    | 1056 --------------
 .../ml/classification/OneVsRestSuite.scala      |    2 +-
 .../spark/ml/tuning/CrossValidatorSuite.scala   |    2 +-
 .../ml/tuning/TrainValidationSplitSuite.scala   |    2 +-
 project/MimaExcludes.scala                      |    3 +
 10 files changed, 1609 insertions(+), 1880 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/26145a5a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 757d520..343d50c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -42,6 +42,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, Row}
 import org.apache.spark.sql.functions.{col, lit}
 import org.apache.spark.sql.types.DoubleType
 import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.VersionUtils
 
 /**
  * Params for logistic regression.
@@ -50,6 +51,8 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
   with HasRegParam with HasElasticNetParam with HasMaxIter with HasFitIntercept with HasTol
   with HasStandardization with HasWeightCol with HasThreshold with HasAggregationDepth {
 
+  import org.apache.spark.ml.classification.LogisticRegression.supportedFamilyNames
+
   /**
    * Set threshold in binary classification, in range [0, 1].
    *
@@ -66,12 +69,37 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
    *
    * @group setParam
    */
+  // TODO: Implement SPARK-11543?
   def setThreshold(value: Double): this.type = {
     if (isSet(thresholds)) clear(thresholds)
     set(threshold, value)
   }
 
   /**
+   * Param for the name of family which is a description of the label distribution
+   * to be used in the model.
+   * Supported options: "auto", "multinomial", "binomial".
+   * Supported options:
+   *  - "auto": Automatically select the family based on the number of classes:
+   *            If numClasses == 1 || numClasses == 2, set to "binomial".
+   *            Else, set to "multinomial"
+   *  - "binomial": Binary logistic regression with pivoting.
+   *  - "multinomial": Multinomial logistic (softmax) regression without pivoting.
+   * Default is "auto".
+   *
+   * @group param
+   */
+  @Since("2.1.0")
+  final val family: Param[String] = new Param(this, "family",
+    "The name of family which is a description of the label distribution to be used in the " +
+      s"model. Supported options: ${supportedFamilyNames.mkString(", ")}.",
+    ParamValidators.inArray[String](supportedFamilyNames))
+
+  /** @group getParam */
+  @Since("2.1.0")
+  def getFamily: String = $(family)
+
+  /**
    * Get threshold for binary classification.
    *
    * If [[thresholds]] is set with length 2 (i.e., binary classification),
@@ -154,9 +182,8 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
 }
 
 /**
- * Logistic regression.
- * Currently, this class only supports binary classification.  For multiclass classification,
- * use [[MultinomialLogisticRegression]]
+ * Logistic regression. Supports multinomial logistic (softmax) regression and binomial logistic
+ * regression.
  */
 @Since("1.2.0")
 class LogisticRegression @Since("1.2.0") (
@@ -221,6 +248,16 @@ class LogisticRegression @Since("1.2.0") (
   setDefault(fitIntercept -> true)
 
   /**
+   * Sets the value of param [[family]].
+   * Default is "auto".
+   *
+   * @group setParam
+   */
+  @Since("2.1.0")
+  def setFamily(value: String): this.type = set(family, value)
+  setDefault(family -> "auto")
+
+  /**
    * Whether to standardize the training features before fitting the model.
    * The coefficients of models will be always returned on the original scale,
    * so it will be transparent for users. Note that with/without standardization,
@@ -261,6 +298,7 @@ class LogisticRegression @Since("1.2.0") (
    * If the dimensions of features or the number of partitions are large,
    * this param could be adjusted to a larger size.
    * Default is 2.
+   *
    * @group expertSetParam
    */
   @Since("2.1.0")
@@ -311,8 +349,27 @@ class LogisticRegression @Since("1.2.0") (
 
     val histogram = labelSummarizer.histogram
     val numInvalid = labelSummarizer.countInvalid
-    val numClasses = histogram.length
     val numFeatures = summarizer.mean.size
+    val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
+
+    val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
+      case Some(n: Int) =>
+        require(n >= histogram.length, s"Specified number of classes $n was " +
+          s"less than the number of unique labels ${histogram.length}.")
+        n
+      case None => histogram.length
+    }
+
+    val isMultinomial = $(family) match {
+      case "binomial" =>
+        require(numClasses == 1 || numClasses == 2, s"Binomial family only supports 1 or 2 " +
+        s"outcome classes but found $numClasses.")
+        false
+      case "multinomial" => true
+      case "auto" => numClasses > 2
+      case other => throw new IllegalArgumentException(s"Unsupported family: $other")
+    }
+    val numCoefficientSets = if (isMultinomial) numClasses else 1
 
     if (isDefined(thresholds)) {
       require($(thresholds).length == numClasses, this.getClass.getSimpleName +
@@ -323,7 +380,7 @@ class LogisticRegression @Since("1.2.0") (
     instr.logNumClasses(numClasses)
     instr.logNumFeatures(numFeatures)
 
-    val (coefficients, intercept, objectiveHistory) = {
+    val (coefficientMatrix, interceptVector, objectiveHistory) = {
       if (numInvalid != 0) {
         val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
           s"Found $numInvalid invalid labels."
@@ -331,24 +388,26 @@ class LogisticRegression @Since("1.2.0") (
         throw new SparkException(msg)
       }
 
-      val isConstantLabel = histogram.count(_ != 0) == 1
+      val isConstantLabel = histogram.count(_ != 0.0) == 1
 
-      if (numClasses > 2) {
-        val msg = s"LogisticRegression with ElasticNet in ML package only supports " +
-          s"binary classification. Found $numClasses in the input dataset. Consider using " +
-          s"MultinomialLogisticRegression instead."
-        logError(msg)
-        throw new SparkException(msg)
-      } else if ($(fitIntercept) && numClasses == 2 && isConstantLabel) {
-        logWarning(s"All labels are one and fitIntercept=true, so the coefficients will be " +
-          s"zeros and the intercept will be positive infinity; as a result, " +
-          s"training is not needed.")
-        (Vectors.sparse(numFeatures, Seq()), Double.PositiveInfinity, Array.empty[Double])
-      } else if ($(fitIntercept) && numClasses == 1) {
-        logWarning(s"All labels are zero and fitIntercept=true, so the coefficients will be " +
-          s"zeros and the intercept will be negative infinity; as a result, " +
-          s"training is not needed.")
-        (Vectors.sparse(numFeatures, Seq()), Double.NegativeInfinity, Array.empty[Double])
+      if ($(fitIntercept) && isConstantLabel) {
+        logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " +
+          s"will be zeros. Training is not needed.")
+        val constantLabelIndex = Vectors.dense(histogram).argmax
+        // TODO: use `compressed` after SPARK-17471
+        val coefMatrix = if (numFeatures < numCoefficientSets) {
+          new SparseMatrix(numCoefficientSets, numFeatures,
+            Array.fill(numFeatures + 1)(0), Array.empty[Int], Array.empty[Double])
+        } else {
+          new SparseMatrix(numCoefficientSets, numFeatures, Array.fill(numCoefficientSets + 1)(0),
+            Array.empty[Int], Array.empty[Double], isTransposed = true)
+        }
+        val interceptVec = if (isMultinomial) {
+          Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity)))
+        } else {
+          Vectors.dense(if (numClasses == 2) Double.PositiveInfinity else Double.NegativeInfinity)
+        }
+        (coefMatrix, interceptVec, Array.empty[Double])
       } else {
         if (!$(fitIntercept) && isConstantLabel) {
           logWarning(s"All labels belong to a single class and fitIntercept=false. It's a " +
@@ -370,7 +429,8 @@ class LogisticRegression @Since("1.2.0") (
 
         val bcFeaturesStd = instances.context.broadcast(featuresStd)
         val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
-          $(standardization), bcFeaturesStd, regParamL2, multinomial = false, $(aggregationDepth))
+          $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial,
+          $(aggregationDepth))
 
         val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
           new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -378,18 +438,28 @@ class LogisticRegression @Since("1.2.0") (
           val standardizationParam = $(standardization)
           def regParamL1Fun = (index: Int) => {
             // Remove the L1 penalization on the intercept
-            if (index == numFeatures) {
+            val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0)
+            if (isIntercept) {
               0.0
             } else {
               if (standardizationParam) {
                 regParamL1
               } else {
+                val featureIndex = if ($(fitIntercept)) {
+                  index % numFeaturesPlusIntercept
+                } else {
+                  index % numFeatures
+                }
                 // If `standardization` is false, we still standardize the data
                 // to improve the rate of convergence; as a result, we have to
                 // perform this reverse standardization by penalizing each component
                 // differently to get effectively the same objective function when
                 // the training dataset is not standardized.
-                if (featuresStd(index) != 0.0) regParamL1 / featuresStd(index) else 0.0
+                if (featuresStd(featureIndex) != 0.0) {
+                  regParamL1 / featuresStd(featureIndex)
+                } else {
+                  0.0
+                }
               }
             }
           }
@@ -397,22 +467,67 @@ class LogisticRegression @Since("1.2.0") (
         }
 
         val initialCoefficientsWithIntercept =
-          Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
-
-        if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) {
-          val vecSize = optInitialModel.get.coefficients.size
-          logWarning(
-            s"Initial coefficients will be ignored!! As its size $vecSize did not match the " +
-            s"expected size $numFeatures")
+          Vectors.zeros(numCoefficientSets * numFeaturesPlusIntercept)
+
+        val initialModelIsValid = optInitialModel match {
+          case Some(_initialModel) =>
+            val providedCoefs = _initialModel.coefficientMatrix
+            val modelIsValid = (providedCoefs.numRows == numCoefficientSets) &&
+              (providedCoefs.numCols == numFeatures) &&
+              (_initialModel.interceptVector.size == numCoefficientSets) &&
+              (_initialModel.getFitIntercept == $(fitIntercept))
+            if (!modelIsValid) {
+              logWarning(s"Initial coefficients will be ignored! Its dimensions " +
+                s"(${providedCoefs.numRows}, ${providedCoefs.numCols}) did not match the " +
+                s"expected size ($numCoefficientSets, $numFeatures)")
+            }
+            modelIsValid
+          case None => false
         }
 
-        if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) {
-          val initialCoefficientsWithInterceptArray = initialCoefficientsWithIntercept.toArray
-          optInitialModel.get.coefficients.foreachActive { case (index, value) =>
-            initialCoefficientsWithInterceptArray(index) = value
+        if (initialModelIsValid) {
+          val initialCoefWithInterceptArray = initialCoefficientsWithIntercept.toArray
+          val providedCoef = optInitialModel.get.coefficientMatrix
+          providedCoef.foreachActive { (row, col, value) =>
+            val flatIndex = row * numFeaturesPlusIntercept + col
+            // We need to scale the coefficients since they will be trained in the scaled space
+            initialCoefWithInterceptArray(flatIndex) = value * featuresStd(col)
           }
           if ($(fitIntercept)) {
-            initialCoefficientsWithInterceptArray(numFeatures) == optInitialModel.get.intercept
+            optInitialModel.get.interceptVector.foreachActive { (index, value) =>
+              val coefIndex = (index + 1) * numFeaturesPlusIntercept - 1
+              initialCoefWithInterceptArray(coefIndex) = value
+            }
+          }
+        } else if ($(fitIntercept) && isMultinomial) {
+          /*
+             For multinomial logistic regression, when we initialize the coefficients as zeros,
+             it will converge faster if we initialize the intercepts such that
+             it follows the distribution of the labels.
+             {{{
+               P(1) = \exp(b_1) / Z
+               ...
+               P(K) = \exp(b_K) / Z
+               where Z = \sum_{k=1}^{K} \exp(b_k)
+             }}}
+             Since this doesn't have a unique solution, one of the solutions that satisfies the
+             above equations is
+             {{{
+               \exp(b_k) = count_k * \exp(\lambda)
+               b_k = \log(count_k) * \lambda
+             }}}
+             \lambda is a free parameter, so choose the phase \lambda such that the
+             mean is centered. This yields
+             {{{
+               b_k = \log(count_k)
+               b_k' = b_k - \mean(b_k)
+             }}}
+           */
+          val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing
+          val rawMean = rawIntercepts.sum / rawIntercepts.length
+          rawIntercepts.indices.foreach { i =>
+            initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) =
+              rawIntercepts(i) - rawMean
           }
         } else if ($(fitIntercept)) {
           /*
@@ -446,6 +561,7 @@ class LogisticRegression @Since("1.2.0") (
           state = states.next()
           arrayBuilder += state.adjustedValue
         }
+        bcFeaturesStd.destroy(blocking = false)
 
         if (state == null) {
           val msg = s"${optimizer.getClass.getName} failed."
@@ -460,33 +576,85 @@ class LogisticRegression @Since("1.2.0") (
            as a result, no scaling is needed.
          */
         val rawCoefficients = state.x.toArray.clone()
-        var i = 0
-        while (i < numFeatures) {
-          rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
-          i += 1
+        val coefficientArray = Array.tabulate(numCoefficientSets * numFeatures) { i =>
+          // flatIndex will loop though rawCoefficients, and skip the intercept terms.
+          val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i
+          val featureIndex = i % numFeatures
+          if (featuresStd(featureIndex) != 0.0) {
+            rawCoefficients(flatIndex) / featuresStd(featureIndex)
+          } else {
+            0.0
+          }
+        }
+
+        if ($(regParam) == 0.0 && isMultinomial) {
+          /*
+            When no regularization is applied, the multinomial coefficients lack identifiability
+            because we do not use a pivot class. We can add any constant value to the coefficients
+            and get the same likelihood. So here, we choose the mean centered coefficients for
+            reproducibility. This method follows the approach in glmnet, described here:
+
+            Friedman, et al. "Regularization Paths for Generalized Linear Models via
+              Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf
+           */
+          val coefficientMean = coefficientArray.sum / coefficientArray.length
+          coefficientArray.indices.foreach { i => coefficientArray(i) -= coefficientMean}
         }
-        bcFeaturesStd.destroy(blocking = false)
 
-        if ($(fitIntercept)) {
-          (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
-            arrayBuilder.result())
+        val denseCoefficientMatrix =
+          new DenseMatrix(numCoefficientSets, numFeatures, coefficientArray, isTransposed = true)
+        // TODO: use `denseCoefficientMatrix.compressed` after SPARK-17471
+        val compressedCoefficientMatrix = if (isMultinomial) {
+          denseCoefficientMatrix
         } else {
-          (Vectors.dense(rawCoefficients).compressed, 0.0, arrayBuilder.result())
+          val compressedVector = Vectors.dense(coefficientArray).compressed
+          compressedVector match {
+            case dv: DenseVector => denseCoefficientMatrix
+            case sv: SparseVector =>
+              new SparseMatrix(1, numFeatures, Array(0, sv.indices.length), sv.indices, sv.values,
+                isTransposed = true)
+          }
         }
+
+        val interceptsArray: Array[Double] = if ($(fitIntercept)) {
+          Array.tabulate(numCoefficientSets) { i =>
+            val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1
+            rawCoefficients(coefIndex)
+          }
+        } else {
+          Array[Double]()
+        }
+        val interceptVector = if (interceptsArray.nonEmpty && isMultinomial) {
+          // The intercepts are never regularized, so we always center the mean.
+          val interceptMean = interceptsArray.sum / numClasses
+          interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean }
+          Vectors.dense(interceptsArray)
+        } else if (interceptsArray.length == 1) {
+          Vectors.dense(interceptsArray)
+        } else {
+          Vectors.sparse(numCoefficientSets, Seq())
+        }
+        (compressedCoefficientMatrix, interceptVector.compressed, arrayBuilder.result())
       }
     }
 
     if (handlePersistence) instances.unpersist()
 
-    val model = copyValues(new LogisticRegressionModel(uid, coefficients, intercept))
-    val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
-    val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
-      summaryModel.transform(dataset),
-      probabilityColName,
-      $(labelCol),
-      $(featuresCol),
-      objectiveHistory)
-    val m = model.setSummary(logRegSummary)
+    val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
+      numClasses, isMultinomial))
+    // TODO: implement summary model for multinomial case
+    val m = if (!isMultinomial) {
+      val (summaryModel, probabilityColName) = model.findSummaryModelAndProbabilityCol()
+      val logRegSummary = new BinaryLogisticRegressionTrainingSummary(
+        summaryModel.transform(dataset),
+        probabilityColName,
+        $(labelCol),
+        $(featuresCol),
+        objectiveHistory)
+      model.setSummary(logRegSummary)
+    } else {
+      model
+    }
     instr.logSuccess(m)
     m
   }
@@ -500,6 +668,9 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
 
   @Since("1.6.0")
   override def load(path: String): LogisticRegression = super.load(path)
+
+  private[classification] val supportedFamilyNames =
+    Array("auto", "binomial", "multinomial").map(_.toLowerCase)
 }
 
 /**
@@ -508,11 +679,59 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
 @Since("1.4.0")
 class LogisticRegressionModel private[spark] (
     @Since("1.4.0") override val uid: String,
-    @Since("2.0.0") val coefficients: Vector,
-    @Since("1.3.0") val intercept: Double)
+    @Since("2.1.0") val coefficientMatrix: Matrix,
+    @Since("2.1.0") val interceptVector: Vector,
+    @Since("1.3.0") override val numClasses: Int,
+    private val isMultinomial: Boolean)
   extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
   with LogisticRegressionParams with MLWritable {
 
+  require(coefficientMatrix.numRows == interceptVector.size, s"Dimension mismatch! Expected " +
+    s"coefficientMatrix.numRows == interceptVector.size, but ${coefficientMatrix.numRows} != " +
+    s"${interceptVector.size}")
+
+  private[spark] def this(uid: String, coefficients: Vector, intercept: Double) =
+    this(uid, new DenseMatrix(1, coefficients.size, coefficients.toArray, isTransposed = true),
+      Vectors.dense(intercept), 2, isMultinomial = false)
+
+  /**
+   * A vector of model coefficients for "binomial" logistic regression. If this model was trained
+   * using the "multinomial" family then an exception is thrown.
+   * @return Vector
+   */
+  @Since("2.0.0")
+  def coefficients: Vector = if (isMultinomial) {
+    throw new SparkException("Multinomial models contain a matrix of coefficients, use " +
+      "coefficientMatrix instead.")
+  } else {
+    _coefficients
+  }
+
+  // convert to appropriate vector representation without replicating data
+  private lazy val _coefficients: Vector = {
+    require(coefficientMatrix.isTransposed,
+      "LogisticRegressionModel coefficients should be row major.")
+    coefficientMatrix match {
+      case dm: DenseMatrix => Vectors.dense(dm.values)
+      case sm: SparseMatrix => Vectors.sparse(coefficientMatrix.numCols, sm.rowIndices, sm.values)
+    }
+  }
+
+  /**
+   * The model intercept for "binomial" logistic regression. If this model was fit with the
+   * "multinomial" family then an exception is thrown.
+   * @return Double
+   */
+  @Since("1.3.0")
+  def intercept: Double = if (isMultinomial) {
+    throw new SparkException("Multinomial models contain a vector of intercepts, use " +
+      "interceptVector instead.")
+  } else {
+    _intercept
+  }
+
+  private lazy val _intercept = interceptVector.toArray.head
+
   @Since("1.5.0")
   override def setThreshold(value: Double): this.type = super.setThreshold(value)
 
@@ -527,7 +746,14 @@ class LogisticRegressionModel private[spark] (
 
   /** Margin (rawPrediction) for class label 1.  For binary classification only. */
   private val margin: Vector => Double = (features) => {
-    BLAS.dot(features, coefficients) + intercept
+    BLAS.dot(features, _coefficients) + _intercept
+  }
+
+  /** Margin (rawPrediction) for each class label. */
+  private val margins: Vector => Vector = (features) => {
+    val m = interceptVector.toDense.copy
+    BLAS.gemv(1.0, coefficientMatrix, features, 1.0, m)
+    m
   }
 
   /** Score (probability) for class label 1.  For binary classification only. */
@@ -537,10 +763,7 @@ class LogisticRegressionModel private[spark] (
   }
 
   @Since("1.6.0")
-  override val numFeatures: Int = coefficients.size
-
-  @Since("1.3.0")
-  override val numClasses: Int = 2
+  override val numFeatures: Int = coefficientMatrix.numCols
 
   private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
 
@@ -595,7 +818,9 @@ class LogisticRegressionModel private[spark] (
    * Predict label for the given feature vector.
    * The behavior of this can be adjusted using [[thresholds]].
    */
-  override protected def predict(features: Vector): Double = {
+  override protected def predict(features: Vector): Double = if (isMultinomial) {
+    super.predict(features)
+  } else {
     // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
     if (score(features) > getThreshold) 1 else 0
   }
@@ -603,13 +828,47 @@ class LogisticRegressionModel private[spark] (
   override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
     rawPrediction match {
       case dv: DenseVector =>
-        var i = 0
-        val size = dv.size
-        while (i < size) {
-          dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
-          i += 1
+        if (isMultinomial) {
+          val size = dv.size
+          val values = dv.values
+
+          // get the maximum margin
+          val maxMarginIndex = rawPrediction.argmax
+          val maxMargin = rawPrediction(maxMarginIndex)
+
+          if (maxMargin == Double.PositiveInfinity) {
+            var k = 0
+            while (k < size) {
+              values(k) = if (k == maxMarginIndex) 1.0 else 0.0
+              k += 1
+            }
+          } else {
+            val sum = {
+              var temp = 0.0
+              var k = 0
+              while (k < numClasses) {
+                values(k) = if (maxMargin > 0) {
+                  math.exp(values(k) - maxMargin)
+                } else {
+                  math.exp(values(k))
+                }
+                temp += values(k)
+                k += 1
+              }
+              temp
+            }
+            BLAS.scal(1 / sum, dv)
+          }
+          dv
+        } else {
+          var i = 0
+          val size = dv.size
+          while (i < size) {
+            dv.values(i) = 1.0 / (1.0 + math.exp(-dv.values(i)))
+            i += 1
+          }
+          dv
         }
-        dv
       case sv: SparseVector =>
         throw new RuntimeException("Unexpected error in LogisticRegressionModel:" +
           " raw2probabilitiesInPlace encountered SparseVector")
@@ -617,33 +876,46 @@ class LogisticRegressionModel private[spark] (
   }
 
   override protected def predictRaw(features: Vector): Vector = {
-    val m = margin(features)
-    Vectors.dense(-m, m)
+    if (isMultinomial) {
+      margins(features)
+    } else {
+      val m = margin(features)
+      Vectors.dense(-m, m)
+    }
   }
 
   @Since("1.4.0")
   override def copy(extra: ParamMap): LogisticRegressionModel = {
-    val newModel = copyValues(new LogisticRegressionModel(uid, coefficients, intercept), extra)
+    val newModel = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
+      numClasses, isMultinomial), extra)
     if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
     newModel.setParent(parent)
   }
 
   override protected def raw2prediction(rawPrediction: Vector): Double = {
-    // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
-    val t = getThreshold
-    val rawThreshold = if (t == 0.0) {
-      Double.NegativeInfinity
-    } else if (t == 1.0) {
-      Double.PositiveInfinity
+    if (isMultinomial) {
+      super.raw2prediction(rawPrediction)
     } else {
-      math.log(t / (1.0 - t))
+      // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
+      val t = getThreshold
+      val rawThreshold = if (t == 0.0) {
+        Double.NegativeInfinity
+      } else if (t == 1.0) {
+        Double.PositiveInfinity
+      } else {
+        math.log(t / (1.0 - t))
+      }
+      if (rawPrediction(1) > rawThreshold) 1 else 0
     }
-    if (rawPrediction(1) > rawThreshold) 1 else 0
   }
 
   override protected def probability2prediction(probability: Vector): Double = {
-    // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
-    if (probability(1) > getThreshold) 1 else 0
+    if (isMultinomial) {
+      super.probability2prediction(probability)
+    } else {
+      // Note: We should use getThreshold instead of $(threshold) since getThreshold is overridden.
+      if (probability(1) > getThreshold) 1 else 0
+    }
   }
 
   /**
@@ -676,39 +948,53 @@ object LogisticRegressionModel extends MLReadable[LogisticRegressionModel] {
     private case class Data(
         numClasses: Int,
         numFeatures: Int,
-        intercept: Double,
-        coefficients: Vector)
+        interceptVector: Vector,
+        coefficientMatrix: Matrix,
+        isMultinomial: Boolean)
 
     override protected def saveImpl(path: String): Unit = {
       // Save metadata and Params
       DefaultParamsWriter.saveMetadata(instance, path, sc)
       // Save model data: numClasses, numFeatures, intercept, coefficients
-      val data = Data(instance.numClasses, instance.numFeatures, instance.intercept,
-        instance.coefficients)
+      val data = Data(instance.numClasses, instance.numFeatures, instance.interceptVector,
+        instance.coefficientMatrix, instance.isMultinomial)
       val dataPath = new Path(path, "data").toString
       sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
   }
 
-  private class LogisticRegressionModelReader
-    extends MLReader[LogisticRegressionModel] {
+  private class LogisticRegressionModelReader extends MLReader[LogisticRegressionModel] {
 
     /** Checked against metadata when loading model */
     private val className = classOf[LogisticRegressionModel].getName
 
     override def load(path: String): LogisticRegressionModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+      val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
 
       val dataPath = new Path(path, "data").toString
       val data = sparkSession.read.format("parquet").load(dataPath)
 
-      // We will need numClasses, numFeatures in the future for multinomial logreg support.
-      // TODO: remove numClasses and numFeatures fields?
-      val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
-        MLUtils.convertVectorColumnsToML(data, "coefficients")
-          .select("numClasses", "numFeatures", "intercept", "coefficients")
-          .head()
-      val model = new LogisticRegressionModel(metadata.uid, coefficients, intercept)
+      val model = if (major.toInt < 2 || (major.toInt == 2 && minor.toInt == 0)) {
+        // 2.0 and before
+        val Row(numClasses: Int, numFeatures: Int, intercept: Double, coefficients: Vector) =
+          MLUtils.convertVectorColumnsToML(data, "coefficients")
+            .select("numClasses", "numFeatures", "intercept", "coefficients")
+            .head()
+        val coefficientMatrix =
+          new DenseMatrix(1, coefficients.size, coefficients.toArray, isTransposed = true)
+        val interceptVector = Vectors.dense(intercept)
+        new LogisticRegressionModel(metadata.uid, coefficientMatrix,
+          interceptVector, numClasses, isMultinomial = false)
+      } else {
+        // 2.1+
+        val Row(numClasses: Int, numFeatures: Int, interceptVector: Vector,
+        coefficientMatrix: Matrix, isMultinomial: Boolean) = data
+          .select("numClasses", "numFeatures", "interceptVector", "coefficientMatrix",
+            "isMultinomial").head()
+        new LogisticRegressionModel(metadata.uid, coefficientMatrix, interceptVector,
+          numClasses, isMultinomial)
+      }
 
       DefaultParamsReader.getAndSetParams(model, metadata)
       model

http://git-wip-us.apache.org/repos/asf/spark/blob/26145a5a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
deleted file mode 100644
index 006f57c..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultinomialLogisticRegression.scala
+++ /dev/null
@@ -1,632 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.classification
-
-import scala.collection.mutable
-
-import breeze.linalg.{DenseVector => BDV}
-import breeze.optimize.{CachedDiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark.SparkException
-import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.internal.Logging
-import org.apache.spark.ml.feature.Instance
-import org.apache.spark.ml.linalg._
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util._
-import org.apache.spark.mllib.linalg.VectorImplicits._
-import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.{col, lit}
-import org.apache.spark.sql.types.DoubleType
-import org.apache.spark.storage.StorageLevel
-
-/**
- * Params for multinomial logistic (softmax) regression.
- */
-private[classification] trait MultinomialLogisticRegressionParams
-  extends ProbabilisticClassifierParams with HasRegParam with HasElasticNetParam with HasMaxIter
-    with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
-    with HasAggregationDepth {
-
-  /**
-   * Set thresholds in multiclass (or binary) classification to adjust the probability of
-   * predicting each class. Array must have length equal to the number of classes, with values >= 0.
-   * The class with largest value p/t is predicted, where p is the original probability of that
-   * class and t is the class' threshold.
-   *
-   * @group setParam
-   */
-  def setThresholds(value: Array[Double]): this.type = {
-    set(thresholds, value)
-  }
-
-  /**
-   * Get thresholds for binary or multiclass classification.
-   *
-   * @group getParam
-   */
-  override def getThresholds: Array[Double] = {
-    $(thresholds)
-  }
-}
-
-/**
- * :: Experimental ::
- * Multinomial Logistic (softmax) regression.
- */
-@Since("2.1.0")
-@Experimental
-class MultinomialLogisticRegression @Since("2.1.0") (
-    @Since("2.1.0") override val uid: String)
-  extends ProbabilisticClassifier[Vector,
-    MultinomialLogisticRegression, MultinomialLogisticRegressionModel]
-    with MultinomialLogisticRegressionParams with DefaultParamsWritable with Logging {
-
-  @Since("2.1.0")
-  def this() = this(Identifiable.randomUID("mlogreg"))
-
-  /**
-   * Set the regularization parameter.
-   * Default is 0.0.
-   *
-   * @group setParam
-   */
-  @Since("2.1.0")
-  def setRegParam(value: Double): this.type = set(regParam, value)
-  setDefault(regParam -> 0.0)
-
-  /**
-   * Set the ElasticNet mixing parameter.
-   * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
-   * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
-   * Default is 0.0 which is an L2 penalty.
-   *
-   * @group setParam
-   */
-  @Since("2.1.0")
-  def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
-  setDefault(elasticNetParam -> 0.0)
-
-  /**
-   * Set the maximum number of iterations.
-   * Default is 100.
-   *
-   * @group setParam
-   */
-  @Since("2.1.0")
-  def setMaxIter(value: Int): this.type = set(maxIter, value)
-  setDefault(maxIter -> 100)
-
-  /**
-   * Set the convergence tolerance of iterations.
-   * Smaller value will lead to higher accuracy with the cost of more iterations.
-   * Default is 1E-6.
-   *
-   * @group setParam
-   */
-  @Since("2.1.0")
-  def setTol(value: Double): this.type = set(tol, value)
-  setDefault(tol -> 1E-6)
-
-  /**
-   * Whether to fit an intercept term.
-   * Default is true.
-   *
-   * @group setParam
-   */
-  @Since("2.1.0")
-  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
-  setDefault(fitIntercept -> true)
-
-  /**
-   * Whether to standardize the training features before fitting the model.
-   * The coefficients of models will be always returned on the original scale,
-   * so it will be transparent for users. Note that with/without standardization,
-   * the models should always converge to the same solution when no regularization
-   * is applied. In R's GLMNET package, the default behavior is true as well.
-   * Default is true.
-   *
-   * @group setParam
-   */
-  @Since("2.1.0")
-  def setStandardization(value: Boolean): this.type = set(standardization, value)
-  setDefault(standardization -> true)
-
-  /**
-   * Sets the value of param [[weightCol]].
-   * If this is not set or empty, we treat all instance weights as 1.0.
-   * Default is not set, so all instances have weight one.
-   *
-   * @group setParam
-   */
-  @Since("2.1.0")
-  def setWeightCol(value: String): this.type = set(weightCol, value)
-
-  @Since("2.1.0")
-  override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
-
-  /**
-   * Suggested depth for treeAggregate (>= 2).
-   * If the dimensions of features or the number of partitions are large,
-   * this param could be adjusted to a larger size.
-   * Default is 2.
-   * @group expertSetParam
-   */
-  @Since("2.1.0")
-  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
-  setDefault(aggregationDepth -> 2)
-
-  override protected[spark] def train(dataset: Dataset[_]): MultinomialLogisticRegressionModel = {
-    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
-    val instances: RDD[Instance] =
-      dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map {
-        case Row(label: Double, weight: Double, features: Vector) =>
-          Instance(label, weight, features)
-      }
-
-    val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
-    if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
-
-    val instr = Instrumentation.create(this, instances)
-    instr.logParams(regParam, elasticNetParam, standardization, thresholds,
-      maxIter, tol, fitIntercept)
-
-    val (summarizer, labelSummarizer) = {
-      val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
-       instance: Instance) =>
-        (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
-
-      val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
-        c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
-          (c1._1.merge(c2._1), c1._2.merge(c2._2))
-
-      instances.treeAggregate(
-        new MultivariateOnlineSummarizer, new MultiClassSummarizer)(seqOp, combOp)
-    }
-
-    val histogram = labelSummarizer.histogram
-    val numInvalid = labelSummarizer.countInvalid
-    val numFeatures = summarizer.mean.size
-    val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
-
-    val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
-      case Some(n: Int) =>
-        require(n >= histogram.length, s"Specified number of classes $n was " +
-          s"less than the number of unique labels ${histogram.length}")
-        n
-      case None => histogram.length
-    }
-
-    instr.logNumClasses(numClasses)
-    instr.logNumFeatures(numFeatures)
-
-    val (coefficients, intercepts, objectiveHistory) = {
-      if (numInvalid != 0) {
-        val msg = s"Classification labels should be in {0 to ${numClasses - 1} " +
-          s"Found $numInvalid invalid labels."
-        logError(msg)
-        throw new SparkException(msg)
-      }
-
-      val isConstantLabel = histogram.count(_ != 0) == 1
-
-      if ($(fitIntercept) && isConstantLabel) {
-        // we want to produce a model that will always predict the constant label so all the
-        // coefficients will be zero, and the constant label class intercept will be +inf
-        val constantLabelIndex = Vectors.dense(histogram).argmax
-        (Matrices.sparse(numClasses, numFeatures, Array.fill(numFeatures + 1)(0),
-          Array.empty[Int], Array.empty[Double]),
-          Vectors.sparse(numClasses, Seq((constantLabelIndex, Double.PositiveInfinity))),
-          Array.empty[Double])
-      } else {
-        if (!$(fitIntercept) && isConstantLabel) {
-          logWarning(s"All labels belong to a single class and fitIntercept=false. It's" +
-            s"a dangerous ground, so the algorithm may not converge.")
-        }
-
-        val featuresStd = summarizer.variance.toArray.map(math.sqrt)
-        val featuresMean = summarizer.mean.toArray
-        if (!$(fitIntercept) && (0 until numFeatures).exists { i =>
-          featuresStd(i) == 0.0 && featuresMean(i) != 0.0 }) {
-          logWarning("Fitting MultinomialLogisticRegressionModel without intercept on dataset " +
-            "with constant nonzero column, Spark MLlib outputs zero coefficients for constant " +
-            "nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.")
-        }
-
-        val regParamL1 = $(elasticNetParam) * $(regParam)
-        val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
-
-        val bcFeaturesStd = instances.context.broadcast(featuresStd)
-        val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
-          $(standardization), bcFeaturesStd, regParamL2, multinomial = true, $(aggregationDepth))
-
-        val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
-          new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
-        } else {
-          val standardizationParam = $(standardization)
-          def regParamL1Fun = (index: Int) => {
-            // Remove the L1 penalization on the intercept
-            val isIntercept = $(fitIntercept) && ((index + 1) % numFeaturesPlusIntercept == 0)
-            if (isIntercept) {
-              0.0
-            } else {
-              if (standardizationParam) {
-                regParamL1
-              } else {
-                val featureIndex = if ($(fitIntercept)) {
-                  index % numFeaturesPlusIntercept
-                } else {
-                  index % numFeatures
-                }
-                // If `standardization` is false, we still standardize the data
-                // to improve the rate of convergence; as a result, we have to
-                // perform this reverse standardization by penalizing each component
-                // differently to get effectively the same objective function when
-                // the training dataset is not standardized.
-                if (featuresStd(featureIndex) != 0.0) {
-                  regParamL1 / featuresStd(featureIndex)
-                } else {
-                  0.0
-                }
-              }
-            }
-          }
-          new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol))
-        }
-
-        val initialCoefficientsWithIntercept = Vectors.zeros(numClasses * numFeaturesPlusIntercept)
-
-        if ($(fitIntercept)) {
-          /*
-             For multinomial logistic regression, when we initialize the coefficients as zeros,
-             it will converge faster if we initialize the intercepts such that
-             it follows the distribution of the labels.
-             {{{
-               P(1) = \exp(b_1) / Z
-               ...
-               P(K) = \exp(b_K) / Z
-               where Z = \sum_{k=1}^{K} \exp(b_k)
-             }}}
-             Since this doesn't have a unique solution, one of the solutions that satisfies the
-             above equations is
-             {{{
-               \exp(b_k) = count_k * \exp(\lambda)
-               b_k = \log(count_k) * \lambda
-             }}}
-             \lambda is a free parameter, so choose the phase \lambda such that the
-             mean is centered. This yields
-             {{{
-               b_k = \log(count_k)
-               b_k' = b_k - \mean(b_k)
-             }}}
-           */
-          val rawIntercepts = histogram.map(c => math.log(c + 1)) // add 1 for smoothing
-          val rawMean = rawIntercepts.sum / rawIntercepts.length
-          rawIntercepts.indices.foreach { i =>
-            initialCoefficientsWithIntercept.toArray(i * numFeaturesPlusIntercept + numFeatures) =
-              rawIntercepts(i) - rawMean
-          }
-        }
-
-        val states = optimizer.iterations(new CachedDiffFunction(costFun),
-          initialCoefficientsWithIntercept.asBreeze.toDenseVector)
-
-        /*
-           Note that in Multinomial Logistic Regression, the objective history
-           (loss + regularization) is log-likelihood which is invariant under feature
-           standardization. As a result, the objective history from optimizer is the same as the
-           one in the original space.
-         */
-        val arrayBuilder = mutable.ArrayBuilder.make[Double]
-        var state: optimizer.State = null
-        while (states.hasNext) {
-          state = states.next()
-          arrayBuilder += state.adjustedValue
-        }
-
-        if (state == null) {
-          val msg = s"${optimizer.getClass.getName} failed."
-          logError(msg)
-          throw new SparkException(msg)
-        }
-        bcFeaturesStd.destroy(blocking = false)
-
-        /*
-           The coefficients are trained in the scaled space; we're converting them back to
-           the original space.
-           Note that the intercept in scaled space and original space is the same;
-           as a result, no scaling is needed.
-         */
-        val rawCoefficients = state.x.toArray
-        val interceptsArray: Array[Double] = if ($(fitIntercept)) {
-          Array.tabulate(numClasses) { i =>
-            val coefIndex = (i + 1) * numFeaturesPlusIntercept - 1
-            rawCoefficients(coefIndex)
-          }
-        } else {
-          Array.empty
-        }
-
-        val coefficientArray: Array[Double] = Array.tabulate(numClasses * numFeatures) { i =>
-          // flatIndex will loop though rawCoefficients, and skip the intercept terms.
-          val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i
-          val featureIndex = i % numFeatures
-          if (featuresStd(featureIndex) != 0.0) {
-            rawCoefficients(flatIndex) / featuresStd(featureIndex)
-          } else {
-            0.0
-          }
-        }
-        val coefficientMatrix =
-          new DenseMatrix(numClasses, numFeatures, coefficientArray, isTransposed = true)
-
-        /*
-          When no regularization is applied, the coefficients lack identifiability because
-          we do not use a pivot class. We can add any constant value to the coefficients and
-          get the same likelihood. So here, we choose the mean centered coefficients for
-          reproducibility. This method follows the approach in glmnet, described here:
-
-          Friedman, et al. "Regularization Paths for Generalized Linear Models via
-            Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf
-         */
-        if ($(regParam) == 0.0) {
-          val coefficientMean = coefficientMatrix.values.sum / (numClasses * numFeatures)
-          coefficientMatrix.update(_ - coefficientMean)
-        }
-        /*
-          The intercepts are never regularized, so we always center the mean.
-         */
-        val interceptVector = if (interceptsArray.nonEmpty) {
-          val interceptMean = interceptsArray.sum / numClasses
-          interceptsArray.indices.foreach { i => interceptsArray(i) -= interceptMean }
-          Vectors.dense(interceptsArray)
-        } else {
-          Vectors.sparse(numClasses, Seq())
-        }
-
-        (coefficientMatrix, interceptVector, arrayBuilder.result())
-      }
-    }
-
-    if (handlePersistence) instances.unpersist()
-
-    val model = copyValues(
-      new MultinomialLogisticRegressionModel(uid, coefficients, intercepts, numClasses))
-    instr.logSuccess(model)
-    model
-  }
-
-  @Since("2.1.0")
-  override def copy(extra: ParamMap): MultinomialLogisticRegression = defaultCopy(extra)
-}
-
-@Since("2.1.0")
-object MultinomialLogisticRegression extends DefaultParamsReadable[MultinomialLogisticRegression] {
-
-  @Since("2.1.0")
-  override def load(path: String): MultinomialLogisticRegression = super.load(path)
-}
-
-/**
- * :: Experimental ::
- * Model produced by [[MultinomialLogisticRegression]].
- */
-@Since("2.1.0")
-@Experimental
-class MultinomialLogisticRegressionModel private[spark] (
-    @Since("2.1.0") override val uid: String,
-    @Since("2.1.0") val coefficients: Matrix,
-    @Since("2.1.0") val intercepts: Vector,
-    @Since("2.1.0") val numClasses: Int)
-  extends ProbabilisticClassificationModel[Vector, MultinomialLogisticRegressionModel]
-    with MultinomialLogisticRegressionParams with MLWritable {
-
-  @Since("2.1.0")
-  override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value)
-
-  @Since("2.1.0")
-  override def getThresholds: Array[Double] = super.getThresholds
-
-  @Since("2.1.0")
-  override val numFeatures: Int = coefficients.numCols
-
-  /** Margin (rawPrediction) for each class label. */
-  private val margins: Vector => Vector = (features) => {
-    val m = intercepts.toDense.copy
-    BLAS.gemv(1.0, coefficients, features, 1.0, m)
-    m
-  }
-
-  /** Score (probability) for each class label. */
-  private val scores: Vector => Vector = (features) => {
-    val m = margins(features)
-    val maxMarginIndex = m.argmax
-    val marginArray = m.toArray
-    val maxMargin = marginArray(maxMarginIndex)
-
-    // adjust margins for overflow
-    val sum = {
-      var temp = 0.0
-      var k = 0
-      while (k < numClasses) {
-        marginArray(k) = if (maxMargin > 0) {
-          math.exp(marginArray(k) - maxMargin)
-        } else {
-          math.exp(marginArray(k))
-        }
-        temp += marginArray(k)
-        k += 1
-      }
-      temp
-    }
-
-    val scores = Vectors.dense(marginArray)
-    BLAS.scal(1 / sum, scores)
-    scores
-  }
-
-  /**
-   * Predict label for the given feature vector.
-   * The behavior of this can be adjusted using [[thresholds]].
-   */
-  override protected def predict(features: Vector): Double = {
-    if (isDefined(thresholds)) {
-      val thresholds: Array[Double] = getThresholds
-      val probabilities = scores(features).toArray
-      var argMax = 0
-      var max = Double.NegativeInfinity
-      var i = 0
-      while (i < numClasses) {
-        if (thresholds(i) == 0.0) {
-          max = Double.PositiveInfinity
-          argMax = i
-        } else {
-          val scaled = probabilities(i) / thresholds(i)
-          if (scaled > max) {
-            max = scaled
-            argMax = i
-          }
-        }
-        i += 1
-      }
-      argMax
-    } else {
-      scores(features).argmax
-    }
-  }
-
-  override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
-    rawPrediction match {
-      case dv: DenseVector =>
-        val size = dv.size
-        val values = dv.values
-
-        // get the maximum margin
-        val maxMarginIndex = rawPrediction.argmax
-        val maxMargin = rawPrediction(maxMarginIndex)
-
-        if (maxMargin == Double.PositiveInfinity) {
-          var k = 0
-          while (k < size) {
-            values(k) = if (k == maxMarginIndex) 1.0 else 0.0
-            k += 1
-          }
-        } else {
-          val sum = {
-            var temp = 0.0
-            var k = 0
-            while (k < numClasses) {
-              values(k) = if (maxMargin > 0) {
-                math.exp(values(k) - maxMargin)
-              } else {
-                math.exp(values(k))
-              }
-              temp += values(k)
-              k += 1
-            }
-            temp
-          }
-          BLAS.scal(1 / sum, dv)
-        }
-        dv
-      case sv: SparseVector =>
-        throw new RuntimeException("Unexpected error in MultinomialLogisticRegressionModel:" +
-          " raw2probabilitiesInPlace encountered SparseVector")
-    }
-  }
-
-  override protected def predictRaw(features: Vector): Vector = margins(features)
-
-  @Since("2.1.0")
-  override def copy(extra: ParamMap): MultinomialLogisticRegressionModel = {
-    val newModel =
-      copyValues(
-        new MultinomialLogisticRegressionModel(uid, coefficients, intercepts, numClasses), extra)
-    newModel.setParent(parent)
-  }
-
-  /**
-   * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance.
-   *
-   * This does not save the [[parent]] currently.
-   */
-  @Since("2.1.0")
-  override def write: MLWriter =
-    new MultinomialLogisticRegressionModel.MultinomialLogisticRegressionModelWriter(this)
-}
-
-
-@Since("2.1.0")
-object MultinomialLogisticRegressionModel extends MLReadable[MultinomialLogisticRegressionModel] {
-
-  @Since("2.1.0")
-  override def read: MLReader[MultinomialLogisticRegressionModel] =
-    new MultinomialLogisticRegressionModelReader
-
-  @Since("2.1.0")
-  override def load(path: String): MultinomialLogisticRegressionModel = super.load(path)
-
-  /** [[MLWriter]] instance for [[MultinomialLogisticRegressionModel]] */
-  private[MultinomialLogisticRegressionModel]
-  class MultinomialLogisticRegressionModelWriter(instance: MultinomialLogisticRegressionModel)
-    extends MLWriter with Logging {
-
-    private case class Data(
-        numClasses: Int,
-        numFeatures: Int,
-        intercepts: Vector,
-        coefficients: Matrix)
-
-    override protected def saveImpl(path: String): Unit = {
-      // Save metadata and Params
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
-      // Save model data: numClasses, numFeatures, intercept, coefficients
-      val data = Data(instance.numClasses, instance.numFeatures, instance.intercepts,
-        instance.coefficients)
-      val dataPath = new Path(path, "data").toString
-      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
-    }
-  }
-
-  private class MultinomialLogisticRegressionModelReader
-    extends MLReader[MultinomialLogisticRegressionModel] {
-
-    /** Checked against metadata when loading model */
-    private val className = classOf[MultinomialLogisticRegressionModel].getName
-
-    override def load(path: String): MultinomialLogisticRegressionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
-
-      val dataPath = new Path(path, "data").toString
-      val data = sqlContext.read.format("parquet").load(dataPath)
-        .select("numClasses", "numFeatures", "intercepts", "coefficients").head()
-      val numClasses = data.getAs[Int](data.fieldIndex("numClasses"))
-      val intercepts = data.getAs[Vector](data.fieldIndex("intercepts"))
-      val coefficients = data.getAs[Matrix](data.fieldIndex("coefficients"))
-      val model =
-        new MultinomialLogisticRegressionModel(metadata.uid, coefficients, intercepts, numClasses)
-
-      DefaultParamsReader.getAndSetParams(model, metadata)
-      model
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/26145a5a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 19df8f7..1b6e775 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -201,11 +201,25 @@ abstract class ProbabilisticClassificationModel[
       probability.argmax
     } else {
       val thresholds: Array[Double] = getThresholds
-      val scaledProbability: Array[Double] =
-        probability.toArray.zip(thresholds).map { case (p, t) =>
-          if (t == 0.0) Double.PositiveInfinity else p / t
+      val probabilities = probability.toArray
+      var argMax = 0
+      var max = Double.NegativeInfinity
+      var i = 0
+      val probabilitySize = probability.size
+      while (i < probabilitySize) {
+        if (thresholds(i) == 0.0) {
+          max = Double.PositiveInfinity
+          argMax = i
+        } else {
+          val scaled = probabilities(i) / thresholds(i)
+          if (scaled > max) {
+            max = scaled
+            argMax = i
+          }
         }
-      Vectors.dense(scaledProbability).argmax
+        i += 1
+      }
+      argMax
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/26145a5a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index e4cbf5a..d851b98 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.classification
 
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Since
+import org.apache.spark.ml.linalg.DenseMatrix
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.classification.impl.GLMClassificationModel
 import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors}
@@ -430,8 +431,9 @@ class LogisticRegressionWithLBFGS
         lr.setStandardization(useFeatureScaling)
         if (userSuppliedWeights) {
           val uid = Identifiable.randomUID("logreg-static")
-          lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel(
-            uid, initialWeights.asML, 1.0))
+          lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel(uid,
+            new DenseMatrix(1, initialWeights.size, initialWeights.toArray),
+            Vectors.dense(1.0).asML, 2, false))
         }
         lr.setFitIntercept(addIntercept)
         lr.setMaxIter(optimizer.getNumIterations())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message