commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From er...@apache.org
Subject [2/8] git commit: MATH-1144 Allow caller to modify the set of parameters generated by the optimizer.
Date Mon, 03 Nov 2014 10:43:36 GMT
MATH-1144
Allow caller to modify the set of parameters generated by the optimizer.


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/321fd029
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/321fd029
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/321fd029

Branch: refs/heads/master
Commit: 321fd029ec5c9c3c9717f1ede0add49d8a709a01
Parents: f820d06
Author: Gilles <erans@apache.org>
Authored: Mon Oct 13 19:43:26 2014 +0200
Committer: Gilles <erans@apache.org>
Committed: Mon Oct 13 19:43:26 2014 +0200

----------------------------------------------------------------------
 .../leastsquares/LeastSquaresBuilder.java       | 46 ++++++++++++++++-
 .../leastsquares/LeastSquaresFactory.java       | 54 ++++++++++++++------
 .../leastsquares/ValueAndJacobianFunction.java  |  2 +-
 .../fitting/leastsquares/EvaluationTest.java    |  8 +--
 .../LevenbergMarquardtOptimizerTest.java        | 42 +++++++++++++++
 5 files changed, 129 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
index 7d3ccbb..7b14b37 100644
--- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
+++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresBuilder.java
@@ -47,6 +47,17 @@ public class LeastSquaresBuilder {
     private RealVector start;
     /** weight matrix */
     private RealMatrix weight;
+    /**
+     * Lazy evaluation.
+     *
+     * @since 3.4
+     */
+    private boolean lazyEvaluation;
+    /** Validator.
+     *
+     * @since 3.4
+     */
+    private ParameterValidator paramValidator;
 
 
     /**
@@ -55,7 +66,15 @@ public class LeastSquaresBuilder {
      * @return a new {@link LeastSquaresProblem}.
      */
     public LeastSquaresProblem build() {
-        return LeastSquaresFactory.create(model, target, start, weight, checker, maxEvaluations,
maxIterations);
+        return LeastSquaresFactory.create(model,
+                                          target,
+                                          start,
+                                          weight,
+                                          checker,
+                                          maxEvaluations,
+                                          maxIterations,
+                                          lazyEvaluation,
+                                          paramValidator);
     }
 
     /**
@@ -179,4 +198,29 @@ public class LeastSquaresBuilder {
         return this;
     }
 
+    /**
+     * Configure whether evaluation will be lazy or not.
+     *
+     * @param newValue Whether to perform lazy evaluation.
+     * @return this object.
+     *
+     * @since 3.4
+     */
+    public LeastSquaresBuilder lazyEvaluation(final boolean newValue) {
+        lazyEvaluation = newValue;
+        return this;
+    }
+
+    /**
+     * Configure the validator of the model parameters.
+     *
+     * @param newValidator Parameter validator.
+     * @return this object.
+     *
+     * @since 3.4
+     */
+    public LeastSquaresBuilder parameterValidator(final ParameterValidator newValidator)
{
+        paramValidator = newValidator;
+        return this;
+    }
 }

