commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From er...@apache.org
Subject svn commit: r1072056 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java site/xdoc/changes.xml test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
Date Fri, 18 Feb 2011 16:45:57 GMT
Author: erans
Date: Fri Feb 18 16:45:57 2011
New Revision: 1072056

URL: http://svn.apache.org/viewvc?rev=1072056&view=rev
Log:
MATH-512
Refactoring of the "GaussianFitter" class by Ole Ersoy.

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
    commons/proper/math/trunk/src/site/xdoc/changes.xml
    commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java?rev=1072056&r1=1072055&r2=1072056&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/optimization/fitting/GaussianFitter.java
Fri Feb 18 16:45:57 2011
@@ -17,12 +17,22 @@
 
 package org.apache.commons.math.optimization.fitting;
 
+import java.util.Arrays;
+import java.util.Comparator;
+
+import org.apache.commons.math.analysis.function.Gaussian;
+import org.apache.commons.math.exception.NullArgumentException;
+import org.apache.commons.math.exception.NumberIsTooSmallException;
+import org.apache.commons.math.exception.OutOfRangeException;
+import org.apache.commons.math.exception.ZeroException;
+import org.apache.commons.math.exception.util.LocalizedFormats;
 import org.apache.commons.math.optimization.DifferentiableMultivariateVectorialOptimizer;
 import org.apache.commons.math.optimization.fitting.CurveFitter;
 import org.apache.commons.math.optimization.fitting.WeightedObservedPoint;
 
 /**
- * Fits points to a Gaussian function (that is, a {@link GaussianFunction}).
+ * Fits points to a {@link
+ * org.apache.commons.math.analysis.function.Gaussian.Parametric Gaussian} function.
  * <p>
  * Usage example:
  * <pre>
@@ -40,16 +50,13 @@ import org.apache.commons.math.optimizat
  *   fitter.addObservedPoint(4.07525716, 1447024.0);
  *   fitter.addObservedPoint(4.08237071, 717104.0);
  *   fitter.addObservedPoint(4.08366408, 620014.0);
- *   GaussianFunction fitFunction = fitter.fit();
+ *   double[] parameters = fitter.fit();
  * </pre>
  *
- * @see ParametricGaussianFunction
  * @since 2.2
  * @version $Revision$ $Date$
  */
