commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r1569353 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math3/fitting/leastsquares/ test/java/org/apache/commons/math3/fitting/leastsquares/
Date Tue, 18 Feb 2014 14:32:45 GMT
Author: luc
Date: Tue Feb 18 14:32:44 2014
New Revision: 1569353

URL: http://svn.apache.org/r1569353
Log:
Use Evaluation instead of PointVectorValuePair

Use Evaluation instead of PointVectorValuePair in the ConvergenceChecker. This
gives the checkers access to more information, such as the rms and covariances.
The change also simplified the optimizer implementations since they no longer
have to keep track of the current function value.

A method was added to LeastSquaresFactory to convert between the two types of
checkers and a method added to LeastSquaresBuilder so that it can accept either
type. I would have prefered to do this through method overloading, but
overloading doesn't play well with generics.

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java
    commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java
    commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizer.java
Tue Feb 18 14:32:44 2014
@@ -28,7 +28,6 @@ import org.apache.commons.math3.linear.Q
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.SingularMatrixException;
 import org.apache.commons.math3.optim.ConvergenceChecker;
-import org.apache.commons.math3.optim.PointVectorValuePair;
 import org.apache.commons.math3.util.Incrementor;
 
 /**
@@ -123,7 +122,7 @@ public class GaussNewtonOptimizer implem
         //create local evaluation and iteration counts
         final Incrementor evaluationCounter = lsp.getEvaluationCounter();
         final Incrementor iterationCounter = lsp.getIterationCounter();
-        final ConvergenceChecker<PointVectorValuePair> checker
+        final ConvergenceChecker<Evaluation> checker
                 = lsp.getConvergenceChecker();
 
         // Computation will be useless without a checker (see "for-loop").
@@ -137,25 +136,23 @@ public class GaussNewtonOptimizer implem
         final double[] currentPoint = lsp.getStart();
 
         // iterate until convergence is reached
-        PointVectorValuePair current = null;
+        Evaluation current = null;
         while (true) {
             iterationCounter.incrementCount();
 
             // evaluate the objective function and its jacobian
-            PointVectorValuePair previous = current;
+            Evaluation previous = current;
             // Value of the objective function at "currentPoint".
             evaluationCounter.incrementCount();
-            final Evaluation value = lsp.evaluate(currentPoint);
-            final double[] currentObjective = value.computeValue();
-            final double[] currentResiduals = value.computeResiduals();
-            final RealMatrix weightedJacobian = value.computeJacobian();
-            current = new PointVectorValuePair(currentPoint, currentObjective);
+            current = lsp.evaluate(currentPoint);
+            final double[] currentResiduals = current.computeResiduals();
+            final RealMatrix weightedJacobian = current.computeJacobian();
 
             // Check convergence.
             if (previous != null) {
                 if (checker.converged(iterationCounter.getCount(), previous, current)) {
                     return new OptimumImpl(
-                            value,
+                            current,
                             evaluationCounter.getCount(),
                             iterationCounter.getCount());
                 }

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresAdapter.java
Tue Feb 18 14:32:44 2014
@@ -1,7 +1,6 @@
 package org.apache.commons.math3.fitting.leastsquares;
 
 import org.apache.commons.math3.optim.ConvergenceChecker;
-import org.apache.commons.math3.optim.PointVectorValuePair;
 import org.apache.commons.math3.util.Incrementor;
 
 /**
@@ -54,7 +53,7 @@ public class LeastSquaresAdapter impleme
     }
 
     /** {@inheritDoc} */
-    public ConvergenceChecker<PointVectorValuePair> getConvergenceChecker() {
+    public ConvergenceChecker<Evaluation> getConvergenceChecker() {
         return problem.getConvergenceChecker();
     }
 }

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
Tue Feb 18 14:32:44 2014
@@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting
 
 import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
 import org.apache.commons.math3.analysis.MultivariateVectorFunction;
+import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.optim.ConvergenceChecker;
 import org.apache.commons.math3.optim.PointVectorValuePair;
