commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r772114 - in /commons/proper/math/trunk/src: java/org/apache/commons/math/stat/regression/ test/org/apache/commons/math/stat/regression/
Date Wed, 06 May 2009 09:40:13 GMT
Author: luc
Date: Wed May  6 09:40:13 2009
New Revision: 772114

URL: http://svn.apache.org/viewvc?rev=772114&view=rev
Log:
replaced matrix by vector where possible

Modified:
    commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
    commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
    commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java

Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java?rev=772114&r1=772113&r2=772114&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
(original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java
Wed May  6 09:40:13 2009
@@ -18,6 +18,7 @@
 
 import org.apache.commons.math.linear.RealMatrix;
 import org.apache.commons.math.linear.RealMatrixImpl;
+import org.apache.commons.math.linear.RealVector;
 import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
 
 
@@ -91,12 +92,12 @@
      * @return beta
      */
     @Override
-    protected RealMatrix calculateBeta() {
+    protected RealVector calculateBeta() {
         RealMatrix OI = getOmegaInverse();
         RealMatrix XT = X.transpose();
         RealMatrix XTOIX = XT.multiply(OI).multiply(X);
         RealMatrix inverse = new LUDecompositionImpl(XTOIX).getSolver().getInverse();
-        return inverse.multiply(XT).multiply(OI).multiply(Y);
+        return inverse.multiply(XT).multiply(OI).operate(Y);
     }
 
     /**
@@ -122,9 +123,9 @@
      */
     @Override
     protected double calculateYVariance() {
-        RealMatrix u = calculateResiduals();
-        RealMatrix sse =  u.transpose().multiply(getOmegaInverse()).multiply(u);
-        return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
+        RealVector residuals = calculateResiduals();
+        double t = residuals.dotProduct(getOmegaInverse().operate(residuals));
+        return t / (X.getRowDimension() - X.getColumnDimension());
     }
     
 }

Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java?rev=772114&r1=772113&r2=772114&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
(original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java
Wed May  6 09:40:13 2009
@@ -18,6 +18,8 @@
 
 import org.apache.commons.math.linear.RealMatrix;
 import org.apache.commons.math.linear.RealMatrixImpl;
+import org.apache.commons.math.linear.RealVector;
+import org.apache.commons.math.linear.RealVectorImpl;
 import org.apache.commons.math.linear.decomposition.LUDecompositionImpl;
 import org.apache.commons.math.linear.decomposition.QRDecomposition;
 import org.apache.commons.math.linear.decomposition.QRDecompositionImpl;
@@ -137,8 +139,8 @@
      * @return beta
      */
     @Override
-    protected RealMatrix calculateBeta() {
-        return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y));
+    protected RealVector calculateBeta() {
+        return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y));
     }
 
     /**
@@ -170,9 +172,9 @@
      */
     @Override
     protected double calculateYVariance() {
-        RealMatrix u = calculateResiduals();
-        RealMatrix sse = u.transpose().multiply(u);
-        return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension());
+        RealVector residuals = calculateResiduals();
+        return residuals.dotProduct(residuals) /
+               (X.getRowDimension() - X.getColumnDimension());
     }
     
     /** TODO:  Find a home for the following methods in the linear package */   
@@ -191,20 +193,16 @@
      * Similarly, extra (zero) rows in coefficients are ignored</p>
      * 
      * @param coefficients upper-triangular coefficients matrix
-     * @param constants column RHS constants matrix
-     * @return solution matrix as a column matrix
+     * @param constants column RHS constants vector
+     * @return solution matrix as a column vector
      * 
      */
-    private static RealMatrix solveUpperTriangular(RealMatrix coefficients,
-            RealMatrix constants) {
+    private static RealVector solveUpperTriangular(RealMatrix coefficients,
+                                                   RealVector constants) {
         if (!isUpperTriangular(coefficients, 1E-12)) {
             throw new IllegalArgumentException(
                    "Coefficients is not upper-triangular");
         }
-        if (constants.getColumnDimension() != 1) {
-            throw new IllegalArgumentException(
-                    "Constants not a column matrix.");
-        }
         int length = coefficients.getColumnDimension();
         double x[] = new double[length];
         for (int i = 0; i < length; i++) {
@@ -213,9 +211,9 @@
             for (int j = index + 1; j < length; j++) {
                 sum += coefficients.getEntry(index, j) * x[j];
             }
-            x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index,
index);
+            x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index);
         } 
-        return new RealMatrixImpl(x);
+        return new RealVectorImpl(x);
     }
     
     /**

Modified: commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java?rev=772114&r1=772113&r2=772114&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
(original)
+++ commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java
Wed May  6 09:40:13 2009
@@ -139,7 +139,7 @@
           new double[]{-3482258.63459582, 15.0618722713733,
                 -0.358191792925910E-01,-2.02022980381683,
                 -1.03322686717359,-0.511041056535807E-01,
-                 1829.15146461355}, 1E-8); // 
+                 1829.15146461355}, 2E-8); // 
         
         // Check expected residuals from R
         double[] residuals = model.estimateResiduals();
@@ -332,7 +332,7 @@
          */
         double[] residuals = model.estimateResiduals();
         RealMatrix I = MatrixUtils.createRealIdentityMatrix(10);
-        double[] hatResiduals = I.subtract(hat).multiply(model.Y).getColumn(0);
+        double[] hatResiduals = I.subtract(hat).operate(model.Y).getData();
         TestUtils.assertEquals(residuals, hatResiduals, 10e-12);    
     }
 }



Mime
View raw message