Repository: commons-rng Updated Branches: refs/heads/master f5599152a -> 5e62c08f7
RNG-51: Added a PoissonSamplerCache for the large mean algorithm Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/fa38dea5 Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/fa38dea5 Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/fa38dea5 Branch: refs/heads/master Commit: fa38dea519c49640c485299d47516f9fb3d51fde Parents: 2a31d87 Author: Alex Herbert <a.herb...@sussex.ac.uk> Authored: Thu Sep 20 23:17:36 2018 +0100 Committer: Alex Herbert <a.herb...@sussex.ac.uk> Committed: Thu Sep 20 23:17:36 2018 +0100 ---------------------------------------------------------------------- .../distribution/LargeMeanPoissonSampler.java | 128 ++++- .../sampling/distribution/PoissonSampler.java | 11 +- .../distribution/PoissonSamplerCache.java | 331 +++++++++++++ .../LargeMeanPoissonSamplerTest.java | 150 ++++++ .../distribution/PoissonSamplerCacheTest.java | 481 +++++++++++++++++++ 5 files changed, 1094 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-rng/blob/fa38dea5/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java index a729686..0beb6b4 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java @@ -78,7 +78,7 @@ public class LargeMeanPoissonSampler */ private final double p1; /** - * Algorithm constant: {@code a1 / aSum} with + * Algorithm constant: {@code a2 / aSum} with * <ul> * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li> * <li>{@code aSum = a1 + a2 + 1}</li> @@ -92,15 +92,93 @@ public class LargeMeanPoissonSampler private final DiscreteSampler smallMeanPoissonSampler; /** + * Encapsulate the state of the sampler. The state is valid for construction of + * a sampler in the range {@code lambda <= mean < lambda+1}. + */ + static class LargeMeanPoissonSamplerState { + /** Algorithm constant: {@code Math.floor(mean)}. */ + private final double lambda; + /** Algorithm constant: {@code Math.log(lambda)}. */ + private final double logLambda; + /** Algorithm constant: {@code factorialLog((int) lambda)}. */ + private final double logLambdaFactorial; + /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */ + private final double delta; + /** Algorithm constant: {@code delta / 2}. */ + private final double halfDelta; + /** Algorithm constant: {@code 2 * lambda + delta}. */ + private final double twolpd; + /** + * Algorithm constant: {@code a1 / aSum} with + * <ul> + * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li> + * <li>{@code aSum = a1 + a2 + 1}</li> + * </ul> + */ + private final double p1; + /** + * Algorithm constant: {@code a2 / aSum} with + * <ul> + * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li> + * <li>{@code aSum = a1 + a2 + 1}</li> + * </ul> + */ + private final double p2; + /** Algorithm constant: {@code 1 / (8 * lambda)}. */ + private final double c1; + + /** + * Creates the state. The state is valid for construction of a sampler in the + * range {@code n <= mean < n+1}. + * + * @param n the value n ({@code floor(mean)}) + * @throws IllegalArgumentException if {@code n < 0}. + */ + LargeMeanPoissonSamplerState(int n) { + if (n < 0) { + throw new IllegalArgumentException(n + " < " + 0); + } + // Cache values used in the algorithm + // This is deliberately a copy of the code in the + // LargeMeanPoissonSampler constructor. + lambda = n; + logLambda = Math.log(lambda); + logLambdaFactorial = NO_CACHE_FACTORIAL_LOG.value(n); + delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); + halfDelta = delta / 2; + twolpd = 2 * lambda + delta; + c1 = 1 / (8 * lambda); + final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1); + final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd); + final double aSum = a1 + a2 + 1; + p1 = a1 / aSum; + p2 = a2 / aSum; + } + + /** + * Get the lambda value for the state. Equal to {@code floor(mean)}. + * @return {@code floor(mean)} + */ + int getLambda() { + return (int) lambda; + } + } + + /** * @param rng Generator of uniformly distributed random numbers. * @param mean Mean. - * @throws IllegalArgumentException if {@code mean <= 0}. + * @throws IllegalArgumentException if {@code mean <= 0} or + * {@code mean >} {@link Integer#MAX_VALUE}. */ public LargeMeanPoissonSampler(UniformRandomProvider rng, double mean) { super(rng); if (mean <= 0) { - throw new IllegalArgumentException(mean + " <= " + 0); + throw new IllegalArgumentException(mean + " <= " + 0); + } + // The algorithm is not valid if Math.floor(mean) is not an integer. + if (mean > Integer.MAX_VALUE) { + throw new IllegalArgumentException(mean + " > " + Integer.MAX_VALUE); } gaussian = new ZigguratNormalizedGaussianSampler(rng); @@ -110,7 +188,6 @@ public class LargeMeanPoissonSampler // Cache values used in the algorithm lambda = Math.floor(mean); - lambdaFractional = mean - lambda; logLambda = Math.log(lambda); logLambdaFactorial = factorialLog((int) lambda); delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); @@ -124,6 +201,49 @@ public class LargeMeanPoissonSampler p2 = a2 / aSum; // The algorithm requires a Poisson sample from the remaining lambda fraction. + lambdaFractional = mean - lambda; + smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? + null : // Not used. + new SmallMeanPoissonSampler(rng, lambdaFractional); + } + + /** + * Instantiates a sampler using a precomputed state. + * + * @param rng Generator of uniformly distributed random numbers. + * @param state The state for {@code lambda = (int)Math.floor(mean)}. + * @param lambdaFractional The lambda fractional value + * ({@code mean - (int)Math.floor(mean))}. + * @throws IllegalArgumentException + * if {@code lambdaFractional < 0 || lambdaFractional >= 1}. + */ + LargeMeanPoissonSampler(UniformRandomProvider rng, + LargeMeanPoissonSamplerState state, + double lambdaFractional) { + super(rng); + if (lambdaFractional < 0 || lambdaFractional >= 1) { + throw new IllegalArgumentException( + "lambdaFractional must be in the range 0 (inclusive) to 1 (exclusive): " + lambdaFractional); + } + + gaussian = new ZigguratNormalizedGaussianSampler(rng); + exponential = new AhrensDieterExponentialSampler(rng, 1); + // Plain constructor uses the uncached function. + factorialLog = NO_CACHE_FACTORIAL_LOG; + + // Use the state to initialise the algorithm + lambda = state.lambda; + logLambda = state.logLambda; + logLambdaFactorial = state.logLambdaFactorial; + delta = state.delta; + halfDelta = state.halfDelta; + twolpd = state.twolpd; + p1 = state.p1; + p2 = state.p2; + c1 = state.c1; + + // The algorithm requires a Poisson sample from the remaining lambda fraction. + this.lambdaFractional = lambdaFractional; smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? null : // Not used. new SmallMeanPoissonSampler(rng, lambdaFractional); http://git-wip-us.apache.org/repos/asf/commons-rng/blob/fa38dea5/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java index cd9187a..8bab0f0 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSampler.java @@ -39,15 +39,20 @@ import org.apache.commons.rng.UniformRandomProvider; public class PoissonSampler implements DiscreteSampler { - /** Value for switching sampling algorithm. */ - private static final double PIVOT = 40; + /** + * Value for switching sampling algorithm. + * + * <p>Package scope for the {@link PoissonSamplerCache}. + */ + static final double PIVOT = 40; /** The internal Poisson sampler. */ private final DiscreteSampler poissonSampler; /** * @param rng Generator of uniformly distributed random numbers. * @param mean Mean. - * @throws IllegalArgumentException if {@code mean <= 0}. + * @throws IllegalArgumentException if {@code mean <= 0} or + * {@code mean >} {@link Integer#MAX_VALUE}. */ public PoissonSampler(UniformRandomProvider rng, double mean) { http://git-wip-us.apache.org/repos/asf/commons-rng/blob/fa38dea5/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java new file mode 100644 index 0000000..2d361ff --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java @@ -0,0 +1,331 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.distribution.LargeMeanPoissonSampler.LargeMeanPoissonSamplerState; + +/** + * Create a sampler for the + * <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson + * distribution</a> using a cache to minimise construction cost. + * + * <p>The cache will return a sampler equivalent to + * {@link PoissonSampler#PoissonSampler(UniformRandomProvider, double)}. + * + * <p>The cache allows the {@link PoissonSampler} construction cost to be minimised + * for low size Poisson samples. The cache stores state for a range of integers where + * integer value {@code n} can be used to construct a sampler for the range + * {@code n <= mean < n+1}. + * + * <p>The cache is advantageous under the following conditions: + * + * <ul> + * <li>The mean of the Poisson distribution falls within a known range. + * <li>The sample size to be made with the <strong>same</strong> sampler is + * small. + * <li>The Poisson samples have different means with the same integer + * value(s) after rounding down. + * </ul> + * + * <p>If the sample size to be made with the <strong>same</strong> sampler is large + * then the construction cost is low compared to the sampling time and the cache + * has minimal benefit. + * + * <p>Performance improvement is dependent on the speed of the + * {@link UniformRandomProvider}. A fast provider can obtain a two-fold speed + * improvement for a single-use Poisson sampler. + * + * <p>The cache is thread safe. Note that concurrent threads using the cache + * must ensure a thread safe {@link UniformRandomProvider} is used when creating + * samplers, e.g. a unique sampler per thread. + */ +public class PoissonSamplerCache { + + /** + * The minimum N covered by the cache where + * {@code N = (int)Math.floor(mean)}. + */ + private final int minN; + /** + * The maximum N covered by the cache where + * {@code N = (int)Math.floor(mean)}. + */ + private final int maxN; + /** The cache of states between {@link minN} and {@link maxN}. */ + private final LargeMeanPoissonSamplerState[] values; + + /** + * @param minMean The minimum mean covered by the cache. + * @param maxMean The maximum mean covered by the cache. + * @throws IllegalArgumentException if {@code maxMean < minMean} + */ + public PoissonSamplerCache(double minMean, + double maxMean) { + + checkMeanRange(minMean, maxMean); + + // The cache can only be used for the LargeMeanPoissonSampler. + if (maxMean < PoissonSampler.PIVOT) { + // The upper limit is too small so no cache will be used. + // This class will just construct new samplers. + minN = 0; + maxN = 0; + values = null; + } else { + // Convert the mean into integers. + // Note the minimum is clipped to the algorithm switch point. + this.minN = (int) Math.floor(Math.max(minMean, PoissonSampler.PIVOT)); + this.maxN = (int) Math.floor(Math.min(maxMean, Integer.MAX_VALUE)); + values = new LargeMeanPoissonSamplerState[maxN - minN + 1]; + } + } + + /** + * @param minN The minimum N covered by the cache where {@code N = (int)Math.floor(mean)}. + * @param maxN The maximum N covered by the cache where {@code N = (int)Math.floor(mean)}. + * @param states The precomputed states. + */ + private PoissonSamplerCache(int minN, + int maxN, + LargeMeanPoissonSamplerState[] states) { + this.minN = minN; + this.maxN = maxN; + this.values = states.clone(); + } + + /** + * Check the mean range. + * + * @param minMean The minimum mean covered by the cache. + * @param maxMean The maximum mean covered by the cache. + * @throws IllegalArgumentException if {@code maxMean < minMean} + */ + private static void checkMeanRange(double minMean, double maxMean) + { + // Note: + // Although a mean of 0 is invalid for a Poisson sampler this case + // is handled to make the cache user friendly. Any low means will + // be handled by the SmallMeanPoissonSampler and not cached. + // For this reason it is also OK if the means are negative. + + // Allow minMean == maxMean so that the cache can be used + // to create samplers with distinct RNGs and the same mean. + if (maxMean < minMean) { + throw new IllegalArgumentException( + "Max mean: " + maxMean + " < " + minMean); + } + } + + /** + * Creates a new Poisson sampler. + * + * <p>The returned sampler will function exactly the + * same as {@link PoissonSampler#PoissonSampler(UniformRandomProvider, double)}. + * + * @param rng Generator of uniformly distributed random numbers. + * @param mean Mean. + * @return A Poisson sampler + * @throws IllegalArgumentException if {@code mean <= 0} or + * {@code mean >} {@link Integer#MAX_VALUE}. + */ + public DiscreteSampler createPoissonSampler(UniformRandomProvider rng, + double mean) { + // Ensure the same functionality as the PoissonSampler by + // using a SmallMeanPoissonSampler under the switch point. + if (mean < PoissonSampler.PIVOT) { + return new SmallMeanPoissonSampler(rng, mean); + } + // The algorithm is not valid if Math.floor(mean) is not an integer. + if (mean > Integer.MAX_VALUE) { + throw new IllegalArgumentException(mean + " > " + Integer.MAX_VALUE); + } + + // Convert the mean into an integer. + final int n = (int) Math.floor(mean); + // Check maxN first as the cache is likely to be used from min=0 + if (n > maxN || n < minN) { + // Outside the range of the cache. + return new LargeMeanPoissonSampler(rng, mean); + } + + // Look in the cache for a state that can be reused. + // Note: The cache is offset by minN. + final int index = n - minN; + LargeMeanPoissonSamplerState state = values[index]; + if (state == null) { + // Compute and store for reuse. + // Do not worry about thread contention + // as the state is effectively immutable. + // If recomputed and replaced it will the same. + state = new LargeMeanPoissonSamplerState(n); + values[index] = state; + } + // Compute the remaining fraction of the mean + final double lambdaFractional = mean - n; + return new LargeMeanPoissonSampler(rng, state, lambdaFractional); + } + + /** + * Check if the mean is within the range where the cache can minimise the + * construction cost of the {@link PoissonSampler}. + * + * @param mean + * the mean + * @return true, if within the cache range + */ + public boolean withinRange(double mean) { + if (mean < PoissonSampler.PIVOT) { + // Construction is optimal + return true; + } + // Convert the mean into an integer. + final int n = (int) Math.floor(mean); + return n <= maxN && n >= minN; + } + + /** + * Checks if the cache covers a valid range of mean values. + * + * <p>Note that the cache is only valid for one of the Poisson sampling + * algorithms. In the instance that a range was requested that was too + * low then there is nothing to cache and this functions returns + * {@code false}. + * + * <p>The cache can still be used to create a {@link PoissonSampler} using + * {@link #createPoissonSampler(UniformRandomProvider, double)}. + * + * <p>This method can be used to determine if the cache has a potential + * performance benefit. + * + * @return true, if the cache covers a range of mean values + */ + public boolean isValidRange() { + return values != null; + } + + /** + * Gets the minimum mean covered by the cache. + * + * <p>This value is the inclusive lower bound and is equal to + * the lowest integer-valued mean that is covered by the cache. + * + * <p>Note that this value may not match the value passed to the constructor + * due to the following reasons: + * + * <ul> + * <li>At small mean values a different algorithm is used for Poisson + * sampling and the cache is unnecessary. + * <li>The minimum is always an integer so may be below the constructor + * minimum mean. + * </ul> + * + * <p>If {@link #isValidRange()} returns {@code true} the cache will store + * state to reduce construction cost of samplers in + * the range {@link #getMinMean()} inclusive to {@link #getMaxMean()} + * inclusive. Otherwise this method returns 0; + * + * @return The minimum mean covered by the cache. + */ + public double getMinMean() + { + return minN; + } + + /** + * Gets the maximum mean covered by the cache. + * + * <p>This value is the inclusive upper bound and is equal to + * the double value below the first integer-valued mean that is + * above range covered by the cache. + * + * <p>Note that this value may not match the value passed to the constructor + * due to the following reasons: + * <ul> + * <li>At small mean values a different algorithm is used for Poisson + * sampling and the cache is unnecessary. + * <li>The maximum is always the double value below an integer so + * may be above the constructor maximum mean. + * </ul> + * + * <p>If {@link #isValidRange()} returns {@code true} the cache will store + * state to reduce construction cost of samplers in + * the range {@link #getMinMean()} inclusive to {@link #getMaxMean()} + * inclusive. Otherwise this method returns 0; + * + * @return The maximum mean covered by the cache. + */ + public double getMaxMean() + { + if (isValidRange()) { + return Math.nextAfter(maxN + 1.0, -1); + } + return 0; + } + + /** + * Create a new {@link PoissonSamplerCache} with the given range + * reusing the current cache values. + * + * <p>This will create a new object even if the range is smaller or the + * same as the current cache. + * + * @param minMean The minimum mean covered by the cache. + * @param maxMean The maximum mean covered by the cache. + * @throws IllegalArgumentException if {@code maxMean < minMean} + * @return the poisson sampler cache + */ + public PoissonSamplerCache withRange(double minMean, + double maxMean) { + if (values == null) { + // Nothing to reuse + return new PoissonSamplerCache(minMean, maxMean); + } + checkMeanRange(minMean, maxMean); + + // The cache can only be used for the LargeMeanPoissonSampler. + if (maxMean < PoissonSampler.PIVOT) { + return new PoissonSamplerCache(0, 0); + } + + // Convert the mean into integers. + // Note the minimum is clipped to the algorithm switch point. + final int withMinN = (int) Math.floor(Math.max(minMean, PoissonSampler.PIVOT)); + final int withMaxN = (int) Math.floor(maxMean); + final LargeMeanPoissonSamplerState[] states = + new LargeMeanPoissonSamplerState[withMaxN - withMinN + 1]; + + // Preserve values from the current array to the next + int currentIndex; + int nextIndex; + if (this.minN <= withMinN) { + // The current array starts before the new array + currentIndex = withMinN - this.minN; + nextIndex = 0; + } else { + // The new array starts before the current array + currentIndex = 0; + nextIndex = this.minN - withMinN; + } + final int length = Math.min(values.length - currentIndex, states.length - nextIndex); + if (length > 0) { + System.arraycopy(values, currentIndex, states, nextIndex, length); + } + + return new PoissonSamplerCache(withMinN, withMaxN, states); + } +} http://git-wip-us.apache.org/repos/asf/commons-rng/blob/fa38dea5/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java new file mode 100644 index 0000000..2b1e24e --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.RandomProviderState; +import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.sampling.distribution.LargeMeanPoissonSampler.LargeMeanPoissonSamplerState; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Assert; +import org.junit.Test; + +/** + * This test checks the {@link LargeMeanPoissonSampler} using the + * {@link LargeMeanPoissonSamplerState}. + */ +public class LargeMeanPoissonSamplerTest { + + // Edge cases for construction + + /** + * Test the constructor with a bad mean. + */ + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithZeroMean() { + final RestorableUniformRandomProvider rng = + RandomSource.create(RandomSource.SPLIT_MIX_64); + @SuppressWarnings("unused") + LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, 0); + } + + /** + * Test the constructor with a mean that is too large. + */ + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithNonIntegerMean() { + final RestorableUniformRandomProvider rng = + RandomSource.create(RandomSource.SPLIT_MIX_64); + final double mean = Integer.MAX_VALUE + 1.0; + @SuppressWarnings("unused") + LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, mean); + } + + /** + * Test the state cannot be created with a negative n. + */ + @Test(expected=IllegalArgumentException.class) + public void testStateCreationThrowsWithNegativeN() { + @SuppressWarnings("unused") + LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(-1); + } + + /** + * Test the constructor with a negative fractional mean. + */ + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithNegativeFractionalMean() { + final RestorableUniformRandomProvider rng = + RandomSource.create(RandomSource.SPLIT_MIX_64); + LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(0); + @SuppressWarnings("unused") + LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, state, -0.1); + } + + /** + * Test the constructor with a non-fractional mean. + */ + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithNonFractionalMean() { + final RestorableUniformRandomProvider rng = + RandomSource.create(RandomSource.SPLIT_MIX_64); + LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(0); + @SuppressWarnings("unused") + LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, state, 1.1); + } + + /** + * Test the constructor with fractional mean of 1. + */ + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithFractionalMeanOne() { + final RestorableUniformRandomProvider rng = + RandomSource.create(RandomSource.SPLIT_MIX_64); + LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(0); + @SuppressWarnings("unused") + LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, state, 1); + } + + // Sampling tests + + /** + * Test the {@link LargeMeanPoissonSampler} returns the same samples when it + * is created using the {@link LargeMeanPoissonSamplerState}. + */ + @Test + public void testCanComputeSameSamplesWhenConstructedWithState() { + // Two identical RNGs + final RestorableUniformRandomProvider rng1 = + RandomSource.create(RandomSource.MWC_256); + final RandomProviderState state = rng1.saveState(); + final RestorableUniformRandomProvider rng2 = + RandomSource.create(RandomSource.MWC_256); + rng2.restoreState(state); + + // The sampler is suitable for mean > 40 + for (int i = 40; i < 44; i++) { + // Test integer mean (no SmallMeanPoissonSampler required) + testPoissonSamples(rng1, rng2, i); + // Test non-integer mean (SmallMeanPoissonSampler required) + testPoissonSamples(rng1, rng2, i + 0.5); + } + } + + /** + * Test poisson samples are the same from the {@link PoissonSampler} + * and {@link PoissonSamplerCache}. The random providers must be + * identical (including state). + * + * @param rng1 the first random provider + * @param rng2 the second random provider + * @param cache the cache + * @param mean the mean + */ + private static void testPoissonSamples( + final RestorableUniformRandomProvider rng1, + final RestorableUniformRandomProvider rng2, + double mean) { + final DiscreteSampler s1 = new LargeMeanPoissonSampler(rng1, mean); + final int n = (int) Math.floor(mean); + final double lambdaFractional = mean - n; + final LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(n); + Assert.assertEquals("Not the correct lambda", n, state.getLambda()); + final DiscreteSampler s2 = new LargeMeanPoissonSampler(rng2, state, lambdaFractional); + for (int j = 0; j < 10; j++) + Assert.assertEquals("Not the same sample", s1.sample(), s2.sample()); + } +} http://git-wip-us.apache.org/repos/asf/commons-rng/blob/fa38dea5/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCacheTest.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCacheTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCacheTest.java new file mode 100644 index 0000000..e4df43a --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCacheTest.java @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.RandomProviderState; +import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Assert; +import org.junit.Test; + +/** + * This test checks the {@link PoissonSamplerCache} functions exactly like the + * constructor of the {@link PoissonSampler}, irrespective of the range + * covered by the cache. + */ +public class PoissonSamplerCacheTest { + + // Set a range so that the SmallMeanPoissonSampler is also required. + + /** The minimum of the range of the mean */ + private final int minRange = (int) Math.floor(PoissonSampler.PIVOT - 2); + /** The maximum of the range of the mean */ + private final int maxRange = (int) Math.floor(PoissonSampler.PIVOT + 6); + /** The mid-point of the range of the mean */ + private final int midRange = (minRange + maxRange) / 2; + + // Edge cases for construction + + /** + * Test the cache can be created without a range that requires a cache. + * In this case the cache will be a pass through to the constructor + * of the SmallMeanPoissonSampler. + */ + @Test + public void testConstructorWithNoCache() { + final double min = 0; + final double max = PoissonSampler.PIVOT - 2; + PoissonSamplerCache cache = createPoissonSamplerCache(min, max); + Assert.assertFalse(cache.isValidRange()); + Assert.assertEquals(0, cache.getMinMean(), 0); + Assert.assertEquals(0, cache.getMaxMean(), 0); + } + + /** + * Test the cache can be created with a range of 1. + * In this case the cache will be valid for all mean values + * in the range {@code n <= mean < n+1}. + */ + @Test + public void testConstructorWhenMaxEqualsMin() { + final double min = PoissonSampler.PIVOT + 2; + final double max = min; + PoissonSamplerCache cache = createPoissonSamplerCache(min, max); + Assert.assertTrue(cache.isValidRange()); + Assert.assertEquals(min, cache.getMinMean(), 0); + Assert.assertEquals(Math.nextAfter(Math.floor(max) + 1, -1), + cache.getMaxMean(), 0); + } + + /** + * Test the cache can be created with a range of 1. + * In this case the cache will be valid for all mean values + * in the range {@code n <= mean < n+1}. + */ + @Test + public void testConstructorWhenMaxAboveMin() { + final double min = PoissonSampler.PIVOT + 2; + final double max = min + 10; + PoissonSamplerCache cache = createPoissonSamplerCache(min, max); + Assert.assertTrue(cache.isValidRange()); + Assert.assertEquals(min, cache.getMinMean(), 0); + Assert.assertEquals(Math.nextAfter(Math.floor(max) + 1, -1), + cache.getMaxMean(), 0); + } + + /** + * Test the cache requires a range with {@code max >= min}. + */ + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWhenMaxIsLessThanMin() { + final double min = PoissonSampler.PIVOT; + final double max = Math.nextAfter(min, -1); + createPoissonSamplerCache(min, max); + } + + /** + * Test the cache can be created with a min range below 0. + * In this case the range is truncated to 0. + */ + @Test + public void testConstructorWhenMinBelow0() { + final double min = -1; + final double max = PoissonSampler.PIVOT + 2; + PoissonSamplerCache cache = createPoissonSamplerCache(min, max); + Assert.assertTrue(cache.isValidRange()); + Assert.assertEquals(PoissonSampler.PIVOT, cache.getMinMean(), 0); + Assert.assertEquals(Math.nextAfter(Math.floor(max) + 1, -1), + cache.getMaxMean(), 0); + } + + /** + * Test the cache can be created with a max range below 0. + * In this case the range is truncated to 0, i.e. no cache. + */ + @Test + public void testConstructorWhenMaxBelow0() { + final double min = -10; + final double max = -1; + PoissonSamplerCache cache = createPoissonSamplerCache(min, max); + Assert.assertFalse(cache.isValidRange()); + Assert.assertEquals(0, cache.getMinMean(), 0); + Assert.assertEquals(0, cache.getMaxMean(), 0); + } + + /** + * Test the cache can be created without a range that requires a cache. + * In this case the cache will be a pass through to the constructor + * of the SmallMeanPoissonSampler. + */ + @Test + public void testWithRangeConstructorWithNoCache() { + final double min = 0; + final double max = PoissonSampler.PIVOT - 2; + PoissonSamplerCache cache = createPoissonSamplerCache().withRange(min, max); + Assert.assertFalse(cache.isValidRange()); + Assert.assertEquals(0, cache.getMinMean(), 0); + Assert.assertEquals(0, cache.getMaxMean(), 0); + } + + /** + * Test the cache can be created with a range of 1. + * In this case the cache will be valid for all mean values + * in the range {@code n <= mean < n+1}. + */ + @Test + public void testWithRangeConstructorWhenMaxEqualsMin() { + final double min = PoissonSampler.PIVOT + 2; + final double max = min; + PoissonSamplerCache cache = createPoissonSamplerCache().withRange(min, max); + Assert.assertTrue(cache.isValidRange()); + Assert.assertEquals(min, cache.getMinMean(), 0); + Assert.assertEquals(Math.nextAfter(Math.floor(max) + 1, -1), + cache.getMaxMean(), 0); + } + + /** + * Test the cache can be created with a range of 1. + * In this case the cache will be valid for all mean values + * in the range {@code n <= mean < n+1}. + */ + @Test + public void testWithRangeConstructorWhenMaxAboveMin() { + final double min = PoissonSampler.PIVOT + 2; + final double max = min + 10; + PoissonSamplerCache cache = createPoissonSamplerCache().withRange(min, max); + Assert.assertTrue(cache.isValidRange()); + Assert.assertEquals(min, cache.getMinMean(), 0); + Assert.assertEquals(Math.nextAfter(Math.floor(max) + 1, -1), + cache.getMaxMean(), 0); + } + + /** + * Test the cache requires a range with {@code max >= min}. + */ + @Test(expected=IllegalArgumentException.class) + public void testWithRangeConstructorThrowsWhenMaxIsLessThanMin() { + final double min = PoissonSampler.PIVOT; + final double max = Math.nextAfter(min, -1); + createPoissonSamplerCache().withRange(min, max); + } + + /** + * Test the cache can be created with a min range below 0. + * In this case the range is truncated to 0. + */ + @Test + public void testWithRangeConstructorWhenMinBelow0() { + final double min = -1; + final double max = PoissonSampler.PIVOT + 2; + PoissonSamplerCache cache = createPoissonSamplerCache().withRange(min, max); + Assert.assertTrue(cache.isValidRange()); + Assert.assertEquals(PoissonSampler.PIVOT, cache.getMinMean(), 0); + Assert.assertEquals(Math.nextAfter(Math.floor(max) + 1, -1), + cache.getMaxMean(), 0); + } + + /** + * Test the cache can be created from a cache with no capacity. + */ + @Test + public void testWithRangeConstructorWhenCacheHasNoCapcity() { + final double min = PoissonSampler.PIVOT + 2; + final double max = min + 10; + PoissonSamplerCache cache = createPoissonSamplerCache(0, 0).withRange(min, max); + Assert.assertTrue(cache.isValidRange()); + Assert.assertEquals(min, cache.getMinMean(), 0); + Assert.assertEquals(Math.nextAfter(Math.floor(max) + 1, -1), + cache.getMaxMean(), 0); + } + + /** + * Test the withinRange function of the cache signals when construction + * cost is minimal. + */ + @Test + public void testWithinRange() { + final double min = PoissonSampler.PIVOT + 10; + final double max = PoissonSampler.PIVOT + 20; + PoissonSamplerCache cache = createPoissonSamplerCache(min, max); + // Under the pivot point is always within range + Assert.assertTrue(cache.withinRange(PoissonSampler.PIVOT - 1)); + Assert.assertFalse(cache.withinRange(min - 1)); + Assert.assertTrue(cache.withinRange(min)); + Assert.assertTrue(cache.withinRange(max)); + Assert.assertFalse(cache.withinRange(max + 10)); + } + + // Edge cases for creating a Poisson sampler + + /** + * Test createPoissonSampler() with a bad mean. + * + * <p>Note this test actually tests the SmallMeanPoissonSampler throws. + */ + @Test(expected=IllegalArgumentException.class) + public void testCreatePoissonSamplerThrowsWithZeroMean() { + final RestorableUniformRandomProvider rng = + RandomSource.create(RandomSource.SPLIT_MIX_64); + final PoissonSamplerCache cache = createPoissonSamplerCache(); + cache.createPoissonSampler(rng, 0); + } + + /** + * Test createPoissonSampler() with a mean that is too large. + */ + @Test(expected=IllegalArgumentException.class) + public void testCreatePoissonSamplerThrowsWithNonIntegerMean() { + final RestorableUniformRandomProvider rng = + RandomSource.create(RandomSource.SPLIT_MIX_64); + final PoissonSamplerCache cache = createPoissonSamplerCache(); + final double mean = Integer.MAX_VALUE + 1.0; + cache.createPoissonSampler(rng, mean); + } + + // Sampling tests + + /** + * Test the cache returns the same samples as the PoissonSampler when it + * covers the entire range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerWithFullRangeCache() { + checkComputeSameSamplesAsPoissonSampler(minRange, + maxRange); + } + + /** + * Test the cache returns the same samples as the PoissonSampler + * with no cache. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerWithNoCache() { + checkComputeSameSamplesAsPoissonSampler(0, + minRange - 2); + } + + /** + * Test the cache returns the same samples as the PoissonSampler with + * partial cache covering the lower range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerWithPartialCacheCoveringLowerRange() { + checkComputeSameSamplesAsPoissonSampler(minRange, + midRange); + } + + /** + * Test the cache returns the same samples as the PoissonSampler with + * partial cache covering the upper range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerWithPartialCacheCoveringUpperRange() { + checkComputeSameSamplesAsPoissonSampler(midRange, + maxRange); + } + + /** + * Test the cache returns the same samples as the PoissonSampler with + * cache above the upper range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerWithCacheAboveTheUpperRange() { + checkComputeSameSamplesAsPoissonSampler(maxRange + 10, + maxRange + 20); + } + + /** + * Check poisson samples are the same from the {@link PoissonSampler} + * and {@link PoissonSamplerCache}. + * + * @param minMean the min mean of the cache + * @param maxMean the max mean of the cache + */ + private void checkComputeSameSamplesAsPoissonSampler(int minMean, + int maxMean) { + // Two identical RNGs + final RestorableUniformRandomProvider rng1 = + RandomSource.create(RandomSource.WELL_19937_C); + final RandomProviderState state = rng1.saveState(); + final RestorableUniformRandomProvider rng2 = + RandomSource.create(RandomSource.WELL_19937_C); + rng2.restoreState(state); + + // Create the cache with the given range + final PoissonSamplerCache cache = + createPoissonSamplerCache(minMean, maxMean); + // Test all means in the test range (which may be different + // from the cache range). + for (int i = minRange; i <= maxRange; i++) { + // Test integer mean (no SmallMeanPoissonSampler required) + testPoissonSamples(rng1, rng2, cache, i); + // Test non-integer mean (SmallMeanPoissonSampler required) + testPoissonSamples(rng1, rng2, cache, i + 0.5); + } + } + + /** + * Creates the poisson sampler cache with the given range. + * + * @param minMean the min mean + * @param maxMean the max mean + * @return the poisson sampler cache + */ + private static PoissonSamplerCache createPoissonSamplerCache(double minMean, + double maxMean) + { + return new PoissonSamplerCache(minMean, maxMean); + } + + /** + * Creates a poisson sampler cache that will have a valid range for the cache. + * + * @return the poisson sampler cache + */ + private static PoissonSamplerCache createPoissonSamplerCache() + { + return new PoissonSamplerCache(PoissonSampler.PIVOT, + PoissonSampler.PIVOT + 10); + } + + /** + * Test poisson samples are the same from the {@link PoissonSampler} + * and {@link PoissonSamplerCache}. The random providers must be + * identical (including state). + * + * @param rng1 the first random provider + * @param rng2 the second random provider + * @param cache the cache + * @param mean the mean + */ + private static void testPoissonSamples( + final RestorableUniformRandomProvider rng1, + final RestorableUniformRandomProvider rng2, + PoissonSamplerCache cache, + double mean) { + final DiscreteSampler s1 = new PoissonSampler(rng1, mean); + final DiscreteSampler s2 = cache.createPoissonSampler(rng2, mean); + for (int j = 0; j < 10; j++) + Assert.assertEquals(s1.sample(), s2.sample()); + } + + /** + * Test the cache returns the same samples as the PoissonSampler with + * a new cache reusing the entire range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerReusingCacheEntireRange() { + checkComputeSameSamplesAsPoissonSamplerReusingCache(midRange, + maxRange, + midRange, + maxRange); + } + + /** + * Test the cache returns the same samples as the PoissonSampler with + * a new cache reusing none of the range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerReusingCacheNoRange() { + checkComputeSameSamplesAsPoissonSamplerReusingCache(midRange, + maxRange, + maxRange + 10, + maxRange + 20); + } + + /** + * Test the cache returns the same samples as the PoissonSampler with + * a new cache reusing some of the lower range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerReusingCacheLowerRange() { + checkComputeSameSamplesAsPoissonSamplerReusingCache(midRange, + maxRange, + minRange, + midRange + 1); + } + + /** + * Test the cache returns the same samples as the PoissonSampler with + * a new cache reusing some of the upper range. + */ + @Test + public void testCanComputeSameSamplesAsPoissonSamplerReusingCacheUpperRange() { + checkComputeSameSamplesAsPoissonSamplerReusingCache(midRange, + maxRange, + maxRange - 1, + maxRange + 5); + } + + /** + * Check poisson samples are the same from the {@link PoissonSampler} + * and a {@link PoissonSamplerCache} created reusing values. + * + * <p>Note: This cannot check the cache values were reused but ensures + * that a new cache created with a range functions correctly. + * + * @param minMean the min mean of the cache + * @param maxMean the max mean of the cache + * @param minMean2 the min mean of the second cache + * @param maxMean2 the max mean of the second cache + */ + private void checkComputeSameSamplesAsPoissonSamplerReusingCache(int minMean, + int maxMean, + int minMean2, + int maxMean2) { + // Two identical RNGs + final RestorableUniformRandomProvider rng1 = + RandomSource.create(RandomSource.WELL_19937_C); + final RandomProviderState state = rng1.saveState(); + final RestorableUniformRandomProvider rng2 = + RandomSource.create(RandomSource.WELL_19937_C); + + // Create the cache with the given range and fill it + final PoissonSamplerCache cache = + createPoissonSamplerCache(minMean, maxMean); + // Test all means in the test range (which may be different + // from the cache range). + for (int i = minMean; i <= maxMean; i++) { + cache.createPoissonSampler(rng1, i); + } + + final PoissonSamplerCache cache2 = cache.withRange(minMean2, maxMean2); + Assert.assertTrue("WithRange cache is the same object", cache != cache2); + + rng1.restoreState(state); + rng2.restoreState(state); + + // Test all means in the test range (which may be different + // from the cache range). + for (int i = minRange; i <= maxRange; i++) { + // Test integer mean (no SmallMeanPoissonSampler required) + testPoissonSamples(rng1, rng2, cache2, i); + // Test non-integer mean (SmallMeanPoissonSampler required) + testPoissonSamples(rng1, rng2, cache2, i + 0.5); + } + } +}