@@ -19,7 +20,7 @@ public class LeastSquaresBuilder {
     /** max iterations */
     private int maxIterations;
     /** convergence checker */
-    private ConvergenceChecker<PointVectorValuePair> checker;
+    private ConvergenceChecker<Evaluation> checker;
     /** model function */
     private MultivariateVectorFunction model;
     /** Jacobian function */
@@ -69,12 +70,24 @@ public class LeastSquaresBuilder {
      * @param checker the convergence checker.
      * @return this
      */
-    public LeastSquaresBuilder checker(final ConvergenceChecker<PointVectorValuePair>
checker) {
+    public LeastSquaresBuilder checker(final ConvergenceChecker<Evaluation> checker)
{
         this.checker = checker;
         return this;
     }
 
     /**
+     * Configure the convergence checker.
+     * <p/>
+     * This function is an overloaded version of {@link #checker(ConvergenceChecker)}.
+     *
+     * @param checker the convergence checker.
+     * @return this
+     */
+    public LeastSquaresBuilder checkerPair(final ConvergenceChecker<PointVectorValuePair>
checker) {
+        return this.checker(LeastSquaresFactory.evaluationChecker(checker));
+    }
+
+    /**
      * Configure the model function.
      *
      * @param model the model function

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
Tue Feb 18 14:32:44 2014
@@ -2,6 +2,7 @@ package org.apache.commons.math3.fitting
 
 import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
 import org.apache.commons.math3.analysis.MultivariateVectorFunction;
+import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.DiagonalMatrix;
@@ -40,7 +41,7 @@ public class LeastSquaresFactory {
     public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
                                              final double[] observed,
                                              final double[] start,
-                                             final ConvergenceChecker<PointVectorValuePair>
checker,
+                                             final ConvergenceChecker<Evaluation> checker,
                                              final int maxEvaluations,
                                              final int maxIterations) {
         return new LeastSquaresProblemImpl(
@@ -70,7 +71,7 @@ public class LeastSquaresFactory {
                                              final MultivariateMatrixFunction jacobian,
                                              final double[] observed,
                                              final double[] start,
-                                             final ConvergenceChecker<PointVectorValuePair>
checker,
+                                             final ConvergenceChecker<Evaluation> checker,
                                              final int maxEvaluations,
                                              final int maxIterations) {
         return create(
@@ -102,7 +103,7 @@ public class LeastSquaresFactory {
                                              final double[] observed,
                                              final double[] start,
                                              final RealMatrix weight,
-                                             final ConvergenceChecker<PointVectorValuePair>
checker,
+                                             final ConvergenceChecker<Evaluation> checker,
                                              final int maxEvaluations,
                                              final int maxIterations) {
         return weightMatrix(
@@ -175,6 +176,35 @@ public class LeastSquaresFactory {
     }
 
     /**
+     * View a convergence checker specified for a {@link PointVectorValuePair} as one
+     * specified for an {@link Evaluation}.
+     *
+     * @param checker the convergence checker to adapt.
+     * @return a convergence checker that delegates to {@code checker}.
+     */
+    public static ConvergenceChecker<Evaluation> evaluationChecker(
+            final ConvergenceChecker<PointVectorValuePair> checker
+    ) {
+        return new ConvergenceChecker<Evaluation>() {
+            public boolean converged(final int iteration,
+                                     final Evaluation previous,
+                                     final Evaluation current) {
+                return checker.converged(
+                        iteration,
+                        new PointVectorValuePair(
+                                previous.getPoint(),
+                                previous.computeValue(),
+                                false),
+                        new PointVectorValuePair(
+                                current.getPoint(),
+                                current.computeValue(),
+                                false)
+                );
+            }
+        };
+    }
+
+    /**
      * Computes the square-root of the weight matrix.
      *
      * @param m Symmetric, positive-definite (weight) matrix.

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblem.java
Tue Feb 18 14:32:44 2014
@@ -1,8 +1,8 @@
 package org.apache.commons.math3.fitting.leastsquares;
 
 import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
 import org.apache.commons.math3.linear.RealMatrix;
-import org.apache.commons.math3.optim.PointVectorValuePair;
 
 /**
  * The data necessary to define a non-linear least squares problem. Includes the observed
@@ -12,7 +12,7 @@ import org.apache.commons.math3.optim.Po
  *
  * @version $Id$
  */
-public interface LeastSquaresProblem extends OptimizationProblem<PointVectorValuePair>
{
+public interface LeastSquaresProblem extends OptimizationProblem<Evaluation> {
 
     /**
      * Gets the initial guess.

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresProblemImpl.java
Tue Feb 18 14:32:44 2014
@@ -17,11 +17,11 @@
 package org.apache.commons.math3.fitting.leastsquares;
 
 import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.fitting.leastsquares.LeastSquaresProblem.Evaluation;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.commons.math3.optim.AbstractOptimizationProblem;
 import org.apache.commons.math3.optim.ConvergenceChecker;
-import org.apache.commons.math3.optim.PointVectorValuePair;
 import org.apache.commons.math3.util.Pair;
 
 /**
@@ -32,7 +32,7 @@ import org.apache.commons.math3.util.Pai
  * @since 3.3
  */
 class LeastSquaresProblemImpl
-        extends AbstractOptimizationProblem<PointVectorValuePair>
+        extends AbstractOptimizationProblem<Evaluation>
         implements LeastSquaresProblem {
 
     /** Target values for the model function at optimum. */
@@ -45,7 +45,7 @@ class LeastSquaresProblemImpl
     LeastSquaresProblemImpl(final MultivariateJacobianFunction model,
                             final double[] target,
                             final double[] start,
-                            final ConvergenceChecker<PointVectorValuePair> checker,
+                            final ConvergenceChecker<Evaluation> checker,
                             final int maxEvaluations,
                             final int maxIterations) {
         super(maxEvaluations, maxIterations, checker);

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizer.java
Tue Feb 18 14:32:44 2014
@@ -23,7 +23,6 @@ import org.apache.commons.math3.linear.R
 import org.apache.commons.math3.exception.ConvergenceException;
 import org.apache.commons.math3.exception.util.LocalizedFormats;
 import org.apache.commons.math3.optim.ConvergenceChecker;
-import org.apache.commons.math3.optim.PointVectorValuePair;
 import org.apache.commons.math3.util.Incrementor;
 import org.apache.commons.math3.util.Precision;
 import org.apache.commons.math3.util.FastMath;
@@ -303,7 +302,7 @@ public class LevenbergMarquardtOptimizer
         final Incrementor iterationCounter = problem.getIterationCounter();
         final Incrementor evaluationCounter = problem.getEvaluationCounter();
         //convergence criterion
-        final ConvergenceChecker<PointVectorValuePair> checker
+        final ConvergenceChecker<Evaluation> checker
                 = problem.getConvergenceChecker();
 
         // arrays shared with the other private methods
@@ -319,7 +318,6 @@ public class LevenbergMarquardtOptimizer
         double[] diag    = new double[nC];
         double[] oldX    = new double[nC];
         double[] oldRes  = new double[nR];
-        double[] oldObj  = new double[nR];
         double[] qtf     = new double[nR];
         double[] work1   = new double[nC];
         double[] work2   = new double[nC];
@@ -329,23 +327,20 @@ public class LevenbergMarquardtOptimizer
         // Evaluate the function at the starting point and calculate its norm.
         evaluationCounter.incrementCount();
         //value will be reassigned in the loop
-        Evaluation value = problem.evaluate(currentPoint);
-        double[] currentObjective = value.computeValue();
-        double[] currentResiduals = value.computeResiduals();
-        PointVectorValuePair current = new PointVectorValuePair(currentPoint, currentObjective);
-        double currentCost = value.computeCost();
+        Evaluation current = problem.evaluate(currentPoint);
+        double[] currentResiduals = current.computeResiduals();
+        double currentCost = current.computeCost();
 
         // Outer loop.
         boolean firstIteration = true;
         while (true) {
             iterationCounter.incrementCount();
 
-            final PointVectorValuePair previous = current;
-            final Evaluation previousValue = value;
+            final Evaluation previous = current;
 
             // QR decomposition of the jacobian matrix
             final InternalData internalData
-                    = qrDecomposition(value.computeJacobian(), solvedCols);
+                    = qrDecomposition(current.computeJacobian(), solvedCols);
             final double[][] weightedJacobian = internalData.weightedJacobian;
             final int[] permutation = internalData.permutation;
             final double[] diagR = internalData.diagR;
@@ -404,7 +399,7 @@ public class LevenbergMarquardtOptimizer
             if (maxCosine <= orthoTolerance) {
                 // Convergence has been reached.
                 return new OptimumImpl(
-                        value,
+                        current,
                         evaluationCounter.getCount(),
                         iterationCounter.getCount());
             }
@@ -426,9 +421,6 @@ public class LevenbergMarquardtOptimizer
                 double[] tmpVec = weightedResidual;
                 weightedResidual = oldRes;
                 oldRes    = tmpVec;
-                tmpVec    = currentObjective;
-                currentObjective = oldObj;
-                oldObj    = tmpVec;
 
                 // determine the Levenberg-Marquardt parameter
                 lmPar = determineLMParameter(qtf, delta, diag,
@@ -452,11 +444,9 @@ public class LevenbergMarquardtOptimizer
 
                 // Evaluate the function at x + p and calculate its norm.
                 evaluationCounter.incrementCount();
-                value = problem.evaluate(currentPoint);
-                currentObjective = value.computeValue();
-                currentResiduals = value.computeResiduals();
-                current = new PointVectorValuePair(currentPoint, currentObjective);
-                currentCost = value.computeCost();
+                current = problem.evaluate(currentPoint);
+                currentResiduals = current.computeResiduals();
+                currentCost = current.computeCost();
 
                 // compute the scaled actual reduction
                 double actRed = -1.0;
@@ -515,7 +505,7 @@ public class LevenbergMarquardtOptimizer
 
                     // tests for convergence.
                     if (checker != null && checker.converged(iterationCounter.getCount(),
previous, current)) {
-                        return new OptimumImpl(value, iterationCounter.getCount(), evaluationCounter.getCount());
+                        return new OptimumImpl(current, iterationCounter.getCount(), evaluationCounter.getCount());
                     }
                 } else {
                     // failed iteration, reset the previous values
@@ -527,12 +517,8 @@ public class LevenbergMarquardtOptimizer
                     tmpVec    = weightedResidual;
                     weightedResidual = oldRes;
                     oldRes    = tmpVec;
-                    tmpVec    = currentObjective;
-                    currentObjective = oldObj;
-                    oldObj    = tmpVec;
                     // Reset "current" to previous values.
-                    current = new PointVectorValuePair(currentPoint, currentObjective);
-                    value = previousValue;
+                    current = previous;
                 }
 
                 // Default convergence criteria.
@@ -540,7 +526,7 @@ public class LevenbergMarquardtOptimizer
                      preRed <= costRelativeTolerance &&
                      ratio <= 2.0) ||
                     delta <= parRelativeTolerance * xNorm) {
-                    return new OptimumImpl(value, iterationCounter.getCount(), evaluationCounter.getCount());
+                    return new OptimumImpl(current, iterationCounter.getCount(), evaluationCounter.getCount());
                 }
 
                 // tests for termination and stringent tolerances

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/AbstractLeastSquaresOptimizerAbstractTest.java
Tue Feb 18 14:32:44 2014
@@ -46,7 +46,7 @@ public abstract class AbstractLeastSquar
 
     public LeastSquaresBuilder base() {
         return new LeastSquaresBuilder()
-                .checker(new SimpleVectorValueChecker(1e-6, 1e-6))
+                .checkerPair(new SimpleVectorValueChecker(1e-6, 1e-6))
                 .maxEvaluations(100)
                 .maxIterations(getMaxIterations());
     }

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java?rev=1569353&r1=1569352&r2=1569353&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math3/fitting/leastsquares/GaussNewtonOptimizerTest.java
Tue Feb 18 14:32:44 2014
@@ -96,7 +96,7 @@ public class GaussNewtonOptimizerTest
         circle.addPoint( 45.0,  97.0);
 
         LeastSquaresProblem lsp = builder(circle)
-                .checker(new SimpleVectorValueChecker(1e-30, 1e-30))
+                .checkerPair(new SimpleVectorValueChecker(1e-30, 1e-30))
                 .maxIterations(Integer.MAX_VALUE)
                 .start(new double[]{98.680, 47.345})
                 .build();



Mime
View raw message