This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-numbers.git
commit 2ec97e42be39f410a06a3ba9c60f89ddea65614c Author: Alex Herbert <aherb...@apache.org> AuthorDate: Mon Nov 7 17:25:35 2022 +0000 Numbers-191: Compute Stirling number of the first kind --- .../commons/numbers/combinatorics/Stirling.java | 174 ++++++++++++++--- .../numbers/combinatorics/StirlingTest.java | 211 +++++++++++++++++++-- 2 files changed, 343 insertions(+), 42 deletions(-) diff --git a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java index d5250fc4..2d301eae 100644 --- a/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java +++ b/commons-numbers-combinatorics/src/main/java/org/apache/commons/numbers/combinatorics/Stirling.java @@ -23,11 +23,47 @@ package org.apache.commons.numbers.combinatorics; * @since 1.2 */ public final class Stirling { + /** Stirling S1 error message. */ + private static final String S1_ERROR_FORMAT = "s(n=%d, k=%d)"; /** Stirling S2 error message. */ private static final String S2_ERROR_FORMAT = "S(n=%d, k=%d)"; + /** Overflow threshold for n when computing s(n, 1). */ + private static final int S1_OVERFLOW_K_EQUALS_1 = 21; + /** Overflow threshold for n when computing s(n, n-2). */ + private static final int S1_OVERFLOW_K_EQUALS_NM2 = 92682; + /** Overflow threshold for n when computing s(n, n-3). */ + private static final int S1_OVERFLOW_K_EQUALS_NM3 = 2761; /** Overflow threshold for n when computing S(n, n-2). */ private static final int S2_OVERFLOW_K_EQUALS_NM2 = 92683; + /** + * Precomputed Stirling numbers of the first kind. + * Provides a thread-safe lazy initialization of the cache. + */ + private static class StirlingS1Cache { + /** Maximum n to compute (exclusive). + * As s(21,3) = 13803759753640704000 is larger than Long.MAX_VALUE + * we must stop computation at row 21. */ + static final int MAX_N = 21; + /** Stirling numbers of the first kind. */ + static final long[][] S1; + + static { + S1 = new long[MAX_N][]; + // Initialise first two rows to allow s(2, 1) to use s(1, 1) + S1[0] = new long[] {1}; + S1[1] = new long[] {0, 1}; + for (int n = 2; n < S1.length; n++) { + S1[n] = new long[n + 1]; + S1[n][0] = 0; + S1[n][n] = 1; + for (int k = 1; k < n; k++) { + S1[n][k] = S1[n - 1][k - 1] - (n - 1) * S1[n - 1][k]; + } + } + } + } + /** * Precomputed Stirling numbers of the second kind. * Provides a thread-safe lazy initialization of the cache. @@ -38,18 +74,18 @@ public final class Stirling { * we must stop computation at row 26. */ static final int MAX_N = 26; /** Stirling numbers of the second kind. */ - static final long[][] STIRLING_S2; + static final long[][] S2; static { - STIRLING_S2 = new long[MAX_N][]; - STIRLING_S2[0] = new long[] {1}; - for (int n = 1; n < STIRLING_S2.length; n++) { - STIRLING_S2[n] = new long[n + 1]; - STIRLING_S2[n][0] = 0; - STIRLING_S2[n][1] = 1; - STIRLING_S2[n][n] = 1; + S2 = new long[MAX_N][]; + S2[0] = new long[] {1}; + for (int n = 1; n < S2.length; n++) { + S2[n] = new long[n + 1]; + S2[n][0] = 0; + S2[n][1] = 1; + S2[n][n] = 1; for (int k = 2; k < n; k++) { - STIRLING_S2[n][k] = k * STIRLING_S2[n - 1][k] + STIRLING_S2[n - 1][k - 1]; + S2[n][k] = k * S2[n - 1][k] + S2[n - 1][k - 1]; } } } @@ -60,6 +96,81 @@ public final class Stirling { // intentionally empty. } + /** + * Returns the <em>signed</em> <a + * href="https://mathworld.wolfram.com/StirlingNumberoftheFirstKind.html"> + * Stirling number of the first kind</a>, "{@code s(n,k)}". The number of permutations of + * {@code n} elements which contain exactly {@code k} permutation cycles is the + * nonnegative number: {@code |s(n,k)| = (-1)^(n-k) s(n,k)} + * + * @param n Size of the set + * @param k Number of permutation cycles ({@code 0 <= k <= n}) + * @return {@code s(n,k)} + * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}. + * @throws ArithmeticException if some overflow happens, typically for n exceeding 20 + * (s(n,n-1) is handled specifically and does not overflow) + */ + public static long stirlingS1(int n, int k) { + checkArguments(n, k); + + if (n < StirlingS1Cache.MAX_N) { + // The number is in the small cache + return StirlingS1Cache.S1[n][k]; + } + + // Simple cases + // https://en.wikipedia.org/wiki/Stirling_numbers_of_the_first_kind#Simple_identities + if (k == 0) { + return 0; + } else if (k == n) { + return 1; + } else if (k == 1) { + checkN(n, k, S1_OVERFLOW_K_EQUALS_1, S1_ERROR_FORMAT); + // Note: Only occurs for n=21 so avoid computing the sign with pow(-1, n-1) * (n-1)! + return Factorial.value(n - 1); + } else if (k == n - 1) { + return -BinomialCoefficient.value(n, 2); + } else if (k == n - 2) { + checkN(n, k, S1_OVERFLOW_K_EQUALS_NM2, S1_ERROR_FORMAT); + // (3n-1) * binom(n, 3) / 4 + final long a = 3L * n - 1; + final long b = BinomialCoefficient.value(n, 3); + // Compute (a*b/4) without intermediate overflow. + // The product (a*b) must be an exact multiple of 4. + // Conditional branch on b which is typically large and even (a is 50% even) + // If b is even: ((b/2) * a) / 2 + // If b is odd then a must be even to make a*b even: ((a/2) * b) / 2 + return (b & 1) == 0 ? ((b >>> 1) * a) >>> 1 : ((a >>> 1) * b) >>> 1; + } else if (k == n - 3) { + checkN(n, k, S1_OVERFLOW_K_EQUALS_NM3, S1_ERROR_FORMAT); + return -BinomialCoefficient.value(n, 2) * BinomialCoefficient.value(n, 4); + } + + // Compute using: + // s(n + 1, k) = s(n, k - 1) - n * s(n, k) + // s(n, k) = s(n - 1, k - 1) - (n - 1) * s(n - 1, k) + + // n >= 21 (MAX_N) + // 2 <= k <= n-4 + + // Start at the largest easily computed value: n < MAX_N or k < 2 + final int reduction = Math.min(n - StirlingS1Cache.MAX_N, k - 2) + 1; + int n0 = n - reduction; + int k0 = k - reduction; + + long sum = stirlingS1(n0, k0); + while (n0 < n) { + k0++; + sum = Math.subtractExact( + sum, + Math.multiplyExact(n0, stirlingS1(n0, k0)) + ); + n0++; + } + + return sum; + } + /** * Returns the <a * href="https://mathworld.wolfram.com/StirlingNumberoftheSecondKind.html"> @@ -70,21 +181,16 @@ public final class Stirling { * @param n Size of the set * @param k Number of non-empty subsets ({@code 0 <= k <= n}) * @return {@code S(n,k)} - * @throws IllegalArgumentException if {@code k < 0} or {@code k > n}. + * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}. * @throws ArithmeticException if some overflow happens, typically for n exceeding 25 and * k between 20 and n-2 (S(n,n-1) is handled specifically and does not overflow) */ public static long stirlingS2(int n, int k) { - if (k < 0) { - throw new CombinatoricsException(CombinatoricsException.NEGATIVE, k); - } - if (k > n) { - throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n); - } + checkArguments(n, k); if (n < StirlingS2Cache.MAX_N) { // The number is in the small cache - return StirlingS2Cache.STIRLING_S2[n][k]; + return StirlingS2Cache.S2[n][k]; } // Simple cases @@ -93,7 +199,7 @@ public final class Stirling { } else if (k == 1 || k == n) { return 1; } else if (k == 2) { - checkN(n, k, 64); + checkN(n, k, 64, S2_ERROR_FORMAT); return (1L << (n - 1)) - 1L; } else if (k == n - 1) { return BinomialCoefficient.value(n, 2); @@ -108,7 +214,7 @@ public final class Stirling { // for i in [1, k]: // sum (i * binom(i+1, 2)) // Avoid overflow checks using the known limit for n when k=n-2 - checkN(n, k, S2_OVERFLOW_K_EQUALS_NM2); + checkN(n, k, S2_OVERFLOW_K_EQUALS_NM2, S2_ERROR_FORMAT); long binom = BinomialCoefficient.value(k + 1, 2); long sum = 0; for (int i = k; i > 0; i--) { @@ -130,28 +236,50 @@ public final class Stirling { long sum = stirlingS2(n0, k0); while (n0 < n) { - n0++; k0++; sum = Math.addExact( - Math.multiplyExact(k0, stirlingS2(n0 - 1, k0)), + Math.multiplyExact(k0, stirlingS2(n0, k0)), sum ); + n0++; } return sum; } + /** + * Check {@code 0 <= k <= n}. + * + * @param n N + * @param k K + * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0} or {@code k > n}. + */ + private static void checkArguments(int n, int k) { + // Combine all checks with a single branch: + // 0 <= n; 0 <= k <= n + // Note: If n >= 0 && k >= 0 && n - k < 0 then k > n. + // Bitwise or will detect a negative sign bit in any of the numbers + if ((n | k | (n - k)) < 0) { + // Raise the correct exception + if (n < 0) { + throw new CombinatoricsException(CombinatoricsException.NEGATIVE, n); + } + throw new CombinatoricsException(CombinatoricsException.OUT_OF_RANGE, k, 0, n); + } + } + /** * Check {@code n <= threshold}, or else throw an {@link ArithmeticException}. * * @param n N * @param k K * @param threshold Threshold for {@code n} + * @param msgFormat Error message format * @throws ArithmeticException if overflow is expected to happen */ - private static void checkN(int n, int k, int threshold) { + private static void checkN(int n, int k, int threshold, String msgFormat) { if (n > threshold) { - throw new ArithmeticException(String.format(S2_ERROR_FORMAT, n, k)); + throw new ArithmeticException(String.format(msgFormat, n, k)); } } } diff --git a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java index 5d797db8..26587619 100644 --- a/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java +++ b/commons-numbers-combinatorics/src/test/java/org/apache/commons/numbers/combinatorics/StirlingTest.java @@ -16,23 +16,209 @@ */ package org.apache.commons.numbers.combinatorics; +import java.util.stream.Stream; +import org.apache.commons.numbers.core.ArithmeticUtils; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.CsvSource; +import org.junit.jupiter.params.provider.MethodSource; /** * Test cases for the {@link Stirling} class. */ class StirlingTest { + /** + * Arguments that are illegal for the Stirling number computations. + * + * @return the arguments + */ + static Stream<Arguments> stirlingIllegalArguments() { + return Stream.of( + Arguments.of(1, -1), + Arguments.of(-1, -1), + Arguments.of(-1, 1), + Arguments.of(10, 15), + Arguments.of(Integer.MIN_VALUE, 1), + Arguments.of(1, Integer.MIN_VALUE), + Arguments.of(Integer.MIN_VALUE, Integer.MIN_VALUE), + Arguments.of(Integer.MAX_VALUE - 1, Integer.MAX_VALUE) + ); + } + + /** + * Arguments that should easily overflow the Stirling number computations. + * Used to verify the exception is correct + * (e.g. no StackOverflowError occurs due to recursion). + * + * @return the arguments + */ + static Stream<Arguments> stirlingOverflowArguments() { + return Stream.of( + Arguments.of(123, 32), + Arguments.of(612534, 56123), + Arguments.of(261388631, 213), + Arguments.of(678688997, 213879), + Arguments.of(1000000002, 1000000000), + Arguments.of(1000000003, 1000000000), + Arguments.of(1000000004, 1000000000), + Arguments.of(1000000005, 1000000000), + Arguments.of(1000000010, 1000000000), + Arguments.of(1000000100, 1000000000) + ); + } + + @ParameterizedTest + @MethodSource(value = {"stirlingIllegalArguments"}) + void testStirlingS1IllegalArgument(int n, int k) { + Assertions.assertThrows(IllegalArgumentException.class, () -> Stirling.stirlingS1(n, k)); + } + + @Test + void testStirlingS1StandardCases() { + Assertions.assertEquals(1, Stirling.stirlingS1(0, 0)); + + for (int n = 1; n < 64; ++n) { + Assertions.assertEquals(0, Stirling.stirlingS1(n, 0)); + if (n < 21) { + Assertions.assertEquals(ArithmeticUtils.pow(-1, n - 1) * Factorial.value(n - 1), + Stirling.stirlingS1(n, 1)); + if (n > 2) { + Assertions.assertEquals(-BinomialCoefficient.value(n, 2), + Stirling.stirlingS1(n, n - 1)); + } + } + Assertions.assertEquals(1, Stirling.stirlingS1(n, n)); + } + } + @ParameterizedTest @CsvSource({ - "1, -1", - "-1, -1", - "-1, 1", - "10, 15", + // Data verified using Mathematica StirlingS1[n, k] + "5, 3, 35", + "6, 3, -225", + "6, 4, 85", + "7, 3, 1624", + "7, 4, -735", + "7, 5, 175", + "8, 3, -13132", + "8, 4, 6769", + "8, 5, -1960", + "8, 6, 322", + "9, 3, 118124", + "9, 4, -67284", + "9, 5, 22449", + "9, 6, -4536", + "9, 7, 546", + "10, 3, -1172700", + "10, 4, 723680", + "10, 5, -269325", + "10, 6, 63273", + "10, 7, -9450", + "10, 8, 870", + // n >= 21 is not cached + // ... k in [1, 7] require n <= 21 + "21, 8, -311333643161390640", + "21, 9, 63030812099294896", + "22, 10, 276019109275035346", + "22, 11, -37600535086859745", + "23, 12, -129006659818331295", + "23, 13, 12363045847086207", + "24, 14, 34701806448704206", + "25, 15, 92446911376173550", + "25, 16, -5700586321864500", + "26, 17, -12972753318542875", + "27, 18, -28460103232088385", + "28, 19, -60383004803151030", + "29, 20, -124243455209483610", + // k in [n-8, n-2] + "33, 25, 42669229615802790", + "40, 33, -16386027912368400", + "66, 60, 98715435586436240", + "155, 150, -1849441185054164625", + "404, 400, 1793805203416799170", + "1003, 1000, -21063481189500750", + "10002, 10000, 1250583420837500", + // Limits for k in [n-1, n] use n = Integer.MAX_VALUE + "2147483647, 2147483646, -2305843005992468481", + "2147483647, 2147483647, 1", + // Data for s(n, n-2) + "21, 19, 20615", + "22, 20, 25025", + "23, 21, 30107", + "24, 22, 35926", + "25, 23, 42550", + "26, 24, 50050", + "27, 25, 58500", + "92679, 92677, 9221886003909976111", + "92680, 92678, 9222284027979459010", + "92681, 92679, 9222682064933083810", + // Data for s(n, n-3) + "21, 18, -1256850", + "22, 19, -1689765", + "23, 20, -2240315", + "24, 21, -2932776", + "25, 22, -3795000", + "26, 23, -4858750", + "27, 24, -6160050", + "2758, 2755, -9145798629595485585", + "2759, 2756, -9165721700732052911", + "2760, 2757, -9185680925511388200", }) + void testStirlingS1(int n, int k, long expected) { + Assertions.assertEquals(expected, Stirling.stirlingS1(n, k)); + } + + @ParameterizedTest + @CsvSource({ + // Upper limits for n with k in [1, 20] + "21, 1, 2432902008176640000", + "21, 2, -8752948036761600000", + "20, 3, -668609730341153280", + "20, 4, 610116075740491776", + "21, 5, 8037811822645051776", + "21, 6, -3599979517947607200", + "21, 7, 1206647803780373360", + "22, 8, 7744654310169576800", + "22, 9, -1634980697246583456", + "23, 10, -7707401101297361068", + "23, 11, 1103230881185949736", + "24, 12, 4070384057007569521", + "24, 13, -413356714301314056", + "25, 14, -1246200069070215000", + "26, 15, -3557372853474553750", + "26, 16, 234961569422786050", + "27, 17, 572253155704900800", + "28, 18, 1340675942971287195", + "29, 19, 3031400077459516035", + "30, 20, 6634460278534540725", + // Upper limits for n with k in [n-9, n-2] + "35, 26, -5576855646887454930", + "44, 36, 6364808704290634598", + "61, 54, -8424028440309413250", + "95, 89, 8864929183170733205", + "181, 176, -8872439767850041020", + "495, 491, 9161199664152744351", + "2761, 2758, -9205676356399769400", + "92682, 92680, 9223080114771128550", + }) + void testStirlingS1LimitsN(int n, int k, long expected) { + Assertions.assertEquals(expected, Stirling.stirlingS1(n, k)); + Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n + 1, k)); + Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n + 100, k)); + Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n + 10000, k)); + } + + @ParameterizedTest + @MethodSource(value = {"stirlingOverflowArguments"}) + void testStirlingS1Overflow(int n, int k) { + Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS1(n, k)); + } + + @ParameterizedTest + @MethodSource(value = {"stirlingIllegalArguments"}) void testStirlingS2IllegalArgument(int n, int k) { Assertions.assertThrows(IllegalArgumentException.class, () -> Stirling.stirlingS2(n, k)); } @@ -129,7 +315,7 @@ class StirlingTest { "30, 20, 581535955088511150", "31, 21, 1359760239259935240", "32, 22, 3069483578649883980", - // Upper limits for n with with k in [n-10, n-2] + // Upper limits for n with k in [n-10, n-2] "33, 23, 6708404338089491700", "38, 29, 6766081393022256030", "47, 39, 8248929419122431611", @@ -148,20 +334,7 @@ class StirlingTest { } @ParameterizedTest - @CsvSource({ - // Large numbers that should easily overflow. Verifies the exception is correct - // (e.g. no StackOverflowError occurs due to recursion) - "123, 32", - "612534, 56123", - "261388631, 213", - "678688997, 213879", - "1000000002, 1000000000", - "1000000003, 1000000000", - "1000000004, 1000000000", - "1000000005, 1000000000", - "1000000010, 1000000000", - "1000000100, 1000000000", - }) + @MethodSource(value = {"stirlingOverflowArguments"}) void testStirlingS2Overflow(int n, int k) { Assertions.assertThrows(ArithmeticException.class, () -> Stirling.stirlingS2(n, k)); }