This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
commit 63d477c41f102aed39b7469625a9471e55601720 Author: aherbert <aherb...@apache.org> AuthorDate: Wed Jul 28 17:50:09 2021 +0100 STATISTICS-32: New "survivalProbability" function for all discrete distributions. While a naive implementation would simply be `1-cumulativeProbability`, that would result in loss of precision. For many of the current discrete distributions a higher precision survival probability is calculated. For others, it is simply `1-cumulativeProbability`. Many tests were added to verify the following: - the precision of cumulativeProbability - the precision of survivalProbability - That survivalProbabiliy is near 1-cumulative - That survival and cumulative probabilities are complementary Through this development, certain distributions were found lacking precision for their cumulativeProbabilities and were improved. These were: - BinomialDistribution - HypergeometricDistribuion Expanding the tests for the Pascal distribution for the degenerate cases found a bug in the p=1 degenerate case when x=0. The NaN return value has been corrected to either 0 (x=0) or -infinity (x!=0). --- .../distribution/BinomialDistribution.java | 20 +- .../distribution/DiscreteDistribution.java | 17 ++ .../distribution/GeometricDistribution.java | 9 + .../distribution/HypergeometricDistribution.java | 84 +++++-- .../distribution/PascalDistribution.java | 20 ++ .../distribution/PoissonDistribution.java | 13 ++ .../distribution/RegularizedBetaUtils.java | 60 +++++ .../distribution/BinomialDistributionTest.java | 50 +++++ .../DiscreteDistributionAbstractTest.java | 241 ++++++++++++++++++++- .../distribution/DiscreteDistributionTest.java | 8 +- .../distribution/GeometricDistributionTest.java | 11 + .../HypergeometricDistributionTest.java | 27 +++ .../distribution/PascalDistributionTest.java | 17 ++ .../distribution/PoissonDistributionTest.java | 20 ++ .../distribution/RegularizedBetaUtilsTest.java | 51 +++++ 15 files changed, 615 insertions(+), 33 deletions(-) diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java index e3090c0..0e382ab 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/BinomialDistribution.java @@ -101,8 +101,24 @@ public class BinomialDistribution extends AbstractDiscreteDistribution { } else if (x >= numberOfTrials) { ret = 1.0; } else { - ret = 1.0 - RegularizedBeta.value(probabilityOfSuccess, - x + 1.0, (double) numberOfTrials - x); + // Use a helper function to compute the complement of the survival probability + ret = RegularizedBetaUtils.complement(probabilityOfSuccess, + x + 1.0, (double) numberOfTrials - x); + } + return ret; + } + + /** {@inheritDoc} */ + @Override + public double survivalProbability(int x) { + double ret; + if (x < 0) { + ret = 1.0; + } else if (x >= numberOfTrials) { + ret = 0.0; + } else { + ret = RegularizedBeta.value(probabilityOfSuccess, + x + 1.0, (double) numberOfTrials - x); } return ret; } diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java index e2d6cb9..292828c 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/DiscreteDistribution.java @@ -73,6 +73,23 @@ public interface DiscreteDistribution { double cumulativeProbability(int x); /** + * For a random variable {@code X} whose values are distributed according + * to this distribution, this method returns {@code P(X > x)}. + * In other words, this method represents the complementary cumulative + * distribution function. + * <p> + * By default, this is defined as {@code 1 - cumulativeProbability(x)}, but + * the specific implementation may be more accurate. + * + * @param x Point at which the survival function is evaluated. + * @return the probability that a random variable with this + * distribution takes a value greater than {@code x}. + */ + default double survivalProbability(int x) { + return 1.0 - cumulativeProbability(x); + } + + /** * Computes the quantile function of this distribution. * For a random variable {@code X} distributed according to this distribution, * the returned value is diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java index d9e88cc..5e28975 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/GeometricDistribution.java @@ -79,6 +79,15 @@ public class GeometricDistribution extends AbstractDiscreteDistribution { return -Math.expm1(log1mProbabilityOfSuccess * (x + 1)); } + /** {@inheritDoc} */ + @Override + public double survivalProbability(int x) { + if (x < 0) { + return 1.0; + } + return Math.exp(log1mProbabilityOfSuccess * (x + 1)); + } + /** * {@inheritDoc} * diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java index cdeab67..8d4596c 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/HypergeometricDistribution.java @@ -79,7 +79,24 @@ public class HypergeometricDistribution extends AbstractDiscreteDistribution { } else if (x >= domain[1]) { ret = 1.0; } else { - ret = innerCumulativeProbability(domain[0], x, 1); + ret = innerCumulativeProbability(domain[0], x); + } + + return ret; + } + + /** {@inheritDoc} */ + @Override + public double survivalProbability(int x) { + double ret; + + final int[] domain = getDomain(populationSize, numberOfSuccesses, sampleSize); + if (x < domain[0]) { + ret = 1.0; + } else if (x >= domain[1]) { + ret = 0.0; + } else { + ret = innerCumulativeProbability(domain[1], x + 1); } return ret; @@ -168,22 +185,36 @@ public class HypergeometricDistribution extends AbstractDiscreteDistribution { } else { final double p = (double) sampleSize / (double) populationSize; final double q = (double) (populationSize - sampleSize) / (double) populationSize; - final double p1 = SaddlePointExpansionUtils.logBinomialProbability(x, - numberOfSuccesses, p, q); - final double p2 = - SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x, - populationSize - numberOfSuccesses, p, q); - final double p3 = - SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, p, q); - ret = p1 + p2 - p3; + ret = logProbability(x, p, q); } return ret; } /** + * Compute the log probability. + * + * @param x Value. + * @param p sample size / population size. + * @param q (population size - sample size) / population size + * @return log(P(X = x)) + */ + private double logProbability(int x, double p, double q) { + final double p1 = + SaddlePointExpansionUtils.logBinomialProbability(x, numberOfSuccesses, p, q); + final double p2 = + SaddlePointExpansionUtils.logBinomialProbability(sampleSize - x, + populationSize - numberOfSuccesses, p, q); + final double p3 = + SaddlePointExpansionUtils.logBinomialProbability(sampleSize, populationSize, p, q); + return p1 + p2 - p3; + } + + /** * For this distribution, {@code X}, this method returns {@code P(X >= x)}. * + * <p>Note: This is not equals to {@link #survivalProbability(int)} which computes {@code P(X > x)}. + * * @param x Value at which the CDF is evaluated. * @return the upper tail CDF for this distribution. */ @@ -196,7 +227,7 @@ public class HypergeometricDistribution extends AbstractDiscreteDistribution { } else if (x > domain[1]) { ret = 0.0; } else { - ret = innerCumulativeProbability(domain[1], x, -1); + ret = innerCumulativeProbability(domain[1], x); } return ret; @@ -206,21 +237,32 @@ public class HypergeometricDistribution extends AbstractDiscreteDistribution { * For this distribution, {@code X}, this method returns * {@code P(x0 <= X <= x1)}. * This probability is computed by summing the point probabilities for the - * values {@code x0, x0 + 1, x0 + 2, ..., x1}, in the order directed by - * {@code dx}. + * values {@code x0, x0 + dx, x0 + 2 * dx, ..., x1}; the direction {@code dx} is determined + * using a comparison of the input bounds. + * This should be called by using {@code x0} as the domain limit and {@code x1} + * as the internal value. This will result in an initial sum of increasing larger magnitudes. * - * @param x0 Inclusive lower bound. - * @param x1 Inclusive upper bound. - * @param dx Direction of summation (1 indicates summing from x0 to x1, and - * 0 indicates summing from x1 to x0). + * @param x0 Inclusive domain bound. + * @param x1 Inclusive internal bound. * @return {@code P(x0 <= X <= x1)}. */ - private double innerCumulativeProbability(int x0, int x1, int dx) { + private double innerCumulativeProbability(int x0, int x1) { + // Assume the range is within the domain. + // Reuse the computation for probability(x) but avoid checking the domain for each call. + final double p = (double) sampleSize / (double) populationSize; + final double q = (double) (populationSize - sampleSize) / (double) populationSize; int x = x0; - double ret = probability(x); - while (x != x1) { - x += dx; - ret += probability(x); + double ret = Math.exp(logProbability(x, p, q)); + if (x0 < x1) { + while (x != x1) { + x++; + ret += Math.exp(logProbability(x, p, q)); + } + } else { + while (x != x1) { + x--; + ret += Math.exp(logProbability(x, p, q)); + } } return ret; } diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java index d2cdfdc..94d385b 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PascalDistribution.java @@ -108,6 +108,9 @@ public class PascalDistribution extends AbstractDiscreteDistribution { double ret; if (x < 0) { ret = 0.0; + } else if (x == 0) { + // Special case exploiting cancellation. + ret = Math.pow(probabilityOfSuccess, numberOfSuccesses); } else { ret = BinomialCoefficientDouble.value(x + numberOfSuccesses - 1, numberOfSuccesses - 1) * @@ -123,6 +126,9 @@ public class PascalDistribution extends AbstractDiscreteDistribution { double ret; if (x < 0) { ret = Double.NEGATIVE_INFINITY; + } else if (x == 0) { + // Special case exploiting cancellation. + ret = logProbabilityOfSuccess * numberOfSuccesses; } else { ret = LogBinomialCoefficient.value(x + numberOfSuccesses - 1, numberOfSuccesses - 1) + @@ -145,6 +151,20 @@ public class PascalDistribution extends AbstractDiscreteDistribution { return ret; } + /** {@inheritDoc} */ + @Override + public double survivalProbability(int x) { + double ret; + if (x < 0) { + ret = 1.0; + } else { + // Use a helper function to compute the complement of the cumulative probability + ret = RegularizedBetaUtils.complement(probabilityOfSuccess, + numberOfSuccesses, x + 1.0); + } + return ret; + } + /** * {@inheritDoc} * diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java index a1cd928..c2649fa 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/PoissonDistribution.java @@ -108,6 +108,19 @@ public class PoissonDistribution extends AbstractDiscreteDistribution { maxIterations); } + /** {@inheritDoc} */ + @Override + public double survivalProbability(int x) { + if (x < 0) { + return 1; + } + if (x == Integer.MAX_VALUE) { + return 0; + } + return RegularizedGamma.P.value((double) x + 1, mean, epsilon, + maxIterations); + } + /** * Calculates the Poisson distribution function using a normal * approximation. The {@code N(mean, sqrt(mean))} distribution is used diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/RegularizedBetaUtils.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/RegularizedBetaUtils.java new file mode 100644 index 0000000..eb8db88 --- /dev/null +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/RegularizedBetaUtils.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.statistics.distribution; + +import org.apache.commons.numbers.gamma.RegularizedBeta; + +/** + * Utilities for the <a href="http://mathworld.wolfram.com/RegularizedBetaFunction.html"> + * Regularized Beta function</a> {@code I(x, a, b)}. + */ +final class RegularizedBetaUtils { + /** No instances. */ + private RegularizedBetaUtils() {} + + /** + * Compute the complement of the regularized beta function {@code I(x, a, b)}. + * <pre> + * 1 - I(x, a, b) = I(1 - x, b, a) + * </pre> + * + * @param x the value. + * @param a Parameter {@code a}. + * @param b Parameter {@code b}. + * @return the complement of the regularized beta function 1 - I(x, a, b). + */ + static double complement(double x, double a, double b) { + // Identity of the regularized beta function: 1 - I_z(a, b) = I_{1-x}(b, a) + // Ideally call RegularizedBeta.value(1 - x, b, a) to maximise precision. + // + // The implementation of the beta function will use the complement based on a condition. + // Here we repeat the condition with a and b switched and testing 1 - x. + // This will avoid double inversion of the parameters. + final double mxp1 = 1 - x; + if (mxp1 > (b + 1) / (2 + b + a)) { + // Note: This drops the addition test '&& x <= (a + 1) / (2 + b + a)' + // The test is to avoid infinite method call recursion which does not apply + // in this case. See MATH-1067. + + // Direct computation of the complement with the input x. + // Avoids loss of precision when x != 1 - (1-x) + return 1.0 - RegularizedBeta.value(x, a, b); + } + // Use the identity which should be computed directly by the RegularizedBeta implementation. + return RegularizedBeta.value(mxp1, b, a); + } +} diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java index 32147d1..f3fdac8 100644 --- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/BinomialDistributionTest.java @@ -78,6 +78,29 @@ class BinomialDistributionTest extends DiscreteDistributionAbstractTest { //-------------------- Additional test cases ------------------------------- + /** Test case n = 10, p = 0.3. */ + @Test + void testSmallPValue() { + final BinomialDistribution dist = new BinomialDistribution(10, 0.3); + setDistribution(dist); + // computed using R version 3.4.4 + setCumulativeTestValues(new double[] {0.00000000000000000000, 0.02824752489999998728, 0.14930834590000002793, + 0.38278278639999974153, 0.64961071840000017552, 0.84973166740000016794, 0.95265101260000006889, + 0.98940792160000001765, 0.99840961360000002323, 0.99985631409999997654, 0.99999409509999992451, + 1.00000000000000000000, 1.00000000000000000000}); + setDensityTestValues(new double[] {0.0000000000000000000e+00, 2.8247524899999980341e-02, + 1.2106082099999991575e-01, 2.3347444049999999116e-01, 2.6682793199999993439e-01, 2.0012094900000007569e-01, + 1.0291934520000002584e-01, 3.6756909000000004273e-02, 9.0016919999999864960e-03, 1.4467005000000008035e-03, + 1.3778099999999990615e-04, 5.9048999999999949131e-06, 0.0000000000000000000e+00}); + setInverseCumulativeTestValues(new int[] {0, 0, 0, 0, 1, 1, 8, 7, 6, 5, 5, 10}); + verifyDensities(); + verifyLogDensities(); + verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); + verifyInverseCumulativeProbabilities(); + } + /** Test degenerate case p = 0 */ @Test void testDegenerate0() { @@ -90,7 +113,10 @@ class BinomialDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {0, 0}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); Assertions.assertEquals(0, dist.getSupportLowerBound()); Assertions.assertEquals(0, dist.getSupportUpperBound()); @@ -108,7 +134,10 @@ class BinomialDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {5, 5}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); Assertions.assertEquals(5, dist.getSupportLowerBound()); Assertions.assertEquals(5, dist.getSupportUpperBound()); @@ -126,7 +155,10 @@ class BinomialDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {0, 0}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); Assertions.assertEquals(0, dist.getSupportLowerBound()); Assertions.assertEquals(0, dist.getSupportUpperBound()); @@ -184,4 +216,22 @@ class BinomialDistributionTest extends DiscreteDistributionAbstractTest { Assertions.assertEquals(trials / 2, p); } } + + @Test + void testHighPrecisionCumulativeProbabilities() { + // computed using R version 3.4.4 + setDistribution(new BinomialDistribution(100, 0.99)); + setCumulativePrecisionTestPoints(new int[] {82, 81}); + setCumulativePrecisionTestValues(new double[] {1.4061271955993513664e-17, 6.1128083336354843707e-19}); + verifyCumulativeProbabilityPrecision(); + } + + @Test + void testHighPrecisionSurvivalProbabilities() { + // computed using R version 3.4.4 + setDistribution(new BinomialDistribution(100, 0.01)); + setSurvivalPrecisionTestPoints(new int[] {18, 19}); + setSurvivalPrecisionTestValues(new double[] {6.1128083336353977038e-19, 2.4944165604029235392e-20}); + verifySurvivalProbabilityPrecision(); + } } diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java index 521dcf5..0df01dd 100644 --- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionAbstractTest.java @@ -16,6 +16,7 @@ */ package org.apache.commons.statistics.distribution; +import java.util.Arrays; import org.apache.commons.rng.simple.RandomSource; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; @@ -38,6 +39,20 @@ import org.junit.jupiter.api.Test; * makeInverseCumulativeTestPoints() -- arguments used to test inverse cdf evaluation * makeInverseCumulativeTestValues() -- expected inverse cdf values * <p> + * If the discrete distribution provides higher precision implementations of cumulativeProbability + * and/or survivalProbability, the following methods should be implemented to provide testing. + * To use these tests, calculate the cumulativeProbability and survivalProbability such that their naive + * complement is exceptionally close to `1` and consequently could lose precision due to floating point + * arithmetic. + * + * NOTE: The default high-precision threshold is 1e-22. + * <pre> + * makeCumulativePrecisionTestPoints() -- high precision test inputs + * makeCumulativePrecisionTestValues() -- high precision expected results + * makeSurvivalPrecisionTestPoints() -- high precision test inputs + * makeSurvivalPrecisionTestValues() -- high precision expected results + * </pre> + * <p> * To implement additional test cases with different distribution instances and test data, * use the setXxx methods for the instance data in test cases and call the verifyXxx methods * to verify results. @@ -51,6 +66,9 @@ abstract class DiscreteDistributionAbstractTest { /** Tolerance used in comparing expected and returned values. */ private double tolerance = 1e-12; + /** Tolerance used in high precision tests. */ + private double highPrecisionTolerance = 1e-22; + /** Arguments used to test probability density calculations. */ private int[] densityTestPoints; @@ -66,6 +84,18 @@ abstract class DiscreteDistributionAbstractTest { /** Values used to test cumulative probability density calculations. */ private double[] cumulativeTestValues; + /** Arguments used to test cumulative probability precision, effectively any x where 1-cdf(x) would result in 1. */ + private int[] cumulativePrecisionTestPoints; + + /** Values used to test cumulative probability precision, usually exceptionally tiny values. */ + private double[] cumulativePrecisionTestValues; + + /** Arguments used to test survival probability precision, effectively any x where 1-sf(x) would result in 1. */ + private int[] survivalPrecisionTestPoints; + + /** Values used to test survival probability precision, usually exceptionally tiny values. */ + private double[] survivalPrecisionTestValues; + /** Arguments used to test inverse cumulative probability density calculations. */ private double[] inverseCumulativeTestPoints; @@ -91,12 +121,7 @@ abstract class DiscreteDistributionAbstractTest { * @return double[] the default logarithmic probability density test expected values. */ public double[] makeLogDensityTestValues() { - final double[] density = makeDensityTestValues(); - final double[] logDensity = new double[density.length]; - for (int i = 0; i < density.length; i++) { - logDensity[i] = Math.log(density[i]); - } - return logDensity; + return Arrays.stream(makeDensityTestValues()).map(Math::log).toArray(); } /** Creates the default cumulative probability density test input values. */ @@ -105,6 +130,34 @@ abstract class DiscreteDistributionAbstractTest { /** Creates the default cumulative probability density test expected values. */ public abstract double[] makeCumulativeTestValues(); + /** Creates the default cumulative probability precision test input values. */ + public int[] makeCumulativePrecisionTestPoints() { + return new int[0]; + } + + /** + * Creates the default cumulative probability precision test expected values. + * Note: The default threshold is 1e-22, any expected values with much higher precision may + * not test the desired results without increasing precision threshold. + */ + public double[] makeCumulativePrecisionTestValues() { + return new double[0]; + } + + /** Creates the default survival probability precision test input values. */ + public int[] makeSurvivalPrecisionTestPoints() { + return new int[0]; + } + + /** + * Creates the default survival probability precision test expected values. + * Note: The default threshold is 1e-22, any expected values with much higher precision may + * not test the desired results without increasing precision threshold. + */ + public double[] makeSurvivalPrecisionTestValues() { + return new double[0]; + } + /** Creates the default inverse cumulative probability test input values. */ public abstract double[] makeInverseCumulativeTestPoints(); @@ -124,6 +177,10 @@ abstract class DiscreteDistributionAbstractTest { logDensityTestValues = makeLogDensityTestValues(); cumulativeTestPoints = makeCumulativeTestPoints(); cumulativeTestValues = makeCumulativeTestValues(); + cumulativePrecisionTestPoints = makeCumulativePrecisionTestPoints(); + cumulativePrecisionTestValues = makeCumulativePrecisionTestValues(); + survivalPrecisionTestPoints = makeSurvivalPrecisionTestPoints(); + survivalPrecisionTestValues = makeSurvivalPrecisionTestValues(); inverseCumulativeTestPoints = makeInverseCumulativeTestPoints(); inverseCumulativeTestValues = makeInverseCumulativeTestValues(); } @@ -139,6 +196,10 @@ abstract class DiscreteDistributionAbstractTest { logDensityTestValues = null; cumulativeTestPoints = null; cumulativeTestValues = null; + cumulativePrecisionTestPoints = null; + cumulativePrecisionTestValues = null; + survivalPrecisionTestPoints = null; + survivalPrecisionTestValues = null; inverseCumulativeTestPoints = null; inverseCumulativeTestValues = null; } @@ -164,10 +225,9 @@ abstract class DiscreteDistributionAbstractTest { */ protected void verifyLogDensities() { for (int i = 0; i < densityTestPoints.length; i++) { - // FIXME: when logProbability methods are added to DiscreteDistribution in 4.0, remove cast below final int testPoint = densityTestPoints[i]; Assertions.assertEquals(logDensityTestValues[i], - ((AbstractDiscreteDistribution) distribution).logProbability(testPoint), tolerance, + distribution.logProbability(testPoint), tolerance, () -> "Incorrect log density value returned for " + testPoint); } } @@ -185,6 +245,57 @@ abstract class DiscreteDistributionAbstractTest { } } + protected void verifySurvivalProbability() { + for (int i = 0; i < cumulativeTestPoints.length; i++) { + final int x = cumulativeTestPoints[i]; + Assertions.assertEquals( + 1 - cumulativeTestValues[i], + distribution.survivalProbability(cumulativeTestPoints[i]), + getTolerance(), + () -> "Incorrect survival probability value returned for " + x); + } + } + + protected void verifySurvivalAndCumulativeProbabilityComplement() { + for (final int x : cumulativeTestPoints) { + Assertions.assertEquals( + 1.0, + distribution.survivalProbability(x) + distribution.cumulativeProbability(x), + getTolerance(), + () -> "survival + cumulative probability were not close to 1.0 for " + x); + } + } + + /** + * Verifies that survival is simply not 1-cdf by testing calculations that would underflow that calculation and + * result in an inaccurate answer. + */ + protected void verifySurvivalProbabilityPrecision() { + for (int i = 0; i < survivalPrecisionTestPoints.length; i++) { + final int x = survivalPrecisionTestPoints[i]; + Assertions.assertEquals( + survivalPrecisionTestValues[i], + distribution.survivalProbability(x), + getHighPrecisionTolerance(), + () -> "survival probability is not precise for " + x); + } + } + + /** + * Verifies that CDF is simply not 1-survival function by testing values that would result with inaccurate results + * if simply calculating 1-survival function. + */ + protected void verifyCumulativeProbabilityPrecision() { + for (int i = 0; i < cumulativePrecisionTestPoints.length; i++) { + final int x = cumulativePrecisionTestPoints[i]; + Assertions.assertEquals( + cumulativePrecisionTestValues[i], + distribution.cumulativeProbability(x), + getHighPrecisionTolerance(), + () -> "cumulative probability is not precise for " + x); + } + } + /** * Verifies that inverse cumulative probability density calculations match expected values * using current test instance data. @@ -227,6 +338,26 @@ abstract class DiscreteDistributionAbstractTest { verifyCumulativeProbabilities(); } + @Test + void testSurvivalProbability() { + verifySurvivalProbability(); + } + + @Test + void testSurvivalAndCumulativeProbabilitiesAreComplementary() { + verifySurvivalAndCumulativeProbabilityComplement(); + } + + @Test + void testCumulativeProbabilityPrecision() { + verifyCumulativeProbabilityPrecision(); + } + + @Test + void testSurvivalProbabilityPrecision() { + verifySurvivalProbabilityPrecision(); + } + /** * Verifies that inverse cumulative probability density calculations match expected values * using default test instance data. @@ -240,9 +371,11 @@ abstract class DiscreteDistributionAbstractTest { void testConsistencyAtSupportBounds() { final int lower = distribution.getSupportLowerBound(); Assertions.assertEquals(0.0, distribution.cumulativeProbability(lower - 1), 0.0, - "Cumulative probability mmust be 0 below support lower bound."); + "Cumulative probability must be 0 below support lower bound."); Assertions.assertEquals(distribution.probability(lower), distribution.cumulativeProbability(lower), getTolerance(), "Cumulative probability of support lower bound must be equal to probability mass at this point."); + Assertions.assertEquals(1.0, distribution.survivalProbability(lower - 1), 0.0, + "Survival probability must be 1.0 below support lower bound."); Assertions.assertEquals(lower, distribution.inverseCumulativeProbability(0.0), "Inverse cumulative probability of 0 must be equal to support lower bound."); @@ -250,6 +383,8 @@ abstract class DiscreteDistributionAbstractTest { if (upper != Integer.MAX_VALUE) { Assertions.assertEquals(1.0, distribution.cumulativeProbability(upper), 0.0, "Cumulative probability of support upper bound must be equal to 1."); + Assertions.assertEquals(0.0, distribution.survivalProbability(upper), 0.0, + "Survival probability of support upper bound must be equal to 0."); } Assertions.assertEquals(upper, distribution.inverseCumulativeProbability(1.0), "Inverse cumulative probability of 1 must be equal to support upper bound."); @@ -357,10 +492,84 @@ abstract class DiscreteDistributionAbstractTest { } /** + * Set the density test values. + * For convenience this recomputes the log density test values using {@link Math#log(double)}. + * * @param densityTestValues The densityTestValues to set. */ protected void setDensityTestValues(double[] densityTestValues) { this.densityTestValues = densityTestValues; + logDensityTestValues = Arrays.stream(densityTestValues).map(Math::log).toArray(); + } + + /** + * @return Returns the logDensityTestValues. + */ + protected double[] getLogDensityTestValues() { + return logDensityTestValues; + } + + /** + * @param logDensityTestValues The logDensityTestValues to set. + */ + protected void setLogDensityTestValues(double[] logDensityTestValues) { + this.logDensityTestValues = logDensityTestValues; + } + + /** + * @return Returns the cumulativePrecisionTestPoints. + */ + protected int[] getCumulativePrecisionTestPoints() { + return cumulativePrecisionTestPoints; + } + + /** + * @param cumulativePrecisionTestPoints The cumulativePrecisionTestPoints to set. + */ + protected void setCumulativePrecisionTestPoints(int[] cumulativePrecisionTestPoints) { + this.cumulativePrecisionTestPoints = cumulativePrecisionTestPoints; + } + + /** + * @return Returns the cumulativePrecisionTestValues. + */ + protected double[] getCumulativePrecisionTestValues() { + return cumulativePrecisionTestValues; + } + + /** + * @param cumulativePrecisionTestValues The cumulativePrecisionTestValues to set. + */ + protected void setCumulativePrecisionTestValues(double[] cumulativePrecisionTestValues) { + this.cumulativePrecisionTestValues = cumulativePrecisionTestValues; + } + + /** + * @return Returns the survivalPrecisionTestPoints. + */ + protected int[] getSurvivalPrecisionTestPoints() { + return survivalPrecisionTestPoints; + } + + /** + * @param survivalPrecisionTestPoints The survivalPrecisionTestPoints to set. + */ + protected void setSurvivalPrecisionTestPoints(int[] survivalPrecisionTestPoints) { + this.survivalPrecisionTestPoints = survivalPrecisionTestPoints; + } + + /** + * @return Returns the survivalPrecisionTestValues. + */ + protected double[] getSurvivalPrecisionTestValues() { + return survivalPrecisionTestValues; + } + + /** + * @param survivalPrecisionTestValues The survivalPrecisionTestValues to set. + */ + protected void setSurvivalPrecisionTestValues(double[] survivalPrecisionTestValues) { + this.survivalPrecisionTestValues = survivalPrecisionTestValues; } /** @@ -420,6 +629,20 @@ abstract class DiscreteDistributionAbstractTest { } /** + * @return Returns the high precision tolerance. + */ + protected double getHighPrecisionTolerance() { + return highPrecisionTolerance; + } + + /** + * @param highPrecisionTolerance The high precision highPrecisionTolerance to set. + */ + protected void setHighPrecisionTolerance(double highPrecisionTolerance) { + this.highPrecisionTolerance = highPrecisionTolerance; + } + + /** * The expected value for {@link DiscreteDistribution#isSupportConnected()}. * The default is {@code true}. Test class should override this when the distribution * is not support connected. diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java index aa4960a..621f7aa 100644 --- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/DiscreteDistributionTest.java @@ -40,7 +40,11 @@ class DiscreteDistributionTest { } @Override public double cumulativeProbability(int x) { - return 0; + // Return some different values to allow the survival probability to be tested + if (x < 0) { + return x < -5 ? 0.25 : 0.5; + } + return x > 5 ? 1.0 : 0.75; } @Override public int inverseCumulativeProbability(double p) { @@ -75,6 +79,8 @@ class DiscreteDistributionTest { for (final int x : new int[] {Integer.MIN_VALUE, -1, 0, 1, 2, Integer.MAX_VALUE}) { // Return the log of the density Assertions.assertEquals(Math.log(x), dist.logProbability(x)); + // Must return 1 - CDF(x) + Assertions.assertEquals(1.0 - dist.cumulativeProbability(x), dist.survivalProbability(x)); } } } diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java index 45dc8d5..6efdb67 100644 --- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/GeometricDistributionTest.java @@ -138,6 +138,17 @@ class GeometricDistributionTest extends DiscreteDistributionAbstractTest { }; } + @Override + public int[] makeSurvivalPrecisionTestPoints() { + return new int[] {74, 81}; + } + + @Override + public double[] makeSurvivalPrecisionTestValues() { + // computed using R version 3.4.4 + return new double[] {2.2979669527522718895e-17, 6.4328367688565960968e-19}; + } + //-------------------- Additional test cases ------------------------------- @Test diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java index d2ce7ab..7dc251c 100644 --- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/HypergeometricDistributionTest.java @@ -101,7 +101,10 @@ class HypergeometricDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {3, 3}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); Assertions.assertEquals(3, dist.getSupportLowerBound()); Assertions.assertEquals(3, dist.getSupportUpperBound()); @@ -119,7 +122,10 @@ class HypergeometricDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {0, 0}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); Assertions.assertEquals(0, dist.getSupportLowerBound()); Assertions.assertEquals(0, dist.getSupportUpperBound()); @@ -137,7 +143,10 @@ class HypergeometricDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {3, 3}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); Assertions.assertEquals(3, dist.getSupportLowerBound()); Assertions.assertEquals(3, dist.getSupportUpperBound()); @@ -321,4 +330,22 @@ class HypergeometricDistributionTest extends DiscreteDistributionAbstractTest { Assertions.assertTrue(sample <= n, () -> "sample=" + sample); } } + + @Test + void testHighPrecisionCumulativeProbabilities() { + // computed using R version 3.4.4 + setDistribution(new HypergeometricDistribution(500, 70, 300)); + setCumulativePrecisionTestPoints(new int[] {10, 8}); + setCumulativePrecisionTestValues(new double[] {2.4055720603264525e-17, 1.2848174992266236e-19}); + verifySurvivalProbabilityPrecision(); + } + + @Test + void testHighPrecisionSurvivalProbabilities() { + // computed using R version 3.4.4 + setDistribution(new HypergeometricDistribution(500, 70, 300)); + setSurvivalPrecisionTestPoints(new int[] {68, 69}); + setSurvivalPrecisionTestValues(new double[] {4.570379934029859e-16, 7.4187180434325268e-18}); + verifySurvivalProbabilityPrecision(); + } } diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java index 66ecfe6..6d6467f 100644 --- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PascalDistributionTest.java @@ -78,6 +78,17 @@ class PascalDistributionTest extends DiscreteDistributionAbstractTest { return new int[] {0, 0, 0, 0, 1, 1, 14, 11, 10, 9, 8, Integer.MAX_VALUE}; } + @Override + public int[] makeSurvivalPrecisionTestPoints() { + return new int[] {47, 52}; + } + + @Override + public double[] makeSurvivalPrecisionTestValues() { + // computed using R version 3.4.4 + return new double[] {3.1403888119656772712e-17, 1.7075879020163069251e-19}; + } + //-------------------- Additional test cases ------------------------------- /** Test degenerate case p = 0 */ @@ -91,7 +102,10 @@ class PascalDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {Integer.MAX_VALUE, Integer.MAX_VALUE}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); } @@ -106,7 +120,10 @@ class PascalDistributionTest extends DiscreteDistributionAbstractTest { setInverseCumulativeTestPoints(new double[] {0.1d, 0.5d}); setInverseCumulativeTestValues(new int[] {0, 0}); verifyDensities(); + verifyLogDensities(); verifyCumulativeProbabilities(); + verifySurvivalProbability(); + verifySurvivalAndCumulativeProbabilityComplement(); verifyInverseCumulativeProbabilities(); } diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java index f30f697..b908926 100644 --- a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/PoissonDistributionTest.java @@ -95,6 +95,17 @@ class PoissonDistributionTest extends DiscreteDistributionAbstractTest { return new int[] {0, 0, 1, 1, 2, 2, 3, 3, 4, 5, 10, 20}; } + @Override + public int[] makeSurvivalPrecisionTestPoints() { + return new int[] {30, 32}; + } + + @Override + public double[] makeSurvivalPrecisionTestValues() { + // computed using R version 3.4.4 + return new double[] {1.1732435431464340474e-17, 1.7630174687875970627e-19}; + } + //-------------------- Additional test cases ------------------------------- /** @@ -221,4 +232,13 @@ class PoissonDistributionTest extends DiscreteDistributionAbstractTest { mean *= 10.0; } } + + @Test + void testLargeMeanHighPrecisionCumulativeProbabilities() { + // computed using R version 3.4.4 + setDistribution(new PoissonDistribution(100)); + setCumulativePrecisionTestPoints(new int[] {28, 25}); + setCumulativePrecisionTestValues(new double[] {1.6858675763053070496e-17, 3.184075559619425735e-19}); + verifyCumulativeProbabilityPrecision(); + } } diff --git a/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/RegularizedBetaUtilsTest.java b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/RegularizedBetaUtilsTest.java new file mode 100644 index 0000000..7756145 --- /dev/null +++ b/commons-statistics-distribution/src/test/java/org/apache/commons/statistics/distribution/RegularizedBetaUtilsTest.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.statistics.distribution; + +import org.apache.commons.numbers.gamma.RegularizedBeta; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** + * Test for {@link RegularizedBetaUtils}. + */ +class RegularizedBetaUtilsTest { + @Test + void testComplement() { + final double[] xs = {0, 0.1, 0.2, 0.25, 0.3, 1.0 / 3, 0.4, 0.5, 0.6, 2.0 / 3, 0.7, 0.75, 0.8, 0.9, 1}; + // Called in PascalDistribution with a >= 1; b >= 1 + // Called in BinomialDistribution with a >= 1; b >= 1 + final double[] as = {1, 2, 3, 4, 5, 10, 20, 100, 1000}; + final double[] bs = {1, 2, 3, 4, 5, 10, 20, 100, 1000}; + for (final double x : xs) { + for (final double a : as) { + for (final double b : bs) { + assertComplement(x, a, b); + } + } + } + } + + private static void assertComplement(double x, double a, double b) { + final double expected1 = 1.0 - RegularizedBeta.value(x, a, b); + final double expected2 = RegularizedBeta.value(1 - x, b, a); + final double actual = RegularizedBetaUtils.complement(x, a, b); + // Expect binary equality with 1 result + Assertions.assertTrue(expected1 == actual || expected2 == actual, + () -> String.format("I(%s, %s, %s) Expected %s or %s: Actual %s", x, a, b, expected1, expected2, actual)); + } +}