commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r731307 - /commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java
Date Sun, 04 Jan 2009 18:08:18 GMT
Author: luc
Date: Sun Jan  4 10:08:18 2009
New Revision: 731307

URL: http://svn.apache.org/viewvc?rev=731307&view=rev
Log:
fixed a dimension error with under-determined problems
removed IllegalStateException
create a DenseRealMatrix when solving A.X = B

Modified:
    commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java

Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java?rev=731307&r1=731306&r2=731307&view=diff
==============================================================================
--- commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java
(original)
+++ commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java
Sun Jan  4 10:08:18 2009
@@ -17,6 +17,10 @@
 
 package org.apache.commons.math.linear;
 
+import java.util.Arrays;
+
+import org.apache.commons.math.MathRuntimeException;
+
 
 /**
  * Calculates the QR-decomposition of a matrix.
@@ -73,7 +77,7 @@
         final int m = matrix.getRowDimension();
         final int n = matrix.getColumnDimension();
         qrt = matrix.transpose().getData();
-        rDiag = new double[n];
+        rDiag = new double[Math.min(m, n)];
         cachedQ  = null;
         cachedQT = null;
         cachedR  = null;
@@ -170,8 +174,7 @@
     }
 
     /** {@inheritDoc} */
-    public RealMatrix getQ()
-        throws IllegalStateException {
+    public RealMatrix getQ() {
         if (cachedQ == null) {
             cachedQ = getQT().transpose();
         }
@@ -179,8 +182,7 @@
     }
 
     /** {@inheritDoc} */
-    public RealMatrix getQT()
-        throws IllegalStateException {
+    public RealMatrix getQT() {
 
         if (cachedQT == null) {
 
@@ -224,8 +226,7 @@
     }
 
     /** {@inheritDoc} */
-    public RealMatrix getH()
-        throws IllegalStateException {
+    public RealMatrix getH() {
 
         if (cachedH == null) {
 
@@ -278,8 +279,7 @@
         }
 
         /** {@inheritDoc} */
-        public boolean isNonSingular()
-        throws IllegalStateException {
+        public boolean isNonSingular() {
 
             for (double diag : rDiag) {
                 if (diag == 0) {
@@ -292,12 +292,14 @@
 
         /** {@inheritDoc} */
         public double[] solve(double[] b)
-        throws IllegalStateException, IllegalArgumentException, InvalidMatrixException {
+        throws IllegalArgumentException, InvalidMatrixException {
 
             final int n = qrt.length;
             final int m = qrt[0].length;
             if (b.length != m) {
-                throw new IllegalArgumentException("Incorrect row dimension");
+                throw MathRuntimeException.createIllegalArgumentException(
+                        "vector length mismatch: got {0} but expected {1}",
+                        new Object[] { b.length, m });
             }
             if (!isNonSingular()) {
                 throw new SingularMatrixException();
@@ -323,7 +325,7 @@
             }
 
             // solve triangular system R.x = y
-            for (int row = n - 1; row >= 0; --row) {
+            for (int row = rDiag.length - 1; row >= 0; --row) {
                 y[row] /= rDiag[row];
                 final double yRow   = y[row];
                 final double[] qrtRow = qrt[row];
@@ -339,7 +341,7 @@
 
         /** {@inheritDoc} */
         public RealVector solve(RealVector b)
-        throws IllegalStateException, IllegalArgumentException, InvalidMatrixException {
+        throws IllegalArgumentException, InvalidMatrixException {
             try {
                 return solve((RealVectorImpl) b);
             } catch (ClassCastException cce) {
@@ -351,76 +353,103 @@
          * <p>The A matrix is implicit here. It is </p>
          * @param b right-hand side of the equation A &times; X = B
          * @return a vector X that minimizes the two norm of A &times; X - B
-         * @exception IllegalStateException if {@link #decompose(RealMatrix) decompose}
-         * has not been called
          * @throws IllegalArgumentException if matrices dimensions don't match
          * @throws InvalidMatrixException if decomposed matrix is singular
          */
         public RealVectorImpl solve(RealVectorImpl b)
-        throws IllegalStateException, IllegalArgumentException, InvalidMatrixException {
+        throws IllegalArgumentException, InvalidMatrixException {
             return new RealVectorImpl(solve(b.getDataRef()), false);
         }
 
         /** {@inheritDoc} */
         public RealMatrix solve(RealMatrix b)
-        throws IllegalStateException, IllegalArgumentException, InvalidMatrixException {
+        throws IllegalArgumentException, InvalidMatrixException {
 
             final int n = qrt.length;
             final int m = qrt[0].length;
             if (b.getRowDimension() != m) {
-                throw new IllegalArgumentException("Incorrect row dimension");
+                throw MathRuntimeException.createIllegalArgumentException(
+                        "dimensions mismatch: got {0}x{1} but expected {2}x{3}",
+                        new Object[] { b.getRowDimension(), b.getColumnDimension(), m, "n"});
             }
             if (!isNonSingular()) {
                 throw new SingularMatrixException();
             }
 
-            final int cols = b.getColumnDimension();
-            final double[][] xData = new double[n][cols];
-            final double[] y = new double[b.getRowDimension()];
-
-            for (int k = 0; k < cols; ++k) {
+            final int columns        = b.getColumnDimension();
+            final int blockSize      = DenseRealMatrix.BLOCK_SIZE;
+            final int cBlocks        = (columns + blockSize - 1) / blockSize;
+            final double[][] xBlocks = DenseRealMatrix.createBlocksLayout(n, columns);
+            final double[][] y       = new double[b.getRowDimension()][blockSize];
+            final double[]   alpha   = new double[blockSize];
+
+            for (int kBlock = 0; kBlock < cBlocks; ++kBlock) {
+                final int kStart = kBlock * blockSize;
+                final int kEnd   = Math.min(kStart + blockSize, columns);
+                final int kWidth = kEnd - kStart;
 
                 // get the right hand side vector
-                for (int j = 0; j < y.length; ++j) {
-                    y[j] = b.getEntry(j, k);
-                }
+                b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y);
 
                 // apply Householder transforms to solve Q.y = b
                 for (int minor = 0; minor < Math.min(m, n); minor++) {
-
                     final double[] qrtMinor = qrt[minor];
-                    double dotProduct = 0;
-                    for (int row = minor; row < m; row++) {
-                        dotProduct += y[row] * qrtMinor[row];
+                    final double factor     = 1.0 / (rDiag[minor] * qrtMinor[minor]); 
+
+                    Arrays.fill(alpha, 0, kWidth, 0.0);
+                    for (int row = minor; row < m; ++row) {
+                        final double   d    = qrtMinor[row];
+                        final double[] yRow = y[row];
+                        for (int k = 0; k < kWidth; ++k) {
+                            alpha[k] += d * yRow[k];
+                        }
+                    }
+                    for (int k = 0; k < kWidth; ++k) {
+                        alpha[k] *= factor;
                     }
-                    dotProduct /= rDiag[minor] * qrtMinor[minor];
 
-                    for (int row = minor; row < m; row++) {
-                        y[row] += dotProduct * qrtMinor[row];
+                    for (int row = minor; row < m; ++row) {
+                        final double   d    = qrtMinor[row];
+                        final double[] yRow = y[row];
+                        for (int k = 0; k < kWidth; ++k) {
+                            yRow[k] += alpha[k] * d;
+                        }
                     }
 
                 }
 
                 // solve triangular system R.x = y
-                for (int row = n - 1; row >= 0; --row) {
-                    y[row] /= rDiag[row];
-                    final double yRow = y[row];
-                    final double[] qrtRow = qrt[row];
-                    xData[row][k] = yRow;
-                    for (int i = 0; i < row; i++) {
-                        y[i] -= yRow * qrtRow[i];
+                for (int j = rDiag.length - 1; j >= 0; --j) {
+                    final int      jBlock = j / blockSize;
+                    final int      jStart = jBlock * blockSize;
+                    final double   factor = 1.0 / rDiag[j];
+                    final double[] yJ     = y[j];
+                    final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock];
+                    for (int k = 0, index = (j - jStart) * kWidth; k < kWidth; ++k, ++index)
{
+                        yJ[k]        *= factor;
+                        xBlock[index] = yJ[k];
                     }
+
+                    final double[] qrtJ = qrt[j];
+                    for (int i = 0; i < j; ++i) {
+                        final double rIJ  = qrtJ[i];
+                        final double[] yI = y[i];
+                        for (int k = 0; k < kWidth; ++k) {
+                            yI[k] -= yJ[k] * rIJ;
+                        }
+                    }
+
                 }
 
             }
 
-            return new RealMatrixImpl(xData, false);
+            return new DenseRealMatrix(n, columns, xBlocks, false);
 
         }
 
         /** {@inheritDoc} */
         public RealMatrix getInverse()
-        throws IllegalStateException, InvalidMatrixException {
+        throws InvalidMatrixException {
             return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length));
         }
 



Mime
View raw message