spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-17748][ML] Minor cleanups to one-pass linear regression with elastic net
Date Tue, 08 Nov 2016 20:58:46 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.1 876eee2b1 -> 21bbf94b4


[SPARK-17748][ML] Minor cleanups to one-pass linear regression with elastic net

## What changes were proposed in this pull request?

* Made SingularMatrixException private ml
* WeightedLeastSquares: Changed to allow tol >= 0 instead of only tol > 0

## How was this patch tested?

existing tests

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #15779 from jkbradley/wls-cleanups.

(cherry picked from commit 26e1c53aceee37e3687a372ff6c6f05463fd8a94)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/21bbf94b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/21bbf94b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/21bbf94b

Branch: refs/heads/branch-2.1
Commit: 21bbf94b41fbd193e370a3820131e449aaf0e3db
Parents: 876eee2
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Tue Nov 8 12:58:29 2016 -0800
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Tue Nov 8 12:58:41 2016 -0800

----------------------------------------------------------------------
 .../spark/ml/optim/NormalEquationSolver.scala   |  9 ++++----
 .../spark/ml/optim/WeightedLeastSquares.scala   |  4 ++--
 .../spark/ml/regression/LinearRegression.scala  | 22 ++++++++++++++------
 3 files changed, 23 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/21bbf94b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala
index 2f5299b..96fd0d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/NormalEquationSolver.scala
@@ -16,9 +16,10 @@
  */
 package org.apache.spark.ml.optim
 
+import scala.collection.mutable
+
 import breeze.linalg.{DenseVector => BDV}
 import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN
=> BreezeOWLQN}
-import scala.collection.mutable
 
 import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vectors}
 import org.apache.spark.mllib.linalg.CholeskyDecomposition
@@ -57,7 +58,7 @@ private[ml] sealed trait NormalEquationSolver {
  */
 private[ml] class CholeskySolver extends NormalEquationSolver {
 
-  def solve(
+  override def solve(
       bBar: Double,
       bbBar: Double,
       abBar: DenseVector,
@@ -80,7 +81,7 @@ private[ml] class QuasiNewtonSolver(
     tol: Double,
     l1RegFunc: Option[(Int) => Double]) extends NormalEquationSolver {
 
-  def solve(
+  override def solve(
       bBar: Double,
       bbBar: Double,
       abBar: DenseVector,
@@ -156,7 +157,7 @@ private[ml] class QuasiNewtonSolver(
  * Exception thrown when solving a linear system Ax = b for which the matrix A is non-invertible
  * (singular).
  */
-class SingularMatrixException(message: String, cause: Throwable)
+private[spark] class SingularMatrixException(message: String, cause: Throwable)
   extends IllegalArgumentException(message, cause) {
 
   def this(message: String) = this(message, null)

http://git-wip-us.apache.org/repos/asf/spark/blob/21bbf94b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
index 90c24e1..56ab967 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/optim/WeightedLeastSquares.scala
@@ -47,7 +47,7 @@ private[ml] class WeightedLeastSquaresModel(
  * formulation:
  *
  * min,,x,z,, 1/2 sum,,i,, w,,i,, (a,,i,,^T^ x + z - b,,i,,)^2^ / sum,,i,, w,,i,,
- *   + lambda / delta (1/2 (1 - alpha) sumj,, (sigma,,j,, x,,j,,)^2^
+ *   + lambda / delta (1/2 (1 - alpha) sum,,j,, (sigma,,j,, x,,j,,)^2^
  *   + alpha sum,,j,, abs(sigma,,j,, x,,j,,)),
  *
  * where lambda is the regularization parameter, alpha is the ElasticNet mixing parameter,
@@ -91,7 +91,7 @@ private[ml] class WeightedLeastSquares(
   require(elasticNetParam >= 0.0 && elasticNetParam <= 1.0,
     s"elasticNetParam must be in [0, 1]: $elasticNetParam")
   require(maxIter >= 0, s"maxIter must be a positive integer: $maxIter")
-  require(tol > 0, s"tol must be greater than zero: $tol")
+  require(tol >= 0.0, s"tol must be >= 0, but was set to $tol")
 
   /**
    * Creates a [[WeightedLeastSquaresModel]] from an RDD of [[Instance]]s.

http://git-wip-us.apache.org/repos/asf/spark/blob/21bbf94b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index ae876b3..9639b07 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -31,7 +31,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.linalg.{Vector, Vectors}
 import org.apache.spark.ml.linalg.BLAS._
-import org.apache.spark.ml.optim.{NormalEquationSolver, WeightedLeastSquares}
+import org.apache.spark.ml.optim.WeightedLeastSquares
 import org.apache.spark.ml.PredictorParams
 import org.apache.spark.ml.param.ParamMap
 import org.apache.spark.ml.param.shared._
@@ -160,11 +160,13 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val
uid: String
   /**
    * Set the solver algorithm used for optimization.
    * In case of linear regression, this can be "l-bfgs", "normal" and "auto".
-   * "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
-   * optimization method. "normal" denotes using Normal Equation as an analytical
-   * solution to the linear regression problem.
-   * The default value is "auto" which means that the solver algorithm is
-   * selected automatically.
+   *  - "l-bfgs" denotes Limited-memory BFGS which is a limited-memory quasi-Newton
+   *    optimization method.
+   *  - "normal" denotes using Normal Equation as an analytical solution to the linear regression
+   *    problem.  This solver is limited to [[LinearRegression.MAX_FEATURES_FOR_NORMAL_SOLVER]].
+   *  - "auto" (default) means that the solver algorithm is selected automatically.
+   *    The Normal Equations solver will be used when possible, but this will automatically
fall
+   *    back to iterative optimization methods when needed.
    *
    * @group setParam
    */
@@ -404,6 +406,14 @@ object LinearRegression extends DefaultParamsReadable[LinearRegression]
{
 
   @Since("1.6.0")
   override def load(path: String): LinearRegression = super.load(path)
+
+  /**
+   * When using [[LinearRegression.solver]] == "normal", the solver must limit the number
of
+   * features to at most this number.  The entire covariance matrix X^T^X will be collected
+   * to the driver. This limit helps prevent memory overflow errors.
+   */
+  @Since("2.1.0")
+  val MAX_FEATURES_FOR_NORMAL_SOLVER: Int = WeightedLeastSquares.MAX_NUM_FEATURES
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message