This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit 855915bb0f5a2cb23158441869928ad7cd49a165 Author: aherbert <aherb...@apache.org> AuthorDate: Tue Aug 6 14:56:27 2019 +0100 RNG-109: Benchmark enumerated probability distributed samplers. --- .../EnumeratedDistributionSamplersPerformance.java | 555 +++++++++++++++++++++ 1 file changed, 555 insertions(+) diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/EnumeratedDistributionSamplersPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/EnumeratedDistributionSamplersPerformance.java new file mode 100644 index 0000000..fc28515 --- /dev/null +++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/distribution/EnumeratedDistributionSamplersPerformance.java @@ -0,0 +1,555 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.rng.examples.jmh.distribution; + +import org.apache.commons.math3.distribution.BinomialDistribution; +import org.apache.commons.math3.distribution.IntegerDistribution; +import org.apache.commons.math3.distribution.PoissonDistribution; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.examples.jmh.RandomSources; +import org.apache.commons.rng.sampling.distribution.AliasMethodDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.DiscreteSampler; +import org.apache.commons.rng.sampling.distribution.GuideTableDiscreteSampler; +import org.apache.commons.rng.sampling.distribution.MarsagliaTsangWangDiscreteSampler; +import org.apache.commons.rng.simple.RandomSource; + +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; + +/** + * Executes benchmark to compare the speed of generation of random numbers from an enumerated + * discrete probability distribution. + */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@State(Scope.Benchmark) +@Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"}) +public class EnumeratedDistributionSamplersPerformance { + /** + * The {@link DiscreteSampler} samplers to use for testing. Creates the sampler for each + * {@link RandomSource} in the default {@link RandomSources}. + * + * <p>This class is abstract. The probability distribution is created by implementations.</p> + */ + @State(Scope.Benchmark) + public abstract static class SamplerSources { + /** + * A factory for creating DiscreteSampler objects. + */ + interface DiscreteSamplerFactory { + /** + * Creates the sampler. + * + * @return the sampler + */ + DiscreteSampler create(); + } + + /** + * RNG providers. + * + * <p>Use different speeds.</p> + * + * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html"> + * Commons RNG user guide</a> + */ + @Param({ + //"WELL_44497_B", + //"ISAAC", + "XO_RO_SHI_RO_128_PLUS", + }) + private String randomSourceName; + + /** + * The sampler type. + */ + @Param({"BinarySearchDiscreteSampler", + "AliasMethodDiscreteSampler", + "GuideTableDiscreteSampler", + "MarsagliaTsangWangDiscreteSampler", + + // Uncomment to test non-default parameters + //"AliasMethodDiscreteSamplerNoPad", // Not optimal for sampling + //"AliasMethodDiscreteSamplerAlpha1", + //"AliasMethodDiscreteSamplerAlpha2", + + // The AliasMethod memory requirement doubles for each alpha increment. + // A fair comparison is to use 2^alpha for the equivalent guide table method. + //"GuideTableDiscreteSamplerAlpha2", + //"GuideTableDiscreteSamplerAlpha4", + }) + private String samplerType; + + /** RNG. */ + private UniformRandomProvider generator; + + /** The factory. */ + private DiscreteSamplerFactory factory; + + /** The sampler. */ + private DiscreteSampler sampler; + + /** + * @return the RNG. + */ + public UniformRandomProvider getGenerator() { + return generator; + } + + /** + * Gets the sampler. + * + * @return the sampler. + */ + public DiscreteSampler getSampler() { + return sampler; + } + + /** Create the distribution (per iteration as it may vary) and instantiates sampler. */ + @Setup(Level.Iteration) + public void setup() { + final RandomSource randomSource = RandomSource.valueOf(randomSourceName); + generator = RandomSource.create(randomSource); + + final double[] probabilities = createProbabilities(); + createSamplerFactory(generator, probabilities); + sampler = factory.create(); + } + + /** + * Creates the probabilities for the distribution. + * + * @return The probabilities. + */ + protected abstract double[] createProbabilities(); + + /** + * Creates the sampler factory. + * + * @param rng The random generator. + * @param probabilities The probabilities. + */ + private void createSamplerFactory(final UniformRandomProvider rng, + final double[] probabilities) { + // This would benefit from Java 8 lambda functions + if ("BinarySearchDiscreteSampler".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return new BinarySearchDiscreteSampler(rng, probabilities); + } + }; + } else if ("AliasMethodDiscreteSampler".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return AliasMethodDiscreteSampler.of(rng, probabilities); + } + }; + } else if ("AliasMethodDiscreteSamplerNoPad".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return AliasMethodDiscreteSampler.of(rng, probabilities, -1); + } + }; + } else if ("AliasMethodDiscreteSamplerAlpha1".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return AliasMethodDiscreteSampler.of(rng, probabilities, 1); + } + }; + } else if ("AliasMethodDiscreteSamplerAlpha2".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return AliasMethodDiscreteSampler.of(rng, probabilities, 2); + } + }; + } else if ("GuideTableDiscreteSampler".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return GuideTableDiscreteSampler.of(rng, probabilities); + } + }; + } else if ("GuideTableDiscreteSamplerAlpha2".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return GuideTableDiscreteSampler.of(rng, probabilities, 2); + } + }; + } else if ("GuideTableDiscreteSamplerAlpha8".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return GuideTableDiscreteSampler.of(rng, probabilities, 8); + } + }; + } else if ("MarsagliaTsangWangDiscreteSampler".equals(samplerType)) { + factory = new DiscreteSamplerFactory() { + @Override + public DiscreteSampler create() { + return MarsagliaTsangWangDiscreteSampler.Enumerated.of(rng, probabilities); + } + }; + } else { + throw new IllegalStateException(); + } + } + + /** + * Creates a new instance of the sampler. + * + * @return The sampler. + */ + public DiscreteSampler createSampler() { + return factory.create(); + } + } + + /** + * Define known probability distributions for testing. These are expected to have well + * behaved cumulative probability functions. + */ + @State(Scope.Benchmark) + public static class KnownDistributionSources extends SamplerSources { + /** The cumulative probability limit for unbounded distributions. */ + private static final double CUMULATIVE_PROBABILITY_LIMIT = 1 - 1e-9; + + /** + * The distribution. + */ + @Param({"Binomial_N67_P0.7", + "Geometric_P0.2", + "4SidedLoadedDie", + "Poisson_Mean3.14", + "Poisson_Mean10_Mean20", + }) + private String distribution; + + /** {@inheritDoc} */ + @Override + protected double[] createProbabilities() { + if ("Binomial_N67_P0.7".equals(distribution)) { + final int trials = 67; + final double probabilityOfSuccess = 0.7; + final BinomialDistribution dist = new BinomialDistribution(null, trials, probabilityOfSuccess); + return createProbabilities(dist, 0, trials); + } else if ("Geometric_P0.2".equals(distribution)) { + final double probabilityOfSuccess = 0.2; + final double probabilityOfFailure = 1 - probabilityOfSuccess; + // https://en.wikipedia.org/wiki/Geometric_distribution + // PMF = (1-p)^k * p + // k is number of failures before a success + double p = 1.0; // (1-p)^0 + // Build until the cumulative function is big + double[] probabilities = new double[100]; + double sum = 0; + int k = 0; + while (k < probabilities.length) { + probabilities[k] = p * probabilityOfSuccess; + sum += probabilities[k++]; + if (sum > CUMULATIVE_PROBABILITY_LIMIT) { + break; + } + // For the next PMF + p *= probabilityOfFailure; + } + return Arrays.copyOf(probabilities, k); + } else if ("4SidedLoadedDie".equals(distribution)) { + return new double[] {1.0 / 2, 1.0 / 3, 1.0 / 12, 1.0 / 12}; + } else if ("Poisson_Mean3.14".equals(distribution)) { + final double mean = 3.14; + final IntegerDistribution dist = createPoissonDistribution(mean); + final int max = dist.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT); + return createProbabilities(dist, 0, max); + } else if ("Poisson_Mean10_Mean20".equals(distribution)) { + // Create a Bimodel using two Poisson distributions + final double mean1 = 10; + final double mean2 = 20; + final IntegerDistribution dist1 = createPoissonDistribution(mean2); + final int max = dist1.inverseCumulativeProbability(CUMULATIVE_PROBABILITY_LIMIT); + double[] p1 = createProbabilities(dist1, 0, max); + double[] p2 = createProbabilities(createPoissonDistribution(mean1), 0, max); + for (int i = 0; i < p1.length; i++) { + p1[i] += p2[i]; + } + // Leave to the distribution to normalise the sum + return p1; + } + throw new IllegalStateException(); + } + + /** + * Creates the poisson distribution. + * + * @param mean the mean + * @return the distribution + */ + private static IntegerDistribution createPoissonDistribution(double mean) { + return new PoissonDistribution(null, mean, + PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS); + } + + /** + * Creates the probabilities from the distribution. + * + * @param dist the distribution + * @param lower the lower bounds (inclusive) + * @param upper the upper bounds (inclusive) + * @return the probabilities + */ + private static double[] createProbabilities(IntegerDistribution dist, int lower, int upper) { + double[] probabilities = new double[upper - lower + 1]; + for (int i = 0, x = lower; x <= upper; i++, x++) { + probabilities[i] = dist.probability(x); + } + return probabilities; + } + } + + /** + * Define random probability distributions of known size for testing. These are random but + * the average cumulative probability function will be straight line given the increment + * average is 0.5. + */ + @State(Scope.Benchmark) + public static class RandomDistributionSources extends SamplerSources { + /** + * The distribution size. + * These are spaced half-way between powers-of-2 to minimise the advantage of + * padding by the Alias method sampler. + */ + @Param({"6", + //"12", + //"24", + //"48", + "96", + //"192", + //"384", + // Above 2048 forces the Alias method to use more than 64-bits for sampling + "3072" + }) + private int randomNonUniformSize; + + /** {@inheritDoc} */ + @Override + protected double[] createProbabilities() { + final double[] probabilities = new double[randomNonUniformSize]; + final ThreadLocalRandom rng = ThreadLocalRandom.current(); + for (int i = 0; i < probabilities.length; i++) { + probabilities[i] = rng.nextDouble(); + } + return probabilities; + } + } + + /** + * Compute a sample by binary search of the cumulative probability distribution.. + */ + static final class BinarySearchDiscreteSampler + implements DiscreteSampler { + /** Underlying source of randomness. */ + private final UniformRandomProvider rng; + /** + * The cumulative probability table. + */ + private final double[] cumulativeProbabilities; + + /** + * @param rng Generator of uniformly distributed random numbers. + * @param probabilities The probabilities. + * @throws IllegalArgumentException if {@code probabilities} is null or empty, a + * probability is negative, infinite or {@code NaN}, or the sum of all + * probabilities is not strictly positive. + */ + BinarySearchDiscreteSampler(UniformRandomProvider rng, + double[] probabilities) { + // Minimal set-up validation + if (probabilities == null || probabilities.length == 0) { + throw new IllegalArgumentException("Probabilities must not be empty."); + } + + final int size = probabilities.length; + cumulativeProbabilities = new double[size]; + + double sumProb = 0; + int count = 0; + for (final double prob : probabilities) { + if (prob < 0 || + Double.isInfinite(prob) || + Double.isNaN(prob)) { + throw new IllegalArgumentException("Invalid probability: " + + prob); + } + + // Compute and store cumulative probability. + sumProb += prob; + cumulativeProbabilities[count++] = sumProb; + } + + if (Double.isInfinite(sumProb) || sumProb <= 0) { + throw new IllegalArgumentException("Invalid sum of probabilities: " + sumProb); + } + + this.rng = rng; + + // Normalise cumulative probability. + for (int i = 0; i < size; i++) { + final double norm = cumulativeProbabilities[i] / sumProb; + cumulativeProbabilities[i] = (norm < 1) ? norm : 1.0; + } + } + + /** {@inheritDoc} */ + @Override + public int sample() { + final double u = rng.nextDouble(); + + // Java binary search + //int index = Arrays.binarySearch(cumulativeProbabilities, u); + //if (index < 0) { + // index = -index - 1; + //} + // + //return index < cumulativeProbabilities.length ? + // index : + // cumulativeProbabilities.length - 1; + + // Binary search within known cumulative probability table. + // Find x so that u > f[x-1] and u <= f[x]. + // This is a looser search than Arrays.binarySearch: + // - The output is x = upper. + // - The table stores probabilities where f[0] is >= 0 and the max == 1.0. + // - u should be >= 0 and <= 1 (or the random generator is broken). + // - It avoids comparisons using Double.doubleToLongBits. + // - It avoids the low likelihood of equality between two doubles for fast exit + // so uses only 1 compare per loop. + int lower = 0; + int upper = cumulativeProbabilities.length - 1; + while (lower < upper) { + final int mid = (lower + upper) >>> 1; + final double midVal = cumulativeProbabilities[mid]; + if (u > midVal) { + // Change lower such that + // u > f[lower - 1] + lower = mid + 1; + } else { + // Change upper such that + // u <= f[upper] + upper = mid; + } + } + return upper; + } + } + + /** + * The value for the baseline generation of an {@code int} value. + * + * <p>This must NOT be final!</p> + */ + private int value; + + // Benchmarks methods below. + + /** + * Baseline for the JMH timing overhead for production of an {@code int} value. + * + * @return the {@code int} value + */ + @Benchmark + public int baselineInt() { + return value; + } + + /** + * Baseline for the production of a {@code double} value. + * This is used to assess the performance of the underlying random source. + * + * @param sources Source of randomness. + * @return the {@code int} value + */ + @Benchmark + public int baselineNextDouble(SamplerSources sources) { + return sources.getGenerator().nextDouble() < 0.5 ? 1 : 0; + } + + /** + * Run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int sampleKnown(KnownDistributionSources sources) { + return sources.getSampler().sample(); + } + + /** + * Run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int singleSampleKnown(KnownDistributionSources sources) { + return sources.createSampler().sample(); + } + + /** + * Run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int sampleRandom(RandomDistributionSources sources) { + return sources.getSampler().sample(); + } + + /** + * Run the sampler. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public int singleSampleRandom(RandomDistributionSources sources) { + return sources.createSampler().sample(); + } +}