MATH-1158. New "Sampler" interface and "createSampler" factory method defined in "RealDistribution" interface.
Default sampling implementation defined in "AbstractRealDistribution" (using the "inversion method"). Overridden in "NormalDistribution" (code copied from "BitsStreamGenerator") and "BetaDistribution". Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/6c94c16e Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/6c94c16e Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/6c94c16e Branch: refs/heads/feature-MATH-1158 Commit: 6c94c16e46127c05fbf70b1fdeb0cd8c2ac98537 Parents: ce8c82f Author: Gilles <er...@apache.org> Authored: Fri Mar 11 02:05:49 2016 +0100 Committer: Gilles <er...@apache.org> Committed: Fri Mar 11 02:05:49 2016 +0100 ---------------------------------------------------------------------- .../distribution/AbstractRealDistribution.java | 29 +++- .../math4/distribution/BetaDistribution.java | 143 +++++++++++-------- .../math4/distribution/NormalDistribution.java | 42 ++++++ .../math4/distribution/RealDistribution.java | 25 ++++ .../distribution/BetaDistributionTest.java | 19 ++- .../ConstantRealDistributionTest.java | 9 ++ .../RealDistributionAbstractTest.java | 24 ++++ 7 files changed, 228 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c94c16e/src/main/java/org/apache/commons/math4/distribution/AbstractRealDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/AbstractRealDistribution.java b/src/main/java/org/apache/commons/math4/distribution/AbstractRealDistribution.java index e43a0fa..4afdb6d 100644 --- a/src/main/java/org/apache/commons/math4/distribution/AbstractRealDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/AbstractRealDistribution.java @@ -25,6 +25,7 @@ import org.apache.commons.math4.exception.NumberIsTooLargeException; import org.apache.commons.math4.exception.OutOfRangeException; import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.math4.random.RandomGenerator; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.util.FastMath; /** @@ -32,10 +33,18 @@ import org.apache.commons.math4.util.FastMath; * Default implementations are provided for some of the methods * that do not vary from distribution to distribution. * + * <p> + * This base class provides a default factory method for creating + * a {@link RealDistribution.Sampler sampler instance} that uses the + * <a href="http://en.wikipedia.org/wiki/Inverse_transform_sampling"> + * inversion method</a> for generating random samples that follow the + * distribution. + * </p> + * * @since 3.0 */ public abstract class AbstractRealDistribution -implements RealDistribution, Serializable { + implements RealDistribution, Serializable { /** Default absolute accuracy for inverse cumulative computation. */ public static final double SOLVER_DEFAULT_ABSOLUTE_ACCURACY = 1e-6; /** Serializable version identifier */ @@ -45,12 +54,14 @@ implements RealDistribution, Serializable { * RNG instance used to generate samples from the distribution. * @since 3.1 */ + @Deprecated protected final RandomGenerator random; /** * @param rng Random number generator. * @since 3.1 */ + @Deprecated protected AbstractRealDistribution(RandomGenerator rng) { random = rng; } @@ -210,6 +221,7 @@ implements RealDistribution, Serializable { /** {@inheritDoc} */ @Override + @Deprecated public void reseedRandomGenerator(long seed) { random.setSeed(seed); } @@ -223,6 +235,7 @@ implements RealDistribution, Serializable { * </a> */ @Override + @Deprecated public double sample() { return inverseCumulativeProbability(random.nextDouble()); } @@ -234,6 +247,7 @@ implements RealDistribution, Serializable { * {@link #sample()} in a loop. */ @Override + @Deprecated public double[] sample(int sampleSize) { if (sampleSize <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, @@ -266,5 +280,16 @@ implements RealDistribution, Serializable { public double logDensity(double x) { return FastMath.log(density(x)); } -} + /**{@inheritDoc} */ + @Override + public RealDistribution.Sampler createSampler(final UniformRandomProvider rng) { + return new RealDistribution.Sampler() { + /** {@inheritDoc} */ + @Override + public double sample() { + return inverseCumulativeProbability(rng.nextDouble()); + } + }; + } +} http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c94c16e/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java b/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java index c8440bd..3009abf 100644 --- a/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/BetaDistribution.java @@ -20,6 +20,7 @@ import org.apache.commons.math4.exception.NumberIsTooSmallException; import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.math4.random.RandomGenerator; import org.apache.commons.math4.random.Well19937c; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.special.Beta; import org.apache.commons.math4.special.Gamma; import org.apache.commons.math4.util.FastMath; @@ -96,6 +97,7 @@ public class BetaDistribution extends AbstractRealDistribution { * @param beta Second shape parameter (must be positive). * @since 3.3 */ + @Deprecated public BetaDistribution(RandomGenerator rng, double alpha, double beta) { this(rng, alpha, beta, DEFAULT_INVERSE_ABSOLUTE_ACCURACY); } @@ -111,6 +113,7 @@ public class BetaDistribution extends AbstractRealDistribution { * {@link #DEFAULT_INVERSE_ABSOLUTE_ACCURACY}). * @since 3.1 */ + @Deprecated public BetaDistribution(RandomGenerator rng, double alpha, double beta, @@ -267,69 +270,94 @@ public class BetaDistribution extends AbstractRealDistribution { return true; } - /** {@inheritDoc} - * <p> - * Sampling is performed using Cheng algorithms: - * </p> - * <p> - * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.". - * Communications of the ACM, 21, 317â322, 1978. - * </p> - */ + /** + * {@inheritDoc} + * + * Sampling is performed using Cheng's algorithm: + * <blockquote> + * <pre> + * R. C. H. Cheng, + * "Generating beta variates with nonintegral shape parameters", + * Communications of the ACM, 21, 317-322, 1978. + * </pre> + * </blockquote> + */ @Override - public double sample() { - return ChengBetaSampler.sample(random, alpha, beta); + public RealDistribution.Sampler createSampler(final UniformRandomProvider rng) { + return new ChengBetaSampler(rng, alpha, beta); } - /** Utility class implementing Cheng's algorithms for beta distribution sampling. - * <p> - * R. C. H. Cheng, "Generating beta variates with nonintegral shape parameters.". - * Communications of the ACM, 21, 317â322, 1978. - * </p> + /** + * Utility class implementing Cheng's algorithms for beta distribution sampling. + * + * <blockquote> + * <pre> + * R. C. H. Cheng, + * "Generating beta variates with nonintegral shape parameters", + * Communications of the ACM, 21, 317-322, 1978. + * </pre> + * </blockquote> + * * @since 3.6 */ - private static final class ChengBetaSampler { + private static class ChengBetaSampler implements RealDistribution.Sampler { + /** RNG (uniform distribution. */ + private final UniformRandomProvider rng; + /** First shape parameter. */ + private final double alphaShape; + /** Second shape parameter. */ + private final double betaShape; /** - * Returns one sample using Cheng's sampling algorithm. - * @param random random generator to use - * @param alpha distribution first shape parameter - * @param beta distribution second shape parameter - * @return sampled value + * Creates a sampler instance. + * + * @param rng Generator. + * @param alpha Distribution first shape parameter. + * @param beta Distribution second shape parameter. */ - static double sample(RandomGenerator random, final double alpha, final double beta) { - final double a = FastMath.min(alpha, beta); - final double b = FastMath.max(alpha, beta); + ChengBetaSampler(UniformRandomProvider generator, + double alpha, + double beta) { + rng = generator; + alphaShape = alpha; + betaShape = beta; + } + + /** {@inheritDoc} */ + @Override + public double sample() { + final double a = FastMath.min(alphaShape, betaShape); + final double b = FastMath.max(alphaShape, betaShape); if (a > 1) { - return algorithmBB(random, alpha, a, b); + return algorithmBB(alphaShape, a, b); } else { - return algorithmBC(random, alpha, b, a); + return algorithmBC(alphaShape, b, a); } } /** - * Returns one sample using Cheng's BB algorithm, when both α and β are greater than 1. - * @param random random generator to use - * @param a0 distribution first shape parameter (α) - * @param a min(α, β) where α, β are the two distribution shape parameters - * @param b max(α, β) where α, β are the two distribution shape parameters - * @return sampled value + * Computes one sample using Cheng's BB algorithm, when α and + * β are both larger than 1. + * + * @param a0 First shape parameter (α). + * @param a Min(α, β) where α, β are the shape parameters. + * @param b Max(α, β) where α, β are the shape parameters. + * @return a random sample. */ - private static double algorithmBB(RandomGenerator random, - final double a0, - final double a, - final double b) { + private double algorithmBB(double a0, + double a, + double b) { final double alpha = a + b; - final double beta = FastMath.sqrt((alpha - 2.) / (2. * a * b - alpha)); - final double gamma = a + 1. / beta; + final double beta = FastMath.sqrt((alpha - 2) / (2 * a * b - alpha)); + final double gamma = a + 1 / beta; double r; double w; double t; do { - final double u1 = random.nextDouble(); - final double u2 = random.nextDouble(); + final double u1 = rng.nextDouble(); + final double u2 = rng.nextDouble(); final double v = beta * (FastMath.log(u1) - FastMath.log1p(-u1)); w = a * FastMath.exp(v); final double z = u1 * u1 * u2; @@ -346,31 +374,32 @@ public class BetaDistribution extends AbstractRealDistribution { } while (r + alpha * (FastMath.log(alpha) - FastMath.log(b + w)) < t); w = FastMath.min(w, Double.MAX_VALUE); + return Precision.equals(a, a0) ? w / (b + w) : b / (b + w); } /** - * Returns one sample using Cheng's BC algorithm, when at least one of α and β is smaller than 1. - * @param random random generator to use - * @param a0 distribution first shape parameter (α) - * @param a max(α, β) where α, β are the two distribution shape parameters - * @param b min(α, β) where α, β are the two distribution shape parameters - * @return sampled value + * Computes one sample using Cheng's BC algorithm, when at least one + * of α and β is smaller than 1. + * + * @param a0 First shape parameter (α). + * @param a Max(α, β) where α, β are the shape parameters. + * @param b min(α, β) where α, β are the shape parameters. + * @return a random sample. */ - private static double algorithmBC(RandomGenerator random, - final double a0, - final double a, - final double b) { + private double algorithmBC(double a0, + double a, + double b) { final double alpha = a + b; - final double beta = 1. / b; - final double delta = 1. + a - b; + final double beta = 1 / b; + final double delta = 1 + a - b; final double k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778); final double k2 = 0.25 + (0.5 + 0.25 / delta) * b; double w; - for (;;) { - final double u1 = random.nextDouble(); - final double u2 = random.nextDouble(); + while (true) { + final double u1 = rng.nextDouble(); + final double u2 = rng.nextDouble(); final double y = u1 * u2; final double z = u1 * y; if (u1 < 0.5) { @@ -397,8 +426,8 @@ public class BetaDistribution extends AbstractRealDistribution { } w = FastMath.min(w, Double.MAX_VALUE); + return Precision.equals(a, a0) ? w / (b + w) : b / (b + w); } - } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c94c16e/src/main/java/org/apache/commons/math4/distribution/NormalDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/NormalDistribution.java b/src/main/java/org/apache/commons/math4/distribution/NormalDistribution.java index 5216867..01de515 100644 --- a/src/main/java/org/apache/commons/math4/distribution/NormalDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/NormalDistribution.java @@ -23,6 +23,7 @@ import org.apache.commons.math4.exception.OutOfRangeException; import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.math4.random.RandomGenerator; import org.apache.commons.math4.random.Well19937c; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.special.Erf; import org.apache.commons.math4.util.FastMath; @@ -116,6 +117,7 @@ public class NormalDistribution extends AbstractRealDistribution { * @throws NotStrictlyPositiveException if {@code sd <= 0}. * @since 3.3 */ + @Deprecated public NormalDistribution(RandomGenerator rng, double mean, double sd) throws NotStrictlyPositiveException { this(rng, mean, sd, DEFAULT_INVERSE_ABSOLUTE_ACCURACY); @@ -131,6 +133,7 @@ public class NormalDistribution extends AbstractRealDistribution { * @throws NotStrictlyPositiveException if {@code sd <= 0}. * @since 3.1 */ + @Deprecated public NormalDistribution(RandomGenerator rng, double mean, double sd, @@ -291,7 +294,46 @@ public class NormalDistribution extends AbstractRealDistribution { /** {@inheritDoc} */ @Override + @Deprecated public double sample() { return standardDeviation * random.nextGaussian() + mean; } + + /** {@inheritDoc} */ + @Override + public RealDistribution.Sampler createSampler(final UniformRandomProvider rng) { + return new RealDistribution.Sampler() { + /** Next gaussian. */ + private double nextGaussian = Double.NaN; + + /** {@inheritDoc} */ + @Override + public double sample() { + final double random; + if (Double.isNaN(nextGaussian)) { + // Generate a pair of Gaussian numbers. + + final double x = rng.nextDouble(); + final double y = rng.nextDouble(); + final double alpha = 2 * FastMath.PI * x; + final double r = FastMath.sqrt(-2 * FastMath.log(y)); + + // Return the first element of the generated pair. + random = r * FastMath.cos(alpha); + + // Keep second element of the pair for next invocation. + nextGaussian = r * FastMath.sin(alpha); + } else { + // Use the second element of the pair (generated at the + // previous invocation). + random = nextGaussian; + + // Both elements of the pair have been used. + nextGaussian = Double.NaN; + } + + return standardDeviation * random + mean; + } + }; + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c94c16e/src/main/java/org/apache/commons/math4/distribution/RealDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/RealDistribution.java b/src/main/java/org/apache/commons/math4/distribution/RealDistribution.java index 5ba3717..f0e453a 100644 --- a/src/main/java/org/apache/commons/math4/distribution/RealDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/RealDistribution.java @@ -18,6 +18,7 @@ package org.apache.commons.math4.distribution; import org.apache.commons.math4.exception.NumberIsTooLargeException; import org.apache.commons.math4.exception.OutOfRangeException; +import org.apache.commons.math4.rng.UniformRandomProvider; /** * Base interface for distributions on the reals. @@ -164,6 +165,7 @@ public interface RealDistribution { * * @param seed the new seed */ + @Deprecated void reseedRandomGenerator(long seed); /** @@ -171,6 +173,7 @@ public interface RealDistribution { * * @return a random value. */ + @Deprecated double sample(); /** @@ -181,5 +184,27 @@ public interface RealDistribution { * @throws org.apache.commons.math4.exception.NotStrictlyPositiveException * if {@code sampleSize} is not positive */ + @Deprecated double[] sample(int sampleSize); + + /** + * Creates a sampler. + * + * @param rng Generator of uniformly distributed numbers. + * @return a sampler that produces random numbers according this + * distribution. + */ + Sampler createSampler(UniformRandomProvider rng); + + /** + * Sampling functionality. + */ + interface Sampler { + /** + * Generates a random value sampled from this distribution. + * + * @return a random value. + */ + double sample(); + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c94c16e/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java index 632e99f..d3e167d 100644 --- a/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java +++ b/src/test/java/org/apache/commons/math4/distribution/BetaDistributionTest.java @@ -22,6 +22,8 @@ import org.apache.commons.math4.distribution.BetaDistribution; import org.apache.commons.math4.random.RandomGenerator; import org.apache.commons.math4.random.Well1024a; import org.apache.commons.math4.random.Well19937a; +import org.apache.commons.math4.rng.RandomSource; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.stat.StatUtils; import org.apache.commons.math4.stat.inference.KolmogorovSmirnovTest; import org.apache.commons.math4.stat.inference.TestUtils; @@ -340,13 +342,23 @@ public class BetaDistributionTest { @Test public void testGoodnessOfFit() { - RandomGenerator random = new Well19937a(0x237db1db907b089fl); + final UniformRandomProvider rng = RandomSource.create(RandomSource.WELL_19937_A, + 123456789L); + final RandomGenerator random = new Well19937a(0x237db1db907b089fL); + final int numSamples = 1000; final double level = 0.01; for (final double alpha : alphaBetas) { for (final double beta : alphaBetas) { - final BetaDistribution betaDistribution = new BetaDistribution(random, alpha, beta); - final double[] observed = betaDistribution.sample(numSamples); + final BetaDistribution betaDistribution = new BetaDistribution(alpha, beta); + + final RealDistribution.Sampler sampler = betaDistribution.createSampler(rng); + final double[] observed = new double[numSamples]; + + for (int i = 0; i < numSamples; i++) { + observed[i] = sampler.sample(); + } + Assert.assertFalse("G goodness-of-fit test rejected null at alpha = " + level, gTest(betaDistribution, observed) < level); Assert.assertFalse("KS goodness-of-fit test rejected null at alpha = " + level, @@ -377,5 +389,4 @@ public class BetaDistributionTest { return TestUtils.gTest(expected, observed); } - } http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c94c16e/src/test/java/org/apache/commons/math4/distribution/ConstantRealDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/ConstantRealDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/ConstantRealDistributionTest.java index 2d65447..275d29a 100644 --- a/src/test/java/org/apache/commons/math4/distribution/ConstantRealDistributionTest.java +++ b/src/test/java/org/apache/commons/math4/distribution/ConstantRealDistributionTest.java @@ -83,11 +83,20 @@ public class ConstantRealDistributionTest extends RealDistributionAbstractTest { } @Test + @Override public void testSampling() { ConstantRealDistribution dist = new ConstantRealDistribution(0); for (int i = 0; i < 10; i++) { Assert.assertEquals(0, dist.sample(), 0); } + } + @Test + @Override + public void testSampler() { + ConstantRealDistribution dist = new ConstantRealDistribution(0); + for (int i = 0; i < 10; i++) { + Assert.assertEquals(0, dist.sample(), 0); + } } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/6c94c16e/src/test/java/org/apache/commons/math4/distribution/RealDistributionAbstractTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/RealDistributionAbstractTest.java b/src/test/java/org/apache/commons/math4/distribution/RealDistributionAbstractTest.java index 9fffeaf..fa1542f 100644 --- a/src/test/java/org/apache/commons/math4/distribution/RealDistributionAbstractTest.java +++ b/src/test/java/org/apache/commons/math4/distribution/RealDistributionAbstractTest.java @@ -32,6 +32,7 @@ import org.apache.commons.math4.analysis.integration.IterativeLegendreGaussInteg import org.apache.commons.math4.distribution.RealDistribution; import org.apache.commons.math4.exception.MathIllegalArgumentException; import org.apache.commons.math4.exception.NumberIsTooLargeException; +import org.apache.commons.math4.rng.RandomSource; import org.apache.commons.math4.util.FastMath; import org.junit.After; import org.junit.Assert; @@ -339,6 +340,29 @@ public abstract class RealDistributionAbstractTest { TestUtils.assertChiSquareAccept(expected, counts, 0.001); } + + // New design + @Test + public void testSampler() { + final int sampleSize = 1000; + final RealDistribution.Sampler sampler = + distribution.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 123456789L)); + + final double[] sample = new double[sampleSize]; + for (int i = 0; i < sampleSize; i++) { + sample[i] = sampler.sample(); + } + + final double[] quartiles = TestUtils.getDistributionQuartiles(distribution); + final double[] expected = {250, 250, 250, 250}; + final long[] counts = new long[4]; + + for (int i = 0; i < sampleSize; i++) { + TestUtils.updateCounts(sample[i], counts, quartiles); + } + TestUtils.assertChiSquareAccept(expected, counts, 0.001); + } + /** * Verify that density integrals match the distribution. * The (filtered, sorted) cumulativeTestPoints array is used to source