This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
The following commit(s) were added to refs/heads/master by this push: new 9040f25 STATISTICS-55: TruncatedNormalDistribution sampler 9040f25 is described below commit 9040f25a27726ff039bf251d12e2e597d2012ca8 Author: Alex Herbert <aherb...@apache.org> AuthorDate: Mon Sep 26 23:48:45 2022 +0100 STATISTICS-55: TruncatedNormalDistribution sampler Add a sampler using rejection sampling of a standard normal deviate. This is used when the truncation covers a set fraction of the CDF of the parent normal distribution. A JMH benchmark indicates that this can be conservatively set at 20%. --- .../distribution/TruncatedNormalDistribution.java | 60 +++++++++ .../jmh/distribution/NormalSamplerPerformance.java | 146 +++++++++++++++++++++ 2 files changed, 206 insertions(+) diff --git a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java index 8eb439c..a8e3bec 100644 --- a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java +++ b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java @@ -17,9 +17,12 @@ package org.apache.commons.statistics.distribution; +import java.util.function.DoubleSupplier; import org.apache.commons.numbers.gamma.Erf; import org.apache.commons.numbers.gamma.ErfDifference; import org.apache.commons.numbers.gamma.Erfcx; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.distribution.ZigguratSampler; /** * Implementation of the truncated normal distribution. @@ -57,6 +60,16 @@ public final class TruncatedNormalDistribution extends AbstractContinuousDistrib /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */ private static final double ROOT_PI_2 = 1.253314137315500251207882642405522626; + /** + * The threshold to switch to a rejection sampler. When the truncated + * distribution covers more than this fraction of the CDF then rejection + * sampling will be more efficient than inverse CDF sampling. Performance + * benchmarks indicate that a normalized Gaussian sampler is up to 10 times + * faster than inverse transform sampling using a fast random generator. See + * STATISTICS-55. + */ + private static final double REJECTION_THRESHOLD = 0.2; + /** Parent normal distribution. */ private final NormalDistribution parentNormal; /** Lower bound of this distribution. */ @@ -217,6 +230,53 @@ public final class TruncatedNormalDistribution extends AbstractContinuousDistrib return clipToRange(x); } + /** {@inheritDoc} */ + @Override + public Sampler createSampler(UniformRandomProvider rng) { + // If the truncation covers a reasonable amount of the normal distribution + // then a rejection sampler can be used. + double threshold = REJECTION_THRESHOLD; + // If the truncation is entirely in the upper or lower half then adjust the + // threshold as twice the samples can be used + if (lower >= 0 || upper <= 0) { + threshold *= 0.5; + } + + if (cdfDelta > threshold) { + // Create the rejection sampler + final ZigguratSampler.NormalizedGaussian sampler = ZigguratSampler.NormalizedGaussian.of(rng); + DoubleSupplier gen; + // Use mirroring if possible + if (lower >= 0) { + // Return the upper-half of the Gaussian + gen = () -> Math.abs(sampler.sample()); + } else if (upper <= 0) { + // Return the lower-half of the Gaussian + gen = () -> -Math.abs(sampler.sample()); + } else { + // Return the full range of the Gaussian + gen = sampler::sample; + } + // Map the bounds to a standard normal distribution + final double u = parentNormal.getMean(); + final double s = parentNormal.getStandardDeviation(); + final double a = (lower - u) / s; + final double b = (upper - u) / s; + // Sample in [a, b] using rejection + return () -> { + double x = gen.getAsDouble(); + while (x < a || x > b) { + x = gen.getAsDouble(); + } + // Avoid floating-point error when mapping back + return clipToRange(u + x * s); + }; + } + + // Default to an inverse CDF sampler + return super.createSampler(rng); + } + /** * {@inheritDoc} * diff --git a/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java new file mode 100644 index 0000000..b525265 --- /dev/null +++ b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.commons.statistics.examples.jmh.distribution; + +import java.util.concurrent.TimeUnit; +import java.util.function.DoubleSupplier; +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.distribution.InverseTransformContinuousSampler; +import org.apache.commons.rng.sampling.distribution.ZigguratSampler; +import org.apache.commons.rng.simple.RandomSource; +import org.apache.commons.statistics.distribution.NormalDistribution; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; + +/** + * Executes a benchmark of the sampling from a normal distribution. + * + * <p>This benchmark is used to determine what sampler to use in a truncated + * normal distribution: + * <ul> + * <li>Rejection sampling from a normal distribution (ignore samples outside the truncated range) + * <li>Inverse transform sampling + * </ul> + * + * <p>Rejection sampling can be used when the truncated distribution covers a + * reasonable proportion of the standard normal distribution and so the + * rejection rate is low. The speed of each method will be approximately equal + * when: + * + * <pre> + * t1 * n = t2 + * </pre> + * + * <p>Where {@code t1} is the speed of the normal distribution sampler, {@code n} is the + * number of samples required to generate a value within the truncated range and {@code t2} is + * the speed of the inverse transform sampler. The crossover point occurs at approximately: + * + * <pre> + * t1 / t2 = 1 / n + * </pre> + * + * <p>Where {@code 1 / n} is the fraction of the CDF covered by the truncated normal distribution. + */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@State(Scope.Benchmark) +@Fork(value = 1, jvmArgs = {"-server", "-Xms512M", "-Xmx512M"}) +public class NormalSamplerPerformance { + /** The value. Must NOT be final to prevent JVM optimisation! */ + private double value; + + /** + * Source of a function to compute a sample from a normal distribution. + */ + @State(Scope.Benchmark) + public static class Source { + /** The method. */ + @Param({"normal", "inverse_transform"}) + private String method; + /** + * RNG providers. + * + * <p>Use different speeds.</p> + * + * @see <a href="https://commons.apache.org/proper/commons-rng/userguide/rng.html"> + * Commons RNG user guide</a> + */ + @Param({"XO_RO_SHI_RO_128_PP", + "MWC_256", + "JDK"}) + private String randomSourceName; + + /** The generator to supply the next sample value. */ + private DoubleSupplier gen; + + /** + * @return the next value + */ + public double next() { + return gen.getAsDouble(); + } + + /** + * Create the sampler for the normal distribution. + */ + @Setup + public void setup() { + final UniformRandomProvider rng = RandomSource.valueOf(randomSourceName).create(); + if ("normal".equals(method)) { + gen = ZigguratSampler.NormalizedGaussian.of(rng)::sample; + } else if ("inverse_transform".equals(method)) { + final NormalDistribution dist = NormalDistribution.of(0, 1); + gen = InverseTransformContinuousSampler.of(rng, dist::inverseCumulativeProbability)::sample; + } else { + throw new IllegalStateException("Unknown method: " + method); + } + } + } + + /** + * Baseline for a JMH method call returning a {@code double}. + * + * @return the value + */ + @Benchmark + public double baseline() { + return value; + } + + /** + * Compute a sample. + * + * @param source Source of the sample. + * @return the value + */ + @Benchmark + public double sample(Source source) { + return source.next(); + } +}