This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
commit ef4cee1eb58d44f134ccb74bbcfc7e17c97fb3f0 Author: Alex Herbert <aherb...@apache.org> AuthorDate: Fri Jul 9 17:13:33 2021 +0100 Added a test for sequential calls to the ziggurat sampler --- .../distribution/ZigguratSamplerPerformance.java | 251 ++++++++++++++++++++- 1 file changed, 247 insertions(+), 4 deletions(-) diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java index 7a2f8af..45e56c1 100644 --- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java +++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/distribution/ZigguratSamplerPerformance.java @@ -48,6 +48,13 @@ import java.util.concurrent.TimeUnit; @State(Scope.Benchmark) @Fork(value = 1, jvmArgs = {"-server", "-Xms128M", "-Xmx128M"}) public class ZigguratSamplerPerformance { + /** The name for the {@link ZigguratNormalizedGaussianSampler}. */ + private static final String GAUSSIAN_128 = "Gaussian128"; + /** The name for a copy of the {@link ZigguratNormalizedGaussianSampler} with a table of size 256. */ + private static final String GAUSSIAN_256 = "Gaussian256"; + /** The name for the {@link ZigguratSampler.NormalizedGaussian}. */ + private static final String MOD_GAUSSIAN = "ModGaussian"; + /** * The value. * @@ -177,7 +184,7 @@ public class ZigguratSamplerPerformance { /** * The sampler type. */ - @Param({"Gaussian128", "Gaussian256", "Exponential", "ModGaussian", "ModExponential", + @Param({GAUSSIAN_128, GAUSSIAN_256, "Exponential", MOD_GAUSSIAN, "ModExponential", "ModGaussian2", "ModExponential2"}) private String type; @@ -196,13 +203,13 @@ public class ZigguratSamplerPerformance { public void setup() { final RandomSource randomSource = RandomSource.valueOf(randomSourceName); final UniformRandomProvider rng = randomSource.create(); - if ("Gaussian128".equals(type)) { + if (GAUSSIAN_128.equals(type)) { sampler = ZigguratNormalizedGaussianSampler.of(rng); - } else if ("Gaussian256".equals(type)) { + } else if (GAUSSIAN_256.equals(type)) { sampler = new ZigguratNormalizedGaussianSampler256(rng); } else if ("Exponential".equals(type)) { sampler = new ZigguratExponentialSampler(rng); - } else if ("ModGaussian".equals(type)) { + } else if (MOD_GAUSSIAN.equals(type)) { sampler = ZigguratSampler.NormalizedGaussian.of(rng); } else if ("ModExponential".equals(type)) { sampler = ZigguratSampler.Exponential.of(rng); @@ -217,6 +224,231 @@ public class ZigguratSamplerPerformance { } /** + * The samplers to use for testing the ziggurat method with sequential sample generation. + * Defines the RandomSource and the sampler type. + * + * <p>This specifically targets the Gaussian sampler. The modified ziggurat sampler + * for the exponential distribution is always faster than the standard zigurat sampler. + * The modified ziggurat sampler is faster on single samples than the standard sampler + * but on repeat calls to generate multiple deviates the standard sampler can be faster + * depending on the JDK (modern JDKs are faster with the 'old' sampler). + */ + @State(Scope.Benchmark) + public static class SequentialSources { + /** + * RNG providers. + * + * <p>Use different speeds.</p> + * + * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html"> + * Commons RNG user guide</a> + */ + @Param({"XO_RO_SHI_RO_128_PP", + //"MWC_256", + //"JDK" + }) + private String randomSourceName; + + /** The sampler type. */ + @Param({GAUSSIAN_128, GAUSSIAN_256, MOD_GAUSSIAN}) + private String type; + + /** The size. */ + @Param({"1", "2", "3", "4", "5", "10", "20", "40"}) + private int size; + + /** The sampler. */ + private ContinuousSampler sampler; + + /** + * @return the sampler. + */ + public ContinuousSampler getSampler() { + return sampler; + } + + /** Instantiates sampler. */ + @Setup + public void setup() { + final RandomSource randomSource = RandomSource.valueOf(randomSourceName); + final UniformRandomProvider rng = randomSource.create(); + ContinuousSampler s = null; + if (GAUSSIAN_128.equals(type)) { + s = ZigguratNormalizedGaussianSampler.of(rng); + } else if (GAUSSIAN_256.equals(type)) { + s = new ZigguratNormalizedGaussianSampler256(rng); + } else if (MOD_GAUSSIAN.equals(type)) { + s = ZigguratSampler.NormalizedGaussian.of(rng); + } else { + throwIllegalStateException(type); + } + sampler = createSampler(size, s); + } + + /** + * Creates the sampler for the specified number of samples. + * + * @param size the size + * @param s the sampler to create the samples + * @return the sampler + */ + private static ContinuousSampler createSampler(int size, ContinuousSampler s) { + // Create size samples + switch (size) { + case 1: + return new Size1Sampler(s); + case 2: + return new Size2Sampler(s); + case 3: + return new Size3Sampler(s); + case 4: + return new Size4Sampler(s); + case 5: + return new Size5Sampler(s); + default: + return new SizeNSampler(s, size); + } + } + + /** + * Create a specified number of samples from an underlying sampler. + */ + abstract static class SizeSampler implements ContinuousSampler { + /** The sampler. */ + protected ContinuousSampler delegate; + + /** + * @param delegate the sampler to create the samples + */ + SizeSampler(ContinuousSampler delegate) { + this.delegate = delegate; + } + } + + /** + * Create 1 sample from the sampler. + */ + static class Size1Sampler extends SizeSampler { + /** + * @param delegate the sampler to create the samples + */ + Size1Sampler(ContinuousSampler delegate) { + super(delegate); + } + + @Override + public double sample() { + return delegate.sample(); + } + } + + /** + * Create 2 samples from the sampler. + */ + static class Size2Sampler extends SizeSampler { + /** + * @param delegate the sampler to create the samples + */ + Size2Sampler(ContinuousSampler delegate) { + super(delegate); + } + + @Override + public double sample() { + delegate.sample(); + return delegate.sample(); + } + } + + /** + * Create 3 samples from the sampler. + */ + static class Size3Sampler extends SizeSampler { + /** + * @param delegate the sampler to create the samples + */ + Size3Sampler(ContinuousSampler delegate) { + super(delegate); + } + + @Override + public double sample() { + delegate.sample(); + delegate.sample(); + return delegate.sample(); + } + } + + /** + * Create 4 samples from the sampler. + */ + static class Size4Sampler extends SizeSampler { + /** + * @param delegate the sampler to create the samples + */ + Size4Sampler(ContinuousSampler delegate) { + super(delegate); + } + + @Override + public double sample() { + delegate.sample(); + delegate.sample(); + delegate.sample(); + return delegate.sample(); + } + } + + /** + * Create 5 samples from the sampler. + */ + static class Size5Sampler extends SizeSampler { + /** + * @param delegate the sampler to create the samples + */ + Size5Sampler(ContinuousSampler delegate) { + super(delegate); + } + + @Override + public double sample() { + delegate.sample(); + delegate.sample(); + delegate.sample(); + return delegate.sample(); + } + } + + /** + * Create N samples from the sampler. + */ + static class SizeNSampler extends SizeSampler { + /** The number of samples. */ + private final int size; + + /** + * @param delegate the sampler to create the samples + * @param size the size + */ + SizeNSampler(ContinuousSampler delegate, int size) { + super(delegate); + if (size < 1) { + throw new IllegalArgumentException("Size must be above zero: " + size); + } + this.size = size; + } + + @Override + public double sample() { + for (int i = size - 1; i != 0; i--) { + delegate.sample(); + } + return delegate.sample(); + } + } + } + + /** * <a href="https://en.wikipedia.org/wiki/Ziggurat_algorithm"> * Marsaglia and Tsang "Ziggurat" method</a> for sampling from a NormalizedGaussian * distribution with mean 0 and standard deviation 1. @@ -1255,4 +1487,15 @@ public class ZigguratSamplerPerformance { public double sample(Sources sources) { return sources.getSampler().sample(); } + + /** + * Run the sampler to generate a number of samples sequentially. + * + * @param sources Source of randomness. + * @return the sample value + */ + @Benchmark + public double sequentialSample(SequentialSources sources) { + return sources.getSampler().sample(); + } }