commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From celes...@apache.org
Subject svn commit: r1339014 - /commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java
Date Wed, 16 May 2012 05:36:40 GMT
Author: celestin
Date: Wed May 16 05:36:40 2012
New Revision: 1339014

URL: http://svn.apache.org/viewvc?rev=1339014&view=rev
Log:
Unit tests for GammaDistribution, based on reference data generated with
Maxima. Solves MATH-753.

Modified:
    commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java?rev=1339014&r1=1339013&r2=1339014&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/GammaDistributionTest.java
Wed May 16 05:36:40 2012
@@ -17,7 +17,15 @@
 
 package org.apache.commons.math3.distribution;
 
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+
 import org.apache.commons.math3.exception.NotStrictlyPositiveException;
+import org.apache.commons.math3.special.Gamma;
+import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
+import org.apache.commons.math3.util.FastMath;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -167,4 +175,155 @@ public class GammaDistributionTest exten
         Assert.assertEquals(dist.getNumericalMean(), 1.1d * 4.2d, tol);
         Assert.assertEquals(dist.getNumericalVariance(), 1.1d * 4.2d * 4.2d, tol);
     }
+
+    public static double density(final double x, final double shape,
+                                 final double scale) {
+        /*
+         * This is a copy of
+         * double GammaDistribution.density(double)
+         * prior to r1338548.
+         */
+        if (x < 0) {
+            return 0;
+        }
+        return FastMath.pow(x / scale, shape - 1) / scale *
+               FastMath.exp(-x / scale) / FastMath.exp(Gamma.logGamma(shape));
+    }
+
+    /*
+     * MATH-753: large values of x or shape parameter cause density(double) to
+     * overflow. Reference data is generated with the Maxima script
+     * gamma-distribution.mac, which can be found in
+     * src/test/resources/org/apache/commons/math3/distribution.
+     */
+
+    private void doTestMath753(final double shape,
+        final double meanNoOF, final double sdNoOF,
+        final double meanOF, final double sdOF,
+        final String resourceName) throws IOException {
+        final GammaDistribution distribution = new GammaDistribution(shape, 1.0);
+        final SummaryStatistics statOld = new SummaryStatistics();
+        final SummaryStatistics statNewNoOF = new SummaryStatistics();
+        final SummaryStatistics statNewOF = new SummaryStatistics();
+
+        final InputStream resourceAsStream;
+        resourceAsStream = this.getClass().getResourceAsStream(resourceName);
+        Assert.assertNotNull("Could not find resource " + resourceName,
+                             resourceAsStream);
+        final BufferedReader in;
+        in = new BufferedReader(new InputStreamReader(resourceAsStream));
+
+        try {
+            for (String line = in.readLine(); line != null; line = in
+                .readLine()) {
+                final String[] tokens = line.split(", ");
+                Assert.assertTrue("expected two floating-point values",
+                                  tokens.length == 2);
+                final double x = Double.parseDouble(tokens[0]);
+                final String msg = "x = " + x + ", shape = " + shape +
+                                   ", scale = 1.0";
+                final double expected = Double.parseDouble(tokens[1]);
+                final double ulp = FastMath.ulp(expected);
+                final double actualOld = density(x, shape, 1.0);
+                final double actualNew = distribution.density(x);
+                final double errOld, errNew;
+                errOld = FastMath.abs((actualOld - expected) / ulp);
+                errNew = FastMath.abs((actualNew - expected) / ulp);
+
+                if (Double.isNaN(actualOld) || Double.isInfinite(actualOld)) {
+                    Assert.assertFalse(msg, Double.isNaN(actualNew));
+                    Assert.assertFalse(msg, Double.isInfinite(actualNew));
+                    statNewOF.addValue(errNew);
+                } else {
+                    statOld.addValue(errOld);
+                    statNewNoOF.addValue(errNew);
+                }
+            }
+            if (statOld.getN() != 0) {
+                /*
+                 * If no overflow occurs, check that new implementation is
+                 * better than old one.
+                 */
+                final StringBuilder sb = new StringBuilder("shape = ");
+                sb.append(shape);
+                sb.append(", scale = 1.0\n");
+                sb.append("Old implementation\n");
+                sb.append("------------------\n");
+                sb.append(statOld.toString());
+                sb.append("New implementation\n");
+                sb.append("------------------\n");
+                sb.append(statNewNoOF.toString());
+                final String msg = sb.toString();
+
+                final double oldMin = statOld.getMin();
+                final double newMin = statNewNoOF.getMin();
+                Assert.assertTrue(msg, newMin <= oldMin);
+
+                final double oldMax = statOld.getMax();
+                final double newMax = statNewNoOF.getMax();
+                Assert.assertTrue(msg, newMax <= oldMax);
+
+                final double oldMean = statOld.getMean();
+                final double newMean = statNewNoOF.getMean();
+                Assert.assertTrue(msg, newMean <= oldMean);
+
+                final double oldSd = statOld.getStandardDeviation();
+                final double newSd = statNewNoOF.getStandardDeviation();
+                Assert.assertTrue(msg, newSd <= oldSd);
+
+                Assert.assertTrue(msg, newMean <= meanNoOF);
+                Assert.assertTrue(msg, newSd <= sdNoOF);
+            }
+            if (statNewOF.getN() != 0) {
+                final double newMean = statNewOF.getMean();
+                final double newSd = statNewOF.getStandardDeviation();
+
+                final StringBuilder sb = new StringBuilder("shape = ");
+                sb.append(shape);
+                sb.append(", scale = 1.0");
+                sb.append(", max. mean error (ulps) = ");
+                sb.append(meanOF);
+                sb.append(", actual mean error (ulps) = ");
+                sb.append(newMean);
+                sb.append(", max. sd of error (ulps) = ");
+                sb.append(sdOF);
+                sb.append(", actual sd of error (ulps) = ");
+                sb.append(newSd);
+                final String msg = sb.toString();
+
+                Assert.assertTrue(msg, newMean <= meanOF);
+                Assert.assertTrue(msg, newSd <= sdOF);
+            }
+        } catch (IOException e) {
+            Assert.fail(e.getMessage());
+        } finally {
+            in.close();
+        }
+    }
+
+
+    @Test
+    public void testMath753Shape1() throws IOException {
+        doTestMath753(1.0, 1.5, 0.5, 0.0, 0.0, "gamma-distribution-shape-1.csv");
+    }
+
+    @Test
+    public void testMath753Shape10() throws IOException {
+        doTestMath753(10.0, 1.0, 1.0, 0.0, 0.0, "gamma-distribution-shape-10.csv");
+    }
+
+    @Test
+    public void testMath753Shape100() throws IOException {
+        doTestMath753(100.0, 1.5, 1.0, 0.0, 0.0, "gamma-distribution-shape-100.csv");
+    }
+
+    @Test
+    public void testMath753Shape142() throws IOException {
+        doTestMath753(142.0, 0.5, 1.5, 40.0, 40.0, "gamma-distribution-shape-142.csv");
+    }
+
+    @Test
+    public void testMath753Shape1000() throws IOException {
+        doTestMath753(1000.0, 1.0, 1.0, 160.0, 220.0, "gamma-distribution-shape-1000.csv");
+    }
 }



Mime
View raw message