Author: luc Date: Wed May 6 09:40:13 2009 New Revision: 772114 URL: http://svn.apache.org/viewvc?rev=772114&view=rev Log: replaced matrix by vector where possible
Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java?rev=772114&r1=772113&r2=772114&view=diff ============================================================================== --- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java (original) +++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/GLSMultipleLinearRegression.java Wed May 6 09:40:13 2009 @@ -18,6 +18,7 @@ import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.RealMatrixImpl; +import org.apache.commons.math.linear.RealVector; import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; @@ -91,12 +92,12 @@ * @return beta */ @Override - protected RealMatrix calculateBeta() { + protected RealVector calculateBeta() { RealMatrix OI = getOmegaInverse(); RealMatrix XT = X.transpose(); RealMatrix XTOIX = XT.multiply(OI).multiply(X); RealMatrix inverse = new LUDecompositionImpl(XTOIX).getSolver().getInverse(); - return inverse.multiply(XT).multiply(OI).multiply(Y); + return inverse.multiply(XT).multiply(OI).operate(Y); } /** @@ -122,9 +123,9 @@ */ @Override protected double calculateYVariance() { - RealMatrix u = calculateResiduals(); - RealMatrix sse = u.transpose().multiply(getOmegaInverse()).multiply(u); - return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension()); + RealVector residuals = calculateResiduals(); + double t = residuals.dotProduct(getOmegaInverse().operate(residuals)); + return t / (X.getRowDimension() - X.getColumnDimension()); } } Modified: commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java?rev=772114&r1=772113&r2=772114&view=diff ============================================================================== --- commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java (original) +++ commons/proper/math/trunk/src/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java Wed May 6 09:40:13 2009 @@ -18,6 +18,8 @@ import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.RealMatrixImpl; +import org.apache.commons.math.linear.RealVector; +import org.apache.commons.math.linear.RealVectorImpl; import org.apache.commons.math.linear.decomposition.LUDecompositionImpl; import org.apache.commons.math.linear.decomposition.QRDecomposition; import org.apache.commons.math.linear.decomposition.QRDecompositionImpl; @@ -137,8 +139,8 @@ * @return beta */ @Override - protected RealMatrix calculateBeta() { - return solveUpperTriangular(qr.getR(), qr.getQ().transpose().multiply(Y)); + protected RealVector calculateBeta() { + return solveUpperTriangular(qr.getR(), qr.getQ().transpose().operate(Y)); } /** @@ -170,9 +172,9 @@ */ @Override protected double calculateYVariance() { - RealMatrix u = calculateResiduals(); - RealMatrix sse = u.transpose().multiply(u); - return sse.getTrace()/(X.getRowDimension()-X.getColumnDimension()); + RealVector residuals = calculateResiduals(); + return residuals.dotProduct(residuals) / + (X.getRowDimension() - X.getColumnDimension()); } /** TODO: Find a home for the following methods in the linear package */ @@ -191,20 +193,16 @@ * Similarly, extra (zero) rows in coefficients are ignored</p> * * @param coefficients upper-triangular coefficients matrix - * @param constants column RHS constants matrix - * @return solution matrix as a column matrix + * @param constants column RHS constants vector + * @return solution matrix as a column vector * */ - private static RealMatrix solveUpperTriangular(RealMatrix coefficients, - RealMatrix constants) { + private static RealVector solveUpperTriangular(RealMatrix coefficients, + RealVector constants) { if (!isUpperTriangular(coefficients, 1E-12)) { throw new IllegalArgumentException( "Coefficients is not upper-triangular"); } - if (constants.getColumnDimension() != 1) { - throw new IllegalArgumentException( - "Constants not a column matrix."); - } int length = coefficients.getColumnDimension(); double x[] = new double[length]; for (int i = 0; i < length; i++) { @@ -213,9 +211,9 @@ for (int j = index + 1; j < length; j++) { sum += coefficients.getEntry(index, j) * x[j]; } - x[index] = (constants.getEntry(index, 0) - sum) / coefficients.getEntry(index, index); + x[index] = (constants.getEntry(index) - sum) / coefficients.getEntry(index, index); } - return new RealMatrixImpl(x); + return new RealVectorImpl(x); } /** Modified: commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java?rev=772114&r1=772113&r2=772114&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java (original) +++ commons/proper/math/trunk/src/test/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java Wed May 6 09:40:13 2009 @@ -139,7 +139,7 @@ new double[]{-3482258.63459582, 15.0618722713733, -0.358191792925910E-01,-2.02022980381683, -1.03322686717359,-0.511041056535807E-01, - 1829.15146461355}, 1E-8); // + 1829.15146461355}, 2E-8); // // Check expected residuals from R double[] residuals = model.estimateResiduals(); @@ -332,7 +332,7 @@ */ double[] residuals = model.estimateResiduals(); RealMatrix I = MatrixUtils.createRealIdentityMatrix(10); - double[] hatResiduals = I.subtract(hat).multiply(model.Y).getColumn(0); + double[] hatResiduals = I.subtract(hat).operate(model.Y).getData(); TestUtils.assertEquals(residuals, hatResiduals, 10e-12); } }