This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 2112fa3d258a827cba42efbea15743dd64e0225a Author: aherbert <aherb...@apache.org> AuthorDate: Tue Feb 26 14:33:17 2019 +0000 RNG-74: DiscreteUniformSampler can be optimised for the algorithm --- .../distribution/DiscreteUniformSampler.java | 134 ++++++++++++++++----- 1 file changed, 107 insertions(+), 27 deletions(-) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java index 54442a0..00e308b 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/DiscreteUniformSampler.java @@ -27,12 +27,103 @@ import org.apache.commons.rng.UniformRandomProvider; public class DiscreteUniformSampler extends SamplerBase implements DiscreteSampler { - /** Lower bound. */ - private final int lower; - /** Upper bound. */ - private final int upper; - /** Underlying source of randomness. */ - private final UniformRandomProvider rng; + + /** The appropriate uniform sampler for the parameters. */ + private final DiscreteSampler delegate; + + /** + * Base class for a sampler from a discrete uniform distribution. + */ + private abstract static class AbstractDiscreteUniformSampler + implements DiscreteSampler { + + /** Underlying source of randomness. */ + protected final UniformRandomProvider rng; + /** Lower bound. */ + protected final int lower; + + /** + * @param rng Generator of uniformly distributed random numbers. + * @param lower Lower bound (inclusive) of the distribution. + */ + AbstractDiscreteUniformSampler(UniformRandomProvider rng, + int lower) { + this.rng = rng; + this.lower = lower; + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return "Uniform deviate [" + rng.toString() + "]"; + } + } + + /** + * Discrete uniform distribution sampler when the range between lower and upper is small + * enough to fit in a positive integer. + */ + private static class SmallRangeDiscreteUniformSampler + extends AbstractDiscreteUniformSampler { + + /** Maximum range of the sample from the lower bound (exclusive). */ + private final int range; + + /** + * @param rng Generator of uniformly distributed random numbers. + * @param lower Lower bound (inclusive) of the distribution. + * @param range Maximum range of the sample from the lower bound (exclusive). + */ + SmallRangeDiscreteUniformSampler(UniformRandomProvider rng, + int lower, + int range) { + super(rng, lower); + this.range = range; + } + + @Override + public int sample() { + return lower + rng.nextInt(range); + } + } + + /** + * Discrete uniform distribution sampler when the range between lower and upper is too large + * to fit in a positive integer. + */ + private static class LargeRangeDiscreteUniformSampler + extends AbstractDiscreteUniformSampler { + + /** Upper bound. */ + private final int upper; + + /** + * @param rng Generator of uniformly distributed random numbers. + * @param lower Lower bound (inclusive) of the distribution. + * @param upper Upper bound (inclusive) of the distribution. + */ + LargeRangeDiscreteUniformSampler(UniformRandomProvider rng, + int lower, + int upper) { + super(rng, lower); + this.upper = upper; + } + + @Override + public int sample() { + // Use a simple rejection method. + // This is used when (upper-lower) >= Integer.MAX_VALUE. + // This will loop on average 2 times in the worst case scenario + // when (upper-lower) == Integer.MAX_VALUE. + while (true) { + final int r = rng.nextInt(); + if (r >= lower && + r <= upper) { + return r; + } + } + } + } /** * @param rng Generator of uniformly distributed random numbers. @@ -44,39 +135,28 @@ public class DiscreteUniformSampler int lower, int upper) { super(null); - this.rng = rng; if (lower > upper) { throw new IllegalArgumentException(lower + " > " + upper); } - - this.lower = lower; - this.upper = upper; + // Choose the algorithm depending on the range + final int range = (upper - lower) + 1; + delegate = range <= 0 ? + // The range is too wide to fit in a positive int (larger + // than 2^31); use a simple rejection method. + new LargeRangeDiscreteUniformSampler(rng, lower, upper) : + // Use a sample from the range added to the lower bound. + new SmallRangeDiscreteUniformSampler(rng, lower, range); } /** {@inheritDoc} */ @Override public int sample() { - final int max = (upper - lower) + 1; - if (max <= 0) { - // The range is too wide to fit in a positive int (larger - // than 2^31); as it covers more than half the integer range, - // we use a simple rejection method. - while (true) { - final int r = rng.nextInt(); - if (r >= lower && - r <= upper) { - return r; - } - } - } else { - // We can shift the range and directly generate a positive int. - return lower + rng.nextInt(max); - } + return delegate.sample(); } /** {@inheritDoc} */ @Override public String toString() { - return "Uniform deviate [" + rng.toString() + "]"; + return delegate.toString(); } }