commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From er...@apache.org
Subject svn commit: r1080571 - in /commons/proper/math/trunk/src: main/java/org/apache/commons/math/analysis/function/Logistic.java test/java/org/apache/commons/math/analysis/function/LogisticTest.java
Date Fri, 11 Mar 2011 14:03:59 GMT
Author: erans
Date: Fri Mar 11 14:03:59 2011
New Revision: 1080571

URL: http://svn.apache.org/viewvc?rev=1080571&view=rev
Log:
MATH-503
Added parametric version of the "Logistic" function.

Modified:
    commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
    commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java

Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java?rev=1080571&r1=1080570&r2=1080571&view=diff
==============================================================================
--- commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
(original)
+++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/analysis/function/Logistic.java
Fri Mar 11 14:03:59 2011
@@ -19,7 +19,10 @@ package org.apache.commons.math.analysis
 
 import org.apache.commons.math.analysis.UnivariateRealFunction;
 import org.apache.commons.math.analysis.DifferentiableUnivariateRealFunction;
+import org.apache.commons.math.analysis.ParametricUnivariateRealFunction;
 import org.apache.commons.math.exception.NotStrictlyPositiveException;
+import org.apache.commons.math.exception.NullArgumentException;
+import org.apache.commons.math.exception.DimensionMismatchException;
 import org.apache.commons.math.util.FastMath;
 
 /**
@@ -76,7 +79,7 @@ public class Logistic implements Differe
 
     /** {@inheritDoc} */
     public double value(double x) {
-        return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * (m - x)), oneOverN);
+        return value(m - x, k, b, q, a, oneOverN);
     }
 
     /** {@inheritDoc} */
@@ -94,4 +97,113 @@ public class Logistic implements Differe
             }
         };
     }
+
+    /**
+     * Parametric function where the input array contains the parameters of
+     * the logit function, ordered as follows:
+     * <ul>
+     *  <li>Lower asymptote</li>
+     *  <li>Higher asymptote</li>
+     * </ul>
+     */
+    public static class Parametric implements ParametricUnivariateRealFunction {
+        /**
+         * Computes the value of the sigmoid at {@code x}.
+         *
+         * @param x Value for which the function must be computed.
+         * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
+         * {@code a} and  {@code n}.
+         * @return the value of the function.
+         * @throws NullArgumentException if {@code param} is {@code null}.
+         * @throws DimensionMismatchException if the size of {@code param} is
+         * not 6.
+         */
+        public double value(double x,
+                            double[] param) {
+            validateParameters(param);
+            return Logistic.value(param[1] - x, param[0],
+                                  param[2], param[3],
+                                  param[4], 1 / param[5]);
+        }
+
+        /**
+         * Computes the value of the gradient at {@code x}.
+         * The components of the gradient vector are the partial
+         * derivatives of the function with respect to each of the
+         * <em>parameters</em>.
+         *
+         * @param x Value at which the gradient must be computed.
+         * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
+         * {@code a} and  {@code n}.
+         * @return the gradient vector at {@code x}.
+         * @throws NullArgumentException if {@code param} is {@code null}.
+         * @throws DimensionMismatchException if the size of {@code param} is
+         * not 6.
+         */
+        public double[] gradient(double x, double[] param) {
+            validateParameters(param);
+
+            final double b = param[2];
+            final double q = param[3];
+
+            final double mMinusX = param[1] - x;
+            final double oneOverN = 1 / param[5];
+            final double exp = FastMath.exp(b * mMinusX);
+            final double qExp = q * exp;
+            final double qExp1 = qExp + 1;
+            final double factor1 = (param[0] - param[4]) * oneOverN / FastMath.pow(qExp1,
oneOverN);
+            final double factor2 = -factor1 / qExp1;
+
+            // Components of the gradient.
+            final double gk = Logistic.value(mMinusX, 1, b, q, 0, oneOverN);
+            final double gm = factor2 * b * qExp;
+            final double gb = factor2 * mMinusX * qExp;
+            final double gq = factor2 * exp;
+            final double ga = Logistic.value(mMinusX, 0, b, q, 1, oneOverN);
+            final double gn = factor1 * Math.log(qExp1) * oneOverN;
+
+            return new double[] { gk, gm, gb, gq, ga, gn };
+        }
+
+        /**
+         * Validates parameters to ensure they are appropriate for the evaluation of
+         * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
+         * methods.
+         *
+         * @param param Values for {@code k}, {@code m}, {@code b}, {@code q},
+         * {@code a} and  {@code n}.
+         * @throws NullArgumentException if {@code param} is {@code null}.
+         * @throws DimensionMismatchException if the size of {@code param} is
+         * not 6.
+         */
+        private void validateParameters(double[] param) {
+            if (param == null) {
+                throw new NullArgumentException();
+            }
+            if (param.length != 6) {
+                throw new DimensionMismatchException(param.length, 6);
+            }
+            if (param[5] <= 0) {
+                throw new NotStrictlyPositiveException(param[5]);
+            }
+        }
+    }
+
+    /**
+     * @param mMinusX {@code m - x}.
+     * @param k {@code k}.
+     * @param b {@code b}.
+     * @param q {@code q}.
+     * @param a {@code a}.
+     * @param oneOverN {@code 1 / n}.
+     * @return the value of the function.
+     */
+    private static double value(double mMinusX,
+                                double k,
+                                double b,
+                                double q,
+                                double a,
+                                double oneOverN) {
+        return a + (k - a) / FastMath.pow(1 + q * FastMath.exp(b * mMinusX), oneOverN);
+    }
 }

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java?rev=1080571&r1=1080570&r2=1080571&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/analysis/function/LogisticTest.java
Fri Mar 11 14:03:59 2011
@@ -19,6 +19,8 @@ package org.apache.commons.math.analysis
 
 import org.apache.commons.math.analysis.UnivariateRealFunction;
 import org.apache.commons.math.exception.NotStrictlyPositiveException;
