Author: erans Date: Wed Oct 30 21:59:57 2013 New Revision: 1537324 URL: http://svn.apache.org/r1537324 Log: MATH-1047 Added overflow checking to "ArithmeticUtils.pow(long,int)".
Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/util/ArithmeticUtils.java commons/proper/math/trunk/src/test/java/org/apache/commons/math3/util/ArithmeticUtilsTest.java Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/util/ArithmeticUtils.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/util/ArithmeticUtils.java?rev=1537324&r1=1537323&r2=1537324&view=diff ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/util/ArithmeticUtils.java (original) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/util/ArithmeticUtils.java Wed Oct 30 21:59:57 2013 @@ -706,25 +706,45 @@ public final class ArithmeticUtils { * * @param k Number to raise. * @param e Exponent (must be positive or zero). - * @return k<sup>e</sup> + * @return \( k^e \) * @throws NotPositiveException if {@code e < 0}. + * @throws MathArithmeticException if the result would overflow. */ - public static long pow(final long k, int e) throws NotPositiveException { + public static long pow(final long k, + final int e) + throws NotPositiveException, + MathArithmeticException { if (e < 0) { throw new NotPositiveException(LocalizedFormats.EXPONENT, e); } - long result = 1l; - long k2p = k; - while (e != 0) { - if ((e & 0x1) != 0) { - result *= k2p; + try { + int exp = e; + long result = 1; + long k2p = k; + while (true) { + if ((exp & 0x1) != 0) { + result = mulAndCheck(result, k2p); + } + + exp >>= 1; + if (exp == 0) { + break; + } + + k2p = mulAndCheck(k2p, k2p); } - k2p *= k2p; - e = e >> 1; - } - return result; + return result; + } catch (MathArithmeticException mae) { + // Add context information. + mae.getContext().addMessage(LocalizedFormats.OVERFLOW); + mae.getContext().addMessage(LocalizedFormats.BASE, k); + mae.getContext().addMessage(LocalizedFormats.EXPONENT, e); + + // Rethrow. + throw mae; + } } /** Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math3/util/ArithmeticUtilsTest.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/util/ArithmeticUtilsTest.java?rev=1537324&r1=1537323&r2=1537324&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/java/org/apache/commons/math3/util/ArithmeticUtilsTest.java (original) +++ commons/proper/math/trunk/src/test/java/org/apache/commons/math3/util/ArithmeticUtilsTest.java Wed Oct 30 21:59:57 2013 @@ -461,29 +461,33 @@ public class ArithmeticUtilsTest { } @Test(expected=MathArithmeticException.class) - public void testPowIntIntOverflow() { - final int base = 21; - final int exp = 8; - ArithmeticUtils.pow(base, exp); + public void testPowIntOverflow() { + ArithmeticUtils.pow(21, 8); } + @Test - public void testPowIntIntNoOverflow() { + public void testPowInt() { final int base = 21; - final int exp = 7; - ArithmeticUtils.pow(base, exp); + + Assert.assertEquals(85766121L, + ArithmeticUtils.pow(base, 6)); + Assert.assertEquals(1801088541L, + ArithmeticUtils.pow(base, 7)); } @Test(expected=MathArithmeticException.class) - public void testPowNegativeIntIntOverflow() { - final int base = -21; - final int exp = 8; - ArithmeticUtils.pow(base, exp); + public void testPowNegativeIntOverflow() { + ArithmeticUtils.pow(-21, 8); } + @Test - public void testPowNegativeIntIntNoOverflow() { + public void testPowNegativeInt() { final int base = -21; - final int exp = 7; - ArithmeticUtils.pow(base, exp); + + Assert.assertEquals(85766121, + ArithmeticUtils.pow(base, 6)); + Assert.assertEquals(-1801088541, + ArithmeticUtils.pow(base, 7)); } @Test @@ -504,6 +508,54 @@ public class ArithmeticUtilsTest { } } + @Test(expected=MathArithmeticException.class) + public void testPowLongOverflow() { + ArithmeticUtils.pow(21, 15); + } + + @Test + public void testPowLong() { + final long base = 21; + + Assert.assertEquals(154472377739119461L, + ArithmeticUtils.pow(base, 13)); + Assert.assertEquals(3243919932521508681L, + ArithmeticUtils.pow(base, 14)); + } + + @Test(expected=MathArithmeticException.class) + public void testPowNegativeLongOverflow() { + ArithmeticUtils.pow(-21L, 15); + } + + @Test + public void testPowNegativeLong() { + final long base = -21; + + Assert.assertEquals(-154472377739119461L, + ArithmeticUtils.pow(base, 13)); + Assert.assertEquals(3243919932521508681L, + ArithmeticUtils.pow(base, 14)); + } + + @Test + public void testPowMinusOneLong() { + final long base = -1; + for (int i = 0; i < 100; i++) { + final long pow = ArithmeticUtils.pow(base, i); + Assert.assertEquals("i: " + i, i % 2 == 0 ? 1 : -1, pow); + } + } + + @Test + public void testPowOneLong() { + final long base = 1; + for (int i = 0; i < 100; i++) { + final long pow = ArithmeticUtils.pow(base, i); + Assert.assertEquals("i: " + i, 1, pow); + } + } + @Test public void testIsPowerOfTwo() { final int n = 1025;