commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From pste...@apache.org
Subject svn commit: r680162 - in /commons/proper/math/branches/MATH_2_0/src: java/org/apache/commons/math/stat/regression/ site/xdoc/ test/org/apache/commons/math/stat/regression/
Date Sun, 27 Jul 2008 18:52:39 GMT
Author: psteitz
Date: Sun Jul 27 11:52:38 2008
New Revision: 680162

URL: http://svn.apache.org/viewvc?rev=680162&view=rev
Log:
Changed OLSMultipleLinearRegression implementation to use QR decomposition to
solve the normal equations.
JIRA: MATH-217


Modified:
    commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
    commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml
    commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java
    commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java

Modified: commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
(original)
+++ commons/proper/math/branches/MATH_2_0/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
Sun Jul 27 11:52:38 2008
@@ -16,30 +16,52 @@
  */
 package org.apache.commons.math.stat.regression;
 
+import org.apache.commons.math.linear.QRDecomposition;
+import org.apache.commons.math.linear.QRDecompositionImpl;
 import org.apache.commons.math.linear.RealMatrix;
-
+import org.apache.commons.math.linear.RealMatrixImpl;
 
 /**
- * The OLS implementation of the multiple linear regression.
+ * <p>Implements ordinary least squares (OLS) to estimate the parameters of a 
+ * multiple linear regression model.</p>
  * 
- * OLS assumes the covariance matrix of the error to be diagonal and with equal variance.
+ * <p>OLS assumes the covariance matrix of the error to be diagonal and with
+ * equal variance.
  * <pre>
  * u ~ N(0, sigma^2*I)
- * </pre>
+ * </pre></p>
  * 
- * Estimated by OLS, 
+ * <p>The regression coefficients, b, satisfy the normal equations:
  * <pre>
- * b=(X'X)^-1X'y
- * </pre>
- * whose variance is
+ * X^T X b = X^T y
+ * </pre></p>
+ * 
+ * <p>To solve the normal equations, this implementation uses QR decomposition
+ * of the X matrix. (See {@link QRDecompositionImpl} for details on the
+ * decomposition algorithm.)
  * <pre>
- * Var(b)=MSE*(X'X)^-1, MSE=u'u/(n-k)
+ * X^T X b = X^T y
+ * (QR)^T (QR) b = (QR)^T y
+ * R^T (Q^T Q) R b = R^T Q^T y
+ * R^T R b = R^T Q^T y
+ * (R^T)^{-1} R^T R b = (R^T)^{-1} R^T Q^T y
+ * R b = Q^T y
  * </pre>
+ * Given Q and R, the last equation is solved by back-subsitution.</p>
+ * 
  * @version $Revision$ $Date$
  * @since 2.0
  */
 public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression {
+    
+    /** Cached QR decomposition of X matrix */
+    private QRDecomposition qr = null;
 
+    /*
+     * {@inheritDoc}
+     * 
+     * Computes and caches QR decomposition of the X matrix.
+     */
     public void newSampleData(double[] y, double[][] x) {
         validateSampleData(x, y);
         newYSampleData(y);
@@ -47,15 +69,33 @@
     }
     
     /**
-     * Calculates beta by OLS.
-     * <pre>
-     * b=(X'X)^-1X'y
-     * </pre> 
+     * {@inheritDoc}
+     * 
+     * Computes and caches QR decomposition of the X matrix
+     */
+    public void newSampleData(double[] data, int nobs, int nvars) {
+        super.newSampleData(data, nobs, nvars);
+        qr = new QRDecompositionImpl(X);
+    }
+    
+    /**
+     * Loads new x sample data, overriding any previous sample
+     * 
+     * @param x the [n,k] array representing the x sample
+     */
+    protected void newXSampleData(double[][] x) {
+        this.X = new RealMatrixImpl(x);
+        qr = new QRDecompositionImpl(X);
+    }
+    
+    /**
+     * Calculates regression coefficients using OLS.
+     * 
      * @return beta
      */
     protected RealMatrix calculateBeta() {
-        RealMatrix XTX = X.transpose().multiply(X);
-        return XTX.inverse().multiply(X.transpose()).multiply(Y);
+        return solveUpperTriangular((RealMatrixImpl) qr.getR(),
+                (RealMatrixImpl) qr.getQ().transpose().multiply(Y));
     }
 
     /**
@@ -83,5 +123,76 @@
         RealMatrix sse = u.transpose().multiply(u);
         return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
     }
-
+    
+    /** TODO:  Find a home for the following methods in the linear package */   
+    
+    /**
+     * <p>Uses back substitution to solve the system</p>
+     * 
+     * <p>coefficients X = constants</p>
+     * 
+     * <p>coefficients must upper-triangular and constants must be a column 
+     * matrix.  The solution is returned as a column matrix.</p>
+     * 
+     * <p>The number of columns in coefficients determines the length
+     * of the returned solution vector (column matrix).  If constants
+     * has more rows than coefficients has columns, excess rows are ignored.
+     * Similarly, extra (zero) rows in coefficients are ignored</p>
+     * 
+     * @param coefficients upper-triangular coefficients matrix
+     * @param constants column RHS constants matrix
+     * @return solution matrix as a column matrix
+     * 
+     */
+    private static RealMatrix solveUpperTriangular(RealMatrixImpl coefficients,
+            RealMatrixImpl constants) {
+        if (!isUpperTriangular(coefficients, 1E-12)) {
+            throw new IllegalArgumentException(
+                   "Coefficients is not upper-triangular");
+        }
+        if (constants.getColumnDimension() != 1) {
+            throw new IllegalArgumentException(
+                    "Constants not a column matrix.");
+        }
+        int length = coefficients.getColumnDimension();
+        double[][] cons = constants.getDataRef();
+        double[][] coef = coefficients.getDataRef();
+        double x[] = new double[length];
+        for (int i = 0; i < length; i++) {
+            int index = length - 1 - i;
+            double sum = 0;
+            for (int j = index + 1; j < length; j++) {
+                sum += coef[index][j] * x[j];
+            }
+            x[index] = (cons[index][0] - sum) / coef[index][index];
+        } 
+        return new RealMatrixImpl(x);
+    }
+    
+    /**
+     * <p>Returns true iff m is an upper-triangular matrix.</p>
+     * 
+     * <p>Makes sure all below-diagonal elements are within epsilon of 0.</p>
+     * 
+     * @param m matrix to check
+     * @param epsilon maximum allowable absolute value for elements below
+     * the main diagonal
+     * 
+     * @return true if m is upper-triangular; false otherwise
+     * @throws NullPointerException if m is null
+     */
+    private static boolean isUpperTriangular(RealMatrixImpl m, double epsilon) {
+        double[][] data = m.getDataRef();
+        int nCols = m.getColumnDimension();
+        int nRows = m.getRowDimension();
+        for (int r = 0; r < nRows; r++) {
+            int bound = Math.min(r, nCols);
+            for (int c = 0; c < bound; c++) {
+                if (Math.abs(data[r][c]) > epsilon) {
+                    return false;
+                }
+            }
+        }
+        return true;
+    }
 }

