Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 2FEC7119C1 for ; Fri, 20 Jun 2014 15:52:37 +0000 (UTC) Received: (qmail 22789 invoked by uid 500); 20 Jun 2014 15:52:37 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 22759 invoked by uid 500); 20 Jun 2014 15:52:37 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@spark.apache.org Delivered-To: mailing list commits@spark.apache.org Received: (qmail 22750 invoked by uid 99); 20 Jun 2014 15:52:37 -0000 Received: from tyr.zones.apache.org (HELO tyr.zones.apache.org) (140.211.11.114) by apache.org (qpsmtpd/0.29) with ESMTP; Fri, 20 Jun 2014 15:52:37 +0000 Received: by tyr.zones.apache.org (Postfix, from userid 65534) id CE17F987467; Fri, 20 Jun 2014 15:52:36 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: meng@apache.org To: commits@spark.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: git commit: [SPARK-2163] class LBFGS optimize with Double tolerance instead of Int Date: Fri, 20 Jun 2014 15:52:36 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master 2f6a835e1 -> d484ddeff [SPARK-2163] class LBFGS optimize with Double tolerance instead of Int https://issues.apache.org/jira/browse/SPARK-2163 This pull request includes the change for **[SPARK-2163]**: * Changed the convergence tolerance parameter from type `Int` to type `Double`. * Added types for vars in `class LBFGS`, making the style consistent with `class GradientDescent`. * Added associated test to check that optimizing via `class LBFGS` produces the same results as via calling `runLBFGS` from `object LBFGS`. This is a very minor change but it will solve the problem in my implementation of a regression model for count data, where I make use of LBFGS for parameter estimation. Author: Gang Bai Closes #1104 from BaiGang/fix_int_tol and squashes the following commits: cecf02c [Gang Bai] Changed setConvergenceTol'' to specify tolerance with a parameter of type Double. For the reason and the problem caused by an Int parameter, please check https://issues.apache.org/jira/browse/SPARK-2163. Added a test in LBFGSSuite for validating that optimizing via class LBFGS produces the same results as calling runLBFGS from object LBFGS. Keep the indentations and styles correct. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d484ddef Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d484ddef Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d484ddef Branch: refs/heads/master Commit: d484ddeff1440d8e14e05c3cd7e7a18746f1a586 Parents: 2f6a835 Author: Gang Bai Authored: Fri Jun 20 08:52:20 2014 -0700 Committer: Xiangrui Meng Committed: Fri Jun 20 08:52:20 2014 -0700 ---------------------------------------------------------------------- .../apache/spark/mllib/optimization/LBFGS.scala | 2 +- .../spark/mllib/optimization/LBFGSSuite.scala | 34 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d484ddef/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index 8f187c9..7bbed9c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -60,7 +60,7 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater) * Set the convergence tolerance of iterations for L-BFGS. Default 1E-4. * Smaller value will lead to higher accuracy with the cost of more iterations. */ - def setConvergenceTol(tolerance: Int): this.type = { + def setConvergenceTol(tolerance: Double): this.type = { this.convergenceTol = tolerance this } http://git-wip-us.apache.org/repos/asf/spark/blob/d484ddef/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 4b18506..fe7a903 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -195,4 +195,38 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { assert(lossLBFGS3.length == 6) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } + + test("Optimize via class LBFGS.") { + val regParam = 0.2 + + // Prepare another non-zero weights to compare the loss in the first iteration. + val initialWeightsWithIntercept = Vectors.dense(0.3, 0.12) + val convergenceTol = 1e-12 + val maxNumIterations = 10 + + val lbfgsOptimizer = new LBFGS(gradient, squaredL2Updater) + .setNumCorrections(numCorrections) + .setConvergenceTol(convergenceTol) + .setMaxNumIterations(maxNumIterations) + .setRegParam(regParam) + + val weightLBFGS = lbfgsOptimizer.optimize(dataRDD, initialWeightsWithIntercept) + + val numGDIterations = 50 + val stepSize = 1.0 + val (weightGD, _) = GradientDescent.runMiniBatchSGD( + dataRDD, + gradient, + squaredL2Updater, + stepSize, + numGDIterations, + regParam, + miniBatchFrac, + initialWeightsWithIntercept) + + // for class LBFGS and the optimize method, we only look at the weights + assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) && + compareDouble(weightLBFGS(1), weightGD(1), 0.02), + "The weight differences between LBFGS and GD should be within 2%.") + } }