Repository: commons-math Updated Branches: refs/heads/MATH_3_X ab2b01168 -> 321269ed9
[MATH-1220] Improve performance of ZipfDistribution.sample(). Thanks to Otmar Ertl. Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/321269ed Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/321269ed Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/321269ed Branch: refs/heads/MATH_3_X Commit: 321269ed9aa84d15b18296ee6e73d53489efb622 Parents: ab2b011 Author: Thomas Neidhart <thomas.neidh...@gmail.com> Authored: Fri May 1 13:24:48 2015 +0200 Committer: Thomas Neidhart <thomas.neidh...@gmail.com> Committed: Fri May 1 13:24:48 2015 +0200 ---------------------------------------------------------------------- pom.xml | 3 + src/changes/changes.xml | 3 + .../math3/distribution/ZipfDistribution.java | 152 ++++++++++++- .../distribution/ZipfDistributionTest.java | 228 ++++++++++++++++++- 4 files changed, 378 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 7c55d33..6820934 100644 --- a/pom.xml +++ b/pom.xml @@ -207,6 +207,9 @@ <name>Ole Ersoy</name> </contributor> <contributor> + <name>Otmar Ertl</name> + </contributor> + <contributor> <name>Ajo Fod</name> </contributor> <contributor> http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/src/changes/changes.xml ---------------------------------------------------------------------- diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 2e818b2..274cb50 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -51,6 +51,9 @@ If the output is not quite correct, check for invisible trailing spaces! </properties> <body> <release version="3.6" date="XXXX-XX-XX" description=""> + <action dev="tn" type="fix" issue="MATH-1220" due-to="Otmar Ertl"> + Improve performance of "ZipfDistribution#sample()" by using a rejection algorithm. + </action> <action dev="tn" type="fix" issue="MATH-1153" due-to="Sergei Lebedev"> Improve performance of "BetaDistribution#sample()" by using Cheng's algorithm. </action> http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java b/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java index 18cb2f4..3755407 100644 --- a/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java +++ b/src/main/java/org/apache/commons/math3/distribution/ZipfDistribution.java @@ -43,6 +43,8 @@ public class ZipfDistribution extends AbstractIntegerDistribution { private double numericalVariance = Double.NaN; /** Whether or not the numerical variance has been calculated */ private boolean numericalVarianceIsCalculated = false; + /** The sampler to be used for the sample() method */ + private transient ZipfRejectionSampler sampler; /** * Create a new Zipf distribution with the given number of elements and @@ -258,5 +260,153 @@ public class ZipfDistribution extends AbstractIntegerDistribution { public boolean isSupportConnected() { return true; } -} + /** + * {@inheritDoc} + * <p> + * An instrumental distribution g(k) is used to generate random values by + * rejection sampling. g(k) is defined as g(1):= 1 and g(k) := I(-s,k-1/2,k+1/2) + * for k larger than 1, where s denotes the exponent of the Zipf distribution + * and I(r,a,b) is the integral of x^r for x from a to b. + * <p> + * Since 1^x^s is a convex function, Jensens's inequality gives + * I(-s,k-1/2,k+1/2) >= 1/k^s for all positive k and non-negative s. + * In order to limit the rejection rate for large exponents s, + * the instrumental distribution weight is differently defined for value 1. + */ + @Override + public int sample() { + if (sampler == null) { + sampler = new ZipfRejectionSampler(numberOfElements, exponent); + } + return sampler.sample(random); + } + + /** + * Utility class implementing a rejection sampling method for a discrete, + * bounded Zipf distribution. + * + * @since 3.6 + */ + static final class ZipfRejectionSampler { + + /** Number of elements. */ + private final int numberOfElements; + /** Exponent parameter of the distribution. */ + private final double exponent; + /** Cached tail weight of instrumental distribution used for rejection sampling */ + private double instrumentalDistributionTailWeight = Double.NaN; + + ZipfRejectionSampler(final int numberOfElements, final double exponent) { + this.numberOfElements = numberOfElements; + this.exponent = exponent; + } + + int sample(final RandomGenerator random) { + if (Double.isNaN(instrumentalDistributionTailWeight)) { + instrumentalDistributionTailWeight = integratePowerFunction(-exponent, 1.5, numberOfElements+0.5); + } + + while(true) { + final double randomValue = random.nextDouble()*(instrumentalDistributionTailWeight + 1.); + if (randomValue < instrumentalDistributionTailWeight) { + final double q = randomValue / instrumentalDistributionTailWeight; + final int sample = sampleFromInstrumentalDistributionTail(q); + if (random.nextDouble() < acceptanceRateForTailSample(sample)) { + return sample; + } + } + else { + return 1; + } + } + } + + /** + * Returns a sample from the instrumental distribution tail for a given + * uniformly distributed random value. + * + * @param q a uniformly distributed random value taken from [0,1] + * @return a sample in the range [2, {@link #numberOfElements}] + */ + int sampleFromInstrumentalDistributionTail(double q) { + final double a = 1.5; + final double b = numberOfElements + 0.5; + final double logBdviA = FastMath.log(b / a); + + final int result = (int) (a * FastMath.exp(logBdviA * helper1(q, logBdviA * (1. - exponent))) + 0.5); + if (result < 2) { + return 2; + } + if (result > numberOfElements) { + return numberOfElements; + } + return result; + } + + /** + * Helper function that calculates log((1-q)+q*exp(x))/x. + * <p> + * A Taylor series expansion is used, if x is close to 0. + * + * @param q a value in the range [0,1] + * @param + * @return log((1-q)+q*exp(x))/x + */ + static double helper1(final double q, final double x) { + if (Math.abs(x) > 1e-8) { + return FastMath.log((1.-q)+q*FastMath.exp(x))/x; + } + else { + return q*(1.+(1./2.)*x*(1.-q)*(1+(1./3.)*x*((1.-2.*q) + (1./4.)*x*(6*q*q*(q-1)+1)))); + } + } + + /** + * Helper function to calculate (exp(x)-1)/x. + * <p> + * A Taylor series expansion is used, if x is close to 0. + * + * @return (exp(x)-1)/x if x is non-zero, 1 if x=0 + */ + static double helper2(final double x) { + if (FastMath.abs(x)>1e-8) { + return FastMath.expm1(x)/x; + } + else { + return 1.+x*(1./2.)*(1.+x*(1./3.)*(1.+x*(1./4.))); + } + } + + /** + * Integrates the power function x^r from x=a to b. + * + * @param r the exponent + * @param a the integral lower bound + * @param b the integral upper bound + * @return the calculated integral value + */ + static double integratePowerFunction(final double r, final double a, final double b) { + final double logA = FastMath.log(a); + final double logBdivA = FastMath.log(b/a); + return FastMath.exp((1.+r)*logA)*helper2((1.+r)*logBdivA)*logBdivA; + + } + + /** + * Calculates the acceptance rate for a sample taken from the tail of the instrumental distribution. + * <p> + * The acceptance rate is given by the ratio k^(-s)/I(-s,k-0.5, k+0.5) + * where I(r,a,b) is the integral of x^r for x from a to b. + * + * @param k the value which has been sampled using the instrumental distribution + * @return the acceptance rate + */ + double acceptanceRateForTailSample(int k) { + final double a = FastMath.log1p(1./(2.*k-1.)); + final double b = FastMath.log1p(2./(2.*k-1.)); + return FastMath.exp((1.-exponent)*a)/(k*b*helper2((1.-exponent)*b)); + } + } + +} http://git-wip-us.apache.org/repos/asf/commons-math/blob/321269ed/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java b/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java index 06ec3c4..3c177ef 100644 --- a/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java +++ b/src/test/java/org/apache/commons/math3/distribution/ZipfDistributionTest.java @@ -17,17 +17,26 @@ package org.apache.commons.math3.distribution; -import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import org.apache.commons.math3.TestUtils; +import org.apache.commons.math3.analysis.UnivariateFunction; +import org.apache.commons.math3.analysis.integration.SimpsonIntegrator; +import org.apache.commons.math3.distribution.ZipfDistribution.ZipfRejectionSampler; +import org.apache.commons.math3.exception.NotStrictlyPositiveException; +import org.apache.commons.math3.random.AbstractRandomGenerator; +import org.apache.commons.math3.random.RandomGenerator; +import org.apache.commons.math3.random.Well1024a; import org.apache.commons.math3.util.FastMath; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; /** * Test cases for {@link ZipfDistribution}. - * Extends IntegerDistributionAbstractTest. See class javadoc for - * IntegerDistributionAbstractTest for details. - * + * Extends IntegerDistributionAbstractTest. + * See class javadoc for IntegerDistributionAbstractTest for details. */ public class ZipfDistributionTest extends IntegerDistributionAbstractTest { @@ -37,7 +46,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest { public ZipfDistributionTest() { setTolerance(1e-12); } - + @Test(expected=NotStrictlyPositiveException.class) public void testPreconditions1() { new ZipfDistribution(0, 1); @@ -62,7 +71,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest { return new int[] {-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; } - /** + /** * Creates the default probability density test expected values. * Reference values are from R, version 2.15.3 (VGAM package 0.9-0). */ @@ -72,7 +81,7 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest { 0.0569028586912, 0.0487738788782, 0.0426771440184, 0.0379352391275, 0.0341417152147, 0}; } - /** + /** * Creates the default logarithmic probability density test expected values. * Reference values are from R, version 2.14.1. */ @@ -119,4 +128,209 @@ public class ZipfDistributionTest extends IntegerDistributionAbstractTest { Assert.assertEquals(dist.getNumericalMean(), FastMath.sqrt(2), tol); Assert.assertEquals(dist.getNumericalVariance(), 0.24264068711928521, tol); } + + /** + * Test sampling for various number of points and exponents. + */ + @Test + public void testSamplingExtended() { + int sampleSize = 1000; + + int[] numPointsValues = { + 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100 + }; + double[] exponentValues = { + 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, + 1. - 1e-9, 1.0, 1. + 1e-9, 1.1, 1.2, 1.3, 1.5, 1.6, 1.7, 1.8, 2.0, + 2.5, 3.0, 4., 5., 6., 7., 8., 9., 10., 20., 30. + }; + + for (int numPoints : numPointsValues) { + for (double exponent : exponentValues) { + double weightSum = 0.; + double[] weights = new double[numPoints]; + for (int i = numPoints; i>=1; i-=1) { + weights[i-1] = Math.pow(i, -exponent); + weightSum += weights[i-1]; + } + + ZipfDistribution distribution = new ZipfDistribution(numPoints, exponent); + distribution.reseedRandomGenerator(6); // use fixed seed, the test is expected to fail for more than 50% of all seeds because each test case can fail with probability 0.001, the chance that all test cases do not fail is 0.999^(32*22) = 0.49442874426 + + double[] expectedCounts = new double[numPoints]; + long[] observedCounts = new long[numPoints]; + for (int i = 0; i < numPoints; i++) { + expectedCounts[i] = sampleSize * (weights[i]/weightSum); + } + int[] sample = distribution.sample(sampleSize); + for (int s : sample) { + observedCounts[s-1]++; + } + TestUtils.assertChiSquareAccept(expectedCounts, observedCounts, 0.001); + } + } + } + + @Test + public void testSamplerIntegratePowerFunction() { + final double tol = 1e-6; + final double[] exponents = { + -1e-5, -1e-4, -1e-3, -1e-2, -1e-1, -1e0, -1e1 + }; + final double[] limits = { + 0.5, 1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6.0, 6.5, 7.0, + 7.5, 8.0, 8.5, 9.0, 9.5, 10.0 + }; + + for (final double exponent : exponents) { + for (int lowerLimitIndex = 0; lowerLimitIndex < limits.length; ++lowerLimitIndex) { + final double lowerLimit = limits[lowerLimitIndex]; + for (int upperLimitIndex = lowerLimitIndex+1; upperLimitIndex < limits.length; ++upperLimitIndex) { + final double upperLimit = limits[upperLimitIndex]; + final double result1 = new SimpsonIntegrator().integrate(10000, new UnivariateFunction() { + public double value(double x) { + return Math.pow(x, exponent); + } + }, lowerLimit, upperLimit); + + final double result2 = ZipfRejectionSampler.integratePowerFunction(exponent, lowerLimit, upperLimit); + + assertEquals(result1, result2, (result1+result2)*tol); + + } + } + } + } + + @Test + public void testSamplerAcceptanceRate() { + final double tol = 1e-12; + final double[] exponents = { + 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2e0, 5e0, 1e1, 1e2, 1e3 + }; + final int[] values = { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14 + }; + final int numberOfElements = 1000; + for (final double exponent : exponents) { + ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent); + for (final int value : values) { + double expected = FastMath.pow(value, -exponent); + double result = sampler.acceptanceRateForTailSample(value) * + ZipfRejectionSampler.integratePowerFunction(-exponent, value - 0.5, value + 0.5); + TestUtils.assertRelativelyEquals(expected, result, tol); + assertTrue(result <=1.); // test Jensen's inequality + } + } + } + + @Test + public void testSamplerInverseInstrumentalDistribution() { + final double tol = 1e-14; + final double[] exponentValues = { + 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2E0, 3e0, 4e0, 5e0, 6., 7., 8., 9., 10., 50. + }; + final double[] qValues = { + 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0 + }; + final int[] numberOfElementsValues = { + 2, 3, 4, 5, 6, 7, 8, 9, 10, 20, 30, 40, 100 + }; + + for (final double exponent : exponentValues) { + for (final int numberOfElements : numberOfElementsValues) { + final ZipfRejectionSampler sampler = new ZipfRejectionSampler(numberOfElements, exponent); + for (final double q : qValues) { + int result = sampler.sampleFromInstrumentalDistributionTail(q); + double total = + ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, numberOfElements + 0.5); + double lowerBound = + ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result - 0.5) / total; + double upperBound = + ZipfRejectionSampler.integratePowerFunction(-exponent, 1.5, result + 0.5) / total; + assertTrue(lowerBound <= q*(1.+tol)); + assertTrue(upperBound >= q*(1.-tol)); + } + } + } + } + + @Test + public void testSamplerHelper1() { + final double tol = 1e-14; + final double[] qValues = { + 0., 1e-12, 1e-11, 1e-10, 1e-9, 9e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, + 1e-3, 1e-2, 1e-1, 1e0 + }; + final double[] xValues = { + -Double.MAX_VALUE, -1e10, -1e9, -1e8, -1e7, -1e6, -1e5, -1e4, -1e3, + -1e2, -1e1, -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7, + -1e-8, -1e-9, -1e-10, -Double.MIN_VALUE, 0.0, Double.MIN_VALUE, + 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, + 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, Double.MAX_VALUE + }; + + for (final double q : qValues) { + for(final double x : xValues) { + double calculated = ZipfRejectionSampler.helper1(q, x); + TestUtils.assertRelativelyEquals((1.-q)+q*Math.exp(x), FastMath.exp(calculated*x), tol); + } + } + } + + @Test + public void testSamplerHelper2() { + final double tol = 1e-12; + final double[] testValues = { + -1e0, -1e-1, -1e-2, -1e-3, -1e-4, -1e-5, -1e-6, -1e-7, -1e-8, + -1e-9, -1e-10, -1e-11, 0., 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, + 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0 + }; + for (double testValue : testValues) { + final double expected = FastMath.expm1(testValue); + TestUtils.assertRelativelyEquals(expected, ZipfRejectionSampler.helper2(testValue)*testValue, tol); + } + } + + @Ignore + @Test + public void testSamplerPerformance() { + int[] numPointsValues = {1, 2, 5, 10, 100, 1000, 10000}; + double[] exponentValues = {1e-3, 1e-2, 1e-1, 1., 2., 5., 10.}; + int numGeneratedSamples = 1000000; + + long sum = 0; + + for (int numPoints : numPointsValues) { + for (double exponent : exponentValues) { + long start = System.currentTimeMillis(); + final int[] randomNumberCounter = new int[1]; + + RandomGenerator randomGenerator = new AbstractRandomGenerator() { + + private final RandomGenerator r = new Well1024a(0L); + + @Override + public void setSeed(long seed) { + } + + @Override + public double nextDouble() { + randomNumberCounter[0]+=1; + return r.nextDouble(); + } + }; + + final ZipfDistribution distribution = new ZipfDistribution(randomGenerator, numPoints, exponent); + for (int i = 0; i < numGeneratedSamples; ++i) { + sum += distribution.sample(); + } + + long end = System.currentTimeMillis(); + System.out.println("n = " + numPoints + ", exponent = " + exponent + ", avg number consumed random values = " + (double)(randomNumberCounter[0])/numGeneratedSamples + ", measured time = " + (end-start)/1000. + "s"); + } + } + System.out.println(sum); + } + }