This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 341af54ecbef6988b28be8ca4ec1db246362f4d5 Author: aherbert <[email protected]> AuthorDate: Tue Aug 6 14:54:16 2019 +0100 RNG-109: Delegate sampling in probability collection sampler. --- .../DiscreteProbabilityCollectionSampler.java | 119 ++++++--------------- .../DiscreteProbabilityCollectionSamplerTest.java | 52 ++++++--- 2 files changed, 70 insertions(+), 101 deletions(-) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java index 4bc50b4..69b45bc 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSampler.java @@ -19,11 +19,11 @@ package org.apache.commons.rng.sampling; import java.util.List; import java.util.Map; -import java.util.HashMap; import java.util.ArrayList; -import java.util.Arrays; import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.SharedStateDiscreteSampler; /** * Sampling from a collection of items with user-defined @@ -40,12 +40,12 @@ import org.apache.commons.rng.UniformRandomProvider; */ public class DiscreteProbabilityCollectionSampler<T> implements SharedStateSampler<DiscreteProbabilityCollectionSampler<T>> { + /** The error message for an empty collection. */ + private static final String EMPTY_COLLECTION = "Empty collection"; /** Collection to be sampled from. */ private final List<T> items; - /** RNG. */ - private final UniformRandomProvider rng; - /** Cumulative probabilities. */ - private final double[] cumulativeProbabilities; + /** Sampler for the probabilities. */ + private final SharedStateDiscreteSampler sampler; /** * Creates a sampler. @@ -64,43 +64,22 @@ public class DiscreteProbabilityCollectionSampler<T> public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, Map<T, Double> collection) { if (collection.isEmpty()) { - throw new IllegalArgumentException("Empty collection"); + throw new IllegalArgumentException(EMPTY_COLLECTION); } - this.rng = rng; + // Extract the items and probabilities final int size = collection.size(); items = new ArrayList<T>(size); - cumulativeProbabilities = new double[size]; + final double[] probabilities = new double[size]; - double sumProb = 0; int count = 0; for (final Map.Entry<T, Double> e : collection.entrySet()) { items.add(e.getKey()); - - final double prob = e.getValue(); - if (prob < 0 || - Double.isInfinite(prob) || - Double.isNaN(prob)) { - throw new IllegalArgumentException("Invalid probability: " + - prob); - } - - // Temporarily store probability. - cumulativeProbabilities[count++] = prob; - sumProb += prob; + probabilities[count++] = e.getValue(); } - if (sumProb <= 0) { - throw new IllegalArgumentException("Invalid sum of probabilities"); - } - - // Compute and store cumulative probability. - for (int i = 0; i < size; i++) { - cumulativeProbabilities[i] /= sumProb; - if (i > 0) { - cumulativeProbabilities[i] += cumulativeProbabilities[i - 1]; - } - } + // Delegate sampling + sampler = createSampler(rng, probabilities); } /** @@ -122,7 +101,19 @@ public class DiscreteProbabilityCollectionSampler<T> public DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, List<T> collection, double[] probabilities) { - this(rng, consolidate(collection, probabilities)); + if (collection.isEmpty()) { + throw new IllegalArgumentException(EMPTY_COLLECTION); + } + final int len = probabilities.length; + if (len != collection.size()) { + throw new IllegalArgumentException("Size mismatch: " + + len + " != " + + collection.size()); + } + // Shallow copy the list + items = new ArrayList<T>(collection); + // Delegate sampling + sampler = createSampler(rng, probabilities); } /** @@ -131,9 +122,8 @@ public class DiscreteProbabilityCollectionSampler<T> */ private DiscreteProbabilityCollectionSampler(UniformRandomProvider rng, DiscreteProbabilityCollectionSampler<T> source) { - this.rng = rng; this.items = source.items; - this.cumulativeProbabilities = source.cumulativeProbabilities; + this.sampler = source.sampler.withUniformRandomProvider(rng); } /** @@ -142,22 +132,7 @@ public class DiscreteProbabilityCollectionSampler<T> * @return a random sample. */ public T sample() { - final double rand = rng.nextDouble(); - - int index = Arrays.binarySearch(cumulativeProbabilities, rand); - if (index < 0) { - index = -index - 1; - } - - if (index < cumulativeProbabilities.length && - rand < cumulativeProbabilities[index]) { - return items.get(index); - } - - // This should never happen, but it ensures we will return a correct - // object in case there is some floating point inequality problem - // wrt the cumulative probabilities. - return items.get(items.size() - 1); + return items.get(sampler.sample()); } /** {@inheritDoc} */ @@ -167,38 +142,14 @@ public class DiscreteProbabilityCollectionSampler<T> } /** - * @param collection Collection to be sampled. - * @param probabilities Probability associated to each item of the - * {@code collection}. - * @return a consolidated map (where probabilities of equal items - * have been summed). - * @throws IllegalArgumentException if the number of items in the - * {@code collection} is not equal to the number of provided - * {@code probabilities}. - * @param <T> Type of items in the collection. + * Creates the sampler of the enumerated probability distribution. + * + * @param rng Generator of uniformly distributed random numbers. + * @param probabilities Probability associated to each item. + * @return the sampler */ - private static <T> Map<T, Double> consolidate(List<T> collection, - double[] probabilities) { - final int len = probabilities.length; - if (len != collection.size()) { - throw new IllegalArgumentException("Size mismatch: " + - len + " != " + - collection.size()); - } - - final Map<T, Double> map = new HashMap<T, Double>(); - for (int i = 0; i < len; i++) { - final T item = collection.get(i); - final Double prob = probabilities[i]; - - Double currentProb = map.get(item); - if (currentProb == null) { - currentProb = 0d; - } - - map.put(item, currentProb + prob); - } - - return map; + private static SharedStateDiscreteSampler createSampler(UniformRandomProvider rng, + double[] probabilities) { + return GuideTableDiscreteSampler.of(rng, probabilities); } } diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java index 78a4391..e5ffc6a 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/DiscreteProbabilityCollectionSamplerTest.java @@ -18,8 +18,11 @@ package org.apache.commons.rng.sampling; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.TreeMap; import org.junit.Assert; import org.junit.Test; @@ -74,6 +77,13 @@ public class DiscreteProbabilityCollectionSamplerTest { new DiscreteProbabilityCollectionSampler<Double>(rng, new HashMap<Double, Double>()); } + @Test(expected = IllegalArgumentException.class) + public void testPrecondition7() { + // Empty List<T> not allowed + new DiscreteProbabilityCollectionSampler<Double>(rng, + Collections.<Double>emptyList(), + new double[0]); + } @Test public void testSample() { @@ -99,30 +109,38 @@ public class DiscreteProbabilityCollectionSamplerTest { Assert.assertEquals(expectedVariance, variance, 2e-3); } - /** - * Edge-case test: - * Create a sampler that will return 1 for nextDouble() forcing the binary search to - * identify the end item of the cumulative probability array. - */ + @Test - public void testSampleWithProbabilityAtLastItem() { - sampleWithProbabilityForLastItem(false); + public void testSampleUsingMap() { + final UniformRandomProvider rng1 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final UniformRandomProvider rng2 = RandomSource.create(RandomSource.SPLIT_MIX_64, 0L); + final List<Integer> items = Arrays.asList(1, 3, 4, 6, 9); + final double[] probabilities = {0.1, 0.2, 0.3, 0.4, 0.5}; + final DiscreteProbabilityCollectionSampler<Integer> sampler1 = + new DiscreteProbabilityCollectionSampler<Integer>(rng1, items, probabilities); + + // Create a map version. The map iterator must be ordered so use a TreeMap. + final Map<Integer, Double> map = new TreeMap<Integer, Double>(); + for (int i = 0; i < probabilities.length; i++) { + map.put(items.get(i), probabilities[i]); + } + final DiscreteProbabilityCollectionSampler<Integer> sampler2 = + new DiscreteProbabilityCollectionSampler<Integer>(rng2, map); + + for (int i = 0; i < 50; i++) { + Assert.assertEquals(sampler1.sample(), sampler2.sample()); + } } /** * Edge-case test: - * Create a sampler that will return over 1 for nextDouble() forcing the binary search to - * identify insertion at the end of the cumulative probability array. + * Create a sampler that will return 1 for nextDouble() forcing the search to + * identify the end item of the cumulative probability array. */ @Test - public void testSampleWithProbabilityPastLastItem() { - sampleWithProbabilityForLastItem(true); - } - - private static void sampleWithProbabilityForLastItem(boolean pastLast) { + public void testSampleWithProbabilityAtLastItem() { // Ensure the samples pick probability 0 (the first item) and then // a probability (for the second item) that hits an edge case. - final double probability = pastLast ? 1.1 : 1; final UniformRandomProvider dummyRng = new UniformRandomProvider() { private int count; // CHECKSTYLE: stop all @@ -132,7 +150,7 @@ public class DiscreteProbabilityCollectionSamplerTest { public int nextInt() { return 0; } public float nextFloat() { return 0; } // Return 0 then the given probability - public double nextDouble() { return (count++ == 0) ? 0 : probability; } + public double nextDouble() { return (count++ == 0) ? 0 : 1.0; } public void nextBytes(byte[] bytes, int start, int len) {} public void nextBytes(byte[] bytes) {} public boolean nextBoolean() { return false; } @@ -164,7 +182,7 @@ public class DiscreteProbabilityCollectionSamplerTest { final DiscreteProbabilityCollectionSampler<Double> sampler1 = new DiscreteProbabilityCollectionSampler<Double>(rng1, items, - new double[] {0.1, 0.2, 0.3, 04}); + new double[] {0.1, 0.2, 0.3, 0.4}); final DiscreteProbabilityCollectionSampler<Double> sampler2 = sampler1.withUniformRandomProvider(rng2); RandomAssert.assertProduceSameSequence( new RandomAssert.Sampler<Double>() {
