commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From er...@apache.org
Subject svn commit: r1073554 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
Date Tue, 22 Feb 2011 23:50:46 GMT
Author: erans
Date: Tue Feb 22 23:50:46 2011
New Revision: 1073554

URL: http://svn.apache.org/viewvc?rev=1073554&view=rev
Log:
MATH-519
Workaround exception generated when the optimizer tries invalid values for
the "sigma" parameter.
Added a method to allow the user to pass his own initial guess.

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
    commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java?rev=1073554&r1=1073553&r2=1073554&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
Tue Feb 22 23:50:46 2011
@@ -21,10 +21,12 @@ import java.util.Arrays;
 import java.util.Comparator;
 
 import org.apache.commons.math.analysis.function.Gaussian;
+import org.apache.commons.math.analysis.ParametricUnivariateRealFunction;
 import org.apache.commons.math.exception.NullArgumentException;
 import org.apache.commons.math.exception.NumberIsTooSmallException;
 import org.apache.commons.math.exception.OutOfRangeException;
 import org.apache.commons.math.exception.ZeroException;
+import org.apache.commons.math.exception.NotStrictlyPositiveException;
 import org.apache.commons.math.exception.util.LocalizedFormats;
 import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer;
 import org.apache.commons.math.optimization.fitting.CurveFitter;
