This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 00b7cc08f706e9d4fd26c3de6ecdd8c926984538 Author: aherbert <aherb...@apache.org> AuthorDate: Tue May 7 17:24:02 2019 +0100 RNG-101: Add MarsagliaTsangWang discrete probability sampler. This adds support for a generic distribution defined by an array of probabilities and also a Poisson and Binomial distribution. --- .../MarsagliaTsangWangBinomialSampler.java | 257 ++++++++++ .../MarsagliaTsangWangDiscreteSampler.java | 540 +++++++++++++++++++++ .../MarsagliaTsangWangSmallMeanPoissonSampler.java | 218 +++++++++ .../distribution/DiscreteSamplersList.java | 27 ++ .../MarsagliaTsangWangBinomialSamplerTest.java | 242 +++++++++ .../MarsagliaTsangWangDiscreteSamplerTest.java | 332 +++++++++++++ ...sagliaTsangWangSmallMeanPoissonSamplerTest.java | 119 +++++ 7 files changed, 1735 insertions(+) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSampler.java new file mode 100644 index 0000000..5b13155 --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSampler.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Sampler for the <a href="https://en.wikipedia.org/wiki/Binomial_distribution">Binomial + * distribution</a> using an optimised look-up table. + * + * <ul> + * <li> + * A Binomial process is simulated using pre-tabulated probabilities, as + * described in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of + * Discrete Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11. + * </li> + * </ul> + * + * <p>The sampler will fail on construction if the distribution cannot be computed. This + * occurs when {@code trials} is large and probability of success is close to {@code 0.5}. + * The exact failure condition is:</p> + * + * <pre> + * {@code Math.exp(trials * Math.log(Math.min(p, 1 - p))) < Double.MIN_VALUE} + * </pre> + * + * <p>In this case the distribution can be approximated using a limiting distributions + * of either a Poisson or a Normal distribution as appropriate.</p> + * + * <p>Note: The algorithm ignores any observation where for a sample size of + * 2<sup>31</sup> the expected number of occurrences is {@code < 0.5}.</p> + * + * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}. Storage + * requirements depend on the probabilities and are capped at 2<sup>17</sup> bytes, or 131 + * kB.</p> + * + * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol. + * 11, Issue 3</a> + * @since 1.3 + */ +public class MarsagliaTsangWangBinomialSampler implements DiscreteSampler { + /** + * The value 2<sup>30</sup> as an {@code int}.</p> + */ + private static final int INT_30 = 1 << 30; + /** + * The value 2<sup>16</sup> as an {@code int}.</p> + */ + private static final int INT_16 = 1 << 16; + /** + * The value 2<sup>31</sup> as an {@code double}.</p> + */ + private static final double DOUBLE_31 = 1L << 31; + + /** The delegate. */ + private final DiscreteSampler delegate; + + /** + * Return a fixed result for the Binomial distribution. + */ + private static class FixedResultDiscreteSampler implements DiscreteSampler { + /** The result. */ + private final int result; + + /** + * @param result Result. + */ + FixedResultDiscreteSampler(int result) { + this.result = result; + } + + @Override + public int sample() { + return result; + } + + @Override + public String toString() { + return "Binomial deviate"; + } + } + + /** + * Return an inversion result for the Binomial distribution. This assumes the + * following: + * + * <pre> + * Binomial(n, p) = 1 - Binomial(n, 1 - p) + * </pre> + */ + private static class InversionBinomialDiscreteSampler implements DiscreteSampler { + /** The number of trials. */ + private final int trials; + /** The Binomial distribution sampler. */ + private final DiscreteSampler sampler; + + /** + * @param trials Number of trials. + * @param sampler Binomial distribution sampler. + */ + InversionBinomialDiscreteSampler(int trials, DiscreteSampler sampler) { + this.trials = trials; + this.sampler = sampler; + } + + @Override + public int sample() { + return trials - sampler.sample(); + } + + @Override + public String toString() { + return sampler.toString(); + } + } + + /** + * Create a new instance. + * + * @param rng Generator of uniformly distributed random numbers. + * @param trials Number of trials. + * @param p Probability of success. + * @throws IllegalArgumentException if {@code trials < 0} or {@code trials >= 2^16}, + * {@code p} is not in the range {@code [0-1]}, or the probability distribution cannot + * be computed. + */ + public MarsagliaTsangWangBinomialSampler(UniformRandomProvider rng, int trials, double p) { + if (trials < 0) { + throw new IllegalArgumentException("Trials is not positive: " + trials); + } + if (p < 0 || p > 1) { + throw new IllegalArgumentException("Probability is not in range [0,1]: " + p); + } + + // Handle edge cases + if (p == 0) { + delegate = new FixedResultDiscreteSampler(0); + return; + } + if (p == 1) { + delegate = new FixedResultDiscreteSampler(trials); + return; + } + + // A simple check using the supported index size. + if (trials >= INT_16) { + throw new IllegalArgumentException("Unsupported number of trials: " + trials); + } + + // The maximum supported value for Math.exp is approximately -744. + // This occurs when trials is large and p is close to 1. + // Handle this by using an inversion: generate j=Binomial(n,1-p), return n-j + final boolean inversion = p > 0.5; + if (inversion) { + p = 1 - p; + } + + // Check if the distribution can be computed + final double p0 = Math.exp(trials * Math.log(1 - p)); + if (p0 < Double.MIN_VALUE) { + throw new IllegalArgumentException("Unable to compute distribution"); + } + + // First find size of probability array + double t = p0; + final double h = p / (1 - p); + // Find first probability + int begin = 0; + if (t * DOUBLE_31 < 1) { + // Somewhere after p(0) + // Note: + // If this loop is entered p(0) is < 2^-31. + // This has been tested at the extreme for p(0)=Double.MIN_VALUE and either + // p=0.5 or trials=2^16-1 and does not fail to find the beginning. + for (int i = 1; i <= trials; i++) { + t *= (trials + 1 - i) * h / i; + if (t * DOUBLE_31 >= 1) { + begin = i; + break; + } + } + } + // Find last probability + int end = trials; + for (int i = begin + 1; i <= trials; i++) { + t *= (trials + 1 - i) * h / i; + if (t * DOUBLE_31 < 1) { + end = i - 1; + break; + } + } + final int size = end - begin + 1; + final int offset = begin; + + // Then assign probability values as 30-bit integers + final int[] prob = new int[size]; + t = p0; + for (int i = 1; i <= begin; i++) { + t *= (trials + 1 - i) * h / i; + } + int sum = toUnsignedInt30(t); + prob[0] = sum; + for (int i = begin + 1; i <= end; i++) { + t *= (trials + 1 - i) * h / i; + prob[i - begin] = toUnsignedInt30(t); + sum += prob[i - begin]; + } + + // If the sum is < 2^30 add the remaining sum to the mode (floor((n+1)p))). + final int mode = (int) ((trials + 1) * p) - offset; + prob[mode] += Math.max(0, INT_30 - sum); + + final MarsagliaTsangWangDiscreteSampler sampler = new MarsagliaTsangWangDiscreteSampler(rng, prob, offset); + + if (inversion) { + delegate = new InversionBinomialDiscreteSampler(trials, sampler); + } else { + delegate = sampler; + } + } + + /** + * Convert the probability to an unsigned integer in the range [0,2^30]. + * + * @param p the probability + * @return the integer + */ + private static int toUnsignedInt30(double p) { + return (int) (p * INT_30 + 0.5); + } + + /** {@inheritDoc} */ + @Override + public int sample() { + return delegate.sample(); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return "Binomial " + delegate.toString(); + } +} diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java new file mode 100644 index 0000000..a1fc5a7 --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSampler.java @@ -0,0 +1,540 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Sampler for a discrete distribution using an optimised look-up table. + * + * <ul> + * <li> + * The method requires 30-bit integer probabilities that sum to 2<sup>30</sup> as described + * in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete + * Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11. + * </li> + * </ul> + * + * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}.</p> + * + * <p>Memory requirements depend on the maximum number of possible sample values, {@code n}, + * and the values for the probabilities. Storage is optimised for {@code n}. The worst case + * scenario is a uniform distribution of the maximum sample size. This is capped at 0.06MB for + * {@code n <= } 2<sup>8</sup>, 17.0MB for {@code n <= } 2<sup>16</sup>, and 4.3GB for + * {@code n <=} 2<sup>30</sup>. Realistic requirements will be in the kB range.</p> + * + * @since 1.3 + * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol. + * 11, Issue 3</a> + */ +public class MarsagliaTsangWangDiscreteSampler implements DiscreteSampler { + /** The exclusive upper bound for an unsigned 8-bit integer. */ + private static final int UNSIGNED_INT_8 = 1 << 8; + /** The exclusive upper bound for an unsigned 16-bit integer. */ + private static final int UNSIGNED_INT_16 = 1 << 16; + + /** Limit for look-up table 1. */ + private final int t1; + /** Limit for look-up table 2. */ + private final int t2; + /** Limit for look-up table 3. */ + private final int t3; + /** Limit for look-up table 4. */ + private final int t4; + + /** Index look-up table. */ + private final IndexTable indexTable; + + /** Underlying source of randomness. */ + private final UniformRandomProvider rng; + + /** + * An index table contains the sample values. This is efficiently accessed for any index in the + * range {@code [0,2^30)} by using an algorithm based on the decomposition of the index into + * 5 base-64 digits. + * + * <p>This interface defines the methods for the filling and accessing values from 5 tables. + * It allows a concrete implementation to allocate appropriate tables to optimise memory + * requirements.</p> + */ + private interface IndexTable { + /** + * @param from Lower bound index (inclusive). + * @param to Upper bound index (exclusive). + * @param value Value. + */ + void fillTable1(int from, int to, int value); + /** + * @param from Lower bound index (inclusive). + * @param to Upper bound index (exclusive). + * @param value Value. + */ + void fillTable2(int from, int to, int value); + /** + * @param from Lower bound index (inclusive). + * @param to Upper bound index (exclusive). + * @param value Value. + */ + void fillTable3(int from, int to, int value); + /** + * @param from Lower bound index (inclusive). + * @param to Upper bound index (exclusive). + * @param value Value. + */ + void fillTable4(int from, int to, int value); + /** + * @param from Lower bound index (inclusive). + * @param to Upper bound index (exclusive). + * @param value Value. + */ + void fillTable5(int from, int to, int value); + + /** + * @param index Index. + * @return Value. + */ + int getTable1(int index); + /** + * @param index Index. + * @return Value. + */ + int getTable2(int index); + /** + * @param index Index. + * @return Value. + */ + int getTable3(int index); + /** + * @param index Index. + * @return Value. + */ + int getTable4(int index); + /** + * @param index Index. + * @return Value. + */ + int getTable5(int index); + } + + /** + * Index table for an 8-bit index. + */ + private static class IndexTable8 implements IndexTable { + /** The mask to convert a {@code byte} to an unsigned 8-bit integer. */ + private static final int MASK = 0xff; + + /** Look-up table table1. */ + private final byte[] table1; + /** Look-up table table2. */ + private final byte[] table2; + /** Look-up table table3. */ + private final byte[] table3; + /** Look-up table table4. */ + private final byte[] table4; + /** Look-up table table5. */ + private final byte[] table5; + + /** + * @param n1 Size of table 1. + * @param n2 Size of table 2. + * @param n3 Size of table 3. + * @param n4 Size of table 4. + * @param n5 Size of table 5. + */ + IndexTable8(int n1, int n2, int n3, int n4, int n5) { + table1 = new byte[n1]; + table2 = new byte[n2]; + table3 = new byte[n3]; + table4 = new byte[n4]; + table5 = new byte[n5]; + } + + @Override + public void fillTable1(int from, int to, int value) { fill(table1, from, to, value); } + @Override + public void fillTable2(int from, int to, int value) { fill(table2, from, to, value); } + @Override + public void fillTable3(int from, int to, int value) { fill(table3, from, to, value); } + @Override + public void fillTable4(int from, int to, int value) { fill(table4, from, to, value); } + @Override + public void fillTable5(int from, int to, int value) { fill(table5, from, to, value); } + + /** + * Fill the table with the value. + * + * @param table Table. + * @param from Lower bound index (inclusive) + * @param to Upper bound index (exclusive) + * @param value Value. + */ + private static void fill(byte[] table, int from, int to, int value) { + while (from < to) { + // Primitive type conversion will extract lower 8 bits + table[from++] = (byte) value; + } + } + + @Override + public int getTable1(int index) { return table1[index] & MASK; } + @Override + public int getTable2(int index) { return table2[index] & MASK; } + @Override + public int getTable3(int index) { return table3[index] & MASK; } + @Override + public int getTable4(int index) { return table4[index] & MASK; } + @Override + public int getTable5(int index) { return table5[index] & MASK; } + } + + /** + * Index table for a 16-bit index. + */ + private static class IndexTable16 implements IndexTable { + /** The mask to convert a {@code short} to an unsigned 16-bit integer. */ + private static final int MASK = 0xffff; + + /** Look-up table table1. */ + private final short[] table1; + /** Look-up table table2. */ + private final short[] table2; + /** Look-up table table3. */ + private final short[] table3; + /** Look-up table table4. */ + private final short[] table4; + /** Look-up table table5. */ + private final short[] table5; + + /** + * @param n1 Size of table 1. + * @param n2 Size of table 2. + * @param n3 Size of table 3. + * @param n4 Size of table 4. + * @param n5 Size of table 5. + */ + IndexTable16(int n1, int n2, int n3, int n4, int n5) { + table1 = new short[n1]; + table2 = new short[n2]; + table3 = new short[n3]; + table4 = new short[n4]; + table5 = new short[n5]; + } + + @Override + public void fillTable1(int from, int to, int value) { fill(table1, from, to, value); } + @Override + public void fillTable2(int from, int to, int value) { fill(table2, from, to, value); } + @Override + public void fillTable3(int from, int to, int value) { fill(table3, from, to, value); } + @Override + public void fillTable4(int from, int to, int value) { fill(table4, from, to, value); } + @Override + public void fillTable5(int from, int to, int value) { fill(table5, from, to, value); } + + /** + * Fill the table with the value. + * + * @param table Table. + * @param from Lower bound index (inclusive) + * @param to Upper bound index (exclusive) + * @param value Value. + */ + private static void fill(short[] table, int from, int to, int value) { + while (from < to) { + // Primitive type conversion will extract lower 16 bits + table[from++] = (short) value; + } + } + + @Override + public int getTable1(int index) { return table1[index] & MASK; } + @Override + public int getTable2(int index) { return table2[index] & MASK; } + @Override + public int getTable3(int index) { return table3[index] & MASK; } + @Override + public int getTable4(int index) { return table4[index] & MASK; } + @Override + public int getTable5(int index) { return table5[index] & MASK; } + } + + /** + * Index table for a 32-bit index. + */ + private static class IndexTable32 implements IndexTable { + /** Look-up table table1. */ + private final int[] table1; + /** Look-up table table2. */ + private final int[] table2; + /** Look-up table table3. */ + private final int[] table3; + /** Look-up table table4. */ + private final int[] table4; + /** Look-up table table5. */ + private final int[] table5; + + /** + * @param n1 Size of table 1. + * @param n2 Size of table 2. + * @param n3 Size of table 3. + * @param n4 Size of table 4. + * @param n5 Size of table 5. + */ + IndexTable32(int n1, int n2, int n3, int n4, int n5) { + table1 = new int[n1]; + table2 = new int[n2]; + table3 = new int[n3]; + table4 = new int[n4]; + table5 = new int[n5]; + } + + @Override + public void fillTable1(int from, int to, int value) { fill(table1, from, to, value); } + @Override + public void fillTable2(int from, int to, int value) { fill(table2, from, to, value); } + @Override + public void fillTable3(int from, int to, int value) { fill(table3, from, to, value); } + @Override + public void fillTable4(int from, int to, int value) { fill(table4, from, to, value); } + @Override + public void fillTable5(int from, int to, int value) { fill(table5, from, to, value); } + + /** + * Fill the table with the value. + * + * @param table Table. + * @param from Lower bound index (inclusive) + * @param to Upper bound index (exclusive) + * @param value Value. + */ + private static void fill(int[] table, int from, int to, int value) { + while (from < to) { + table[from++] = value; + } + } + + @Override + public int getTable1(int index) { return table1[index]; } + @Override + public int getTable2(int index) { return table2[index]; } + @Override + public int getTable3(int index) { return table3[index]; } + @Override + public int getTable4(int index) { return table4[index]; } + @Override + public int getTable5(int index) { return table5[index]; } + } + + /** + * Create a new instance for probabilities {@code p(i)} where the sample value {@code x} is + * {@code i + offset}. + * + * <p>The sum of the probabilities must be >= 2<sup>30</sup>. Only the + * values for cumulative probability up to 2<sup>30</sup> will be sampled.</p> + * + * <p>Note: This is package-private for use by discrete distribution samplers that can + * compute their probability distribution.</p> + * + * @param rng Generator of uniformly distributed random numbers. + * @param prob The probabilities. + * @param offset The offset (must be positive). + * @throws IllegalArgumentException if the offset is negative or the maximum sample index + * exceeds the maximum positive {@code int} value (2<sup>31</sup> - 1). + */ + MarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng, + int[] prob, + int offset) { + if (offset < 0) { + throw new IllegalArgumentException("Unsupported offset: " + offset); + } + if ((long) prob.length + offset > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Unsupported sample index: " + (prob.length + offset)); + } + + this.rng = rng; + + // Get table sizes for each base-64 digit + int n1 = 0; + int n2 = 0; + int n3 = 0; + int n4 = 0; + int n5 = 0; + for (final int m : prob) { + n1 += getBase64Digit(m, 1); + n2 += getBase64Digit(m, 2); + n3 += getBase64Digit(m, 3); + n4 += getBase64Digit(m, 4); + n5 += getBase64Digit(m, 5); + } + + // Allocate tables based on the maximum index + final int maxIndex = prob.length + offset - 1; + if (maxIndex < UNSIGNED_INT_8) { + indexTable = new IndexTable8(n1, n2, n3, n4, n5); + } else if (maxIndex < UNSIGNED_INT_16) { + indexTable = new IndexTable16(n1, n2, n3, n4, n5); + } else { + indexTable = new IndexTable32(n1, n2, n3, n4, n5); + } + + // Compute offsets + t1 = n1 << 24; + t2 = t1 + (n2 << 18); + t3 = t2 + (n3 << 12); + t4 = t3 + (n4 << 6); + n1 = n2 = n3 = n4 = n5 = 0; + + // Fill tables + for (int i = 0; i < prob.length; i++) { + final int m = prob[i]; + final int k = i + offset; + indexTable.fillTable1(n1, n1 += getBase64Digit(m, 1), k); + indexTable.fillTable2(n2, n2 += getBase64Digit(m, 2), k); + indexTable.fillTable3(n3, n3 += getBase64Digit(m, 3), k); + indexTable.fillTable4(n4, n4 += getBase64Digit(m, 4), k); + indexTable.fillTable5(n5, n5 += getBase64Digit(m, 5), k); + } + } + + /** + * Creates a sampler. + * + * <p>The probabilities will be normalised using their sum. The only requirement is the sum + * is positive.</p> + * + * <p>The sum of the probabilities is normalised to 2<sup>30</sup>. Any probability less + * than 2<sup>-30</sup> will not be observed in samples. An adjustment is made to the maximum + * probability to compensate for round-off during conversion.</p> + * + * @param rng Generator of uniformly distributed random numbers. + * @param probabilities The list of probabilities. + * @throws IllegalArgumentException if {@code probabilities} is null or empty, a + * probability is negative, infinite or {@code NaN}, or the sum of all + * probabilities is not strictly positive. + */ + public MarsagliaTsangWangDiscreteSampler(UniformRandomProvider rng, + double[] probabilities) { + this(rng, normaliseProbabilities(probabilities), 0); + } + + /** + * Normalise the probabilities to integers that sum to 2<sup>30</sup>. + * + * @param probabilities The list of probabilities. + * @return the normalised probabilities. + * @throws IllegalArgumentException if {@code probabilities} is null or empty, a + * probability is negative, infinite or {@code NaN}, or the sum of all + * probabilities is not strictly positive. + */ + private static int[] normaliseProbabilities(double[] probabilities) { + final double sumProb = validateProbabilities(probabilities); + + // Compute the normalisation: 2^30 / sum + final double normalisation = (1 << 30) / sumProb; + final int[] prob = new int[probabilities.length]; + int sum = 0; + int max = 0; + int mode = 0; + for (int i = 0; i < prob.length; i++) { + // Add 0.5 for rounding + final int p = (int) (probabilities[i] * normalisation + 0.5); + sum += p; + // Find the mode (maximum probability) + if (max < p) { + max = p; + mode = i; + } + prob[i] = p; + } + + // The sum must be >= 2^30. + // Here just compensate the difference onto the highest probability. + prob[mode] += (1 << 30) - sum; + + return prob; + } + + /** + * Validate the probabilities sum to a finite positive number. + * + * @param probabilities the probabilities + * @return the sum + * @throws IllegalArgumentException if {@code probabilities} is null or empty, a + * probability is negative, infinite or {@code NaN}, or the sum of all + * probabilities is not strictly positive. + */ + private static double validateProbabilities(double[] probabilities) { + if (probabilities == null || probabilities.length == 0) { + throw new IllegalArgumentException("Probabilities must not be empty."); + } + + double sumProb = 0; + for (final double prob : probabilities) { + if (prob < 0 || + Double.isInfinite(prob) || + Double.isNaN(prob)) { + throw new IllegalArgumentException("Invalid probability: " + + prob); + } + sumProb += prob; + } + + if (Double.isInfinite(sumProb) || sumProb <= 0) { + throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb); + } + return sumProb; + } + + /** + * Gets the k<sup>th</sup> base 64 digit of {@code m}. + * + * @param m the value m. + * @param k the digit. + * @return the base 64 digit + */ + private static int getBase64Digit(int m, int k) { + return (m >>> (30 - 6 * k)) & 63; + } + + /** {@inheritDoc} */ + @Override + public int sample() { + final int j = rng.nextInt() >>> 2; + if (j < t1) { + return indexTable.getTable1(j >>> 24); + } + if (j < t2) { + return indexTable.getTable2((j - t1) >>> 18); + } + if (j < t3) { + return indexTable.getTable3((j - t2) >>> 12); + } + if (j < t4) { + return indexTable.getTable4((j - t3) >>> 6); + } + // Note the tables are filled on the assumption that the sum of the probabilities. + // is >=2^30. If this is not true then the final table table5 will be smaller by the + // difference. So the tables *must* be constructed correctly. + return indexTable.getTable5(j - t4); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return "Marsaglia Tsang Wang discrete deviate [" + rng.toString() + "]"; + } +} diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSampler.java new file mode 100644 index 0000000..4ac66e8 --- /dev/null +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSampler.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; + +/** + * Sampler for the <a href="http://mathworld.wolfram.com/PoissonDistribution.html">Poisson + * distribution</a> using an optimised look-up table. + * + * <ul> + * <li> + * A Poisson process is simulated using pre-tabulated probabilities, as described + * in George Marsaglia, Wai Wan Tsang, Jingbo Wang (2004) Fast Generation of Discrete + * Random Variables. Journal of Statistical Software. Vol. 11, Issue. 3, pp. 1-11. + * </li> + * </ul> + * + * <p>This sampler is suitable for {@code mean <= 1024}. Larger means accumulate errors + * when tabulating the Poisson probability. For large means, {@link LargeMeanPoissonSampler} + * should be used instead.</p> + * + * <p>Note: The algorithm ignores any observation where for a sample size of + * 2<sup>31</sup> the expected number of occurrences is {@code < 0.5}.</p> + * + * <p>Sampling uses 1 call to {@link UniformRandomProvider#nextInt()}. Storage requirements + * depend on the tabulated probability values. Example storage requirements are listed below.</p> + * + * <pre> + * mean table size kB + * 0.25 882 0.88 + * 0.5 1135 1.14 + * 1 1200 1.20 + * 2 1451 1.45 + * 4 1955 1.96 + * 8 2961 2.96 + * 16 4410 4.41 + * 32 6115 6.11 + * 64 8499 8.50 + * 128 11528 11.53 + * 256 15935 31.87 + * 512 20912 41.82 + * 1024 30614 61.23 + * </pre> + * + * <p>Note: Storage changes to 2 bytes per index when {@code mean=256}.</p> + * + * @since 1.3 + * @see <a href="http://dx.doi.org/10.18637/jss.v011.i03">Margsglia, et al (2004) JSS Vol. + * 11, Issue 3</a> + */ +public class MarsagliaTsangWangSmallMeanPoissonSampler implements DiscreteSampler { + /** + * The value 2<sup>30</sup> as an {@code int}.</p> + */ + private static final int INT_30 = 1 << 30; + /** + * The value 2<sup>31</sup> as an {@code double}.</p> + */ + private static final double DOUBLE_31 = 1L << 31; + /** + * Upper bound to avoid exceeding the table sizes. + * + * <p>The number of possible values of the distribution should not exceed 2^16.</p> + * + * <p>The original source code provided in Marsaglia, et al (2004) has no explicit + * limit but the code fails at mean >= 1941 as the transform to compute p(x=mode) + * produces infinity. Use a conservative limit of 1024.</p> + */ + private static final double MAX_MEAN = 1024; + + /** The delegate. */ + private final DiscreteSampler delegate; + + /** + * Create a new instance. + * + * @param rng Generator of uniformly distributed random numbers. + * @param mean Mean. + * @throws IllegalArgumentException if {@code mean <= 0} or {@code mean > 1024}. + */ + public MarsagliaTsangWangSmallMeanPoissonSampler(UniformRandomProvider rng, double mean) { + if (mean <= 0) { + throw new IllegalArgumentException("mean is not strictly positive: " + mean); + } + // The algorithm is not valid if Math.floor(mean) is not an integer. + if (mean > MAX_MEAN) { + throw new IllegalArgumentException("mean " + mean + " > " + MAX_MEAN); + } + + // Probabilities are 30-bit integers, assumed denominator 2^30 + int[] prob; + // This is the minimum sample value: prob[x - offset] = p(x) + int offset; + + // Generate P's from 0 if mean < 21.4 + if (mean < 21.4) { + final double p0 = Math.exp(-mean); + + // Recursive update of Poisson probability until the value is too small + // p(x + 1) = p(x) * mean / (x + 1) + double p = p0; + int i; + for (i = 1; p * DOUBLE_31 >= 1; i++) { + p *= mean / i; + } + + // Fill P as (30-bit integers) + offset = 0; + final int size = i - 1; + prob = new int[size]; + + p = p0; + prob[0] = toUnsignedInt30(p); + // The sum must exceed 2^30. In edges cases this is false due to round-off. + int sum = prob[0]; + for (i = 1; i < prob.length; i++) { + p *= mean / i; + prob[i] = toUnsignedInt30(p); + sum += prob[i]; + } + + // If the sum is < 2^30 add the remaining sum to the mode (floor(mean)). + prob[(int) mean] += Math.max(0, INT_30 - sum); + } else { + // If mean >= 21.4, generate from largest p-value up, then largest down. + // The largest p-value will be at the mode (floor(mean)). + + // Find p(x=mode) + final int mode = (int) mean; + // This transform is stable until mean >= 1941 where p will result in Infinity + // before the divisor i is large enough to start reducing the product (i.e. i > c). + final double c = mean * Math.exp(-mean / mode); + double p = 1.0; + int i; + for (i = 1; i <= mode; i++) { + p *= c / i; + } + final double pX = p; + // Note this will exit when i overflows to negative so no check on the range + for (i = mode + 1; p * DOUBLE_31 >= 1; i++) { + p *= mean / i; + } + final int last = i - 2; + p = pX; + int j = -1; + for (i = mode - 1; i >= 0; i--) { + p *= (i + 1) / mean; + if (p * DOUBLE_31 < 1) { + j = i; + break; + } + } + + // Fill P as (30-bit integers) + offset = j + 1; + final int size = last - offset + 1; + prob = new int[size]; + + p = pX; + prob[mode - offset] = toUnsignedInt30(p); + // The sum must exceed 2^30. In edges cases this is false due to round-off. + int sum = prob[mode - offset]; + for (i = mode + 1; i <= last; i++) { + p *= mean / i; + prob[i - offset] = toUnsignedInt30(p); + sum += prob[i - offset]; + } + p = pX; + for (i = mode - 1; i >= offset; i--) { + p *= (i + 1) / mean; + prob[i - offset] = toUnsignedInt30(p); + sum += prob[i - offset]; + } + + // If the sum is < 2^30 add the remaining sum to the mode + prob[mode - offset] += Math.max(0, INT_30 - sum); + } + + delegate = new MarsagliaTsangWangDiscreteSampler(rng, prob, offset); + } + + /** + * Convert the probability to an unsigned integer in the range [0,2^30]. + * + * @param p the probability + * @return the integer + */ + private static int toUnsignedInt30(double p) { + return (int) (p * INT_30 + 0.5); + } + + /** {@inheritDoc} */ + @Override + public int sample() { + return delegate.sample(); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + return "Small Mean Poisson " + delegate.toString(); + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java index 5dab832..6158a2d 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java @@ -50,6 +50,15 @@ public class DiscreteSamplersList { add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial), MathArrays.sequence(8, 9, 1), RandomSource.create(RandomSource.KISS)); + add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, probSuccessBinomial), + // range [9,16] + MathArrays.sequence(8, 9, 1), + new MarsagliaTsangWangBinomialSampler(RandomSource.create(RandomSource.WELL_19937_A), trialsBinomial, probSuccessBinomial)); + // Inverted + add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(unusedRng, trialsBinomial, 1 - probSuccessBinomial), + // range [4,11] = [20-16, 20-9] + MathArrays.sequence(8, 4, 1), + new MarsagliaTsangWangBinomialSampler(RandomSource.create(RandomSource.WELL_19937_C), trialsBinomial, 1 - probSuccessBinomial)); // Geometric ("inverse method"). final double probSuccessGeometric = 0.21; @@ -146,6 +155,11 @@ public class DiscreteSamplersList { add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(unusedRng, veryLargeMeanPoisson, epsilonPoisson, maxIterationsPoisson), MathArrays.sequence(100, (int) (veryLargeMeanPoisson - 50), 1), new LargeMeanPoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), veryLargeMeanPoisson)); + + // Any discrete distribution + double[] discreteProbabilities = new double[] { 0.1, 0.2, 0.3, 0.4 }; + add(LIST, discreteProbabilities, + new MarsagliaTsangWangDiscreteSampler(RandomSource.create(RandomSource.XO_SHI_RO_512_PLUS), discreteProbabilities)); } catch (Exception e) { System.err.println("Unexpected exception while creating the list of samplers: " + e); e.printStackTrace(System.err); @@ -201,6 +215,19 @@ public class DiscreteSamplersList { } /** + * @param list List of data (one the "parameters" tested by the Junit parametric test). + * @param probabilities Probability distribution to which the samples are supposed to conform. + * @param sampler Sampler. + */ + private static void add(List<DiscreteSamplerTestData[]> list, + final double[] probabilities, + final DiscreteSampler sampler) { + list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(sampler, + MathArrays.natural(probabilities.length), + probabilities) }); + } + + /** * Subclasses that are "parametric" tests can forward the call to * the "@Parameters"-annotated method to this method. * diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSamplerTest.java new file mode 100644 index 0000000..bfe5052 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangBinomialSamplerTest.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; +import org.junit.Test; + +import org.junit.Assert; + +/** + * Test for the {@link MarsagliaTsangWangBinomialSampler}. The tests hit edge cases for + * the sampler. + */ +public class MarsagliaTsangWangBinomialSamplerTest { + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithTrialsBelow0() { + final UniformRandomProvider rng = new FixedRNG(0); + final int trials = -1; + final double p = 0.5; + @SuppressWarnings("unused") + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithTrialsAboveMax() { + final UniformRandomProvider rng = new FixedRNG(0); + final int trials = 1 << 16; // 2^16 + final double p = 0.5; + @SuppressWarnings("unused") + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithProbabilityBelow0() { + final UniformRandomProvider rng = new FixedRNG(0); + final int trials = 1; + final double p = -0.5; + @SuppressWarnings("unused") + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithProbabilityAbove1() { + final UniformRandomProvider rng = new FixedRNG(0); + final int trials = 1; + final double p = 1.5; + @SuppressWarnings("unused") + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + } + + /** + * Test the constructor with distribution parameters that create a very small p(0) + * with a high probability of success. + */ + @Test + public void testSamplerWithSmallestP0ValueAndHighestProbabilityOfSuccess() { + final UniformRandomProvider rng = new FixedRNG(0xffffffff); + // p(0) = Math.exp(trials * Math.log(1-p)) + // p(0) will be smaller as Math.log(1-p) is more negative, which occurs when p is + // larger. + // Since the sampler uses inversion the largest value for p is 0.5. + // At the extreme for p = 0.5: + // trials = Math.log(p(0)) / Math.log(1-p) + // = Math.log(Double.MIN_VALUE) / Math.log(0.5) + // = 1074 + final int trials = (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5)); + final double p = 0.5; + // Validate set-up + Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getP0(trials, p), 0); + Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials + 1, p), 0); + + // This will throw if the table does not sum to 2^30 + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + sampler.sample(); + } + + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWhenP0IsZero() { + final UniformRandomProvider rng = new FixedRNG(0); + // As above but increase the trials so p(0) should be zero + final int trials = 1 + (int) Math.floor(Math.log(Double.MIN_VALUE) / Math.log(0.5)); + final double p = 0.5; + // Validate set-up + Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials, p), 0); + @SuppressWarnings("unused") + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + } + + /** + * Test the constructor with distribution parameters that create a very small p(0) + * with a high number of trials. + */ + @Test + public void testSamplerWithLargestTrialsAndSmallestProbabilityOfSuccess() { + final UniformRandomProvider rng = new FixedRNG(0xffffffff); + // p(0) = Math.exp(trials * Math.log(1-p)) + // p(0) will be smaller as Math.log(1-p) is more negative, which occurs when p is + // larger. + // Since the sampler uses inversion the largest value for p is 0.5. + // At the extreme for trials = 2^16-1: + // p = 1 - Math.exp(Math.log(p(0)) / trials) + // = 1 - Math.exp(Math.log(Double.MIN_VALUE) / trials) + // = 0.011295152668039599 + final int trials = (1 << 16) - 1; + double p = 1 - Math.exp(Math.log(Double.MIN_VALUE) / trials); + + // Validate set-up + Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getP0(trials, p), 0); + + // Search for larger p until Math.nextAfter(p, 1) produces 0 + double upper = p * 2; + Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials, upper), 0); + + double lower = p; + while (Double.doubleToRawLongBits(lower) + 1 < Double.doubleToRawLongBits(upper)) { + final double mid = (upper + lower) / 2; + if (getP0(trials, mid) == 0) { + upper = mid; + } else { + lower = mid; + } + } + p = lower; + + // Re-validate + Assert.assertEquals("Invalid test set-up for p(0)", Double.MIN_VALUE, getP0(trials, p), 0); + Assert.assertEquals("Invalid test set-up for p(0)", 0, getP0(trials, Math.nextAfter(p, 1)), 0); + + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + // This will throw if the table does not sum to 2^30 + sampler.sample(); + } + + /** + * Gets the p(0) value. + * + * @param trials the trials + * @param probabilityOfSuccess the probability of success + * @return the p(0) value + */ + private static double getP0(int trials, double probabilityOfSuccess) { + return Math.exp(trials * Math.log(1 - probabilityOfSuccess)); + } + + @Test + public void testSamplerWithProbability0() { + final UniformRandomProvider rng = new FixedRNG(0); + final int trials = 1000000; + final double p = 0; + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(0, sampler.sample()); + } + // Hit the toString() method + Assert.assertTrue(sampler.toString().contains("Binomial")); + } + + @Test + public void testSamplerWithProbability1() { + final UniformRandomProvider rng = new FixedRNG(0); + final int trials = 1000000; + final double p = 1; + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + for (int i = 0; i < 5; i++) { + Assert.assertEquals(trials, sampler.sample()); + } + // Hit the toString() method + Assert.assertTrue(sampler.toString().contains("Binomial")); + } + + /** + * Test the sampler with a large number of trials. This tests the sampler can create the + * Binomial distribution for a large size when a limiting distribution (e.g. the Normal distribution) + * could be used instead. + */ + @Test + public void testSamplerWithLargeNumberOfTrials() { + final UniformRandomProvider rng = new FixedRNG(0xffffffff); + final int trials = 65000; + final double p = 0.01; + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + // This will throw if the table does not sum to 2^30 + sampler.sample(); + } + + /** + * Test the sampler with a probability of 0.5. This should hit the edge case in the loop to + * search for the last probability of the Binomial distribution. + */ + @Test + public void testSamplerWithProbability0_5() { + final UniformRandomProvider rng = new FixedRNG(0xffffffff); + final int trials = 10; + final double p = 0.5; + final MarsagliaTsangWangBinomialSampler sampler = new MarsagliaTsangWangBinomialSampler(rng, trials, p); + // This will throw if the table does not sum to 2^30 + sampler.sample(); + } + + /** + * A RNG returning a fixed value. + */ + private static class FixedRNG implements UniformRandomProvider { + /** The value. */ + private final int value; + + /** + * @param value the value + */ + FixedRNG(int value) { + this.value = value; + } + + @Override + public int nextInt() { + return value; + } + + public void nextBytes(byte[] bytes) {} + public void nextBytes(byte[] bytes, int start, int len) {} + public int nextInt(int n) { return 0; } + public long nextLong() { return 0; } + public long nextLong(long n) { return 0; } + public boolean nextBoolean() { return false; } + public float nextFloat() { return 0; } + public double nextDouble() { return 0; } + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java new file mode 100644 index 0000000..d1fce42 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangDiscreteSamplerTest.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.math3.stat.inference.ChiSquareTest; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.core.source32.IntProvider; +import org.apache.commons.rng.core.source64.SplitMix64; +import org.apache.commons.rng.simple.RandomSource; +import org.junit.Assert; +import org.junit.Test; + +/** + * Test for the {@link MarsagliaTsangWangDiscreteSampler}. The tests hit edge cases for + * the sampler. + */ +public class MarsagliaTsangWangDiscreteSamplerTest { + // Tests for the package-private constructor using int[] + offset + + /** + * Test constructor throws with max index above integer max. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithMaxIndexAboveIntegerMax() { + final int[] prob = new int[1]; + final int offset = Integer.MAX_VALUE; + createSampler(prob, offset); + } + + /** + * Test constructor throws with negative offset. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithNegativeOffset() { + final int[] prob = new int[1]; + final int offset = -1; + createSampler(prob, offset); + } + + /** + * Test construction is allowed or when max index equals integer max. + */ + @Test + public void testConstructorWhenMaxIndexEqualsIntegerMax() { + final int[] prob = new int[1]; + prob[0] = 1 << 30; // So the total probability is 2^30 + final int offset = Integer.MAX_VALUE - 1; + createSampler(prob, offset); + } + + /** + * Creates the sampler. + * + * @param prob the probabilities + * @param offset the offset + * @return the sampler + */ + private static MarsagliaTsangWangDiscreteSampler createSampler(final int[] probabilities, int offset) { + final UniformRandomProvider rng = new SplitMix64(0L); + return new MarsagliaTsangWangDiscreteSampler(rng, probabilities, offset); + } + + // Tests for the public constructor using double[] + + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithNullProbabilites() { + createSampler(null); + } + + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithZeroLengthProbabilites() { + createSampler(new double[0]); + } + + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithNegativeProbabilites() { + createSampler(new double[] { -1, 0.1, 0.2 }); + } + + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithNaNProbabilites() { + createSampler(new double[] { 0.1, Double.NaN, 0.2 }); + } + + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithInfiniteProbabilites() { + createSampler(new double[] { 0.1, Double.POSITIVE_INFINITY, 0.2 }); + } + + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithInfiniteSumProbabilites() { + createSampler(new double[] { Double.MAX_VALUE, Double.MAX_VALUE }); + } + + @Test(expected=IllegalArgumentException.class) + public void testConstructorThrowsWithZeroSumProbabilites() { + createSampler(new double[4]); + } + + /** + * Creates the sampler. + * + * @param probabilities the probabilities + * @return the sampler + */ + private static MarsagliaTsangWangDiscreteSampler createSampler(double[] probabilities) { + final UniformRandomProvider rng = new SplitMix64(0L); + return new MarsagliaTsangWangDiscreteSampler(rng, probabilities); + } + + // Sampling tests + + /** + * Test offset samples. This test hits all code paths in the sampler for 8, 16, and 32-bit + * storage using different offsets to control the maximum sample value. + */ + @Test + public void testOffsetSamples() { + // This is filled with probabilities to hit all edge cases in the fill procedure. + // The probabilities must have a digit from each of the 5 possible. + final int[] prob = new int[6]; + prob[0] = 1; + prob[1] = 1 + 1 << 6; + prob[2] = 1 + 1 << 12; + prob[3] = 1 + 1 << 18; + prob[4] = 1 + 1 << 24; + // Ensure probabilities sum to 2^30 + prob[5] = (1 << 30) - (prob[0] + prob[1] + prob[2] + prob[3] + prob[4]); + + // To hit all samples requires integers that are under the look-up table limits. + // So compute the limits here. + int n1 = 0; + int n2 = 0; + int n3 = 0; + int n4 = 0; + for (final int m : prob) { + n1 += getBase64Digit(m, 1); + n2 += getBase64Digit(m, 2); + n3 += getBase64Digit(m, 3); + n4 += getBase64Digit(m, 4); + } + + final int t1 = n1 << 24; + final int t2 = t1 + (n2 << 18); + final int t3 = t2 + (n3 << 12); + final int t4 = t3 + (n4 << 6); + + // Create values under the limits and bit shift by 2 to reverse what the sampler does. + final int[] values = new int[] { 0, t1, t2, t3, t4, 0xffffffff }; + for (int i = 0; i < values.length; i++) { + values[i] <<= 2; + } + + final UniformRandomProvider rng1 = new FixedSequenceIntProvider(values); + final UniformRandomProvider rng2 = new FixedSequenceIntProvider(values); + final UniformRandomProvider rng3 = new FixedSequenceIntProvider(values); + + // Create offsets to force storage as 8, 16, or 32-bit + final int offset1 = 1; + final int offset2 = 1 << 8; + final int offset3 = 1 << 16; + + final MarsagliaTsangWangDiscreteSampler sampler1 = new MarsagliaTsangWangDiscreteSampler(rng1, prob, offset1); + final MarsagliaTsangWangDiscreteSampler sampler2 = new MarsagliaTsangWangDiscreteSampler(rng2, prob, offset2); + final MarsagliaTsangWangDiscreteSampler sampler3 = new MarsagliaTsangWangDiscreteSampler(rng3, prob, offset3); + + for (int i = 0; i < values.length; i++) { + // Remove offsets + final int s1 = sampler1.sample() - offset1; + final int s2 = sampler2.sample() - offset2; + final int s3 = sampler3.sample() - offset3; + Assert.assertEquals("Offset sample 1 and 2 do not match", s1, s2); + Assert.assertEquals("Offset Sample 1 and 3 do not match", s1, s3); + } + } + + /** + * Test samples from a distribution expressed using {@code double} probabilities. + */ + @Test + public void testRealProbabilityDistributionSamples() { + // These do not have to sum to 1 + final double[] probabilities = new double[11]; + final UniformRandomProvider rng = RandomSource.create(RandomSource.SPLIT_MIX_64); + for (int i = 0; i < probabilities.length; i++) { + probabilities[i] = rng.nextDouble(); + } + + // First test the table is completely filled to 2^30 + final UniformRandomProvider dummyRng = new FixedSequenceIntProvider(new int[] { 0xffffffff}); + final MarsagliaTsangWangDiscreteSampler dummySampler = new MarsagliaTsangWangDiscreteSampler(dummyRng, probabilities); + // This will throw if the table is incomplete as it hits the upper limit + dummySampler.sample(); + + // Do a test of the actual sampler + final MarsagliaTsangWangDiscreteSampler sampler = new MarsagliaTsangWangDiscreteSampler(rng, probabilities); + + final int numberOfSamples = 10000; + final long[] samples = new long[probabilities.length]; + for (int i = 0; i < numberOfSamples; i++) { + samples[sampler.sample()]++; + } + + final ChiSquareTest chiSquareTest = new ChiSquareTest(); + // Pass if we cannot reject null hypothesis that the distributions are the same. + Assert.assertFalse(chiSquareTest.chiSquareTest(probabilities, samples, 0.001)); + } + + /** + * Test the storage requirements for a worst case set of 2^8 probabilities. This tests the + * limits described in the class Javadoc is correct. + */ + @Test + public void testStorageRequirements8() { + // Max digits from 2^22: + // (2^4 + 2^6 + 2^6 + 2^6) + // Storage in bytes + // = (15 + 3 * 63) * 2^8 + // = 52224 B + // = 0.0522 MB + checkStorageRequirements(8, 0.06); + } + + /** + * Test the storage requirements for a worst case set of 2^16 probabilities. This tests the + * limits described in the class Javadoc is correct. + */ + @Test + public void testStorageRequirements16() { + // Max digits from 2^14: + // (2^2 + 2^6 + 2^6) + // Storage in bytes + // = 2 * (3 + 2 * 63) * 2^16 + // = 16908288 B + // = 16.91 MB + checkStorageRequirements(16, 17.0); + } + + /** + * Test the storage requirements for a worst case set of 2^k probabilities. This + * tests the limits described in the class Javadoc is correct. + * + * @param k Base is 2^k. + * @param expectedLimitMB the expected limit in MB + */ + private static void checkStorageRequirements(int k, double expectedLimitMB) { + // Worst case scenario is a uniform distribution of 2^k samples each with the highest + // mask set for base 64 digits. + // The max number of samples: 2^k + final int maxSamples = (1 << k); + + // The highest value for each sample: + // 2^30 / 2^k = 2^(30-k) + // The highest mask is all bits set + final int m = (1 << (30 - k)) - 1; + + // Check the sum is less than 2^30 + final long sum = (long) maxSamples * m; + final int total = 1 << 30; + Assert.assertTrue("Worst case uniform distribution is above 2^30", sum < total); + + // Get the digits as per the sampler and compute storage + final int d1 = getBase64Digit(m, 1); + final int d2 = getBase64Digit(m, 2); + final int d3 = getBase64Digit(m, 3); + final int d4 = getBase64Digit(m, 4); + final int d5 = getBase64Digit(m, 5); + // Compute storage in MB assuming 2 byte storage + int bytes; + if (k <= 8) { + bytes = 1; + } else if (k <= 16) { + bytes = 2; + } else { + bytes = 4; + } + final double storageMB = bytes * 1e-6 * (d1 + d2 + d3 + d4 + d5) * maxSamples; + Assert.assertTrue( + "Worst case uniform distribution storage " + storageMB + "MB is above expected limit: " + expectedLimitMB, + storageMB < expectedLimitMB); + } + + /** + * Gets the k<sup>th</sup> base 64 digit of {@code m}. + * + * @param m the value m. + * @param k the digit. + * @return the base 64 digit + */ + private static int getBase64Digit(int m, int k) { + return (m >>> (30 - 6 * k)) & 63; + } + + /** + * Return a fixed sequence of {@code int} output. + */ + private class FixedSequenceIntProvider extends IntProvider { + /** The count of values output. */ + private int count; + /** The values. */ + private final int[] values; + + /** + * Instantiates a new fixed sequence int provider. + * + * @param values Values. + */ + FixedSequenceIntProvider(int[] values) { + this.values = values; + } + + @Override + public int next() { + // This should not be called enough to overflow count + return values[count++ % values.length]; + } + } +} diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSamplerTest.java new file mode 100644 index 0000000..207840f --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/MarsagliaTsangWangSmallMeanPoissonSamplerTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import org.apache.commons.rng.UniformRandomProvider; +import org.junit.Test; + +/** + * Test for the {@link MarsagliaTsangWangSmallMeanPoissonSampler}. The tests hit edge + * cases for the sampler. + */ +public class MarsagliaTsangWangSmallMeanPoissonSamplerTest { + /** + * Test the constructor with a bad mean. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithMeanLargerThanUpperBound() { + final UniformRandomProvider rng = new FixedRNG(0); + final double mean = 1025; + @SuppressWarnings("unused") + final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng, + mean); + } + + /** + * Test the constructor with a bad mean. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithZeroMean() { + final UniformRandomProvider rng = new FixedRNG(0); + final double mean = 0; + @SuppressWarnings("unused") + final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng, + mean); + } + + /** + * Test the constructor with the maximum mean. + */ + @Test + public void testConstructorWithMaximumMean() { + final UniformRandomProvider rng = new FixedRNG(0); + final double mean = 1024; + @SuppressWarnings("unused") + final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng, + mean); + } + + /** + * Test the constructor with a small mean that hits the edge case where the + * probability sum is not 2^30. + */ + @Test + public void testConstructorWithSmallMean() { + final UniformRandomProvider rng = new FixedRNG(0xffffffff); + final double mean = 0.25; + final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng, + mean); + // This will throw if the table does not sum to 2^30 + sampler.sample(); + } + + /** + * Test the constructor with a medium mean that is at the switch point for how the probability + * distribution is computed. + */ + @Test + public void testConstructorWithMediumMean() { + final UniformRandomProvider rng = new FixedRNG(0xffffffff); + final double mean = 21.4; + final MarsagliaTsangWangSmallMeanPoissonSampler sampler = new MarsagliaTsangWangSmallMeanPoissonSampler(rng, + mean); + // This will throw if the table does not sum to 2^30 + sampler.sample(); + } + + /** + * A RNG returning a fixed value. + */ + private static class FixedRNG implements UniformRandomProvider { + /** The value. */ + private final int value; + + /** + * @param value the value + */ + FixedRNG(int value) { + this.value = value; + } + + @Override + public int nextInt() { + return value; + } + + public void nextBytes(byte[] bytes) {} + public void nextBytes(byte[] bytes, int start, int len) {} + public int nextInt(int n) { return 0; } + public long nextLong() { return 0; } + public long nextLong(long n) { return 0; } + public boolean nextBoolean() { return false; } + public float nextFloat() { return 0; } + public double nextDouble() { return 0; } + } +}