RNG-51: Changed representation of LargeMeanPoissonSampler state

Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/db751b91
Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/db751b91
Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/db751b91

Branch: refs/heads/master
Commit: db751b911b56fecff1bce1ce33320e4308158c6a
Parents: fa38dea
Author: Alex Herbert <a.herb...@sussex.ac.uk>
Authored: Fri Sep 21 00:56:19 2018 +0100
Committer: Alex Herbert <a.herb...@sussex.ac.uk>
Committed: Fri Sep 21 00:56:19 2018 +0100

----------------------------------------------------------------------
 .../distribution/LargeMeanPoissonSampler.java   | 244 ++++++++++++-------
 .../distribution/PoissonSamplerCache.java       |  12 +-
 .../LargeMeanPoissonSamplerTest.java            |  33 +--
 3 files changed, 180 insertions(+), 109 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-rng/blob/db751b91/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
index 0beb6b4..802c5be 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSampler.java
@@ -92,79 +92,6 @@ public class LargeMeanPoissonSampler
     private final DiscreteSampler smallMeanPoissonSampler;
 
     /**
-     * Encapsulate the state of the sampler. The state is valid for 
construction of
-     * a sampler in the range {@code lambda <= mean < lambda+1}.
-     */
-    static class LargeMeanPoissonSamplerState {
-        /** Algorithm constant: {@code Math.floor(mean)}. */
-        private final double lambda;
-        /** Algorithm constant: {@code Math.log(lambda)}. */
-        private final double logLambda;
-        /** Algorithm constant: {@code factorialLog((int) lambda)}. */
-        private final double logLambdaFactorial;
-        /** Algorithm constant: {@code Math.sqrt(lambda * Math.log(32 * lambda 
/ Math.PI + 1))}. */
-        private final double delta;
-        /** Algorithm constant: {@code delta / 2}. */
-        private final double halfDelta;
-        /** Algorithm constant: {@code 2 * lambda + delta}. */
-        private final double twolpd;
-        /**
-         * Algorithm constant: {@code a1 / aSum} with
-         * <ul>
-         *  <li>{@code a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1)}</li>
-         *  <li>{@code aSum = a1 + a2 + 1}</li>
-         * </ul>
-         */
-        private final double p1;
-        /**
-         * Algorithm constant: {@code a2 / aSum} with
-         * <ul>
-         *  <li>{@code a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) / 
twolpd)}</li>
-         *  <li>{@code aSum = a1 + a2 + 1}</li>
-         * </ul>
-         */
-        private final double p2;
-        /** Algorithm constant: {@code 1 / (8 * lambda)}. */
-        private final double c1;
-
-        /**
-         * Creates the state. The state is valid for construction of a sampler 
in the
-         * range {@code n <= mean < n+1}.
-         *
-         * @param n the value n ({@code floor(mean)})
-         * @throws IllegalArgumentException if {@code n < 0}.
-         */
-        LargeMeanPoissonSamplerState(int n) {
-            if (n < 0) {
-                throw new IllegalArgumentException(n + " < " + 0);
-            }
-            // Cache values used in the algorithm
-            // This is deliberately a copy of the code in the 
-            // LargeMeanPoissonSampler constructor.
-            lambda = n;
-            logLambda = Math.log(lambda);
-            logLambdaFactorial = NO_CACHE_FACTORIAL_LOG.value(n);
-            delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
-            halfDelta = delta / 2;
-            twolpd = 2 * lambda + delta;
-            c1 = 1 / (8 * lambda);
-            final double a1 = Math.sqrt(Math.PI * twolpd) * Math.exp(c1);
-            final double a2 = (twolpd / delta) * Math.exp(-delta * (1 + delta) 
/ twolpd);
-            final double aSum = a1 + a2 + 1;
-            p1 = a1 / aSum;
-            p2 = a2 / aSum;
-        }
-
-        /**
-         * Get the lambda value for the state. Equal to {@code floor(mean)}.
-         * @return {@code floor(mean)}
-         */
-        int getLambda() {
-            return (int) lambda;
-        }
-    }
-
-    /**
      * @param rng  Generator of uniformly distributed random numbers.
      * @param mean Mean.
      * @throws IllegalArgumentException if {@code mean <= 0} or
@@ -188,6 +115,7 @@ public class LargeMeanPoissonSampler
 
         // Cache values used in the algorithm
         lambda = Math.floor(mean);
+        lambdaFractional = mean - lambda;
         logLambda = Math.log(lambda);
         logLambdaFactorial = factorialLog((int) lambda);
         delta = Math.sqrt(lambda * Math.log(32 * lambda / Math.PI + 1));
@@ -201,7 +129,6 @@ public class LargeMeanPoissonSampler
         p2 = a2 / aSum;
 
         // The algorithm requires a Poisson sample from the remaining lambda 
fraction.
-        lambdaFractional = mean - lambda;
         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
             null : // Not used.
             new SmallMeanPoissonSampler(rng, lambdaFractional);
@@ -232,18 +159,18 @@ public class LargeMeanPoissonSampler
         factorialLog = NO_CACHE_FACTORIAL_LOG;
 
         // Use the state to initialise the algorithm
-        lambda = state.lambda;
-        logLambda = state.logLambda;
-        logLambdaFactorial = state.logLambdaFactorial;
-        delta = state.delta;
-        halfDelta = state.halfDelta;
-        twolpd = state.twolpd;
-        p1 = state.p1;
-        p2 = state.p2;
-        c1 = state.c1;
+        lambda = state.getLambdaRaw();
+        this.lambdaFractional = lambdaFractional;
+        logLambda = state.getLogLambda();
+        logLambdaFactorial = state.getLogLambdaFactorial();
+        delta = state.getDelta();
+        halfDelta = state.getHalfDelta();
+        twolpd = state.getTwolpd();
+        p1 = state.getP1();
+        p2 = state.getP2();
+        c1 = state.getC1();
 
         // The algorithm requires a Poisson sample from the remaining lambda 
fraction.
-        this.lambdaFractional = lambdaFractional;
         smallMeanPoissonSampler = (lambdaFractional < Double.MIN_VALUE) ?
             null : // Not used.
             new SmallMeanPoissonSampler(rng, lambdaFractional);
@@ -324,4 +251,153 @@ public class LargeMeanPoissonSampler
     public String toString() {
         return "Large Mean Poisson deviate [" + super.toString() + "]";
     }
+
+    /**
+     * Gets the initialisation state of the sampler.
+     *
+     * <p>The state is computed using an integer {@code lambda} value of
+     * {@code lambda = (int)Math.floor(mean)}.
+     *
+     * <p>The state will be suitable for reconstructing a new sampler with a 
mean
+     * in the range {@code lambda <= mean < lambda+1} using
+     * {@link #LargeMeanPoissonSampler(UniformRandomProvider, 
LargeMeanPoissonSamplerState, double)}.
+     *
+     * @return the state
+     */
+    LargeMeanPoissonSamplerState getState() {
+        return new LargeMeanPoissonSamplerState(lambda, logLambda, 
logLambdaFactorial,
+                delta, halfDelta, twolpd, p1, p2, c1);
+    }
+
+    /**
+     * Encapsulate the state of the sampler. The state is valid for 
construction of
+     * a sampler in the range {@code lambda <= mean < lambda+1}.
+     *
+     * <p>This class is immutable.
+     *
+     * @see #getLambda()
+     */
+    static class LargeMeanPoissonSamplerState {
+        /** Algorithm constant {@code lambda}. */
+        private final double lambda;
+        /** Algorithm constant {@code logLambda}. */
+        private final double logLambda;
+        /** Algorithm constant {@code logLambdaFactorial}. */
+        private final double logLambdaFactorial;
+        /** Algorithm constant {@code delta}. */
+        private final double delta;
+        /** Algorithm constant {@code halfDelta}. */
+        private final double halfDelta;
+        /** Algorithm constant {@code twolpd}. */
+        private final double twolpd;
+        /** Algorithm constant {@code p1}. */
+        private final double p1;
+        /** Algorithm constant {@code p2}. */
+        private final double p2;
+        /** Algorithm constant {@code c1}. */
+        private final double c1;
+
+        /**
+         * Creates the state.
+         *
+         * <p>The state is valid for construction of a sampler in the range
+         * {@code lambda <= mean < lambda+1} where {@code lambda} is an 
integer.
+         *
+         * @param lambda the lambda
+         * @param logLambda the log lambda
+         * @param logLambdaFactorial the log lambda factorial
+         * @param delta the delta
+         * @param halfDelta the half delta
+         * @param twolpd the two lambda plus delta
+         * @param p1 the p1 constant
+         * @param p2 the p2 constant
+         * @param c1 the c1 constant
+         */
+        private LargeMeanPoissonSamplerState(double lambda, double logLambda,
+                double logLambdaFactorial, double delta, double halfDelta, 
double twolpd,
+                double p1, double p2, double c1) {
+          this.lambda = lambda;
+          this.logLambda = logLambda;
+          this.logLambdaFactorial = logLambdaFactorial;
+          this.delta = delta;
+          this.halfDelta = halfDelta;
+          this.twolpd = twolpd;
+          this.p1 = p1;
+          this.p2 = p2;
+          this.c1 = c1;
+        }
+
+        /**
+         * Get the lambda value for the state.
+         *
+         * <p>Equal to {@code floor(mean)} for a Poisson sampler.
+         * @return the lambda value
+         */
+        int getLambda() {
+            return (int) getLambdaRaw();
+        }
+
+        /**
+         * @return algorithm constant {@code lambda}
+         */
+        double getLambdaRaw() {
+          return lambda;
+        }
+
+        /**
+         * @return algorithm constant {@code logLambda}
+         */
+        double getLogLambda() {
+          return logLambda;
+        }
+
+        /**
+         * @return algorithm constant {@code logLambdaFactorial}
+         */
+        double getLogLambdaFactorial() {
+          return logLambdaFactorial;
+        }
+
+        /**
+         * @return algorithm constant {@code delta}
+         */
+        double getDelta() {
+          return delta;
+        }
+
+        /**
+         * @return algorithm constant {@code halfDelta}
+         */
+        double getHalfDelta() {
+          return halfDelta;
+        }
+
+        /**
+         * @return algorithm constant {@code twolpd}
+         */
+        double getTwolpd() {
+          return twolpd;
+        }
+
+        /**
+         * @return algorithm constant {@code p1}
+         */
+        double getP1() {
+          return p1;
+        }
+
+        /**
+         * @return algorithm constant {@code p2}
+         */
+        double getP2() {
+          return p2;
+        }
+
+        /**
+         * @return algorithm constant {@code c1}
+         */
+        double getC1() {
+          return c1;
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/db751b91/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java
index 2d361ff..4b74084 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/PoissonSamplerCache.java
@@ -105,7 +105,8 @@ public class PoissonSamplerCache {
                                 LargeMeanPoissonSamplerState[] states) {
         this.minN = minN;
         this.maxN = maxN;
-        this.values = states.clone();
+        // Stored directly as the states were newly created within this class.
+        this.values = states;
     }
 
     /**
@@ -166,14 +167,15 @@ public class PoissonSamplerCache {
         // Look in the cache for a state that can be reused.
         // Note: The cache is offset by minN.
         final int index = n - minN;
-        LargeMeanPoissonSamplerState state = values[index];
+        final LargeMeanPoissonSamplerState state = values[index];
         if (state == null) {
-            // Compute and store for reuse.
+            // Create a sampler and store the state for reuse.
             // Do not worry about thread contention
             // as the state is effectively immutable.
             // If recomputed and replaced it will the same.
-            state = new LargeMeanPoissonSamplerState(n);
-            values[index] = state;
+            final LargeMeanPoissonSampler sampler = new 
LargeMeanPoissonSampler(rng, mean);
+            values[index] = sampler.getState();
+            return sampler;
         }
         // Compute the remaining fraction of the mean
         final double lambdaFractional = mean - n;

http://git-wip-us.apache.org/repos/asf/commons-rng/blob/db751b91/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java
----------------------------------------------------------------------
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java
index 2b1e24e..480f452 100644
--- 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/LargeMeanPoissonSamplerTest.java
@@ -24,8 +24,8 @@ import org.junit.Assert;
 import org.junit.Test;
 
 /**
- * This test checks the {@link LargeMeanPoissonSampler} using the
- * {@link LargeMeanPoissonSamplerState}.
+ * This test checks the {@link LargeMeanPoissonSampler} can be created
+ * from a saved state.
  */
 public class LargeMeanPoissonSamplerTest {
 
@@ -55,22 +55,13 @@ public class LargeMeanPoissonSamplerTest {
     }
 
     /**
-     * Test the state cannot be created with a negative n.
-     */
-    @Test(expected=IllegalArgumentException.class)
-    public void testStateCreationThrowsWithNegativeN() {
-        @SuppressWarnings("unused")
-        LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSamplerState(-1);
-    }
-
-    /**
      * Test the constructor with a negative fractional mean.
      */
     @Test(expected=IllegalArgumentException.class)
     public void testConstructorThrowsWithNegativeFractionalMean() {
         final RestorableUniformRandomProvider rng =
                 RandomSource.create(RandomSource.SPLIT_MIX_64);
-        LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSamplerState(0);
+        final LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSampler(rng, 1).getState();
         @SuppressWarnings("unused")
         LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, 
state, -0.1);
     }
@@ -82,7 +73,7 @@ public class LargeMeanPoissonSamplerTest {
     public void testConstructorThrowsWithNonFractionalMean() {
         final RestorableUniformRandomProvider rng =
                 RandomSource.create(RandomSource.SPLIT_MIX_64);
-        LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSamplerState(0);
+        final LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSampler(rng, 1).getState();
         @SuppressWarnings("unused")
         LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, 
state, 1.1);
     }
@@ -94,7 +85,7 @@ public class LargeMeanPoissonSamplerTest {
     public void testConstructorThrowsWithFractionalMeanOne() {
         final RestorableUniformRandomProvider rng =
                 RandomSource.create(RandomSource.SPLIT_MIX_64);
-        LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSamplerState(0);
+        final LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSampler(rng, 1).getState();
         @SuppressWarnings("unused")
         LargeMeanPoissonSampler sampler = new LargeMeanPoissonSampler(rng, 
state, 1);
     }
@@ -103,7 +94,7 @@ public class LargeMeanPoissonSamplerTest {
 
     /**
      * Test the {@link LargeMeanPoissonSampler} returns the same samples when 
it
-     * is created using the {@link LargeMeanPoissonSamplerState}.
+     * is created using the saved state.
      */
     @Test
     public void testCanComputeSameSamplesWhenConstructedWithState() {
@@ -125,8 +116,8 @@ public class LargeMeanPoissonSamplerTest {
     }
 
     /**
-     * Test poisson samples are the same from the {@link PoissonSampler}
-     * and {@link PoissonSamplerCache}. The random providers must be
+     * Test the {@link LargeMeanPoissonSampler} returns the same samples when 
it
+     * is created using the saved state. The random providers must be
      * identical (including state).
      *
      * @param rng1  the first random provider
@@ -141,9 +132,11 @@ public class LargeMeanPoissonSamplerTest {
         final DiscreteSampler s1 = new LargeMeanPoissonSampler(rng1, mean);
         final int n = (int) Math.floor(mean);
         final double lambdaFractional = mean - n;
-        final LargeMeanPoissonSamplerState state = new 
LargeMeanPoissonSamplerState(n);
-        Assert.assertEquals("Not the correct lambda", n, state.getLambda());
-        final DiscreteSampler s2 = new LargeMeanPoissonSampler(rng2, state, 
lambdaFractional);
+        final LargeMeanPoissonSamplerState state1 = 
((LargeMeanPoissonSampler)s1).getState();
+        final DiscreteSampler s2 = new LargeMeanPoissonSampler(rng2, state1, 
lambdaFractional);
+        final LargeMeanPoissonSamplerState state2 = 
((LargeMeanPoissonSampler)s2).getState();
+        Assert.assertEquals("State lambdas are not equal", state1.getLambda(), 
state2.getLambda());
+        Assert.assertNotSame("States are the same object", state1, state2);
         for (int j = 0; j < 10; j++)
             Assert.assertEquals("Not the same sample", s1.sample(), 
s2.sample());
     }

Reply via email to