Repository: commons-math Updated Branches: refs/heads/master 9b2772e38 -> 5597ed7ea
[MATH-1153] Improve performance of BetaDistribution#sample. Thanks to Sergei Lebedev. Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/5597ed7e Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/5597ed7e Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/5597ed7e Branch: refs/heads/master Commit: 5597ed7ea300ae3d08cd893b0133bce26038a7df Parents: 9b2772e Author: Thomas Neidhart <thomas.neidh...@gmail.com> Authored: Fri May 1 11:57:54 2015 +0200 Committer: Thomas Neidhart <thomas.neidh...@gmail.com> Committed: Fri May 1 11:57:54 2015 +0200 ---------------------------------------------------------------------- pom.xml | 3 + src/changes/changes.xml | 5 +- .../math4/distribution/BetaDistribution.java | 130 ++++++++++++++++++- .../distribution/BetaDistributionTest.java | 74 +++++++++++ .../math4/random/RandomDataGeneratorTest.java | 96 +++++--------- 5 files changed, 244 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/5597ed7e/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 42747fd..ae022e8 100644 --- a/pom.xml +++ b/pom.xml @@ -252,6 +252,9 @@ <name>Piotr Kochanski</name> </contributor> <contributor> + <name>Sergei Lebedev</name> + </contributor> + <contributor> <name>Bob MacCallum</name> </contributor> <contributor> http://git-wip-us.apache.org/repos/asf/commons-math/blob/5597ed7e/src/changes/changes.xml ---------------------------------------------------------------------- diff --git a/src/changes/changes.xml b/src/changes/changes.xml index cc38aa8..27705e2 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -54,12 +54,15 @@ If the output is not quite correct, check for invisible trailing spaces! </release> <release version="4.0" date="XXXX-XX-XX" description=""> + <action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev"> <!-- backported to 3.6 --> + Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm. + </action> <action dev="tn" type="update" issue="MATH-853"> "MathRuntimeException" is now the base class for all commons-math exceptions (except for "NullArgumentException" which extends "NullPointerException"). </action> - <action dev="tn" type="fix" issue="MATH-1197"> + <action dev="tn" type="fix" issue="MATH-1197"> <!-- backported to 3.6 --> Computation of 2-sample Kolmogorov-Smirnov statistic in case of ties was not correct. </action> http://git-wip-us.apache.org/repos/asf/commons-math/blob/5597ed7e/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java b/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java index d800510..953bb36 100644 --- a/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java @@ -23,6 +23,7 @@ import org.apache.commons.math4.random.Well19937c; import org.apache.commons.math4.special.Beta; import org.apache.commons.math4.special.Gamma; import org.apache.commons.math4.util.FastMath; +import org.apache.commons.math4.util.Precision; /** * Implements the Beta distribution. @@ -162,12 +163,14 @@ public class BetaDistribution extends AbstractRealDistribution { return Double.NEGATIVE_INFINITY; } else if (x == 0) { if (alpha < 1) { - throw new NumberIsTooSmallException(LocalizedFormats.CANNOT_COMPUTE_BETA_DENSITY_AT_0_FOR_SOME_ALPHA, alpha, 1, false); + throw new NumberIsTooSmallException(LocalizedFormats.CANNOT_COMPUTE_BETA_DENSITY_AT_0_FOR_SOME_ALPHA, + alpha, 1, false); } return Double.NEGATIVE_INFINITY; } else if (x == 1) { if (beta < 1) { - throw new NumberIsTooSmallException(LocalizedFormats.CANNOT_COMPUTE_BETA_DENSITY_AT_1_FOR_SOME_BETA, beta, 1, false); + throw new NumberIsTooSmallException(LocalizedFormats.CANNOT_COMPUTE_BETA_DENSITY_AT_1_FOR_SOME_BETA, + beta, 1, false); } return Double.NEGATIVE_INFINITY; } else { @@ -263,4 +266,127 @@ public class BetaDistribution extends AbstractRealDistribution { public boolean isSupportConnected() { return true; } + + /** {@inheritDoc} + * <p> + * Sampling is performed using Cheng algorithms: + * </p> + * <p> + * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.". + * Communications of the ACM, 21, 317â322, 1978. + * </p> + */ + @Override + public double sample() { + return ChengBetaSampler.sample(random, alpha, beta); + } + + /** Utility class implementing Cheng's algorithms for beta distribution sampling. + * <p> + * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.". + * Communications of the ACM, 21, 317â322, 1978. + * </p> + * @since 3.6 + */ + private static final class ChengBetaSampler { + + /** + * Returns one sample using Cheng's sampling algorithm. + * @param random random generator to use + * @param alpha distribution first shape parameter + * @param beta distribution second shape parameter + * @return sampled value + */ + static double sample(RandomGenerator random, final double alpha, final double beta) { + final double a = FastMath.min(alpha, beta); + final double b = FastMath.max(alpha, beta); + + if (a > 1) { + return algorithmBB(random, alpha, a, b); + } else { + return algorithmBC(random, alpha, b, a); + } + } + + /** + * Returns one sample using Cheng's BB algorithm, when both α and β are greater than 1. + */ + private static double algorithmBB(RandomGenerator random, + final double a0, + final double a, + final double b) { + final double alpha = a + b; + final double beta = FastMath.sqrt((alpha - 2.) / (2. * a * b - alpha)); + final double gamma = a + 1. / beta; + + double r, w, t; + do { + final double u1 = random.nextDouble(); + final double u2 = random.nextDouble(); + final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1)); + w = a * FastMath.exp(v); + final double z = u1 * u1 * u2; + r = gamma * v - 1.3862944; + final double s = a + r - w; + if (s + 2.609438 >= 5 * z) { + break; + } + + t = FastMath.log(z); + if (s >= t) { + break; + } + } while (r + alpha * (FastMath.log(alpha) - FastMath.log(b + w)) < t); + + w = FastMath.min(w, Double.MAX_VALUE); + return Precision.equals(a, a0) ? w / (b + w) : b / (b + w); + } + + /** + * Returns one sample using Cheng's BC algorithm, when at least one of α and β is smaller than 1. + */ + private static double algorithmBC(RandomGenerator random, + final double a0, + final double a, + final double b) { + final double alpha = a + b; + final double beta = 1. / b; + final double delta = 1. + a - b; + final double k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778); + final double k2 = 0.25 + (0.5 + 0.25 / delta) * b; + + double w; + for (;;) { + final double u1 = random.nextDouble(); + final double u2 = random.nextDouble(); + final double y = u1 * u2; + final double z = u1 * y; + if (u1 < 0.5) { + if (0.25 * u2 + z - y >= k1) { + continue; + } + } else { + if (z <= 0.25) { + final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1)); + w = a * FastMath.exp(v); + break; + } + + if (z >= k2) { + continue; + } + } + + final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1)); + w = a * FastMath.exp(v); + if (alpha * (FastMath.log(alpha) - FastMath.log(b + w) + v) - 1.3862944 >= FastMath.log(z)) { + break; + } + } + + w = FastMath.min(w, Double.MAX_VALUE); + return Precision.equals(a, a0) ? w / (b + w) : b / (b + w); + } + + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/5597ed7e/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java index 9c64cf9..632e99f 100644 --- a/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java +++ b/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java @@ -16,11 +16,23 @@ */ package org.apache.commons.math4.distribution; +import java.util.Arrays; + import org.apache.commons.math4.distribution.BetaDistribution; +import org.apache.commons.math4.random.RandomGenerator; +import org.apache.commons.math4.random.Well1024a; +import org.apache.commons.math4.random.Well19937a; +import org.apache.commons.math4.stat.StatUtils; +import org.apache.commons.math4.stat.inference.KolmogorovSmirnovTest; +import org.apache.commons.math4.stat.inference.TestUtils; import org.junit.Assert; import org.junit.Test; public class BetaDistributionTest { + + static final double[] alphaBetas = {0.1, 1, 10, 100, 1000}; + static final double epsilon = StatUtils.min(alphaBetas); + @Test public void testCumulative() { double[] x = new double[]{-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}; @@ -304,4 +316,66 @@ public class BetaDistributionTest { Assert.assertEquals(dist.getNumericalMean(), 2.0 / 7.0, tol); Assert.assertEquals(dist.getNumericalVariance(), 10.0 / (49.0 * 8.0), tol); } + + @Test + public void testMomentsSampling() { + RandomGenerator random = new Well1024a(0x7829862c82fec2dal); + final int numSamples = 1000; + for (final double alpha : alphaBetas) { + for (final double beta : alphaBetas) { + final BetaDistribution betaDistribution = new BetaDistribution(random, alpha, beta); + final double[] observed = new BetaDistribution(alpha, beta).sample(numSamples); + Arrays.sort(observed); + + final String distribution = String.format("Beta(%.2f, %.2f)", alpha, beta); + Assert.assertEquals(String.format("E[%s]", distribution), + betaDistribution.getNumericalMean(), + StatUtils.mean(observed), epsilon); + Assert.assertEquals(String.format("Var[%s]", distribution), + betaDistribution.getNumericalVariance(), + StatUtils.variance(observed), epsilon); + } + } + } + + @Test + public void testGoodnessOfFit() { + RandomGenerator random = new Well19937a(0x237db1db907b089fl); + final int numSamples = 1000; + final double level = 0.01; + for (final double alpha : alphaBetas) { + for (final double beta : alphaBetas) { + final BetaDistribution betaDistribution = new BetaDistribution(random, alpha, beta); + final double[] observed = betaDistribution.sample(numSamples); + Assert.assertFalse("G goodness-of-fit test rejected null at alpha = " + level, + gTest(betaDistribution, observed) < level); + Assert.assertFalse("KS goodness-of-fit test rejected null at alpha = " + level, + new KolmogorovSmirnovTest(random).kolmogorovSmirnovTest(betaDistribution, observed) < level); + } + } + } + + private double gTest(final RealDistribution expectedDistribution, final double[] values) { + final int numBins = values.length / 30; + final double[] breaks = new double[numBins]; + for (int b = 0; b < breaks.length; b++) { + breaks[b] = expectedDistribution.inverseCumulativeProbability((double) b / numBins); + } + + final long[] observed = new long[numBins]; + for (final double value : values) { + int b = 0; + do { + b++; + } while (b < numBins && value >= breaks[b]); + + observed[b - 1]++; + } + + final double[] expected = new double[numBins]; + Arrays.fill(expected, (double) values.length / numBins); + + return TestUtils.gTest(expected, observed); + } + } http://git-wip-us.apache.org/repos/asf/commons-math/blob/5597ed7e/src/test/java/org/apache/commons/math4/random/RandomDataGeneratorTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/random/RandomDataGeneratorTest.java b/src/test/java/org/apache/commons/math4/random/RandomDataGeneratorTest.java index c8f8813..2982efb 100644 --- a/src/test/java/org/apache/commons/math4/random/RandomDataGeneratorTest.java +++ b/src/test/java/org/apache/commons/math4/random/RandomDataGeneratorTest.java @@ -44,8 +44,6 @@ import org.apache.commons.math4.distribution.ZipfDistribution; import org.apache.commons.math4.distribution.ZipfDistributionTest; import org.apache.commons.math4.exception.MathIllegalArgumentException; import org.apache.commons.math4.random.RandomDataGenerator; -import org.apache.commons.math4.random.RandomGenerator; -import org.apache.commons.math4.random.Well19937c; import org.apache.commons.math4.stat.Frequency; import org.apache.commons.math4.stat.inference.ChiSquareTest; import org.apache.commons.math4.util.FastMath; @@ -86,7 +84,7 @@ public class RandomDataGeneratorTest { long y = randomData.nextLong(Long.MIN_VALUE, Long.MAX_VALUE); Assert.assertFalse(x == y); } - + @Test public void testNextUniformExtremeValues() { double x = randomData.nextUniform(-Double.MAX_VALUE, Double.MAX_VALUE); @@ -97,7 +95,7 @@ public class RandomDataGeneratorTest { Assert.assertFalse(Double.isInfinite(x)); Assert.assertFalse(Double.isInfinite(y)); } - + @Test public void testNextIntIAE() { try { @@ -107,7 +105,7 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test public void testNextIntNegativeToPositiveRange() { for (int i = 0; i < 5; i++) { @@ -116,7 +114,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextIntNegativeRange() { for (int i = 0; i < 5; i++) { checkNextIntUniform(-7, -4); @@ -125,7 +123,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextIntPositiveRange() { for (int i = 0; i < 5; i++) { checkNextIntUniform(0, 3); @@ -151,7 +149,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.001); } @@ -172,7 +170,7 @@ public class RandomDataGeneratorTest { (((double) upper) - ((double) lower)); Assert.assertTrue(ratio > 0.99999); } - + @Test public void testNextLongIAE() { try { @@ -191,7 +189,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextLongNegativeRange() { for (int i = 0; i < 5; i++) { checkNextLongUniform(-7, -4); @@ -200,7 +198,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextLongPositiveRange() { for (int i = 0; i < 5; i++) { checkNextLongUniform(0, 3); @@ -226,7 +224,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.01); } @@ -247,7 +245,7 @@ public class RandomDataGeneratorTest { (((double) upper) - ((double) lower)); Assert.assertTrue(ratio > 0.99999); } - + @Test public void testNextSecureLongIAE() { try { @@ -257,7 +255,7 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test @Retry(3) public void testNextSecureLongNegativeToPositiveRange() { @@ -266,7 +264,7 @@ public class RandomDataGeneratorTest { checkNextSecureLongUniform(-3, 6); } } - + @Test @Retry(3) public void testNextSecureLongNegativeRange() { @@ -275,7 +273,7 @@ public class RandomDataGeneratorTest { checkNextSecureLongUniform(-15, -2); } } - + @Test @Retry(3) public void testNextSecureLongPositiveRange() { @@ -284,7 +282,7 @@ public class RandomDataGeneratorTest { checkNextSecureLongUniform(2, 12); } } - + private void checkNextSecureLongUniform(int min, int max) { final Frequency freq = new Frequency(); for (int i = 0; i < smallSampleSize; i++) { @@ -301,7 +299,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.0001); } @@ -314,7 +312,7 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test @Retry(3) public void testNextSecureIntNegativeToPositiveRange() { @@ -323,7 +321,7 @@ public class RandomDataGeneratorTest { checkNextSecureIntUniform(-3, 6); } } - + @Test @Retry(3) public void testNextSecureIntNegativeRange() { @@ -332,8 +330,8 @@ public class RandomDataGeneratorTest { checkNextSecureIntUniform(-15, -2); } } - - @Test + + @Test @Retry(3) public void testNextSecureIntPositiveRange() { for (int i = 0; i < 5; i++) { @@ -341,7 +339,7 @@ public class RandomDataGeneratorTest { checkNextSecureIntUniform(2, 12); } } - + private void checkNextSecureIntUniform(int min, int max) { final Frequency freq = new Frequency(); for (int i = 0; i < smallSampleSize; i++) { @@ -358,11 +356,11 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.0001); } - - + + /** * Make sure that empirical distribution of random Poisson(4)'s has P(X <= @@ -389,7 +387,7 @@ public class RandomDataGeneratorTest { } catch (MathIllegalArgumentException ex) { // ignored } - + final double mean = 4.0d; final int len = 5; PoissonDistribution poissonDistribution = new PoissonDistribution(mean); @@ -406,7 +404,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = poissonDistribution.probability(i + 1) * largeSampleSize; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.0001); } @@ -686,35 +684,35 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test public void testNextUniformUniformPositiveBounds() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(0, 10); } } - + @Test public void testNextUniformUniformNegativeToPositiveBounds() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(-3, 5); } } - + @Test public void testNextUniformUniformNegaiveBounds() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(-7, -3); } } - + @Test public void testNextUniformUniformMaximalInterval() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(-Double.MAX_VALUE, Double.MAX_VALUE); } } - + private void checkNextUniformUniform(double min, double max) { // Set up bin bounds - min, binBound[0], ..., binBound[binCount-2], max final int binCount = 5; @@ -724,7 +722,7 @@ public class RandomDataGeneratorTest { for (int i = 1; i < binCount - 1; i++) { binBounds[i] = binBounds[i - 1] + binSize; // + instead of * to avoid overflow in extreme case } - + final Frequency freq = new Frequency(); for (int i = 0; i < smallSampleSize; i++) { final double value = randomData.nextUniform(min, max); @@ -736,7 +734,7 @@ public class RandomDataGeneratorTest { } freq.addValue(j); } - + final long[] observed = new long[binCount]; for (int i = 0; i < binCount; i++) { observed[i] = freq.getCount(i); @@ -745,7 +743,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < binCount; i++) { expected[i] = 1d / binCount; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.01); } @@ -954,7 +952,7 @@ public class RandomDataGeneratorTest { int[] perm = randomData.nextPermutation(3, 3); observed[findPerm(p, perm)]++; } - + String[] labels = {"{0, 1, 2}", "{ 0, 2, 1 }", "{ 1, 0, 2 }", "{ 1, 2, 0 }", "{ 2, 0, 1 }", "{ 2, 1, 0 }"}; TestUtils.assertChiSquareAccept(labels, expected, observed, 0.001); @@ -1013,30 +1011,6 @@ public class RandomDataGeneratorTest { } @Test - public void testNextInversionDeviate() { - // Set the seed for the default random generator - RandomGenerator rg = new Well19937c(100); - RandomDataGenerator rdg = new RandomDataGenerator(rg); - double[] quantiles = new double[10]; - for (int i = 0; i < 10; i++) { - quantiles[i] = rdg.nextUniform(0, 1); - } - // Reseed again so the inversion generator gets the same sequence - rg.setSeed(100); - BetaDistribution betaDistribution = new BetaDistribution(rg, 2, 4, - BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); - /* - * Generate a sequence of deviates using inversion - the distribution function - * evaluated at the random value from the distribution should match the uniform - * random value used to generate it, which is stored in the quantiles[] array. - */ - for (int i = 0; i < 10; i++) { - double value = betaDistribution.sample(); - Assert.assertEquals(betaDistribution.cumulativeProbability(value), quantiles[i], 10E-9); - } - } - - @Test public void testNextBeta() { double[] quartiles = TestUtils.getDistributionQuartiles(new BetaDistribution(2,5)); long[] counts = new long[4];