Modified: commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml (original)
+++ commons/proper/math/branches/MATH_2_0/src/site/xdoc/changes.xml Sun Jul 27 11:52:38 2008
@@ -39,6 +39,10 @@
   </properties>
   <body>
     <release version="2.0" date="TBD" description="TBD">
+      <action dev="psteitz" type="update" issue="MATH-217">
+        Changed OLS regression implementation added in MATH-203 to use
+        QR decomposition to solve the normal equations.
+      </action>
       <action dev="luc" type="add">
         New ODE integrators have been added: the explicit Adams-Bashforth and implicit
         Adams-Moulton multistep methods. These methods support customizable starter

Modified: commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java
(original)
+++ commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/AbstractMultipleLinearRegressionTest.java
Sun Jul 27 11:52:38 2008
@@ -58,8 +58,10 @@
 
     @Test
     public void canEstimateRegressandVariance(){
-        double variance = regression.estimateRegressandVariance();
-        assertTrue(variance > 0.0);
+        if (getSampleSize() > getNumberOfRegressors()) {
+            double variance = regression.estimateRegressandVariance();
+            assertTrue(variance > 0.0);
+        }
     }   
 
 }

Modified: commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java?rev=680162&r1=680161&r2=680162&view=diff
==============================================================================
--- commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
(original)
+++ commons/proper/math/branches/MATH_2_0/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
Sun Jul 27 11:52:38 2008
@@ -18,7 +18,10 @@
 
 import org.junit.Before;
 import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 import org.apache.commons.math.TestUtils;
+import org.apache.commons.math.linear.RealMatrixImpl;
 
 public class OLSMultipleLinearRegressionTest extends AbstractMultipleLinearRegressionTest
{
 
@@ -131,7 +134,7 @@
           new double[]{-3482258.63459582, 15.0618722713733,
                 -0.358191792925910E-01,-2.02022980381683,
                 -1.03322686717359,-0.511041056535807E-01,
-                 1829.15146461355}, 1E-1); // <- UGH! need better accuracy!
+                 1829.15146461355}, 1E-8); // 
         
         // Check expected residuals from R
         double[] residuals = model.estimateResiduals();
@@ -142,7 +145,7 @@
                  455.394094551857,-17.26892711483297,-39.0550425226967,
                 -155.5499735953195,-85.6713080421283,341.9315139607727,
                 -206.7578251937366},
-                      1E-2); // <- UGH again! need better accuracy!
+                      1E-8);
         
         // Check standard errors from NIST
         double[][] errors = model.estimateRegressionParametersVariance();



Mime
View raw message