Author: luc Date: Sun Jan 4 10:08:18 2009 New Revision: 731307 URL: http://svn.apache.org/viewvc?rev=731307&view=rev Log: fixed a dimension error with under-determined problems removed IllegalStateException create a DenseRealMatrix when solving A.X = B
Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java?rev=731307&r1=731306&r2=731307&view=diff ============================================================================== --- commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java (original) +++ commons/proper/math/trunk/src/java/org/apache/commons/math/linear/QRDecompositionImpl.java Sun Jan 4 10:08:18 2009 @@ -17,6 +17,10 @@ package org.apache.commons.math.linear; +import java.util.Arrays; + +import org.apache.commons.math.MathRuntimeException; + /** * Calculates the QR-decomposition of a matrix. @@ -73,7 +77,7 @@ final int m = matrix.getRowDimension(); final int n = matrix.getColumnDimension(); qrt = matrix.transpose().getData(); - rDiag = new double[n]; + rDiag = new double[Math.min(m, n)]; cachedQ = null; cachedQT = null; cachedR = null; @@ -170,8 +174,7 @@ } /** {...@inheritdoc} */ - public RealMatrix getQ() - throws IllegalStateException { + public RealMatrix getQ() { if (cachedQ == null) { cachedQ = getQT().transpose(); } @@ -179,8 +182,7 @@ } /** {...@inheritdoc} */ - public RealMatrix getQT() - throws IllegalStateException { + public RealMatrix getQT() { if (cachedQT == null) { @@ -224,8 +226,7 @@ } /** {...@inheritdoc} */ - public RealMatrix getH() - throws IllegalStateException { + public RealMatrix getH() { if (cachedH == null) { @@ -278,8 +279,7 @@ } /** {...@inheritdoc} */ - public boolean isNonSingular() - throws IllegalStateException { + public boolean isNonSingular() { for (double diag : rDiag) { if (diag == 0) { @@ -292,12 +292,14 @@ /** {...@inheritdoc} */ public double[] solve(double[] b) - throws IllegalStateException, IllegalArgumentException, InvalidMatrixException { + throws IllegalArgumentException, InvalidMatrixException { final int n = qrt.length; final int m = qrt[0].length; if (b.length != m) { - throw new IllegalArgumentException("Incorrect row dimension"); + throw MathRuntimeException.createIllegalArgumentException( + "vector length mismatch: got {0} but expected {1}", + new Object[] { b.length, m }); } if (!isNonSingular()) { throw new SingularMatrixException(); @@ -323,7 +325,7 @@ } // solve triangular system R.x = y - for (int row = n - 1; row >= 0; --row) { + for (int row = rDiag.length - 1; row >= 0; --row) { y[row] /= rDiag[row]; final double yRow = y[row]; final double[] qrtRow = qrt[row]; @@ -339,7 +341,7 @@ /** {...@inheritdoc} */ public RealVector solve(RealVector b) - throws IllegalStateException, IllegalArgumentException, InvalidMatrixException { + throws IllegalArgumentException, InvalidMatrixException { try { return solve((RealVectorImpl) b); } catch (ClassCastException cce) { @@ -351,76 +353,103 @@ * <p>The A matrix is implicit here. It is </p> * @param b right-hand side of the equation A × X = B * @return a vector X that minimizes the two norm of A × X - B - * @exception IllegalStateException if {...@link #decompose(RealMatrix) decompose} - * has not been called * @throws IllegalArgumentException if matrices dimensions don't match * @throws InvalidMatrixException if decomposed matrix is singular */ public RealVectorImpl solve(RealVectorImpl b) - throws IllegalStateException, IllegalArgumentException, InvalidMatrixException { + throws IllegalArgumentException, InvalidMatrixException { return new RealVectorImpl(solve(b.getDataRef()), false); } /** {...@inheritdoc} */ public RealMatrix solve(RealMatrix b) - throws IllegalStateException, IllegalArgumentException, InvalidMatrixException { + throws IllegalArgumentException, InvalidMatrixException { final int n = qrt.length; final int m = qrt[0].length; if (b.getRowDimension() != m) { - throw new IllegalArgumentException("Incorrect row dimension"); + throw MathRuntimeException.createIllegalArgumentException( + "dimensions mismatch: got {0}x{1} but expected {2}x{3}", + new Object[] { b.getRowDimension(), b.getColumnDimension(), m, "n"}); } if (!isNonSingular()) { throw new SingularMatrixException(); } - final int cols = b.getColumnDimension(); - final double[][] xData = new double[n][cols]; - final double[] y = new double[b.getRowDimension()]; - - for (int k = 0; k < cols; ++k) { + final int columns = b.getColumnDimension(); + final int blockSize = DenseRealMatrix.BLOCK_SIZE; + final int cBlocks = (columns + blockSize - 1) / blockSize; + final double[][] xBlocks = DenseRealMatrix.createBlocksLayout(n, columns); + final double[][] y = new double[b.getRowDimension()][blockSize]; + final double[] alpha = new double[blockSize]; + + for (int kBlock = 0; kBlock < cBlocks; ++kBlock) { + final int kStart = kBlock * blockSize; + final int kEnd = Math.min(kStart + blockSize, columns); + final int kWidth = kEnd - kStart; // get the right hand side vector - for (int j = 0; j < y.length; ++j) { - y[j] = b.getEntry(j, k); - } + b.copySubMatrix(0, m - 1, kStart, kEnd - 1, y); // apply Householder transforms to solve Q.y = b for (int minor = 0; minor < Math.min(m, n); minor++) { - final double[] qrtMinor = qrt[minor]; - double dotProduct = 0; - for (int row = minor; row < m; row++) { - dotProduct += y[row] * qrtMinor[row]; + final double factor = 1.0 / (rDiag[minor] * qrtMinor[minor]); + + Arrays.fill(alpha, 0, kWidth, 0.0); + for (int row = minor; row < m; ++row) { + final double d = qrtMinor[row]; + final double[] yRow = y[row]; + for (int k = 0; k < kWidth; ++k) { + alpha[k] += d * yRow[k]; + } + } + for (int k = 0; k < kWidth; ++k) { + alpha[k] *= factor; } - dotProduct /= rDiag[minor] * qrtMinor[minor]; - for (int row = minor; row < m; row++) { - y[row] += dotProduct * qrtMinor[row]; + for (int row = minor; row < m; ++row) { + final double d = qrtMinor[row]; + final double[] yRow = y[row]; + for (int k = 0; k < kWidth; ++k) { + yRow[k] += alpha[k] * d; + } } } // solve triangular system R.x = y - for (int row = n - 1; row >= 0; --row) { - y[row] /= rDiag[row]; - final double yRow = y[row]; - final double[] qrtRow = qrt[row]; - xData[row][k] = yRow; - for (int i = 0; i < row; i++) { - y[i] -= yRow * qrtRow[i]; + for (int j = rDiag.length - 1; j >= 0; --j) { + final int jBlock = j / blockSize; + final int jStart = jBlock * blockSize; + final double factor = 1.0 / rDiag[j]; + final double[] yJ = y[j]; + final double[] xBlock = xBlocks[jBlock * cBlocks + kBlock]; + for (int k = 0, index = (j - jStart) * kWidth; k < kWidth; ++k, ++index) { + yJ[k] *= factor; + xBlock[index] = yJ[k]; } + + final double[] qrtJ = qrt[j]; + for (int i = 0; i < j; ++i) { + final double rIJ = qrtJ[i]; + final double[] yI = y[i]; + for (int k = 0; k < kWidth; ++k) { + yI[k] -= yJ[k] * rIJ; + } + } + } } - return new RealMatrixImpl(xData, false); + return new DenseRealMatrix(n, columns, xBlocks, false); } /** {...@inheritdoc} */ public RealMatrix getInverse() - throws IllegalStateException, InvalidMatrixException { + throws InvalidMatrixException { return solve(MatrixUtils.createRealIdentityMatrix(rDiag.length)); }