commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From celes...@apache.org
Subject svn commit: r1179488 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math/linear/ConjugateGradient.java test/java/org/apache/commons/math/linear/ConjugateGradientTest.java
Date Thu, 06 Oct 2011 02:14:21 GMT
Author: celestin
Date: Thu Oct  6 02:14:20 2011
New Revision: 1179488

URL: http://svn.apache.org/viewvc?rev=1179488&view=rev
Log:
Modifications to the ConjugateGradient class and unit tests
  - altered the way iterations are counted: Incrementor is incremented prior to any modification
to the current state, so that the solver is in a consistent state (accessible residual corresponds
to the last estimate of the solution), even in case of MaxCountExceededException occuring.
  - modified some tests which were not testing anything.

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java
    commons/proper/math/trunk/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java?rev=1179488&r1=1179487&r2=1179488&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/linear/ConjugateGradient.java
Thu Oct  6 02:14:20 2011
@@ -174,13 +174,15 @@ public class ConjugateGradient
         manager.resetIterationCount();
         final double r2max = delta * delta * b.dotProduct(b);
 
+        // Initialization phase counts as one iteration.
+        manager.incrementIterationCount();
         // p and x are constructed as copies of x0, since presumably, the type
         // of x is optimized for the calculation of the matrix-vector product
         // A.x.
         final RealVector x = x0;
         final RealVector p = x.copy();
         RealVector q = a.operate(p);
-        manager.incrementIterationCount();
+
         final RealVector r = b.combine(1, -1, q);
         double r2 = r.dotProduct(r);
         RealVector z;
