commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From celes...@apache.org
Subject svn commit: r1209942 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
Date Sat, 03 Dec 2011 16:24:55 GMT
Author: celestin
Date: Sat Dec  3 16:24:55 2011
New Revision: 1209942

URL: http://svn.apache.org/viewvc?rev=1209942&view=rev
Log:
New implementation of AbstractRealDistribution.inverseCumulativeProbability(double). Solves
MATH-699, and leads to slightly smaller execution times.

Added:
    commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
  (with props)
Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java?rev=1209942&r1=1209941&r2=1209942&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/distribution/AbstractRealDistribution.java
Sat Dec  3 16:24:55 2011
@@ -20,7 +20,6 @@ import java.io.Serializable;
 
 import org.apache.commons.math.analysis.UnivariateFunction;
 import org.apache.commons.math.analysis.solvers.UnivariateRealSolverUtils;
-import org.apache.commons.math.exception.MathInternalError;
 import org.apache.commons.math.exception.NotStrictlyPositiveException;
 import org.apache.commons.math.exception.NumberIsTooLargeException;
 import org.apache.commons.math.exception.OutOfRangeException;
@@ -69,50 +68,80 @@ implements RealDistribution, Serializabl
 
     /** {@inheritDoc} */
     public double inverseCumulativeProbability(final double p) throws OutOfRangeException
{
-
         if (p < 0.0 || p > 1.0) {
             throw new OutOfRangeException(p, 0, 1);
         }
 
-        // by default, do simple root finding using bracketing and default solver.
-        // subclasses can override if there is a better method.
-        UnivariateFunction rootFindingFunction =
-            new UnivariateFunction() {
-            public double value(double x) {
+        double lowerBound = getSupportLowerBound();
+        if (p == 0.0) {
+            return lowerBound;
+        }
+
+        double upperBound = getSupportUpperBound();
+        if (p == 1.0) {
+            return upperBound;
+        }
+
+        final double mu = getNumericalMean();
+        final double sig = FastMath.sqrt(getNumericalVariance());
+        final boolean chebyshevApplies;
+        chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) ||
+                             Double.isInfinite(sig) || Double.isNaN(sig));
+
+        if (lowerBound == Double.NEGATIVE_INFINITY) {
+            if (chebyshevApplies) {
+                lowerBound = mu - sig * FastMath.sqrt((1. - p) / p);
+            } else {
+                lowerBound = -1.0;
+                while (cumulativeProbability(lowerBound) >= p) {
+                    lowerBound *= 2.0;
+                }
+            }
+        }
+
+        if (upperBound == Double.POSITIVE_INFINITY) {
+            if (chebyshevApplies) {
+                upperBound = mu + sig * FastMath.sqrt(p / (1. - p));
+            } else {
+                upperBound = 1.0;
+                while (cumulativeProbability(upperBound) < p) {
+                    upperBound *= 2.0;
+                }
+            }
+        }
+
+        final UnivariateFunction toSolve = new UnivariateFunction() {
+
+            public double value(final double x) {
                 return cumulativeProbability(x) - p;
             }
         };
 
-        // Try to bracket root, test domain endpoints if this fails
-        double lowerBound = getDomainLowerBound(p);
-        double upperBound = getDomainUpperBound(p);
-        double[] bracket = null;
-        try {
-            bracket = UnivariateRealSolverUtils.bracket(
-                    rootFindingFunction, getInitialDomain(p),
-                    lowerBound, upperBound);
-        } catch (NumberIsTooLargeException ex) {
-            /*
-             * Check domain endpoints to see if one gives value that is within
-             * the default solver's defaultAbsoluteAccuracy of 0 (will be the
-             * case if density has bounded support and p is 0 or 1).
-             */
-            if (FastMath.abs(rootFindingFunction.value(lowerBound)) < getSolverAbsoluteAccuracy())
{
-                return lowerBound;
-            }
-            if (FastMath.abs(rootFindingFunction.value(upperBound)) < getSolverAbsoluteAccuracy())
{
-                return upperBound;
+        double x = UnivariateRealSolverUtils.solve(toSolve,
+                                                   lowerBound,
+                                                   upperBound,
+                                                   getSolverAbsoluteAccuracy());
+
+        if (!isSupportConnected()) {
+            /* Test for plateau. */
+            final double dx = getSolverAbsoluteAccuracy();
+            if (x - dx >= getSupportLowerBound()) {
+                double px = cumulativeProbability(x);
+                if (cumulativeProbability(x - dx) == px) {
+                    upperBound = x;
+                    while (upperBound - lowerBound > dx) {
+                        final double midPoint = 0.5 * (lowerBound + upperBound);
+                        if (cumulativeProbability(midPoint) < px) {
+                            lowerBound = midPoint;
+                        } else {
+                            upperBound = midPoint;
+                        }
+                    }
+                    return upperBound;
+                }
             }
-            // Failed bracket convergence was not because of corner solution
-            throw new MathInternalError(ex);
         }
-
-        // find root
-        double root = UnivariateRealSolverUtils.solve(rootFindingFunction,
-                // override getSolverAbsoluteAccuracy() to use a Brent solver with
-                // absolute accuracy different from the default.
-                bracket[0],bracket[1], getSolverAbsoluteAccuracy());
-        return root;
+        return x;
     }
 
     /**

Added: commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java?rev=1209942&view=auto
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
(added)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
Sat Dec  3 16:24:55 2011
@@ -0,0 +1,237 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.commons.math.distribution;
+
+import org.apache.commons.math.analysis.UnivariateFunction;
+import org.apache.commons.math.analysis.integration.RombergIntegrator;
+import org.apache.commons.math.analysis.integration.UnivariateRealIntegrator;
+import org.apache.commons.math.exception.OutOfRangeException;
+import org.junit.Assert;
+import org.junit.Test;
+
+/** Various tests related to MATH-699. */
+public class AbstractRealDistributionTest {
+
+    @Test
+    public void testContinuous() {
+        final double x0 = 0.0;
+        final double x1 = 1.0;
+        final double x2 = 2.0;
+        final double x3 = 3.0;
+        final double p12 = 0.5;
+        final AbstractRealDistribution distribution;
+        distribution = new AbstractRealDistribution() {
+
+            public double cumulativeProbability(final double x) {
+                if ((x < x0) || (x > x3)) {
+                    throw new OutOfRangeException(x, x0, x3);
+                }
+                if (x <= x1) {
+                    return p12 * (x - x0) / (x1 - x0);
+                } else if (x <= x2) {
+                    return p12;
+                } else if (x <= x3) {
+                    return p12 + (1.0 - p12) * (x - x2) / (x3 - x2);
+                }
+                return 0.0;
+            }
+
+            public double density(final double x) {
+                if ((x < x0) || (x > x3)) {
+                    throw new OutOfRangeException(x, x0, x3);
+                }
+                if (x <= x1) {
+                    return p12 / (x1 - x0);
+                } else if (x <= x2) {
+                    return 0.0;
+                } else if (x <= x3) {
+                    return (1.0 - p12) / (x3 - x2);
+                }
+                return 0.0;
+            }
+
+            @Override
+            protected double getDomainLowerBound(final double p) {
+                throw new UnsupportedOperationException();
+            }
+
+            @Override
+            protected double getDomainUpperBound(final double p) {
+                throw new UnsupportedOperationException();
+            }
+
+            @Override
+            protected double getInitialDomain(final double p) {
+                throw new UnsupportedOperationException();
+            }
+
+            public double getNumericalMean() {
+                return ((x0 + x1) * p12 + (x2 + x3) * (1.0 - p12)) / 2.0;
+            }
+
+            public double getNumericalVariance() {
+                final double meanX = getNumericalMean();
+                final double meanX2;
+                meanX2 = ((x0 * x0 + x0 * x1 + x1 * x1) * p12 + (x2 * x2 + x2
+                        * x3 + x3 * x3)
+                        * (1.0 - p12)) / 3.0;
+                return meanX2 - meanX * meanX;
+            }
+
+            public double getSupportLowerBound() {
+                return x0;
+            }
+
+            public double getSupportUpperBound() {
+                return x3;
+            }
+
+            public boolean isSupportConnected() {
+                return false;
+            }
+
+            public boolean isSupportLowerBoundInclusive() {
+                return true;
+            }
+
+            public boolean isSupportUpperBoundInclusive() {
+                return true;
+            }
+
+            public double probability(final double x) {
+                throw new UnsupportedOperationException();
+            }
+        };
+        final double expected = x1;
+        final double actual = distribution.inverseCumulativeProbability(p12);
+        Assert.assertEquals("", expected, actual,
+                distribution.getSolverAbsoluteAccuracy());
+    }
+
+    @Test
+    public void testDiscontinuous() {
+        final double x0 = 0.0;
+        final double x1 = 0.25;
+        final double x2 = 0.5;
+        final double x3 = 0.75;
+        final double x4 = 1.0;
+        final double p12 = 1.0 / 3.0;
+        final double p23 = 2.0 / 3.0;
+        final AbstractRealDistribution distribution;
+        distribution = new AbstractRealDistribution() {
+
+            public double cumulativeProbability(final double x) {
+                if ((x < x0) || (x > x4)) {
+                    throw new OutOfRangeException(x, x0, x4);
+                }
+                if (x <= x1) {
+                    return p12 * (x - x0) / (x1 - x0);
+                } else if (x <= x2) {
+                    return p12;
+                } else if (x <= x3) {
+                    return p23;
+                } else {
+                    return (1.0 - p23) * (x - x3) / (x4 - x3) + p23;
+                }
+            }
+
+            public double density(final double x) {
+                if ((x < x0) || (x > x4)) {
+                    throw new OutOfRangeException(x, x0, x4);
+                }
+                if (x <= x1) {
+                    return p12 / (x1 - x0);
+                } else if (x <= x2) {
+                    return 0.0;
+                } else if (x <= x3) {
+                    return 0.0;
+                } else {
+                    return (1.0 - p23) / (x4 - x3);
+                }
+            }
+
+            @Override
+            protected double getDomainLowerBound(final double p) {
+                throw new UnsupportedOperationException();
+            }
+
+            @Override
+            protected double getDomainUpperBound(final double p) {
+                throw new UnsupportedOperationException();
+            }
+
+            @Override
+            protected double getInitialDomain(final double p) {
+                throw new UnsupportedOperationException();
+            }
+
+            public double getNumericalMean() {
+                final UnivariateFunction f = new UnivariateFunction() {
+
+                    public double value(final double x) {
+                        return x * density(x);
+                    }
+                };
+                final UnivariateRealIntegrator integrator = new RombergIntegrator();
+                return integrator.integrate(Integer.MAX_VALUE, f, x0, x4);
+            }
+
+            public double getNumericalVariance() {
+                final double meanX = getNumericalMean();
+                final UnivariateFunction f = new UnivariateFunction() {
+
+                    public double value(final double x) {
+                        return x * x * density(x);
+                    }
+                };
+                final UnivariateRealIntegrator integrator = new RombergIntegrator();
+                final double meanX2 = integrator.integrate(Integer.MAX_VALUE,
+                        f, x0, x4);
+                return meanX2 - meanX * meanX;
+            }
+
+            public double getSupportLowerBound() {
+                return x0;
+            }
+
+            public double getSupportUpperBound() {
+                return x4;
+            }
+
+            public boolean isSupportConnected() {
+                return false;
+            }
+
+            public boolean isSupportLowerBoundInclusive() {
+                return true;
+            }
+
+            public boolean isSupportUpperBoundInclusive() {
+                return true;
+            }
+
+            public double probability(final double x) {
+                throw new UnsupportedOperationException();
+            }
+        };
+        final double expected = x2;
+        final double actual = distribution.inverseCumulativeProbability(p23);
+        Assert.assertEquals("", expected, actual,
+                distribution.getSolverAbsoluteAccuracy());
+
+    }
+}

Propchange: commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
------------------------------------------------------------------------------
    svn:eol-style = native

Propchange: commons/proper/math/trunk/src/test/java/org/apache/commons/math/distribution/AbstractRealDistributionTest.java
------------------------------------------------------------------------------
    svn:keywords = Author Date Id Revision



Mime
View raw message