Author: psteitz Date: Mon Aug 23 02:55:01 2010 New Revision: 987983 URL: http://svn.apache.org/viewvc?rev=987983&view=rev Log: Added R-squared and adjusted R-squared statistics to OLSMultipleLinearRegression JIRA: MATH-386
Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java commons/proper/math/trunk/src/site/xdoc/changes.xml commons/proper/math/trunk/src/test/R/multipleOLSRegressionTestCases commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java?rev=987983&r1=987982&r2=987983&view=diff ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java (original) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegression.java Mon Aug 23 02:55:01 2010 @@ -22,6 +22,7 @@ import org.apache.commons.math.linear.QR import org.apache.commons.math.linear.QRDecompositionImpl; import org.apache.commons.math.linear.RealMatrix; import org.apache.commons.math.linear.RealVector; +import org.apache.commons.math.stat.descriptive.moment.SecondMoment; /** * <p>Implements ordinary least squares (OLS) to estimate the parameters of a @@ -122,6 +123,55 @@ public class OLSMultipleLinearRegression } /** + * Returns the sum of squared deviations of Y from its mean. + * + * @return total sum of squares + */ + public double calculateTotalSumOfSquares() { + return new SecondMoment().evaluate(Y.getData()); + } + + /** + * Returns the sum of square residuals. + * + * @return residual sum of squares + */ + public double calculateResidualSumOfSquares() { + final RealVector residuals = calculateResiduals(); + return residuals.dotProduct(residuals); + } + + /** + * Returns the R-Squared statistic, defined by the formula <pre> + * R<sup>2</sup> = 1 - SSR / SSTO + * </pre> + * where SSR is the {...@link #calculateResidualSumOfSquares() sum of squared residuals} + * and SSTO is the {...@link #calculateTotalSumOfSquares() total sum of squares} + * + * @return R-square statistic + */ + public double calculateRSquared() { + return 1 - calculateResidualSumOfSquares() / calculateTotalSumOfSquares(); + } + + /** + * Returns the adjusted R-squared statistic, defined by the formula <pre> + * R<sup>2</sup><sub>adj</sub> = 1 - [SSR (n - 1)] / [SSTO (n - p)] + * </pre> + * where SSR is the {...@link #calculateResidualSumOfSquares() sum of squared residuals}, + * SSTO is the {...@link #calculateTotalSumOfSquares() total sum of squares}, n is the number + * of observations and p is the number of parameters estimated (including the intercept). + * + * @return adjusted R-Squared statistic + */ + public double calculateAdjustedRSquared() { + final double n = X.getRowDimension(); + return 1 - (calculateResidualSumOfSquares() * (n - 1)) / + (calculateTotalSumOfSquares() * (n - X.getColumnDimension())); + // return 1 - ((1 - calculateRSquare()) * (n - 1) / (n - X.getColumnDimension() - 1)); + } + + /** * Loads new x sample data, overriding any previous sample * * @param x the [n,k] array representing the x sample Modified: commons/proper/math/trunk/src/site/xdoc/changes.xml URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/site/xdoc/changes.xml?rev=987983&r1=987982&r2=987983&view=diff ============================================================================== --- commons/proper/math/trunk/src/site/xdoc/changes.xml (original) +++ commons/proper/math/trunk/src/site/xdoc/changes.xml Mon Aug 23 02:55:01 2010 @@ -52,6 +52,9 @@ The <action> type attribute can be add,u If the output is not quite correct, check for invisible trailing spaces! --> <release version="2.2" date="TBD" description="TBD"> + <action dev="psteitz" type="fix" issue="MATH-386"> + Added R-squared and adjusted R-squared statistics to OLSMultipleLinearRegression. + </action> <action dev="psteitz" type="fix" issue="MATH-392" due-to="Mark Devaney"> Corrected the formula used for Y variance returned by calculateYVariance and associated methods in multiple regression classes (AbstractMultipleLinearRegression, Modified: commons/proper/math/trunk/src/test/R/multipleOLSRegressionTestCases URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/R/multipleOLSRegressionTestCases?rev=987983&r1=987982&r2=987983&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/R/multipleOLSRegressionTestCases (original) +++ commons/proper/math/trunk/src/test/R/multipleOLSRegressionTestCases Mon Aug 23 02:55:01 2010 @@ -32,11 +32,13 @@ options(digits=16) # override # function to verify OLS computations verifyRegression <- function(model, expectedBeta, expectedResiduals, - expectedErrors, expectedStdError, modelName) { + expectedErrors, expectedStdError, expectedRSquare, expecteAdjRSquare, modelName) { betaHat <- as.vector(coefficients(model)) residuals <- as.vector(residuals(model)) errors <- as.vector(as.matrix(coefficients(summary(model)))[,2]) stdError <- summary(model)$sigma + rSquare <- summary(model)$r.squared + adjRSquare <- summary(model)$adj.r.squared output <- c("Parameter test dataset = ", modelName) if (assertEquals(expectedBeta,betaHat,tol,"Parameters")) { displayPadded(output, SUCCEEDED, WIDTH) @@ -61,6 +63,18 @@ verifyRegression <- function(model, expe } else { displayPadded(output, FAILED, WIDTH) } + output <- c("RSquared test dataset = ", modelName) + if (assertEquals(expectedRSquare,rSquare,tol,"RSquared")) { + displayPadded(output, SUCCEEDED, WIDTH) + } else { + displayPadded(output, FAILED, WIDTH) + } + output <- c("Adjusted RSquared test dataset = ", modelName) + if (assertEquals(expecteAdjRSquare,adjRSquare,tol,"Adjusted RSquared")) { + displayPadded(output, SUCCEEDED, WIDTH) + } else { + displayPadded(output, FAILED, WIDTH) + } } #-------------------------------------------------------------------------- @@ -78,8 +92,10 @@ expectedBeta <- c(11.0,0.5,0.66666666666 expectedResiduals <- c(0,0,0,0,0,0) expectedErrors <- c(NaN,NaN,NaN,NaN,NaN,NaN) expectedStdError <- NaN -verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, - "perfect fit") +expectedRSquare <- 1 +expectedAdjRSquare <- NaN +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, + expectedStdError, expectedRSquare, expectedAdjRSquare, "perfect fit") # Longly # @@ -134,9 +150,11 @@ expectedResiduals <- c( 267.340029759711 -39.0550425226967,-155.5499735953195,-85.6713080421283,341.9315139607727, -206.7578251937366) expectedStdError <- 304.8540735619638 +expectedRSquare <- 0.995479004577296 +expectedAdjRSquare <- 0.992465007628826 -verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, - "Longly") +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, +expectedStdError, expectedRSquare, expectedAdjRSquare, "Longly") # Swiss Fertility (R dataset named "swiss") @@ -225,9 +243,11 @@ expectedResiduals <- c(7.104426785973051 -0.4515205619767598,-10.2916870903837587,-15.7812984571900063) expectedStdError <- 7.73642194433223 +expectedRSquare <- 0.649789742860228 +expectedAdjRSquare <- 0.6164363850373927 -verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, expectedStdError, - "Swiss Fertility") +verifyRegression(model, expectedBeta, expectedResiduals, expectedErrors, +expectedStdError, expectedRSquare, expectedAdjRSquare, "Swiss Fertility") displayDashes(WIDTH) Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java?rev=987983&r1=987982&r2=987983&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java (original) +++ commons/proper/math/trunk/src/test/java/org/apache/commons/math/stat/regression/OLSMultipleLinearRegressionTest.java Mon Aug 23 02:55:01 2010 @@ -109,7 +109,7 @@ public class OLSMultipleLinearRegression assertEquals(0.0, errors.subtract(referenceVariance).getNorm(), 5.0e-16 * referenceVariance.getNorm()); - + assertEquals(1, ((OLSMultipleLinearRegression) regression).calculateRSquared(), 1E-12); } @@ -186,6 +186,10 @@ public class OLSMultipleLinearRegression // Check regression standard error against R assertEquals(304.8540735619638, model.estimateRegressionStandardError(), 1E-10); + // Check R-Square statistics against R + assertEquals(0.995479004577296, model.calculateRSquared(), 1E-12); + assertEquals(0.992465007628826, model.calculateAdjustedRSquared(), 1E-12); + checkVarianceConsistency(model); } @@ -294,6 +298,10 @@ public class OLSMultipleLinearRegression // Check regression standard error against R assertEquals(7.73642194433223, model.estimateRegressionStandardError(), 1E-12); + // Check R-Square statistics against R + assertEquals(0.649789742860228, model.calculateRSquared(), 1E-12); + assertEquals(0.6164363850373927, model.calculateAdjustedRSquared(), 1E-12); + checkVarianceConsistency(model); }