@@ -213,6 +215,7 @@ public class ConjugateGradient
         }
         double rhoPrev = 0.;
         while (true) {
+            manager.incrementIterationCount();
             manager.fireIterationStartedEvent(event);
             if (m != null) {
                 z = m.solve(r);
@@ -226,13 +229,12 @@ public class ConjugateGradient
                 context.setValue(VECTOR, r);
                 throw e;
             }
-            if (manager.getIterations() == 1) {
+            if (manager.getIterations() == 2) {
                 p.setSubVector(0, z);
             } else {
                 p.combineToSelf(rhoNext / rhoPrev, 1., z);
             }
             q = a.operate(p);
-            manager.incrementIterationCount();
             final double pq = p.dotProduct(q);
             if (check && (pq <= 0.)) {
                 final NonPositiveDefiniteOperatorException e;

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java?rev=1179488&r1=1179487&r2=1179488&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/linear/ConjugateGradientTest.java
Thu Oct  6 02:14:20 2011
@@ -153,7 +153,7 @@ public class ConjugateGradientTest {
      * due to the loss of orthogonality of the successive search directions.
      * Therefore, in the present test, the number of iterations is limited.
      */
-    @Test(expected = MaxCountExceededException.class)
+    @Test
     public void testUnpreconditionedResidual() {
         final int n = 10;
         final int maxIterations = n;
@@ -161,10 +161,11 @@ public class ConjugateGradientTest {
         final ConjugateGradient solver;
         solver = new ConjugateGradient(maxIterations, 1E-15, true);
         final RealVector r = new ArrayRealVector(n);
+        final RealVector x = new ArrayRealVector(n);
         final IterationListener listener = new IterationListener() {
 
             public void terminationPerformed(final IterationEvent e) {
-                r.setSubVector(0, ((ProvidesResidual) e).getResidual());
+                // Do nothing
             }
 
             public void iterationStarted(final IterationEvent e) {
@@ -172,7 +173,10 @@ public class ConjugateGradientTest {
             }
 
             public void iterationPerformed(final IterationEvent e) {
-                // Do nothing
+                RealVector v = ((ProvidesResidual)e).getResidual();
+                r.setSubVector(0, v);
+                v = ((IterativeLinearSolverEvent) e).getSolution();
+                x.setSubVector(0, v);
             }
 
             public void initializationPerformed(final IterationEvent e) {
@@ -180,22 +184,29 @@ public class ConjugateGradientTest {
             }
         };
         solver.getIterationManager().addIterationListener(listener);
-
         final RealVector b = new ArrayRealVector(n);
         for (int j = 0; j < n; j++) {
             b.set(0.);
             b.setEntry(j, 1.);
 
-            final RealVector x = solver.solve(a, b);
-            final RealVector y = a.operate(x);
-            for (int i = 0; i < n; i++) {
-                final double actual = b.getEntry(i) - y.getEntry(i);
-                final double expected = r.getEntry(i);
-                final double delta = 1E-6 * Math.abs(expected);
-                final String msg = String
-                    .format("column %d, residual %d", i, j);
-                Assert.assertEquals(msg, expected, actual, delta);
+            boolean caught = false;
+            try {
+                solver.solve(a, b);
+            } catch (MaxCountExceededException e) {
+                caught = true;
+                final RealVector y = a.operate(x);
+                for (int i = 0; i < n; i++) {
+                    final double actual = b.getEntry(i) - y.getEntry(i);
+                    final double expected = r.getEntry(i);
+                    final double delta = 1E-6 * Math.abs(expected);
+                    final String msg = String
+                        .format("column %d, residual %d", i, j);
+                    Assert.assertEquals(msg, expected, actual, delta);
+                }
             }
+            Assert
+                .assertTrue("MaxCountExceededException should have been caught",
+                            caught);
         }
     }
 
@@ -331,7 +342,7 @@ public class ConjugateGradientTest {
         }
     }
 
-    @Test(expected = MaxCountExceededException.class)
+    @Test
     public void testPreconditionedResidual() {
         final int n = 10;
         final int maxIterations = n;
@@ -340,10 +351,11 @@ public class ConjugateGradientTest {
         final ConjugateGradient solver;
         solver = new ConjugateGradient(maxIterations, 1E-15, true);
         final RealVector r = new ArrayRealVector(n);
+        final RealVector x = new ArrayRealVector(n);
         final IterationListener listener = new IterationListener() {
 
             public void terminationPerformed(final IterationEvent e) {
-                r.setSubVector(0, ((ProvidesResidual) e).getResidual());
+                // Do nothing
             }
 
             public void iterationStarted(final IterationEvent e) {
@@ -351,7 +363,10 @@ public class ConjugateGradientTest {
             }
 
             public void iterationPerformed(final IterationEvent e) {
-                // Do nothing
+                RealVector v = ((ProvidesResidual)e).getResidual();
+                r.setSubVector(0, v);
+                v = ((IterativeLinearSolverEvent) e).getSolution();
+                x.setSubVector(0, v);
             }
 
             public void initializationPerformed(final IterationEvent e) {
@@ -364,20 +379,25 @@ public class ConjugateGradientTest {
         for (int j = 0; j < n; j++) {
             b.set(0.);
             b.setEntry(j, 1.);
-            final RealVector x = solver.solve(a, m, b);
-            final RealVector y = a.operate(x);
-            double rnorm = 0.;
-            for (int i = 0; i < n; i++) {
-                final double actual = b.getEntry(i) - y.getEntry(i);
-                final double expected = r.getEntry(i);
-                final double delta = 1E-6 * Math.abs(expected);
-                final String msg = String
-                    .format("column %d, residual %d", i, j);
-                Assert.assertEquals(msg, expected, actual, delta);
+
+            boolean caught = false;
+            try {
+                solver.solve(a, m, b);
+            } catch (MaxCountExceededException e) {
+                caught = true;
+                final RealVector y = a.operate(x);
+                for (int i = 0; i < n; i++) {
+                    final double actual = b.getEntry(i) - y.getEntry(i);
+                    final double expected = r.getEntry(i);
+                    final double delta = 1E-6 * Math.abs(expected);
+                    final String msg = String
+                        .format("column %d, residual %d", i, j);
+                    Assert.assertEquals(msg, expected, actual, delta);
+                }
             }
-            rnorm = r.getNorm();
-            Assert.assertEquals("norm of residual", rnorm, r.getNorm(),
-                                1E-6 * Math.abs(rnorm));
+            Assert
+                .assertTrue("MaxCountExceededException should have been caught",
+                            caught);
         }
     }
 



Mime
View raw message