-public class GaussianFitter {
-    /** Fitter used for fitting. */
-    private final CurveFitter fitter;
+public class GaussianFitter extends CurveFitter {
 
     /**
      * Constructs an instance using the specified optimizer.
@@ -57,54 +64,244 @@ public class GaussianFitter {
      * @param optimizer optimizer to use for the fitting
      */
     public GaussianFitter(DifferentiableMultivariateVectorialOptimizer optimizer) {
-        fitter = new CurveFitter(optimizer);
+        super(optimizer);;
     }
 
-    /**
-     * Adds point ({@code x}, {@code y}) to list of observed points
-     * with a weight of 1.
-     *
-     * @param x Abscissa value.
-     * @param y Ordinate value.
-     */
-    public void addObservedPoint(double x, double y) {
-        addObservedPoint(1, x, y);
-    }
-
-    /**
-     * Adds point ({@code x}, {@code y}) to list of observed points
-     * with a weight of {@code weight}.
-     *
-     * @param weight Weight assigned to the given point.
-     * @param x Abscissa value.
-     * @param y Ordinate value.
-     */
-    public void addObservedPoint(double weight, double x, double y) {
-        fitter.addObservedPoint(weight, x, y);
-    }
 
     /**
      * Fits Gaussian function to the observed points.
      * It will call the base class
-     * {@link CurveFitter#fit(ParametricUnivariateRealFunction,double[]) fit} method.
+     * {@link CurveFitter#fit(
+     * org.apache.commons.math.analysis.ParametricUnivariateRealFunction,
+     * double[]) fit} method.
      *
      * @return the Gaussian function that best fits the observed points.
-     * @see CurveFitter
      */
-    public GaussianFunction fit() {
-        return new GaussianFunction(fitter.fit(new ParametricGaussianFunction(),
-                                               createParametersGuesser(fitter.getObservations()).guess()));
+    public double[] fit() {
+        return fit(new Gaussian.Parametric(),
+                   (new ParameterGuesser(getObservations())).guess());
     }
 
     /**
-     * Factory method to create a {@code GaussianParametersGuesser}
-     * instance initialized with the specified observations.
-     *
-     * @param observations points used to initialize the created
-     * {@code GaussianParametersGuesser} instance.
-     * @return a new {@code GaussianParametersGuesser} instance.
+     * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
+     * of a {@link org.apache.commons.math.analysis.function.Gaussian.Parametric}
+     * based on the specified observed points.
      */
-    protected GaussianParametersGuesser createParametersGuesser(WeightedObservedPoint[] observations)
{
-        return new GaussianParametersGuesser(observations);
+    public static class ParameterGuesser {
+        /** Observed points. */
+        private final WeightedObservedPoint[] observations;
+
+        /** Resulting guessed parameters. */
+        private double[] parameters;
+
+        /**
+         * Constructs instance with the specified observed points.
+         *
+         * @param observations observed points upon which should base guess
+         */
+        public ParameterGuesser(WeightedObservedPoint[] observations) {
+            if (observations == null) {
+                throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
+            }
+            if (observations.length < 3) {
+                throw new NumberIsTooSmallException(observations.length, 3, true);
+            }
+            this.observations = observations.clone();
+        }
+
+        /**
+         * Guesses the parameters based on the observed points.
+         *
+         * @return guessed parameters array <code>{norm, mean, sigma}</code>
+         */
+        public double[] guess() {
+            if (parameters == null) {
+                parameters = basicGuess(observations);
+            }
+            return parameters.clone();
+        }
+
+        /**
+         * Guesses the parameters based on the specified observed points.
+         *
+         * @param points observed points upon which should base guess
+         *
+         * @return guessed parameters array <code>{norm, mean, sigma}</code>
+         */
+        private double[] basicGuess(WeightedObservedPoint[] points) {
+            Arrays.sort(points, createWeightedObservedPointComparator());
+            double[] params = new double[3];
+
+
+            int maxYIdx = findMaxY(points);
+            params[0] = points[maxYIdx].getY();
+            params[1] = points[maxYIdx].getX();
+
+            double fwhmApprox;
+            try {
+                double halfY = params[0] + ((params[1] - params[0]) / 2.0);
+                double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
+                double fwhmX2 = interpolateXAtY(points, maxYIdx, +1, halfY);
+                fwhmApprox = fwhmX2 - fwhmX1;
+            } catch (OutOfRangeException e) {
+                fwhmApprox = points[points.length - 1].getX() - points[0].getX();
+            }
+            params[2] = fwhmApprox / (2.0 * Math.sqrt(2.0 * Math.log(2.0)));
+
+            return params;
+        }
+
+        /**
+         * Finds index of point in specified points with the largest Y.
+         *
+         * @param points points to search
+         *
+         * @return index in specified points array
+         */
+        private int findMaxY(WeightedObservedPoint[] points) {
+            int maxYIdx = 0;
+            for (int i = 1; i < points.length; i++) {
+                if (points[i].getY() > points[maxYIdx].getY()) {
+                    maxYIdx = i;
+                }
+            }
+            return maxYIdx;
+        }
+
+        /**
+         * Interpolates using the specified points to determine X at the specified
+         * Y.
+         *
+         * @param points points to use for interpolation
+         * @param startIdx index within points from which to start search for
+         *        interpolation bounds points
+         * @param idxStep index step for search for interpolation bounds points
+         * @param y Y value for which X should be determined
+         *
+         * @return value of X at the specified Y
+         *
+         * @throws IllegalArgumentException if idxStep is 0
+         * @throws OutOfRangeException if specified <code>y</code> is not within
the
+         *         range of the specified <code>points</code>
+         */
+        private double interpolateXAtY(WeightedObservedPoint[] points,
+                                       int startIdx, int idxStep, double y)
+            throws OutOfRangeException {
+            if (idxStep == 0) {
+                throw new ZeroException();
+            }
+            WeightedObservedPoint[] twoPoints = getInterpolationPointsForY(points, startIdx,
idxStep, y);
+            WeightedObservedPoint pointA = twoPoints[0];
+            WeightedObservedPoint pointB = twoPoints[1];
+            if (pointA.getY() == y) {
+                return pointA.getX();
+            }
+            if (pointB.getY() == y) {
+                return pointB.getX();
+            }
+            return pointA.getX() +
+                   (((y - pointA.getY()) * (pointB.getX() - pointA.getX())) /
+                    (pointB.getY() - pointA.getY()));
+        }
+
+        /**
+         * Gets the two bounding interpolation points from the specified points
+         * suitable for determining X at the specified Y.
+         *
+         * @param points points to use for interpolation
+         * @param startIdx index within points from which to start search for
+         *        interpolation bounds points
+         * @param idxStep index step for search for interpolation bounds points
+         * @param y Y value for which X should be determined
+         *
+         * @return array containing two points suitable for determining X at the
+         *         specified Y
+         *
+         * @throws IllegalArgumentException if idxStep is 0
+         * @throws OutOfRangeException if specified <code>y</code> is not within
the
+         *         range of the specified <code>points</code>
+         */
+        private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[]
points,
+                                                                   int startIdx, int idxStep,
double y)
+            throws OutOfRangeException {
+            if (idxStep == 0) {
+                throw new ZeroException();
+            }
+            for (int i = startIdx;
+                 (idxStep < 0) ? (i + idxStep >= 0) : (i + idxStep < points.length);
+                 i += idxStep) {
+                if (isBetween(y, points[i].getY(), points[i + idxStep].getY())) {
+                    return (idxStep < 0) ?
+                           new WeightedObservedPoint[] { points[i + idxStep], points[i] }
:
+                           new WeightedObservedPoint[] { points[i], points[i + idxStep] };
+                }
+            }
+
+            double minY = Double.POSITIVE_INFINITY;
+            double maxY = Double.NEGATIVE_INFINITY;
+            for (final WeightedObservedPoint point : points) {
+                minY = Math.min(minY, point.getY());
+                maxY = Math.max(maxY, point.getY());
+            }
+            throw new OutOfRangeException(y, minY, maxY);
+
+        }
+
+        /**
+         * Determines whether a value is between two other values.
+         *
+         * @param value value to determine whether is between <code>boundary1</code>
+         *        and <code>boundary2</code>
+         * @param boundary1 one end of the range
+         * @param boundary2 other end of the range
+         *
+         * @return true if <code>value</code> is between <code>boundary1</code>
and
+         *         <code>boundary2</code> (inclusive); false otherwise
+         */
+        private boolean isBetween(double value, double boundary1, double boundary2) {
+            return (value >= boundary1 && value <= boundary2) ||
+                   (value >= boundary2 && value <= boundary1);
+        }
+
+        /**
+         * Factory method creating <code>Comparator</code> for comparing
+         * <code>WeightedObservedPoint</code> instances.
+         *
+         * @return new <code>Comparator</code> instance
+         */
+        private Comparator<WeightedObservedPoint> createWeightedObservedPointComparator()
{
+            return new Comparator<WeightedObservedPoint>() {
+                public int compare(WeightedObservedPoint p1, WeightedObservedPoint p2) {
+                    if (p1 == null && p2 == null) {
+                        return 0;
+                    }
+                    if (p1 == null) {
+                        return -1;
+                    }
+                    if (p2 == null) {
+                        return 1;
+                    }
+                    if (p1.getX() < p2.getX()) {
+                        return -1;
+                    }
+                    if (p1.getX() > p2.getX()) {
+                        return 1;
+                    }
+                    if (p1.getY() < p2.getY()) {
+                        return -1;
+                    }
+                    if (p1.getY() > p2.getY()) {
+                        return 1;
+                    }
+                    if (p1.getWeight() < p2.getWeight()) {
+                        return -1;
+                    }
+                    if (p1.getWeight() > p2.getWeight()) {
+                        return 1;
+                    }
+                    return 0;
+                }
+            };
+        }
     }
 }

Modified: commons/proper/math/trunk/src/site/xdoc/changes.xml
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/site/xdoc/changes.xml?rev=1072056&r1=1072055&r2=1072056&view=diff
==============================================================================
--- commons/proper/math/trunk/src/site/xdoc/changes.xml (original)
+++ commons/proper/math/trunk/src/site/xdoc/changes.xml Fri Feb 18 16:45:57 2011
@@ -52,6 +52,11 @@ The <action> type attribute can be add,u
     If the output is not quite correct, check for invisible trailing spaces!
      -->
     <release version="3.0" date="TBD" description="TBD">
+      <action dev="erans" type="update" issue="MATH-512" due-to="Ole Ersoy">
+        Refactored "GaussianFitter" (in package "optimization.fitting").
+        The class now really fits a Gaussian function (whereas previously it was
+        fitting the sum of a constant and a Gaussian).
+      </action>
       <action dev="erans" type="fix" issue="MATH-442" due-to="Dietmar Wolz">
         Implementation of the CMA-ES optimization algorithm.
       </action>

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java?rev=1072056&r1=1072055&r2=1072056&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/optimization/fitting/GaussianFitterTest.java
Fri Feb 18 16:45:57 2011
@@ -191,11 +191,11 @@ public class GaussianFitterTest {
     throws OptimizationException {
         GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
         addDatasetToGaussianFitter(DATASET1, fitter);
-        GaussianFunction fitFunction = fitter.fit();
-        assertEquals(99200.86969833552, fitFunction.getA(), 1e-4);
-        assertEquals(3410515.285208688, fitFunction.getB(), 1e-4);
-        assertEquals(4.054928275302832, fitFunction.getC(), 1e-4);
-        assertEquals(0.014609868872574, fitFunction.getD(), 1e-4);
+        double[] parameters = fitter.fit();
+
+        assertEquals(3496978.1837704973, parameters[0], 1e-4);
+        assertEquals(4.054933085999146, parameters[1], 1e-4);
+        assertEquals(0.015039355620304326, parameters[2], 1e-4);
     }
 
     /**
@@ -209,7 +209,7 @@ public class GaussianFitterTest {
         GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
         fitter.fit();
     }
-
+    
     /**
      * Two points is not enough observed points.
      *
@@ -225,7 +225,7 @@ public class GaussianFitterTest {
             fitter);
         fitter.fit();
     }
-
+    
     /**
      * Poor data: right of peak not symmetric with left of peak.
      *
@@ -233,16 +233,17 @@ public class GaussianFitterTest {
      */
     @Test
     public void testFit04()
-    throws OptimizationException {
+    throws OptimizationException 
+    {
         GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
         addDatasetToGaussianFitter(DATASET2, fitter);
-        GaussianFunction fitFunction = fitter.fit();
-        assertEquals(-256534.689445631, fitFunction.getA(), 1e-4);
-        assertEquals(481328.2181530679, fitFunction.getB(), 1e-4);
-        assertEquals(-10.5217226891099, fitFunction.getC(), 1e-4);
-        assertEquals(-7.64248239366800, fitFunction.getD(), 1e-4);
-    }
+        double[] parameters = fitter.fit();
 
+        assertEquals(233003.2967252038, parameters[0], 1e-4);
+        assertEquals(-10.654887521095983, parameters[1], 1e-4);
+        assertEquals(4.335937353196641, parameters[2], 1e-4);
+    }  
+    
     /**
      * Poor data: long tails.
      *
@@ -253,13 +254,13 @@ public class GaussianFitterTest {
     throws OptimizationException {
         GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
         addDatasetToGaussianFitter(DATASET3, fitter);
-        GaussianFunction fitFunction = fitter.fit();
-        assertEquals(491.6310079258938, fitFunction.getA(), 1e-4);
-        assertEquals(283508.6800413632, fitFunction.getB(), 1e-4);
-        assertEquals(-13.2966857238057, fitFunction.getC(), 1e-4);
-        assertEquals(1.725590356962981, fitFunction.getD(), 1e-4);
-    }
+        double[] parameters = fitter.fit();
 
+        assertEquals(283863.81929180305, parameters[0], 1e-4);
+        assertEquals(-13.29641995105174, parameters[1], 1e-4);
+        assertEquals(1.7297330293549908, parameters[2], 1e-4);
+    }
+    
     /**
      * Poor data: right of peak is missing.
      *
@@ -270,12 +271,12 @@ public class GaussianFitterTest {
     throws OptimizationException {
         GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
         addDatasetToGaussianFitter(DATASET4, fitter);
-        GaussianFunction fitFunction = fitter.fit();
-        assertEquals(530.3649792355617, fitFunction.getA(), 1e-4);
-        assertEquals(284517.0835567514, fitFunction.getB(), 1e-4);
-        assertEquals(-13.5355534565105, fitFunction.getC(), 1e-4);
-        assertEquals(1.512353018625465, fitFunction.getD(), 1e-4);
-    }
+        double[] parameters = fitter.fit();
+
+        assertEquals(285250.66754309234, parameters[0], 1e-4);
+        assertEquals(-13.528375695228455, parameters[1], 1e-4);
+        assertEquals(1.5204344894331614, parameters[2], 1e-4);
+    }    
 
     /**
      * Basic with smaller dataset.
@@ -284,16 +285,17 @@ public class GaussianFitterTest {
      */
     @Test
     public void testFit07()
-    throws OptimizationException {
+    throws OptimizationException 
+    {
         GaussianFitter fitter = new GaussianFitter(new LevenbergMarquardtOptimizer());
         addDatasetToGaussianFitter(DATASET5, fitter);
-        GaussianFunction fitFunction = fitter.fit();
-        assertEquals(176748.1400947575, fitFunction.getA(), 1e-4);
-        assertEquals(3361537.018813906, fitFunction.getB(), 1e-4);
-        assertEquals(4.054949992747176, fitFunction.getC(), 1e-4);
-        assertEquals(0.014192380137002, fitFunction.getD(), 1e-4);
-    }
+        double[] parameters = fitter.fit();
 
+        assertEquals(3514384.729342235, parameters[0], 1e-4);
+        assertEquals(4.054970307455625, parameters[1], 1e-4);
+        assertEquals(0.015029412832160017, parameters[2], 1e-4);
+    }
+    
     /**
      * Adds the specified points to specified <code>GaussianFitter</code>
      * instance.



Mime
View raw message