+import org.apache.commons.math.exception.NullArgumentException;
+import org.apache.commons.math.exception.DimensionMismatchException;
 import org.apache.commons.math.util.FastMath;
 
 import org.junit.Assert;
@@ -97,4 +99,99 @@ public class LogisticTest {
             Assert.assertEquals("x=" + x, dgdx.value(x), dfdx.value(x), EPS);
         }
     }
+
+    @Test(expected=NullArgumentException.class)
+    public void testParametricUsage1() {
+        final Logistic.Parametric g = new Logistic.Parametric();
+        g.value(0, null);
+    }
+
+    @Test(expected=DimensionMismatchException.class)
+    public void testParametricUsage2() {
+        final Logistic.Parametric g = new Logistic.Parametric();
+        g.value(0, new double[] {0});
+    }
+
+    @Test(expected=NullArgumentException.class)
+    public void testParametricUsage3() {
+        final Logistic.Parametric g = new Logistic.Parametric();
+        g.gradient(0, null);
+    }
+
+    @Test(expected=DimensionMismatchException.class)
+    public void testParametricUsage4() {
+        final Logistic.Parametric g = new Logistic.Parametric();
+        g.gradient(0, new double[] {0});
+    }
+
+    @Test(expected=NotStrictlyPositiveException.class)
+    public void testParametricUsage5() {
+        final Logistic.Parametric g = new Logistic.Parametric();
+        g.value(0, new double[] {1, 0, 1, 1, 0 ,0});
+    }
+
+    @Test(expected=NotStrictlyPositiveException.class)
+    public void testParametricUsage6() {
+        final Logistic.Parametric g = new Logistic.Parametric();
+        g.gradient(0, new double[] {1, 0, 1, 1, 0 ,0});
+    }
+
+    @Test
+    public void testGradientComponent0Component4() {
+        final double k = 3;
+        final double a = 2;
+
+        final Logistic.Parametric f = new Logistic.Parametric();
+        // Compare using the "Sigmoid" function.
+        final Sigmoid.Parametric g = new Sigmoid.Parametric();
+        
+        final double x = 0.12345;
+        final double[] gf = f.gradient(x, new double[] {k, 0, 1, 1, a, 1});
+        final double[] gg = g.gradient(x, new double[] {a, k});
+
+        Assert.assertEquals(gg[0], gf[4], EPS);
+        Assert.assertEquals(gg[1], gf[0], EPS);
+    }
+
+    @Test
+    public void testGradientComponent5() {
+        final double m = 1.2;
+        final double k = 3.4;
+        final double a = 2.3;
+        final double q = 0.567;
+        final double b = -FastMath.log(q);
+        final double n = 3.4;
+
+        final Logistic.Parametric f = new Logistic.Parametric();
+        
+        final double x = m - 1;
+        final double qExp1 = 2;
+
+        final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
+
+        Assert.assertEquals((k - a) * FastMath.log(qExp1) / (n * n * FastMath.pow(qExp1,
1 / n)),
+                            gf[5], EPS);
+    }
+
+    @Test
+    public void testGradientComponent1Component2Component3() {
+        final double m = 1.2;
+        final double k = 3.4;
+        final double a = 2.3;
+        final double b = 0.567;
+        final double q = 1 / FastMath.exp(b * m);
+        final double n = 3.4;
+
+        final Logistic.Parametric f = new Logistic.Parametric();
+        
+        final double x = 0;
+        final double qExp1 = 2;
+
+        final double[] gf = f.gradient(x, new double[] {k, m, b, q, a, n});
+
+        final double factor = (a - k) / (n * FastMath.pow(qExp1, 1 / n + 1));
+        Assert.assertEquals(factor * b, gf[1], EPS);
+        Assert.assertEquals(factor * m, gf[2], EPS);
+        Assert.assertEquals(factor / q, gf[3], EPS);
+    }
 }



Mime
View raw message