Repository: commons-math Updated Branches: refs/heads/MATH_3_X 2011e11e5 -> f5d028ca6
[MATH-1153] Improve performance of BetaDistribution.sample(). Thanks to Sergei Lebedev. Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/f5d028ca Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/f5d028ca Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/f5d028ca Branch: refs/heads/MATH_3_X Commit: f5d028ca6af5591ca51785da7c15d7bd81d4215f Parents: 2011e11 Author: Thomas Neidhart <thomas.neidh...@gmail.com> Authored: Fri May 1 12:07:52 2015 +0200 Committer: Thomas Neidhart <thomas.neidh...@gmail.com> Committed: Fri May 1 12:07:52 2015 +0200 ---------------------------------------------------------------------- pom.xml | 3 + src/changes/changes.xml | 3 + .../math3/distribution/BetaDistribution.java | 134 +++++++++++++++++++ .../distribution/BetaDistributionTest.java | 73 ++++++++++ .../math3/random/RandomDataGeneratorTest.java | 94 +++++-------- 5 files changed, 248 insertions(+), 59 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 223b316..7c55d33 100644 --- a/pom.xml +++ b/pom.xml @@ -252,6 +252,9 @@ <name>Piotr Kochanski</name> </contributor> <contributor> + <name>Sergei Lebedev</name> + </contributor> + <contributor> <name>Bob MacCallum</name> </contributor> <contributor> http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/changes/changes.xml ---------------------------------------------------------------------- diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 0759e8e..2e818b2 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces! </properties> <body> <release version="3.6" date="XXXX-XX-XX" description=""> + <action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev"> + Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm. + </action> <action dev="tn" type="fix" issue="MATH-1197"> Computation of 2-sample Kolmogorov-Smirnov statistic in case of ties was not correct. http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java b/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java index 3f62f64..19b19e0 100644 --- a/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java +++ b/src/main/java/org/apache/commons/math3/distribution/BetaDistribution.java @@ -23,6 +23,7 @@ import org.apache.commons.math3.random.Well19937c; import org.apache.commons.math3.special.Beta; import org.apache.commons.math3.special.Gamma; import org.apache.commons.math3.util.FastMath; +import org.apache.commons.math3.util.Precision; /** * Implements the Beta distribution. @@ -148,6 +149,7 @@ public class BetaDistribution extends AbstractRealDistribution { } /** {@inheritDoc} */ + @Override public double density(double x) { final double logDensity = logDensity(x); return logDensity == Double.NEGATIVE_INFINITY ? 0 : FastMath.exp(logDensity); @@ -177,6 +179,7 @@ public class BetaDistribution extends AbstractRealDistribution { } /** {@inheritDoc} */ + @Override public double cumulativeProbability(double x) { if (x <= 0) { return 0; @@ -205,6 +208,7 @@ public class BetaDistribution extends AbstractRealDistribution { * For first shape parameter {@code alpha} and second shape parameter * {@code beta}, the mean is {@code alpha / (alpha + beta)}. */ + @Override public double getNumericalMean() { final double a = getAlpha(); return a / (a + getBeta()); @@ -217,6 +221,7 @@ public class BetaDistribution extends AbstractRealDistribution { * {@code beta}, the variance is * {@code (alpha * beta) / [(alpha + beta)^2 * (alpha + beta + 1)]}. */ + @Override public double getNumericalVariance() { final double a = getAlpha(); final double b = getBeta(); @@ -231,6 +236,7 @@ public class BetaDistribution extends AbstractRealDistribution { * * @return lower bound of the support (always 0) */ + @Override public double getSupportLowerBound() { return 0; } @@ -242,16 +248,19 @@ public class BetaDistribution extends AbstractRealDistribution { * * @return upper bound of the support (always 1) */ + @Override public double getSupportUpperBound() { return 1; } /** {@inheritDoc} */ + @Override public boolean isSupportLowerBoundInclusive() { return false; } /** {@inheritDoc} */ + @Override public boolean isSupportUpperBoundInclusive() { return false; } @@ -263,7 +272,132 @@ public class BetaDistribution extends AbstractRealDistribution { * * @return {@code true} */ + @Override public boolean isSupportConnected() { return true; } + + + /** {@inheritDoc} + * <p> + * Sampling is performed using Cheng algorithms: + * </p> + * <p> + * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.". + * Communications of the ACM, 21, 317ââ¬â322, 1978. + * </p> + */ + @Override + public double sample() { + return ChengBetaSampler.sample(random, alpha, beta); + } + + /** Utility class implementing Cheng's algorithms for beta distribution sampling. + * <p> + * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.". + * Communications of the ACM, 21, 317ââ¬â322, 1978. + * </p> + * @since 3.6 + */ + private static final class ChengBetaSampler { + + /** + * Returns one sample using Cheng's sampling algorithm. + * @param random random generator to use + * @param alpha distribution first shape parameter + * @param beta distribution second shape parameter + * @return sampled value + */ + static double sample(RandomGenerator random, final double alpha, final double beta) { + final double a = FastMath.min(alpha, beta); + final double b = FastMath.max(alpha, beta); + + if (a > 1) { + return algorithmBB(random, alpha, a, b); + } else { + return algorithmBC(random, alpha, b, a); + } + } + + /** + * Returns one sample using Cheng's BB algorithm, when both α and β are greater than 1. + */ + private static double algorithmBB(RandomGenerator random, + final double a0, + final double a, + final double b) { + final double alpha = a + b; + final double beta = FastMath.sqrt((alpha - 2.) / (2. * a * b - alpha)); + final double gamma = a + 1. / beta; + + double r, w, t; + do { + final double u1 = random.nextDouble(); + final double u2 = random.nextDouble(); + final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1)); + w = a * FastMath.exp(v); + final double z = u1 * u1 * u2; + r = gamma * v - 1.3862944; + final double s = a + r - w; + if (s + 2.609438 >= 5 * z) { + break; + } + + t = FastMath.log(z); + if (s >= t) { + break; + } + } while (r + alpha * (FastMath.log(alpha) - FastMath.log(b + w)) < t); + + w = FastMath.min(w, Double.MAX_VALUE); + return Precision.equals(a, a0) ? w / (b + w) : b / (b + w); + } + + /** + * Returns one sample using Cheng's BC algorithm, when at least one of α and β is smaller than 1. + */ + private static double algorithmBC(RandomGenerator random, + final double a0, + final double a, + final double b) { + final double alpha = a + b; + final double beta = 1. / b; + final double delta = 1. + a - b; + final double k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778); + final double k2 = 0.25 + (0.5 + 0.25 / delta) * b; + + double w; + for (;;) { + final double u1 = random.nextDouble(); + final double u2 = random.nextDouble(); + final double y = u1 * u2; + final double z = u1 * y; + if (u1 < 0.5) { + if (0.25 * u2 + z - y >= k1) { + continue; + } + } else { + if (z <= 0.25) { + final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1)); + w = a * FastMath.exp(v); + break; + } + + if (z >= k2) { + continue; + } + } + + final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1)); + w = a * FastMath.exp(v); + if (alpha * (FastMath.log(alpha) - FastMath.log(b + w) + v) - 1.3862944 >= FastMath.log(z)) { + break; + } + } + + w = FastMath.min(w, Double.MAX_VALUE); + return Precision.equals(a, a0) ? w / (b + w) : b / (b + w); + } + + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java index 217ae66..3778bfe 100644 --- a/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java +++ b/src/test/java/org/apache/commons/math3/distribution/BetaDistributionTest.java @@ -16,10 +16,22 @@ */ package org.apache.commons.math3.distribution; +import java.util.Arrays; + +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.Well1024a; +import org.apache.commons.math3.random.Well19937a; +import org.apache.commons.math3.stat.StatUtils; +import org.apache.commons.math3.stat.inference.KolmogorovSmirnovTest; +import org.apache.commons.math3.stat.inference.TestUtils; import org.junit.Assert; import org.junit.Test; public class BetaDistributionTest { + + static final double[] alphaBetas = {0.1, 1, 10, 100, 1000}; + static final double epsilon = StatUtils.min(alphaBetas); + @Test public void testCumulative() { double[] x = new double[]{-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1}; @@ -303,4 +315,65 @@ public class BetaDistributionTest { Assert.assertEquals(dist.getNumericalMean(), 2.0 / 7.0, tol); Assert.assertEquals(dist.getNumericalVariance(), 10.0 / (49.0 * 8.0), tol); } + + @Test + public void testMomentsSampling() { + RandomGenerator random = new Well1024a(0x7829862c82fec2dal); + final int numSamples = 1000; + for (final double alpha : alphaBetas) { + for (final double beta : alphaBetas) { + final BetaDistribution betaDistribution = new BetaDistribution(random, alpha, beta); + final double[] observed = new BetaDistribution(alpha, beta).sample(numSamples); + Arrays.sort(observed); + + final String distribution = String.format("Beta(%.2f, %.2f)", alpha, beta); + Assert.assertEquals(String.format("E[%s]", distribution), + betaDistribution.getNumericalMean(), + StatUtils.mean(observed), epsilon); + Assert.assertEquals(String.format("Var[%s]", distribution), + betaDistribution.getNumericalVariance(), + StatUtils.variance(observed), epsilon); + } + } + } + + @Test + public void testGoodnessOfFit() { + RandomGenerator random = new Well19937a(0x237db1db907b089fl); + final int numSamples = 1000; + final double level = 0.01; + for (final double alpha : alphaBetas) { + for (final double beta : alphaBetas) { + final BetaDistribution betaDistribution = new BetaDistribution(random, alpha, beta); + final double[] observed = betaDistribution.sample(numSamples); + Assert.assertFalse("G goodness-of-fit test rejected null at alpha = " + level, + gTest(betaDistribution, observed) < level); + Assert.assertFalse("KS goodness-of-fit test rejected null at alpha = " + level, + new KolmogorovSmirnovTest(random).kolmogorovSmirnovTest(betaDistribution, observed) < level); + } + } + } + + private double gTest(final RealDistribution expectedDistribution, final double[] values) { + final int numBins = values.length / 30; + final double[] breaks = new double[numBins]; + for (int b = 0; b < breaks.length; b++) { + breaks[b] = expectedDistribution.inverseCumulativeProbability((double) b / numBins); + } + + final long[] observed = new long[numBins]; + for (final double value : values) { + int b = 0; + do { + b++; + } while (b < numBins && value >= breaks[b]); + + observed[b - 1]++; + } + + final double[] expected = new double[numBins]; + Arrays.fill(expected, (double) values.length / numBins); + + return TestUtils.gTest(expected, observed); + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/f5d028ca/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java b/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java index a0b6e26..839b1e6 100644 --- a/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java +++ b/src/test/java/org/apache/commons/math3/random/RandomDataGeneratorTest.java @@ -83,7 +83,7 @@ public class RandomDataGeneratorTest { long y = randomData.nextLong(Long.MIN_VALUE, Long.MAX_VALUE); Assert.assertFalse(x == y); } - + @Test public void testNextUniformExtremeValues() { double x = randomData.nextUniform(-Double.MAX_VALUE, Double.MAX_VALUE); @@ -94,7 +94,7 @@ public class RandomDataGeneratorTest { Assert.assertFalse(Double.isInfinite(x)); Assert.assertFalse(Double.isInfinite(y)); } - + @Test public void testNextIntIAE() { try { @@ -104,7 +104,7 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test public void testNextIntNegativeToPositiveRange() { for (int i = 0; i < 5; i++) { @@ -113,7 +113,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextIntNegativeRange() { for (int i = 0; i < 5; i++) { checkNextIntUniform(-7, -4); @@ -122,7 +122,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextIntPositiveRange() { for (int i = 0; i < 5; i++) { checkNextIntUniform(0, 3); @@ -148,7 +148,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.001); } @@ -169,7 +169,7 @@ public class RandomDataGeneratorTest { (((double) upper) - ((double) lower)); Assert.assertTrue(ratio > 0.99999); } - + @Test public void testNextLongIAE() { try { @@ -188,7 +188,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextLongNegativeRange() { for (int i = 0; i < 5; i++) { checkNextLongUniform(-7, -4); @@ -197,7 +197,7 @@ public class RandomDataGeneratorTest { } } - @Test + @Test public void testNextLongPositiveRange() { for (int i = 0; i < 5; i++) { checkNextLongUniform(0, 3); @@ -223,7 +223,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.01); } @@ -244,7 +244,7 @@ public class RandomDataGeneratorTest { (((double) upper) - ((double) lower)); Assert.assertTrue(ratio > 0.99999); } - + @Test public void testNextSecureLongIAE() { try { @@ -254,7 +254,7 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test @Retry(3) public void testNextSecureLongNegativeToPositiveRange() { @@ -263,7 +263,7 @@ public class RandomDataGeneratorTest { checkNextSecureLongUniform(-3, 6); } } - + @Test @Retry(3) public void testNextSecureLongNegativeRange() { @@ -272,7 +272,7 @@ public class RandomDataGeneratorTest { checkNextSecureLongUniform(-15, -2); } } - + @Test @Retry(3) public void testNextSecureLongPositiveRange() { @@ -281,7 +281,7 @@ public class RandomDataGeneratorTest { checkNextSecureLongUniform(2, 12); } } - + private void checkNextSecureLongUniform(int min, int max) { final Frequency freq = new Frequency(); for (int i = 0; i < smallSampleSize; i++) { @@ -298,7 +298,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.0001); } @@ -311,7 +311,7 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test @Retry(3) public void testNextSecureIntNegativeToPositiveRange() { @@ -320,7 +320,7 @@ public class RandomDataGeneratorTest { checkNextSecureIntUniform(-3, 6); } } - + @Test @Retry(3) public void testNextSecureIntNegativeRange() { @@ -329,8 +329,8 @@ public class RandomDataGeneratorTest { checkNextSecureIntUniform(-15, -2); } } - - @Test + + @Test @Retry(3) public void testNextSecureIntPositiveRange() { for (int i = 0; i < 5; i++) { @@ -338,7 +338,7 @@ public class RandomDataGeneratorTest { checkNextSecureIntUniform(2, 12); } } - + private void checkNextSecureIntUniform(int min, int max) { final Frequency freq = new Frequency(); for (int i = 0; i < smallSampleSize; i++) { @@ -355,11 +355,11 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = 1d / len; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.0001); } - - + + /** * Make sure that empirical distribution of random Poisson(4)'s has P(X <= @@ -386,7 +386,7 @@ public class RandomDataGeneratorTest { } catch (MathIllegalArgumentException ex) { // ignored } - + final double mean = 4.0d; final int len = 5; PoissonDistribution poissonDistribution = new PoissonDistribution(mean); @@ -403,7 +403,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < len; i++) { expected[i] = poissonDistribution.probability(i + 1) * largeSampleSize; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.0001); } @@ -683,35 +683,35 @@ public class RandomDataGeneratorTest { // ignored } } - + @Test public void testNextUniformUniformPositiveBounds() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(0, 10); } } - + @Test public void testNextUniformUniformNegativeToPositiveBounds() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(-3, 5); } } - + @Test public void testNextUniformUniformNegaiveBounds() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(-7, -3); } } - + @Test public void testNextUniformUniformMaximalInterval() { for (int i = 0; i < 5; i++) { checkNextUniformUniform(-Double.MAX_VALUE, Double.MAX_VALUE); } } - + private void checkNextUniformUniform(double min, double max) { // Set up bin bounds - min, binBound[0], ..., binBound[binCount-2], max final int binCount = 5; @@ -721,7 +721,7 @@ public class RandomDataGeneratorTest { for (int i = 1; i < binCount - 1; i++) { binBounds[i] = binBounds[i - 1] + binSize; // + instead of * to avoid overflow in extreme case } - + final Frequency freq = new Frequency(); for (int i = 0; i < smallSampleSize; i++) { final double value = randomData.nextUniform(min, max); @@ -733,7 +733,7 @@ public class RandomDataGeneratorTest { } freq.addValue(j); } - + final long[] observed = new long[binCount]; for (int i = 0; i < binCount; i++) { observed[i] = freq.getCount(i); @@ -742,7 +742,7 @@ public class RandomDataGeneratorTest { for (int i = 0; i < binCount; i++) { expected[i] = 1d / binCount; } - + TestUtils.assertChiSquareAccept(expected, observed, 0.01); } @@ -951,7 +951,7 @@ public class RandomDataGeneratorTest { int[] perm = randomData.nextPermutation(3, 3); observed[findPerm(p, perm)]++; } - + String[] labels = {"{0, 1, 2}", "{ 0, 2, 1 }", "{ 1, 0, 2 }", "{ 1, 2, 0 }", "{ 2, 0, 1 }", "{ 2, 1, 0 }"}; TestUtils.assertChiSquareAccept(labels, expected, observed, 0.001); @@ -1010,30 +1010,6 @@ public class RandomDataGeneratorTest { } @Test - public void testNextInversionDeviate() { - // Set the seed for the default random generator - RandomGenerator rg = new Well19937c(100); - RandomDataGenerator rdg = new RandomDataGenerator(rg); - double[] quantiles = new double[10]; - for (int i = 0; i < 10; i++) { - quantiles[i] = rdg.nextUniform(0, 1); - } - // Reseed again so the inversion generator gets the same sequence - rg.setSeed(100); - BetaDistribution betaDistribution = new BetaDistribution(rg, 2, 4, - BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); - /* - * Generate a sequence of deviates using inversion - the distribution function - * evaluated at the random value from the distribution should match the uniform - * random value used to generate it, which is stored in the quantiles[] array. - */ - for (int i = 0; i < 10; i++) { - double value = betaDistribution.sample(); - Assert.assertEquals(betaDistribution.cumulativeProbability(value), quantiles[i], 10E-9); - } - } - - @Test public void testNextBeta() { double[] quartiles = TestUtils.getDistributionQuartiles(new BetaDistribution(2,5)); long[] counts = new long[4];