+ * When a univariate real function y = f(x) does depend on some + * unknown parameters p0, p1 ... pn-1, + * this class can be used to find these parameters. It does this + * by fitting the curve so it remains very close to a set of + * observed points (x0, y0), (x1, + * y1) ... (xk-1, yk-1). This fitting + * is done by finding the parameters values that minimizes the objective + * function ∑(yi-f(xi))2. This is + * really a least squares problem. + * + * @param Function to use for the fit. + * + * @version \$Id: CurveFitter.java 1416643 2012-12-03 19:37:14Z tn \$ + * @since 2.0 + */ +public class CurveFitter { + /** Optimizer to use for the fitting. */ + private final MultivariateVectorOptimizer optimizer; + /** Observed points. */ + private final List observations; + + /** + * Simple constructor. + * + * @param optimizer Optimizer to use for the fitting. + * @since 3.1 + */ + public CurveFitter(final MultivariateVectorOptimizer optimizer) { + this.optimizer = optimizer; + observations = new ArrayList(); + } + + /** Add an observed (x,y) point to the sample with unit weight. + *

Calling this method is equivalent to call + * {@code addObservedPoint(1.0, x, y)}.

+ * Usage example: + *

```+ *   GaussianFitter fitter = new GaussianFitter(
+ *     new LevenbergMarquardtOptimizer());
+ *   double[] parameters = fitter.fit();
+ * ```
+ * + * @since 2.2 + * @version \$Id: GaussianFitter.java 1416643 2012-12-03 19:37:14Z tn \$ + */ +public class GaussianFitter extends CurveFitter { + /** + * Constructs an instance using the specified optimizer. + * + * @param optimizer Optimizer to use for the fitting. + */ + public GaussianFitter(MultivariateVectorOptimizer optimizer) { + super(optimizer); + } + + /** + * Fits a Gaussian function to the observed points. + * + * @param initialGuess First guess values in the following order: + *
+ *
• Norm
• + *
• Mean
• + *
• Sigma
• + *
+ * @return the parameters of the Gaussian function that best fits the + * observed points (in the same order as above). + * @since 3.0 + */ + public double[] fit(double[] initialGuess) { + final Gaussian.Parametric f = new Gaussian.Parametric() { + @Override + public double value(double x, double ... p) { + double v = Double.POSITIVE_INFINITY; + try { + v = super.value(x, p); + } catch (NotStrictlyPositiveException e) { + // Do nothing. + } + return v; + } + + @Override + public double[] gradient(double x, double ... p) { + double[] v = { Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY, + Double.POSITIVE_INFINITY }; + try { + v = super.gradient(x, p); + } catch (NotStrictlyPositiveException e) { + // Do nothing. + } + return v; + } + }; + + return fit(f, initialGuess); + } + + /** + * Fits a Gaussian function to the observed points. + * + * @return the parameters of the Gaussian function that best fits the + * observed points (in the same order as above). + */ + public double[] fit() { + final double[] guess = (new ParameterGuesser(getObservations())).guess(); + return fit(guess); + } + + /** + * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma} + * of a {@link org.apache.commons.math3.analysis.function.Gaussian.Parametric} + * based on the specified observed points. + */ + public static class ParameterGuesser { + /** Normalization factor. */ + private final double norm; + /** Mean. */ + private final double mean; + /** Standard deviation. */ + private final double sigma; + + /** + * Constructs instance with the specified observed points. + * + * @param observations Observed points from which to guess the + * parameters of the Gaussian. + * @throws NullArgumentException if {@code observations} is + * {@code null}. + * @throws NumberIsTooSmallException if there are less than 3 + * observations. + */ + public ParameterGuesser(WeightedObservedPoint[] observations) { + if (observations == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + if (observations.length < 3) { + throw new NumberIsTooSmallException(observations.length, 3, true); + } + + final WeightedObservedPoint[] sorted = sortObservations(observations); + final double[] params = basicGuess(sorted); + + norm = params[0]; + mean = params[1]; + sigma = params[2]; + } + + /** + * Gets an estimation of the parameters. + * + * @return the guessed parameters, in the following order: + *
+ *
• Normalization factor
• + *
• Mean
• + *
• Standard deviation
• + *
+ */ + public double[] guess() { + return new double[] { norm, mean, sigma }; + } + + /** + * Sort the observations. + * + * @param unsorted Input observations. + * @return the input observations, sorted. + */ + private WeightedObservedPoint[] sortObservations(WeightedObservedPoint[] unsorted) { + final WeightedObservedPoint[] observations = unsorted.clone(); + final Comparator cmp + = new Comparator() { + public int compare(WeightedObservedPoint p1, + WeightedObservedPoint p2) { + if (p1 == null && p2 == null) { + return 0; + } + if (p1 == null) { + return -1; + } + if (p2 == null) { + return 1; + } + if (p1.getX() < p2.getX()) { + return -1; + } + if (p1.getX() > p2.getX()) { + return 1; + } + if (p1.getY() < p2.getY()) { + return -1; + } + if (p1.getY() > p2.getY()) { + return 1; + } + if (p1.getWeight() < p2.getWeight()) { + return -1; + } + if (p1.getWeight() > p2.getWeight()) { + return 1; + } + return 0; + } + }; + + Arrays.sort(observations, cmp); + return observations; + } + + /** + * Guesses the parameters based on the specified observed points. + * + * @param points Observed points, sorted. + * @return the guessed parameters (normalization factor, mean and + * sigma). + */ + private double[] basicGuess(WeightedObservedPoint[] points) { + final int maxYIdx = findMaxY(points); + final double n = points[maxYIdx].getY(); + final double m = points[maxYIdx].getX(); + + double fwhmApprox; + try { + final double halfY = n + ((m - n) / 2); + final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY); + final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY); + fwhmApprox = fwhmX2 - fwhmX1; + } catch (OutOfRangeException e) { + // TODO: Exceptions should not be used for flow control. + fwhmApprox = points[points.length - 1].getX() - points[0].getX(); + } + final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2))); + + return new double[] { n, m, s }; + } + + /** + * Finds index of point in specified points with the largest Y. + * + * @param points Points to search. + * @return the index in specified points array. + */ + private int findMaxY(WeightedObservedPoint[] points) { + int maxYIdx = 0; + for (int i = 1; i < points.length; i++) { + if (points[i].getY() > points[maxYIdx].getY()) { + maxYIdx = i; + } + } + return maxYIdx; + } + + /** + * Interpolates using the specified points to determine X at the + * specified Y. + * + * @param points Points to use for interpolation. + * @param startIdx Index within points from which to start the search for + * interpolation bounds points. + * @param idxStep Index step for searching interpolation bounds points. + * @param y Y value for which X should be determined. + * @return the value of X for the specified Y. + * @throws ZeroException if {@code idxStep} is 0. + * @throws OutOfRangeException if specified {@code y} is not within the + * range of the specified {@code points}. + */ + private double interpolateXAtY(WeightedObservedPoint[] points, + int startIdx, + int idxStep, + double y) + throws OutOfRangeException { + if (idxStep == 0) { + throw new ZeroException(); + } + final WeightedObservedPoint[] twoPoints + = getInterpolationPointsForY(points, startIdx, idxStep, y); + final WeightedObservedPoint p1 = twoPoints[0]; + final WeightedObservedPoint p2 = twoPoints[1]; + if (p1.getY() == y) { + return p1.getX(); + } + if (p2.getY() == y) { + return p2.getX(); + } + return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) / + (p2.getY() - p1.getY())); + } + + /** + * Gets the two bounding interpolation points from the specified points + * suitable for determining X at the specified Y. + * + * @param points Points to use for interpolation. + * @param startIdx Index within points from which to start search for + * interpolation bounds points. + * @param idxStep Index step for search for interpolation bounds points. + * @param y Y value for which X should be determined. + * @return the array containing two points suitable for determining X at + * the specified Y. + * @throws ZeroException if {@code idxStep} is 0. + * @throws OutOfRangeException if specified {@code y} is not within the + * range of the specified {@code points}. + */ + private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points, + int startIdx, + int idxStep, + double y) + throws OutOfRangeException { + if (idxStep == 0) { + throw new ZeroException(); + } + for (int i = startIdx; + idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length; + i += idxStep) { + final WeightedObservedPoint p1 = points[i]; + final WeightedObservedPoint p2 = points[i + idxStep]; + if (isBetween(y, p1.getY(), p2.getY())) { + if (idxStep < 0) { + return new WeightedObservedPoint[] { p2, p1 }; + } else { + return new WeightedObservedPoint[] { p1, p2 }; + } + } + } + + // Boundaries are replaced by dummy values because the raised + // exception is caught and the message never displayed. + // TODO: Exceptions should not be used for flow control. + throw new OutOfRangeException(y, + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY); + } + + /** + * Determines whether a value is between two other values. + * + * @param value Value to test whether it is between {@code boundary1} + * and {@code boundary2}. + * @param boundary1 One end of the range. + * @param boundary2 Other end of the range. + * @return {@code true} if {@code value} is between {@code boundary1} and + * {@code boundary2} (inclusive), {@code false} otherwise. + */ + private boolean isBetween(double value, + double boundary1, + double boundary2) { + return (value >= boundary1 && value <= boundary2) || + (value >= boundary2 && value <= boundary1); + } + } +} Propchange: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/GaussianFitter.java ------------------------------------------------------------------------------ svn:eol-style = native Added: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/HarmonicFitter.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/HarmonicFitter.java?rev=1420684&view=auto ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/HarmonicFitter.java (added) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/HarmonicFitter.java Wed Dec 12 14:10:38 2012 @@ -0,0 +1,382 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.fitting; + +import org.apache.commons.math3.optim.nonlinear.vector.MultivariateVectorOptimizer; +import org.apache.commons.math3.analysis.function.HarmonicOscillator; +import org.apache.commons.math3.exception.ZeroException; +import org.apache.commons.math3.exception.NumberIsTooSmallException; +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.util.LocalizedFormats; +import org.apache.commons.math3.util.FastMath; + +/** + * Class that implements a curve fitting specialized for sinusoids. + * + * Harmonic fitting is a very simple case of curve fitting. The + * estimated coefficients are the amplitude a, the pulsation ω and + * the phase φ: `f (t) = a cos (ω t + φ)`. They are + * searched by a least square estimator initialized with a rough guess + * based on integrals. + * + * @version \$Id: HarmonicFitter.java 1416643 2012-12-03 19:37:14Z tn \$ + * @since 2.0 + */ +public class HarmonicFitter extends CurveFitter { + /** + * Simple constructor. + * @param optimizer Optimizer to use for the fitting. + */ + public HarmonicFitter(final MultivariateVectorOptimizer optimizer) { + super(optimizer); + } + + /** + * Fit an harmonic function to the observed points. + * + * @param initialGuess First guess values in the following order: + *
+ *
• Amplitude
• + *
• Angular frequency
• + *
• Phase
• + *
+ * @return the parameters of the harmonic function that best fits the + * observed points (in the same order as above). + */ + public double[] fit(double[] initialGuess) { + return fit(new HarmonicOscillator.Parametric(), initialGuess); + } + + /** + * Fit an harmonic function to the observed points. + * An initial guess will be automatically computed. + * + * @return the parameters of the harmonic function that best fits the + * observed points (see the other {@link #fit(double[]) fit} method. + * @throws NumberIsTooSmallException if the sample is too short for the + * the first guess to be computed. + * @throws ZeroException if the first guess cannot be computed because + * the abscissa range is zero. + */ + public double[] fit() { + return fit((new ParameterGuesser(getObservations())).guess()); + } + + /** + * This class guesses harmonic coefficients from a sample. + *

The algorithm used to guess the coefficients is as follows:

+ * + *

We know f (t) at some sampling points ti and want to find a, + * ω and φ such that f (t) = a cos (ω t + φ). + *

+ * + *

From the analytical expression, we can compute two primitives : + *

```+     *     If2  (t) = ∫ f2  = a2 × [t + S (t)] / 2
+     *     If'2 (t) = ∫ f'2 = a2 ω2 × [t - S (t)] / 2
+     *     where S (t) = sin (2 (ω t + φ)) / (2 ω)
+     * ```
+ *

+ * + *

We can remove S between these expressions : + *

```+     *     If'2 (t) = a2 ω2 t - ω2 If2 (t)
+     * ```
+ *

+ * + *

The preceding expression shows that If'2 (t) is a linear + * combination of both t and If2 (t): If'2 (t) = A × t + B × If2 (t) + *

+ * + *

From the primitive, we can deduce the same form for definite + * integrals between t1 and ti for each ti : + *

```+     *   If2 (ti) - If2 (t1) = A × (ti - t1) + B × (If2 (ti) - If2 (t1))
+     * ```
+ *

+ * + *

We can find the coefficients A and B that best fit the sample + * to this linear expression by computing the definite integrals for + * each sample points. + *

+ * + *

For a bilinear expression z (xi, yi) = A × xi + B × yi, the + * coefficients A and B that minimize a least square criterion + * ∑ (zi - z (xi, yi))2 are given by these expressions:

+ *
```+     *
+     *         ∑yiyi ∑xizi - ∑xiyi ∑yizi
+     *     A = ------------------------
+     *         ∑xixi ∑yiyi - ∑xiyi ∑xiyi
+     *
+     *         ∑xixi ∑yizi - ∑xiyi ∑xizi
+     *     B = ------------------------
+     *         ∑xixi ∑yiyi - ∑xiyi ∑xiyi
+     * ```
+ *

+ * + * + *

In fact, we can assume both a and ω are positive and + * compute them directly, knowing that A = a2 ω2 and that + * B = - ω2. The complete algorithm is therefore:

+ *
```+     *
+     * for each ti from t1 to tn-1, compute:
+     *   f  (ti)
+     *   f' (ti) = (f (ti+1) - f(ti-1)) / (ti+1 - ti-1)
+     *   xi = ti - t1
+     *   yi = ∫ f2 from t1 to ti
+     *   zi = ∫ f'2 from t1 to ti
+     *   update the sums ∑xixi, ∑yiyi, ∑xiyi, ∑xizi and ∑yizi
+     * end for
+     *
+     *            |--------------------------
+     *         \  | ∑yiyi ∑xizi - ∑xiyi ∑yizi
+     * a     =  \ | ------------------------
+     *           \| ∑xiyi ∑xizi - ∑xixi ∑yizi
+     *
+     *
+     *            |--------------------------
+     *         \  | ∑xiyi ∑xizi - ∑xixi ∑yizi
+     * ω     =  \ | ------------------------
+     *           \| ∑xixi ∑yiyi - ∑xiyi ∑xiyi
+     *
+     * ```
+ *

+ * + *

Once we know ω, we can compute: + *

```+     *    fc = ω f (t) cos (ω t) - f' (t) sin (ω t)
+     *    fs = ω f (t) sin (ω t) + f' (t) cos (ω t)
+     * ```
+ *

+ * + *

It appears that `fc = a ω cos (φ)` and + * `fs = -a ω sin (φ)`, so we can use these + * expressions to compute φ. The best estimate over the sample is + * given by averaging these expressions. + *

+ * + *

Since integrals and means are involved in the preceding + * estimations, these operations run in O(n) time, where n is the + * number of measurements.

+ */ + public static class ParameterGuesser { + /** Amplitude. */ + private final double a; + /** Angular frequency. */ + private final double omega; + /** Phase. */ + private final double phi; + + /** + * Simple constructor. + * + * @param observations Sampled observations. + * @throws NumberIsTooSmallException if the sample is too short. + * @throws ZeroException if the abscissa range is zero. + * @throws MathIllegalStateException when the guessing procedure cannot + * produce sensible results. + */ + public ParameterGuesser(WeightedObservedPoint[] observations) { + if (observations.length < 4) { + throw new NumberIsTooSmallException(LocalizedFormats.INSUFFICIENT_OBSERVED_POINTS_IN_SAMPLE, + observations.length, 4, true); + } + + final WeightedObservedPoint[] sorted = sortObservations(observations); + + final double aOmega[] = guessAOmega(sorted); + a = aOmega[0]; + omega = aOmega[1]; + + phi = guessPhi(sorted); + } + + /** + * Gets an estimation of the parameters. + * + * @return the guessed parameters, in the following order: + *
+ *
• Amplitude
• + *
• Angular frequency
• + *
• Phase
• + *

Instances of this class are guaranteed to be immutable.

+ * @version \$Id: WeightedObservedPoint.java 1416643 2012-12-03 19:37:14Z tn \$ + * @since 2.0 + */ +public class WeightedObservedPoint implements Serializable { + /** Serializable version id. */ + private static final long serialVersionUID = 5306874947404636157L; + /** Weight of the measurement in the fitting process. */ + private final double weight; + /** Abscissa of the point. */ + private final double x; + /** Observed value of the function at x. */ + private final double y; + + /** + * Simple constructor. + * + * @param weight Weight of the measurement in the fitting process. + * @param x Abscissa of the measurement. + * @param y Ordinate of the measurement. + */ + public WeightedObservedPoint(final double weight, final double x, final double y) { + this.weight = weight; + this.x = x; + this.y = y; + } + + /** + * Gets the weight of the measurement in the fitting process. + * + * @return the weight of the measurement in the fitting process. + */ + public double getWeight() { + return weight; + } + + /** + * Gets the abscissa of the point. + * + * @return the abscissa of the point. + */ + public double getX() { + return x; + } + + /** + * Gets the observed value of the function at x. + * + * @return the observed value of the function at x. + */ + public double getY() { + return y; + } + +} + Propchange: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/WeightedObservedPoint.java ------------------------------------------------------------------------------ svn:eol-style = native Added: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/package-info.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/package-info.java?rev=1420684&view=auto ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/package-info.java (added) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/package-info.java Wed Dec 12 14:10:38 2012 @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Classes to perform curve fitting. + * + * Curve fitting is a special case of a least squares problem + * were the parameters are the coefficients of a function {@code f} + * whose graph {@code y = f(x)} should pass through sample points, and + * were the objective function is the squared sum of the residuals + * `f(xi) - yi` for observed points + * `(xi, yi)`. + */ +package org.apache.commons.math3.fitting; Propchange: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/fitting/package-info.java ------------------------------------------------------------------------------ svn:eol-style = native Added: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/AbstractConvergenceChecker.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/AbstractConvergenceChecker.java?rev=1420684&view=auto ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/AbstractConvergenceChecker.java (added) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/AbstractConvergenceChecker.java Wed Dec 12 14:10:38 2012 @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.optim; + +/** + * Base class for all convergence checker implementations. + * + * @param Type of (point, value) pair. + * + * @version \$Id: AbstractConvergenceChecker.java 1370215 2012-08-07 12:38:59Z sebb \$ + * @since 3.0 + */ +public abstract class AbstractConvergenceChecker + implements ConvergenceChecker { + /** + * Relative tolerance threshold. + */ + private final double relativeThreshold; + /** + * Absolute tolerance threshold. + */ + private final double absoluteThreshold; + + /** + * Build an instance with a specified thresholds. + * + * @param relativeThreshold relative tolerance threshold + * @param absoluteThreshold absolute tolerance threshold + */ + public AbstractConvergenceChecker(final double relativeThreshold, + final double absoluteThreshold) { + this.relativeThreshold = relativeThreshold; + this.absoluteThreshold = absoluteThreshold; + } + + /** + * @return the relative threshold. + */ + public double getRelativeThreshold() { + return relativeThreshold; + } + + /** + * @return the absolute threshold. + */ + public double getAbsoluteThreshold() { + return absoluteThreshold; + } + + /** + * {@inheritDoc} + */ + public abstract boolean converged(int iteration, + PAIR previous, + PAIR current); +} Propchange: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/AbstractConvergenceChecker.java ------------------------------------------------------------------------------ svn:eol-style = native Added: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/BaseMultiStartMultivariateOptimizer.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/BaseMultiStartMultivariateOptimizer.java?rev=1420684&view=auto ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/BaseMultiStartMultivariateOptimizer.java (added) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/BaseMultiStartMultivariateOptimizer.java Wed Dec 12 14:10:38 2012 @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math3.optim; + +import org.apache.commons.math3.exception.MathIllegalStateException; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.random.RandomVectorGenerator; +import org.apache.commons.math3.optim.InitialGuess; + +/** + * Base class multi-start optimizer for a multivariate function. + *
+ * This class wraps an optimizer in order to use it several times in + * turn with different starting points (trying to avoid being trapped + * in a local extremum when looking for a global one). + * It is not a "user" class. + * + * @param Type of the point/value pair returned by the optimization + * algorithm. + * + * @version \$Id\$ + * @since 3.0 + */ +public abstract class BaseMultiStartMultivariateOptimizer + extends BaseMultivariateOptimizer { + /** Underlying classical optimizer. */ + private final BaseMultivariateOptimizer optimizer; + /** Number of evaluations already performed for all starts. */ + private int totalEvaluations; + /** Number of starts to go. */ + private int starts; + /** Random generator for multi-start. */ + private RandomVectorGenerator generator; + /** Optimization data. */ + private OptimizationData[] optimData; + /** + * Location in {@link #optimData} where the updated maximum + * number of evaluations will be stored. + */ + private int maxEvalIndex = -1; + /** + * Location in {@link #optimData} where the updated start value + * will be stored. + */ + private int initialGuessIndex = -1; + + /** + * Create a multi-start optimizer from a single-start optimizer. + * + * @param optimizer Single-start optimizer to wrap. + * @param starts Number of starts to perform. If {@code starts == 1}, + * the {@link #optimize(OptimizationData[]) optimize} will return the + * same solution as the given {@code optimizer} would return. + * @param generator Random vector generator to use for restarts. + * @throws NullArgumentException if {@code optimizer} or {@code generator} + * is {@code null}. + * @throws NotStrictlyPositiveException if {@code starts < 1}. + */ + public BaseMultiStartMultivariateOptimizer(final BaseMultivariateOptimizer optimizer, + final int starts, + final RandomVectorGenerator generator) { + super(optimizer.getConvergenceChecker()); + + if (optimizer == null || + generator == null) { + throw new NullArgumentException(); + } + if (starts < 1) { + throw new NotStrictlyPositiveException(starts); + } + + this.optimizer = optimizer; + this.starts = starts; + this.generator = generator; + } + + /** {@inheritDoc} */ + @Override + public int getEvaluations() { + return totalEvaluations; + } + + /** + * Gets all the optima found during the last call to {@code optimize}. + * The optimizer stores all the optima found during a set of + * restarts. The {@code optimize} method returns the best point only. + * This method returns all the points found at the end of each starts, + * including the best one already returned by the {@code optimize} method. + *
+ * The returned array as one element for each start as specified + * in the constructor. It is ordered with the results from the + * runs that did converge first, sorted from best to worst + * objective value (i.e in ascending order if minimizing and in + * descending order if maximizing), followed by {@code null} elements + * corresponding to the runs that did not converge. This means all + * elements will be {@code null} if the {@code optimize} method did throw + * an exception. + * This also means that if the first element is not {@code null}, it is + * the best point found across all starts. + *
+ * The behaviour is undefined if this method is called before + * {@code optimize}; it will likely throw {@code NullPointerException}. + * + * @return an array containing the optima sorted from best to worst. + */ + public abstract PAIR[] getOptima(); + + /** + * {@inheritDoc} + * + * @throws MathIllegalStateException if {@code optData} does not contain an + * instance of {@link MaxEval} or {@link InitialGuess}. + */ + @Override + public PAIR optimize(OptimizationData... optData) { + // Store arguments in order to pass them to the internal optimizer. + optimData = optData; + // Set up base class and perform computations. + return super.optimize(optData); + } + + /** {@inheritDoc} */ + @Override + protected PAIR doOptimize() { + // Remove all instances of "MaxEval" and "InitialGuess" from the + // array that will be passed to the internal optimizer. + // The former is to enforce smaller numbers of allowed evaluations + // (according to how many have been used up already), and the latter + // to impose a different start value for each start. + for (int i = 0; i < optimData.length; i++) { + if (optimData[i] instanceof MaxEval) { + optimData[i] = null; + maxEvalIndex = i; + } + if (optimData[i] instanceof InitialGuess) { + optimData[i] = null; + initialGuessIndex = i; + continue; + } + } + if (maxEvalIndex == -1) { + throw new MathIllegalStateException(); + } + if (initialGuessIndex == -1) { + throw new MathIllegalStateException(); + } + + RuntimeException lastException = null; + totalEvaluations = 0; + clear(); + + final int maxEval = getMaxEvaluations(); + final double[] min = getLowerBound(); + final double[] max = getUpperBound(); + final double[] startPoint = getStartPoint(); + + // Multi-start loop. + for (int i = 0; i < starts; i++) { + // CHECKSTYLE: stop IllegalCatch + try { + // Decrease number of allowed evaluations. + optimData[maxEvalIndex] = new MaxEval(maxEval - totalEvaluations); + // New start value. + final double[] s = (i == 0) ? + startPoint : + generator.nextVector(); // XXX This does not enforce bounds! + optimData[initialGuessIndex] = new InitialGuess(s); + // Optimize. + final PAIR result = optimizer.optimize(optimData); + store(result); + } catch (RuntimeException mue) { + lastException = mue; + } + // CHECKSTYLE: resume IllegalCatch + + totalEvaluations += optimizer.getEvaluations(); + } + + final PAIR[] optima = getOptima(); + if (optima.length == 0) { + // All runs failed. + throw lastException; // Cannot be null if starts >= 1. + } + + // Return the best optimum. + return optima[0]; + } + + /** + * Method that will be called in order to store each found optimum. + * + * @param optimum Result of an optimization run. + */ + protected abstract void store(PAIR optimum); + /** + * Method that will called in order to clear all stored optima. + */ + protected abstract void clear(); +} Propchange: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/optim/BaseMultiStartMultivariateOptimizer.java ------------------------------------------------------------------------------ svn:eol-style = native