http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
index 917acfc..1a92ac9 100644
--- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
+++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/LeastSquaresFactory.java
@@ -56,22 +56,33 @@ public class LeastSquaresFactory {
      * @param maxIterations  the maximum number to times to iterate in the algorithm
      * @param lazyEvaluation Whether the call to {@link Evaluation#evaluate(RealVector)}
      * will defer the evaluation until access to the value is requested.
+     * @param paramValidator Model parameters validator.
      * @return the specified General Least Squares problem.
+     *
+     * @since 3.4
      */
     public static LeastSquaresProblem create(final MultivariateJacobianFunction model,
                                              final RealVector observed,
                                              final RealVector start,
+                                             final RealMatrix weight,
                                              final ConvergenceChecker<Evaluation> checker,
                                              final int maxEvaluations,
                                              final int maxIterations,
-                                             final boolean lazyEvaluation) {
-        return new LocalLeastSquaresProblem(model,
-                                            observed,
-                                            start,
-                                            checker,
-                                            maxEvaluations,
-                                            maxIterations,
-                                            lazyEvaluation);
+                                             final boolean lazyEvaluation,
+                                             final ParameterValidator paramValidator) {
+        final LeastSquaresProblem p = new LocalLeastSquaresProblem(model,
+                                                                   observed,
+                                                                   start,
+                                                                   checker,
+                                                                   maxEvaluations, 
+                                                                   maxIterations,
+                                                                   lazyEvaluation,
+                                                                   paramValidator);
+        if (weight != null) {
+            return weightMatrix(p, weight);
+        } else {
+            return p;
+        }
     }
 
     /**
@@ -92,13 +103,15 @@ public class LeastSquaresFactory {
                                              final ConvergenceChecker<Evaluation> checker,
                                              final int maxEvaluations,
                                              final int maxIterations) {
-        return new LocalLeastSquaresProblem(model,
-                                            observed,
-                                            start,
-                                            checker,
-                                            maxEvaluations,
-                                            maxIterations,
-                                            false);
+        return create(model,
+                      observed,
+                      start,
+                      null,
+                      checker,
+                      maxEvaluations,
+                      maxIterations,
+                      false,
+                      null);
     }
 
     /**
@@ -345,6 +358,8 @@ public class LeastSquaresFactory {
         private final RealVector start;
         /** Whether to use lazy evaluation. */
         private final boolean lazyEvaluation;
+        /** Model parameters validator. */
+        private final ParameterValidator paramValidator;
 
         /**
          * Create a {@link LeastSquaresProblem} from the given data.
@@ -357,6 +372,7 @@ public class LeastSquaresFactory {
          * @param maxIterations  the allowed iterations
          * @param lazyEvaluation Whether the call to {@link Evaluation#evaluate(RealVector)}
          * will defer the evaluation until access to the value is requested.
+         * @param paramValidator Model parameters validator.
          */
         LocalLeastSquaresProblem(final MultivariateJacobianFunction model,
                                  final RealVector target,
@@ -364,12 +380,14 @@ public class LeastSquaresFactory {
                                  final ConvergenceChecker<Evaluation> checker,
                                  final int maxEvaluations,
                                  final int maxIterations,
-                                 boolean lazyEvaluation) {
+                                 final boolean lazyEvaluation,
+                                 final ParameterValidator paramValidator) {
             super(maxEvaluations, maxIterations, checker);
             this.target = target;
             this.model = model;
             this.start = start;
             this.lazyEvaluation = lazyEvaluation;
+            this.paramValidator = paramValidator;
 
             if (lazyEvaluation &&
                 !(model instanceof ValueAndJacobianFunction)) {
@@ -398,7 +416,9 @@ public class LeastSquaresFactory {
         /** {@inheritDoc} */
         public Evaluation evaluate(final RealVector point) {
             // Copy so optimizer can change point without changing our instance.
-            final RealVector p = point.copy();
+            final RealVector p = paramValidator == null ?
+                point.copy() :
+                paramValidator.validate(point.copy());
 
             if (lazyEvaluation) {
                 return new LazyUnweightedEvaluation((ValueAndJacobianFunction) model,

http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
b/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
index 39e7ae4..180e328 100644
--- a/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
+++ b/src/main/java/org/apache/commons/math3/fitting/leastsquares/ValueAndJacobianFunction.java
@@ -23,7 +23,7 @@ import org.apache.commons.math3.linear.RealVector;
  * A interface for functions that compute a vector of values and can compute their
  * derivatives (Jacobian).
  *
- * @since 3.3
+ * @since 3.4
  */
 public interface ValueAndJacobianFunction extends MultivariateJacobianFunction {
     /**

http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
index 9cfbe0b..a53b3f7 100644
--- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
+++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/EvaluationTest.java
@@ -226,7 +226,7 @@ public class EvaluationTest {
 
         final LeastSquaresProblem p
             = LeastSquaresFactory.create(LeastSquaresFactory.model(dummyModel(), dummyJacobian()),
-                                         dummy, dummy, null, 0, 0, true);
+                                         dummy, dummy, null, null, 0, 0, true, null);
 
         // Should not throw because actual evaluation is deferred.
         final Evaluation eval = p.evaluate(dummy);
@@ -263,7 +263,7 @@ public class EvaluationTest {
 
         try {
             // Should throw.
-            LeastSquaresFactory.create(m1, dummy, dummy, null, 0, 0, true);
+            LeastSquaresFactory.create(m1, dummy, dummy, null, null, 0, 0, true, null);
             Assert.fail("Expecting MathIllegalStateException");
         } catch (MathIllegalStateException e) {
             // Expected.
@@ -282,7 +282,7 @@ public class EvaluationTest {
             };
 
         // Should pass.
-        LeastSquaresFactory.create(m2, dummy, dummy, null, 0, 0, true);
+        LeastSquaresFactory.create(m2, dummy, dummy, null, null, 0, 0, true, null);
     }
 
     @Test
@@ -291,7 +291,7 @@ public class EvaluationTest {
 
         final LeastSquaresProblem p
             = LeastSquaresFactory.create(LeastSquaresFactory.model(dummyModel(), dummyJacobian()),
-                                         dummy, dummy, null, 0, 0, false);
+                                         dummy, dummy, null, null, 0, 0, false, null);
 
         try {
             // Should throw.

http://git-wip-us.apache.org/repos/asf/commons-math/blob/321fd029/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
----------------------------------------------------------------------
diff --git a/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
b/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
index b2c8f54..46658db 100644
--- a/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
+++ b/src/test/java/org/apache/commons/math3/fitting/leastsquares/LevenbergMarquardtOptimizerTest.java
@@ -268,6 +268,48 @@ public class LevenbergMarquardtOptimizerTest
     }
 
     @Test
+    public void testParameterValidator() {
+        // Setup.
+        final double xCenter = 123.456;
+        final double yCenter = 654.321;
+        final double xSigma = 10;
+        final double ySigma = 15;
+        final double radius = 111.111;
+        final long seed = 3456789L;
+        final RandomCirclePointGenerator factory
+            = new RandomCirclePointGenerator(xCenter, yCenter, radius,
+                                             xSigma, ySigma,
+                                             seed);
+        final CircleProblem circle = new CircleProblem(xSigma, ySigma);
+
+        final int numPoints = 10;
+        for (Vector2D p : factory.generate(numPoints)) {
+            circle.addPoint(p.getX(), p.getY());
+        }
+
+        // First guess for the center's coordinates and radius.
+        final double[] init = { 90, 659, 115 };
+        final Optimum optimum
+            = optimizer.optimize(builder(circle).maxIterations(50).start(init).build());
+        final int numEval = optimum.getEvaluations();
+        Assert.assertTrue(numEval > 1);
+
+        // Build a new problem with an validator that amounts to cheating.
+        final ParameterValidator cheatValidator
+            = new ParameterValidator() {
+                    public RealVector validate(RealVector params) {
+                        // Cheat: return the optimum found previously.
+                        return optimum.getPoint();
+                    }
+                };
+
+        final Optimum cheatOptimum
+            = optimizer.optimize(builder(circle).maxIterations(50).start(init).parameterValidator(cheatValidator).build());
+        final int cheatNumEval = cheatOptimum.getEvaluations();
+        Assert.assertTrue(cheatNumEval < numEval);
+    }
+
+    @Test
     public void testEvaluationCount() {
         //setup
         LeastSquaresProblem lsp = new LinearProblem(new double[][] {{1}}, new double[] {1})


Mime
View raw message