RNG-51: Changed representation of LargeMeanPoissonSampler state Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/db751b91 Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/db751b91 Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/db751b91
Branch: refs/heads/master Commit: db751b911b56fecff1bce1ce33320e4308158c6a Parents: fa38dea Author: Alex Herbert <a.herb...@sussex.ac.uk> Authored: Fri Sep 21 00:56:19 2018 +0100 Committer: Alex Herbert <a.herb...@sussex.ac.uk> Committed: Fri Sep 21 00:56:19 2018 +0100 ---------------------------------------------------------------------- .../distribution/LargeMeanPoissonSampler.java | 244 ++++++++++++------- .../distribution/PoissonSamplerCache.java | 12 +- .../LargeMeanPoissonSamplerTest.java | 33 +-- 3 files changed, 180 insertions(+), 109 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-rng/blob/db751b91/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java index 0beb6b4..802c5be 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java @@ -92,79 +92,6 @@ public class LargeMeanPoissonSampler private final DiscreteSampler smallMeanPoissonSampler; /** - * Encapsulate the state of the sampler. The state is valid for construction of - * a sampler in the range {@code lambda <= mean < lambda+1}. - */ - static class LargeMeanPoissonSamplerState { - /** Algorithm constant: {@code Math.floor(mean)}. */ - private final double lambda; - /** Algorithm constant: {@code Math.log(lambda)}. */ - private final double logLambda; - /** Algorithm constant: {@code factorialLog((int) lambda)}. */ - private final double logLambdaFactorial; - /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1))}. */ - private final double delta; - /** Algorithm constant: {@code delta / 2}. */ - private final double halfDelta; - /** Algorithm constant: {@code 2 * lambda + delta}. */ - private final double twolpd; - /** - * Algorithm constant: {@code a1 / aSum} with - * <ul> - * <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li> - * <li>{@code aSum = a1 + a2 + 1}</li> - * </ul> - */ - private final double p1; - /** - * Algorithm constant: {@code a2 / aSum} with - * <ul> - * <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd)}</li> - * <li>{@code aSum = a1 + a2 + 1}</li> - * </ul> - */ - private final double p2; - /** Algorithm constant: {@code 1 / (8 * lambda)}. */ - private final double c1; - - /** - * Creates the state. The state is valid for construction of a sampler in the - * range {@code n <= mean < n+1}. - * - * @param n the value n ({@code floor(mean)}) - * @throws IllegalArgumentException if {@code n < 0}. - */ - LargeMeanPoissonSamplerState(int n) { - if (n < 0) { - throw new IllegalArgumentException(n + " < " + 0); - } - // Cache values used in the algorithm - // This is deliberately a copy of the code in the - // LargeMeanPoissonSampler constructor. - lambda = n; - logLambda = Math.log(lambda); - logLambdaFactorial = NO_CACHE_FACTORIAL_LOG.value(n); - delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); - halfDelta = delta / 2; - twolpd = 2 * lambda + delta; - c1 = 1 / (8 * lambda); - final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1); - final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / twolpd); - final double aSum = a1 + a2 + 1; - p1 = a1 / aSum; - p2 = a2 / aSum; - } - - /** - * Get the lambda value for the state. Equal to {@code floor(mean)}. - * @return {@code floor(mean)} - */ - int getLambda() { - return (int) lambda; - } - } - - /** * @param rng Generator of uniformly distributed random numbers. * @param mean Mean. * @throws IllegalArgumentException if {@code mean <= 0} or @@ -188,6 +115,7 @@ public class LargeMeanPoissonSampler // Cache values used in the algorithm lambda = Math.floor(mean); + lambdaFractional = mean - lambda; logLambda = Math.log(lambda); logLambdaFactorial = factorialLog((int) lambda); delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1)); @@ -201,7 +129,6 @@ public class LargeMeanPoissonSampler p2 = a2 / aSum; // The algorithm requires a Poisson sample from the remaining lambda fraction. - lambdaFractional = mean - lambda; smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? null : // Not used. new SmallMeanPoissonSampler(rng, lambdaFractional); @@ -232,18 +159,18 @@ public class LargeMeanPoissonSampler factorialLog = NO_CACHE_FACTORIAL_LOG; // Use the state to initialise the algorithm - lambda = state.lambda; - logLambda = state.logLambda; - logLambdaFactorial = state.logLambdaFactorial; - delta = state.delta; - halfDelta = state.halfDelta; - twolpd = state.twolpd; - p1 = state.p1; - p2 = state.p2; - c1 = state.c1; + lambda = state.getLambdaRaw(); + this.lambdaFractional = lambdaFractional; + logLambda = state.getLogLambda(); + logLambdaFactorial = state.getLogLambdaFactorial(); + delta = state.getDelta(); + halfDelta = state.getHalfDelta(); + twolpd = state.getTwolpd(); + p1 = state.getP1(); + p2 = state.getP2(); + c1 = state.getC1(); // The algorithm requires a Poisson sample from the remaining lambda fraction. - this.lambdaFractional = lambdaFractional; smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ? null : // Not used. new SmallMeanPoissonSampler(rng, lambdaFractional); @@ -324,4 +251,153 @@ public class LargeMeanPoissonSampler public String toString() { return "Large Mean Poisson deviate [" + super.toString() + "]"; } + + /** + * Gets the initialisation state of the sampler. + * + * <p>The state is computed using an integer {@code lambda} value of + * {@code lambda = (int)Math.floor(mean)}. + * + * <p>The state will be suitable for reconstructing a new sampler with a mean + * in the range {@code lambda <= mean < lambda+1} using + * {@link #LargeMeanPoissonSampler(UniformRandomProvider, LargeMeanPoissonSamplerState, double)}. + * + * @return the state + */ + LargeMeanPoissonSamplerState getState() { + return new LargeMeanPoissonSamplerState(lambda, logLambda, logLambdaFactorial, + delta, halfDelta, twolpd, p1, p2, c1); + } + + /** + * Encapsulate the state of the sampler. The state is valid for construction of + * a sampler in the range {@code lambda <= mean < lambda+1}. + * + * <p>This class is immutable. + * + * @see #getLambda() + */ + static class LargeMeanPoissonSamplerState { + /** Algorithm constant {@code lambda}. */ + private final double lambda; + /** Algorithm constant {@code logLambda}. */ + private final double logLambda; + /** Algorithm constant {@code logLambdaFactorial}. */ + private final double logLambdaFactorial; + /** Algorithm constant {@code delta}. */ + private final double delta; + /** Algorithm constant {@code halfDelta}. */ + private final double halfDelta; + /** Algorithm constant {@code twolpd}. */ + private final double twolpd; + /** Algorithm constant {@code p1}. */ + private final double p1; + /** Algorithm constant {@code p2}. */ + private final double p2; + /** Algorithm constant {@code c1}. */ + private final double c1; + + /** + * Creates the state. + * + * <p>The state is valid for construction of a sampler in the range + * {@code lambda <= mean < lambda+1} where {@code lambda} is an integer. + * + * @param lambda the lambda + * @param logLambda the log lambda + * @param logLambdaFactorial the log lambda factorial + * @param delta the delta + * @param halfDelta the half delta + * @param twolpd the two lambda plus delta + * @param p1 the p1 constant + * @param p2 the p2 constant + * @param c1 the c1 constant + */ + private LargeMeanPoissonSamplerState(double lambda, double logLambda, + double logLambdaFactorial, double delta, double halfDelta, double twolpd, + double p1, double p2, double c1) { + this.lambda = lambda; + this.logLambda = logLambda; + this.logLambdaFactorial = logLambdaFactorial; + this.delta = delta; + this.halfDelta = halfDelta; + this.twolpd = twolpd; + this.p1 = p1; + this.p2 = p2; + this.c1 = c1; + } + + /** + * Get the lambda value for the state. + * + * <p>Equal to {@code floor(mean)} for a Poisson sampler. + * @return the lambda value + */ + int getLambda() { + return (int) getLambdaRaw(); + } + + /** + * @return algorithm constant {@code lambda} + */ + double getLambdaRaw() { + return lambda; + } + + /** + * @return algorithm constant {@code logLambda} + */ + double getLogLambda() { + return logLambda; + } + + /** + * @return algorithm constant {@code logLambdaFactorial} + */ + double getLogLambdaFactorial() { + return logLambdaFactorial; + } + + /** + * @return algorithm constant {@code delta} + */ + double getDelta() { + return delta; + } + + /** + * @return algorithm constant {@code halfDelta} + */ + double getHalfDelta() { + return halfDelta; + } + + /** + * @return algorithm constant {@code twolpd} + */ + double getTwolpd() { + return twolpd; + } + + /** + * @return algorithm constant {@code p1} + */ + double getP1() { + return p1; + } + + /** + * @return algorithm constant {@code p2} + */ + double getP2() { + return p2; + } + + /** + * @return algorithm constant {@code c1} + */ + double getC1() { + return c1; + } + } } http://git-wip-us.apache.org/repos/asf/commons-rng/blob/db751b91/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java index 2d361ff..4b74084 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java @@ -105,7 +105,8 @@ public class PoissonSamplerCache { LargeMeanPoissonSamplerState[] states) { this.minN = minN; this.maxN = maxN; - this.values = states.clone(); + // Stored directly as the states were newly created within this class. + this.values = states; } /** @@ -166,14 +167,15 @@ public class PoissonSamplerCache { // Look in the cache for a state that can be reused. // Note: The cache is offset by minN. final int index = n - minN; - LargeMeanPoissonSamplerState state = values[index]; + final LargeMeanPoissonSamplerState state = values[index]; if (state == null) { - // Compute and store for reuse. + // Create a sampler and store the state for reuse. // Do not worry about thread contention // as the state is effectively immutable. // If recomputed and replaced it will the same. - state = new LargeMeanPoissonSamplerState(n); - values[index] = state; + final LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, mean); + values[index] = sampler.getState(); + return sampler; } // Compute the remaining fraction of the mean final double lambdaFractional = mean - n; http://git-wip-us.apache.org/repos/asf/commons-rng/blob/db751b91/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java index 2b1e24e..480f452 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java @@ -24,8 +24,8 @@ import org.junit.Assert; import org.junit.Test; /** - * This test checks the {@link LargeMeanPoissonSampler} using the - * {@link LargeMeanPoissonSamplerState}. + * This test checks the {@link LargeMeanPoissonSampler} can be created + * from a saved state. */ public class LargeMeanPoissonSamplerTest { @@ -55,22 +55,13 @@ public class LargeMeanPoissonSamplerTest { } /** - * Test the state cannot be created with a negative n. - */ - @Test(expected=IllegalArgumentException.class) - public void testStateCreationThrowsWithNegativeN() { - @SuppressWarnings("unused") - LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(-1); - } - - /** * Test the constructor with a negative fractional mean. */ @Test(expected=IllegalArgumentException.class) public void testConstructorThrowsWithNegativeFractionalMean() { final RestorableUniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); - LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(0); + final LargeMeanPoissonSamplerState state = new LargeMeanPoissonSampler(rng, 1).getState(); @SuppressWarnings("unused") LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, state, -0.1); } @@ -82,7 +73,7 @@ public class LargeMeanPoissonSamplerTest { public void testConstructorThrowsWithNonFractionalMean() { final RestorableUniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); - LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(0); + final LargeMeanPoissonSamplerState state = new LargeMeanPoissonSampler(rng, 1).getState(); @SuppressWarnings("unused") LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, state, 1.1); } @@ -94,7 +85,7 @@ public class LargeMeanPoissonSamplerTest { public void testConstructorThrowsWithFractionalMeanOne() { final RestorableUniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); - LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(0); + final LargeMeanPoissonSamplerState state = new LargeMeanPoissonSampler(rng, 1).getState(); @SuppressWarnings("unused") LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, state, 1); } @@ -103,7 +94,7 @@ public class LargeMeanPoissonSamplerTest { /** * Test the {@link LargeMeanPoissonSampler} returns the same samples when it - * is created using the {@link LargeMeanPoissonSamplerState}. + * is created using the saved state. */ @Test public void testCanComputeSameSamplesWhenConstructedWithState() { @@ -125,8 +116,8 @@ public class LargeMeanPoissonSamplerTest { } /** - * Test poisson samples are the same from the {@link PoissonSampler} - * and {@link PoissonSamplerCache}. The random providers must be + * Test the {@link LargeMeanPoissonSampler} returns the same samples when it + * is created using the saved state. The random providers must be * identical (including state). * * @param rng1 the first random provider @@ -141,9 +132,11 @@ public class LargeMeanPoissonSamplerTest { final DiscreteSampler s1 = new LargeMeanPoissonSampler(rng1, mean); final int n = (int) Math.floor(mean); final double lambdaFractional = mean - n; - final LargeMeanPoissonSamplerState state = new LargeMeanPoissonSamplerState(n); - Assert.assertEquals("Not the correct lambda", n, state.getLambda()); - final DiscreteSampler s2 = new LargeMeanPoissonSampler(rng2, state, lambdaFractional); + final LargeMeanPoissonSamplerState state1 = ((LargeMeanPoissonSampler)s1).getState(); + final DiscreteSampler s2 = new LargeMeanPoissonSampler(rng2, state1, lambdaFractional); + final LargeMeanPoissonSamplerState state2 = ((LargeMeanPoissonSampler)s2).getState(); + Assert.assertEquals("State lambdas are not equal", state1.getLambda(), state2.getLambda()); + Assert.assertNotSame("States are the same object", state1, state2); for (int j = 0; j < 10; j++) Assert.assertEquals("Not the same sample", s1.sample(), s2.sample()); }