@@ -57,29 +59,66 @@ import org.apache.commons.math.optimizat
  * @version $Revision$ $Date$
  */
 public class GaussianFitter extends CurveFitter {
-
     /**
      * Constructs an instance using the specified optimizer.
      *
-     * @param optimizer optimizer to use for the fitting
+     * @param optimizer Optimizer to use for the fitting.
      */
     public GaussianFitter(DifferentiableMultivariateVectorialOptimizer optimizer) {
-        super(optimizer);;
+        super(optimizer);
     }
 
+    /**
+     * Fits a Gaussian function to the observed points.
+     *
+     * @param initialGuess First guess values in the following order:
+     * <ul>
+     *  <li>Norm</li>
+     *  <li>Mean</li>
+     *  <li>Sigma</li>
+     * </ul>
+     * @return the parameters of the Gaussian function that best fits the
+     * observed points (in the same order as above).
+     */
+    public double[] fit(double[] initialGuess) {
+        final ParametricUnivariateRealFunction f = new ParametricUnivariateRealFunction()
{
+                private final ParametricUnivariateRealFunction g = new Gaussian.Parametric();
+
+                public double value(double x, double[] p) {
+                    double v = Double.POSITIVE_INFINITY;
+                    try {
+                        v = g.value(x, p);
+                    } catch (NotStrictlyPositiveException e) {
+                        // Do nothing.
+                    }
+                    return v;
+                }
+
+                public double[] gradient(double x, double[] p) {
+                    double[] v = { Double.POSITIVE_INFINITY,
+                                   Double.POSITIVE_INFINITY,
+                                   Double.POSITIVE_INFINITY };
+                    try {
+                        v = g.gradient(x, p);
+                    } catch (NotStrictlyPositiveException e) {
+                        // Do nothing.
+                    }
+                    return v;
+                }
+            };
+
+        return fit(f, initialGuess);
+    }
 
     /**
-     * Fits Gaussian function to the observed points.
-     * It will call the base class
-     * {@link CurveFitter#fit(
-     * org.apache.commons.math.analysis.ParametricUnivariateRealFunction,
-     * double[]) fit} method.
+     * Fits a Gaussian function to the observed points.
      *
-     * @return the Gaussian function that best fits the observed points.
+     * @return the parameters of the Gaussian function that best fits the
+     * observed points (in the same order as above).
      */
     public double[] fit() {
-        return fit(new Gaussian.Parametric(),
-                   (new ParameterGuesser(getObservations())).guess());
+        final double[] guess = (new ParameterGuesser(getObservations())).guess();
+        return fit(guess);
     }
 
     /**
@@ -90,7 +129,6 @@ public class GaussianFitter extends Curv
     public static class ParameterGuesser {
         /** Observed points. */
         private final WeightedObservedPoint[] observations;
-
         /** Resulting guessed parameters. */
         private double[] parameters;
 
@@ -112,7 +150,7 @@ public class GaussianFitter extends Curv
         /**
          * Guesses the parameters based on the observed points.
          *
-         * @return guessed parameters array <code>{norm, mean, sigma}</code>
+         * @return the guessed parameters: norm, mean and sigma.
          */
         public double[] guess() {
             if (parameters == null) {
@@ -124,15 +162,13 @@ public class GaussianFitter extends Curv
         /**
          * Guesses the parameters based on the specified observed points.
          *
-         * @param points observed points upon which should base guess
-         *
-         * @return guessed parameters array <code>{norm, mean, sigma}</code>
+         * @param points Observed points upon which should base guess.
+         * @return the guessed parameters: norm, mean and sigma.
          */
         private double[] basicGuess(WeightedObservedPoint[] points) {
             Arrays.sort(points, createWeightedObservedPointComparator());
             double[] params = new double[3];
 
-
             int maxYIdx = findMaxY(points);
             params[0] = points[maxYIdx].getY();
             params[1] = points[maxYIdx].getX();
@@ -154,9 +190,8 @@ public class GaussianFitter extends Curv
         /**
          * Finds index of point in specified points with the largest Y.
          *
-         * @param points points to search
-         *
-         * @return index in specified points array
+         * @param points Points to search.
+         * @return the index in specified points array.
          */
         private int findMaxY(WeightedObservedPoint[] points) {
             int maxYIdx = 0;
@@ -169,20 +204,18 @@ public class GaussianFitter extends Curv
         }
 
         /**
-         * Interpolates using the specified points to determine X at the specified
-         * Y.
-         *
-         * @param points points to use for interpolation
-         * @param startIdx index within points from which to start search for
-         *        interpolation bounds points
-         * @param idxStep index step for search for interpolation bounds points
-         * @param y Y value for which X should be determined
+         * Interpolates using the specified points to determine X at the
+         * specified Y.
          *
-         * @return value of X at the specified Y
-         *
-         * @throws IllegalArgumentException if idxStep is 0
-         * @throws OutOfRangeException if specified <code>y</code> is not within
the
-         *         range of the specified <code>points</code>
+         * @param points Points to use for interpolation.
+         * @param startIdx Index within points from which to start search for
+         *  interpolation bounds points.
+         * @param idxStep Index step for search for interpolation bounds points.
+         * @param y Y value for which X should be determined.
+         * @return the value of X at the specified Y.
+         * @throws ZeroException if {@code idxStep} is 0.
+         * @throws OutOfRangeException if specified {@code y} is not within the
+         * range of the specified {@code points}.
          */
         private double interpolateXAtY(WeightedObservedPoint[] points,
                                        int startIdx, int idxStep, double y)
@@ -208,18 +241,16 @@ public class GaussianFitter extends Curv
          * Gets the two bounding interpolation points from the specified points
          * suitable for determining X at the specified Y.
          *
-         * @param points points to use for interpolation
-         * @param startIdx index within points from which to start search for
-         *        interpolation bounds points
-         * @param idxStep index step for search for interpolation bounds points
-         * @param y Y value for which X should be determined
-         *
-         * @return array containing two points suitable for determining X at the
-         *         specified Y
-         *
-         * @throws IllegalArgumentException if idxStep is 0
-         * @throws OutOfRangeException if specified <code>y</code> is not within
the
-         *         range of the specified <code>points</code>
+         * @param points Points to use for interpolation.
+         * @param startIdx Index within points from which to start search for
+         * interpolation bounds points.
+         * @param idxStep Index step for search for interpolation bounds points.
+         * @param y Y value for which X should be determined.
+         * @return the array containing two points suitable for determining X at
+         * the specified Y.
+         * @throws ZeroException if {@code idxStep} is 0.
+         * @throws OutOfRangeException if specified {@code y} is not within the
+         * range of the specified {@code points}.
          */
         private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[]
points,
                                                                    int startIdx, int idxStep,
double y)
@@ -244,19 +275,17 @@ public class GaussianFitter extends Curv
                 maxY = Math.max(maxY, point.getY());
             }
             throw new OutOfRangeException(y, minY, maxY);
-
         }
 
         /**
          * Determines whether a value is between two other values.
          *
-         * @param value value to determine whether is between <code>boundary1</code>
-         *        and <code>boundary2</code>
-         * @param boundary1 one end of the range
-         * @param boundary2 other end of the range
-         *
-         * @return true if <code>value</code> is between <code>boundary1</code>
and
-         *         <code>boundary2</code> (inclusive); false otherwise
+         * @param value Value to determine whether is between {@code boundary1}
+         * and {@code boundary2}.
+         * @param boundary1 One end of the range.
+         * @param boundary2 Other end of the range.
+         * @return {@code true} if {@code value} is between {@code boundary1} and
+         * {@code boundary2} (inclusive), {@code false} otherwise.
          */
         private boolean isBetween(double value, double boundary1, double boundary2) {
             return (value >= boundary1 && value <= boundary2) ||
@@ -264,10 +293,10 @@ public class GaussianFitter extends Curv
         }
 
         /**
-         * Factory method creating <code>Comparator</code> for comparing
-         * <code>WeightedObservedPoint</code> instances.
+         * Factory method creating {@code Comparator} for comparing
+         * {@code WeightedObservedPoint} instances.
          *
-         * @return new <code>Comparator</code> instance
+         * @return the new {@code Comparator} instance.
          */
         private Comparator<WeightedObservedPoint> createWeightedObservedPointComparator()
{
             return new Comparator<WeightedObservedPoint>() {

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java?rev=1073554&r1=1073553&r2=1073554&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
Tue Feb 22 23:50:46 2011
@@ -17,12 +17,11 @@
 
 package org.apache.commons.math.optimization.fitting;
 
-import static org.junit.Assert.assertEquals;
-
 import org.apache.commons.math.exception.MathIllegalArgumentException;
 import org.apache.commons.math.optimization.OptimizationException;
 import org.apache.commons.math.optimization.general.LevenbergMarquardtOptimizer;
 
+import org.junit.Assert;
 import org.junit.Test;
 
 /**
@@ -193,9 +192,9 @@ public class GaussianFitterTest {
         addDatasetToGaussianFitter(DATASET1, fitter);
         double[] parameters = fitter.fit();
 
-        assertEquals(3496978.1837704973, parameters[0], 1e-4);
-        assertEquals(4.054933085999146, parameters[1], 1e-4);
-        assertEquals(0.015039355620304326, parameters[2], 1e-4);
+        Assert.assertEquals(3496978.1837704973, parameters[0], 1e-4);
+        Assert.assertEquals(4.054933085999146, parameters[1], 1e-4);
+        Assert.assertEquals(0.015039355620304326, parameters[2], 1e-4);
     }
 
     /**
@@ -239,9 +238,9 @@ public class GaussianFitterTest {
         addDatasetToGaussianFitter(DATASET2, fitter);
         double[] parameters = fitter.fit();
 
-        assertEquals(233003.2967252038, parameters[0], 1e-4);
-        assertEquals(-10.654887521095983, parameters[1], 1e-4);
-        assertEquals(4.335937353196641, parameters[2], 1e-4);
+        Assert.assertEquals(233003.2967252038, parameters[0], 1e-4);
+        Assert.assertEquals(-10.654887521095983, parameters[1], 1e-4);
+        Assert.assertEquals(4.335937353196641, parameters[2], 1e-4);
     }  
     
     /**
@@ -256,9 +255,9 @@ public class GaussianFitterTest {
         addDatasetToGaussianFitter(DATASET3, fitter);
         double[] parameters = fitter.fit();
 
-        assertEquals(283863.81929180305, parameters[0], 1e-4);
-        assertEquals(-13.29641995105174, parameters[1], 1e-4);
-        assertEquals(1.7297330293549908, parameters[2], 1e-4);
+        Assert.assertEquals(283863.81929180305, parameters[0], 1e-4);
+        Assert.assertEquals(-13.29641995105174, parameters[1], 1e-4);
+        Assert.assertEquals(1.7297330293549908, parameters[2], 1e-4);
     }
     
     /**
@@ -273,9 +272,9 @@ public class GaussianFitterTest {
         addDatasetToGaussianFitter(DATASET4, fitter);
         double[] parameters = fitter.fit();
 
-        assertEquals(285250.66754309234, parameters[0], 1e-4);
-        assertEquals(-13.528375695228455, parameters[1], 1e-4);
-        assertEquals(1.5204344894331614, parameters[2], 1e-4);
+        Assert.assertEquals(285250.66754309234, parameters[0], 1e-4);
+        Assert.assertEquals(-13.528375695228455, parameters[1], 1e-4);
+        Assert.assertEquals(1.5204344894331614, parameters[2], 1e-4);
     }    
 
     /**
@@ -285,15 +284,59 @@ public class GaussianFitterTest {
      */
     @Test
     public void testFit07()
-    throws OptimizationException 
-    {
+    throws OptimizationException {
         GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
         addDatasetToGaussianFitter(DATASET5, fitter);
         double[] parameters = fitter.fit();
 
-        assertEquals(3514384.729342235, parameters[0], 1e-4);
-        assertEquals(4.054970307455625, parameters[1], 1e-4);
-        assertEquals(0.015029412832160017, parameters[2], 1e-4);
+        Assert.assertEquals(3514384.729342235, parameters[0], 1e-4);
+        Assert.assertEquals(4.054970307455625, parameters[1], 1e-4);
+        Assert.assertEquals(0.015029412832160017, parameters[2], 1e-4);
+    }
+
+    @Test
+    public void testMath519() {
+        // The optimizer will try negative sigma values but "GaussianFitter"
+        // will catch the raised exceptions and return NaN values instead.
+
+        final double[] data = { 
+            1.1143831578403364E-29,
+            4.95281403484594E-28,
+            1.1171347211930288E-26,
+            1.7044813962636277E-25,
+            1.9784716574832164E-24,
+            1.8630236407866774E-23,
+            1.4820532905097742E-22,
+            1.0241963854632831E-21,
+            6.275077366673128E-21,
+            3.461808994532493E-20,
+            1.7407124684715706E-19,
+            8.056687953553974E-19,
+            3.460193945992071E-18,
+            1.3883326374011525E-17,
+            5.233894983671116E-17,
+            1.8630791465263745E-16,
+            6.288759227922111E-16,
+            2.0204433920597856E-15,
+            6.198768938576155E-15,
+            1.821419346860626E-14,
+            5.139176445538471E-14,
+            1.3956427429045787E-13,
+            3.655705706448139E-13,
+            9.253753324779779E-13,
+            2.267636001476696E-12,
+            5.3880460095836855E-12,
+            1.2431632654852931E-11
+        };
+
+        GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
+        for (int i = 0; i < data.length; i++) {
+            fitter.addObservedPoint(i, data[i]);
+        }
+        final double[] p = fitter.fit();
+
+        Assert.assertEquals(53.1572792, p[1], 1e-7);
+        Assert.assertEquals(5.75214622, p[2], 1e-8);
     }
     
     /**



Mime
View raw message