commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r1401838 - /commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java
Date Wed, 24 Oct 2012 19:40:08 GMT
Author: luc
Date: Wed Oct 24 19:40:08 2012
New Revision: 1401838

URL: http://svn.apache.org/viewvc?rev=1401838&view=rev
Log:
Use the new differentation API for all optimizers.

The older API is still supported as of version 3.1, but is implemented
by wrapping the user function into the new API and then calling the new
code.

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java?rev=1401838&r1=1401837&r2=1401838&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optimization/general/AbstractLeastSquaresOptimizer.java
Wed Oct 24 19:40:08 2012
@@ -17,16 +17,17 @@
 
 package org.apache.commons.math3.optimization.general;
 
-import org.apache.commons.math3.exception.NumberIsTooSmallException;
-import org.apache.commons.math3.exception.DimensionMismatchException;
 import org.apache.commons.math3.analysis.DifferentiableMultivariateVectorFunction;
-import org.apache.commons.math3.analysis.MultivariateMatrixFunction;
+import org.apache.commons.math3.analysis.FunctionUtils;
+import org.apache.commons.math3.analysis.differentiation.DerivativeStructure;
 import org.apache.commons.math3.analysis.differentiation.JacobianFunction;
 import org.apache.commons.math3.analysis.differentiation.MultivariateDifferentiableVectorFunction;
+import org.apache.commons.math3.exception.DimensionMismatchException;
+import org.apache.commons.math3.exception.NumberIsTooSmallException;
 import org.apache.commons.math3.exception.util.LocalizedFormats;
-import org.apache.commons.math3.linear.QRDecomposition;
 import org.apache.commons.math3.linear.DecompositionSolver;
 import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.QRDecomposition;
 import org.apache.commons.math3.optimization.ConvergenceChecker;
 import org.apache.commons.math3.optimization.DifferentiableMultivariateVectorOptimizer;
 import org.apache.commons.math3.optimization.PointVectorValuePair;
@@ -75,7 +76,7 @@ public abstract class AbstractLeastSquar
     /** Cost value (square root of the sum of the residuals). */
     protected double cost;
     /** Objective function derivatives. */
-    private MultivariateMatrixFunction jF;
+    private MultivariateDifferentiableVectorFunction jF;
     /** Number of evaluations of the Jacobian. */
     private int jacobianEvaluations;
 
@@ -110,9 +111,22 @@ public abstract class AbstractLeastSquar
      */
     protected void updateJacobian() {
         ++jacobianEvaluations;
-        weightedResidualJacobian = jF.value(point);
-        if (weightedResidualJacobian.length != rows) {
-            throw new DimensionMismatchException(weightedResidualJacobian.length, rows);
+
+        DerivativeStructure[] dsPoint = new DerivativeStructure[point.length];
+        for (int i = 0; i < point.length; ++i) {
+            dsPoint[i] = new DerivativeStructure(point.length, 1, i, point[i]);
+        }
+        DerivativeStructure[] dsValue = jF.value(dsPoint);
+        if (dsValue.length != rows) {
+            throw new DimensionMismatchException(dsValue.length, rows);
+        }
+        for (int i = 0; i < rows; ++i) {
+            int[] orders = new int[point.length];
+            for (int j = 0; j < point.length; ++j) {
+                orders[j] = 1;
+                weightedResidualJacobian[i][j] = dsValue[i].getPartialDerivative(orders);
+                orders[j] = 0;
+            }
         }
 
         final double[] residualsWeights = getWeightRef();
@@ -303,23 +317,8 @@ public abstract class AbstractLeastSquar
                                          final DifferentiableMultivariateVectorFunction f,
                                          final double[] target, final double[] weights,
                                          final double[] startPoint) {
-        // Reset counter.
-        jacobianEvaluations = 0;
-
-        // Store least squares problem characteristics.
-        jF = f.jacobian();
-
-        // Arrays shared with the other private methods.
-        point = startPoint.clone();
-        rows = target.length;
-        cols = point.length;
-
-        weightedResidualJacobian = new double[rows][cols];
-        this.weightedResiduals = new double[rows];
-
-        cost = Double.POSITIVE_INFINITY;
-
-        return optimizeInternal(maxEval, f, target, weights, startPoint);
+        return optimize(maxEval, FunctionUtils.toMultivariateDifferentiableVectorFunction(f),
+                        target, weights, startPoint);
     }
 
     /**
@@ -351,7 +350,7 @@ public abstract class AbstractLeastSquar
         jacobianEvaluations = 0;
 
         // Store least squares problem characteristics.
-        jF = new JacobianFunction(f);
+        jF = f;
 
         // Arrays shared with the other private methods.
         point = startPoint.clone();



Mime
View raw message