This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
The following commit(s) were added to refs/heads/master by this push:
new 9040f25 STATISTICS-55: TruncatedNormalDistribution sampler
9040f25 is described below
commit 9040f25a27726ff039bf251d12e2e597d2012ca8
Author: Alex Herbert <[email protected]>
AuthorDate: Mon Sep 26 23:48:45 2022 +0100
STATISTICS-55: TruncatedNormalDistribution sampler
Add a sampler using rejection sampling of a standard normal deviate.
This is used when the truncation covers a set fraction of the CDF of the
parent normal distribution.
A JMH benchmark indicates that this can be conservatively set at 20%.
---
.../distribution/TruncatedNormalDistribution.java | 60 +++++++++
.../jmh/distribution/NormalSamplerPerformance.java | 146 +++++++++++++++++++++
2 files changed, 206 insertions(+)
diff --git
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
index 8eb439c..a8e3bec 100644
---
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
+++
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
@@ -17,9 +17,12 @@
package org.apache.commons.statistics.distribution;
+import java.util.function.DoubleSupplier;
import org.apache.commons.numbers.gamma.Erf;
import org.apache.commons.numbers.gamma.ErfDifference;
import org.apache.commons.numbers.gamma.Erfcx;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
/**
* Implementation of the truncated normal distribution.
@@ -57,6 +60,16 @@ public final class TruncatedNormalDistribution extends
AbstractContinuousDistrib
/** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */
private static final double ROOT_PI_2 =
1.253314137315500251207882642405522626;
+ /**
+ * The threshold to switch to a rejection sampler. When the truncated
+ * distribution covers more than this fraction of the CDF then rejection
+ * sampling will be more efficient than inverse CDF sampling. Performance
+ * benchmarks indicate that a normalized Gaussian sampler is up to 10 times
+ * faster than inverse transform sampling using a fast random generator.
See
+ * STATISTICS-55.
+ */
+ private static final double REJECTION_THRESHOLD = 0.2;
+
/** Parent normal distribution. */
private final NormalDistribution parentNormal;
/** Lower bound of this distribution. */
@@ -217,6 +230,53 @@ public final class TruncatedNormalDistribution extends
AbstractContinuousDistrib
return clipToRange(x);
}
+ /** {@inheritDoc} */
+ @Override
+ public Sampler createSampler(UniformRandomProvider rng) {
+ // If the truncation covers a reasonable amount of the normal
distribution
+ // then a rejection sampler can be used.
+ double threshold = REJECTION_THRESHOLD;
+ // If the truncation is entirely in the upper or lower half then
adjust the
+ // threshold as twice the samples can be used
+ if (lower >= 0 || upper <= 0) {
+ threshold *= 0.5;
+ }
+
+ if (cdfDelta > threshold) {
+ // Create the rejection sampler
+ final ZigguratSampler.NormalizedGaussian sampler =
ZigguratSampler.NormalizedGaussian.of(rng);
+ DoubleSupplier gen;
+ // Use mirroring if possible
+ if (lower >= 0) {
+ // Return the upper-half of the Gaussian
+ gen = () -> Math.abs(sampler.sample());
+ } else if (upper <= 0) {
+ // Return the lower-half of the Gaussian
+ gen = () -> -Math.abs(sampler.sample());
+ } else {
+ // Return the full range of the Gaussian
+ gen = sampler::sample;
+ }
+ // Map the bounds to a standard normal distribution
+ final double u = parentNormal.getMean();
+ final double s = parentNormal.getStandardDeviation();
+ final double a = (lower - u) / s;
+ final double b = (upper - u) / s;
+ // Sample in [a, b] using rejection
+ return () -> {
+ double x = gen.getAsDouble();
+ while (x < a || x > b) {
+ x = gen.getAsDouble();
+ }
+ // Avoid floating-point error when mapping back
+ return clipToRange(u + x * s);
+ };
+ }
+
+ // Default to an inverse CDF sampler
+ return super.createSampler(rng);
+ }
+
/**
* {@inheritDoc}
*
diff --git
a/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java
b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java
new file mode 100644
index 0000000..b525265
--- /dev/null
+++
b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.statistics.examples.jmh.distribution;
+
+import java.util.concurrent.TimeUnit;
+import java.util.function.DoubleSupplier;
+import org.apache.commons.rng.UniformRandomProvider;
+import
org.apache.commons.rng.sampling.distribution.InverseTransformContinuousSampler;
+import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
+import org.apache.commons.rng.simple.RandomSource;
+import org.apache.commons.statistics.distribution.NormalDistribution;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+/**
+ * Executes a benchmark of the sampling from a normal distribution.
+ *
+ * <p>This benchmark is used to determine what sampler to use in a truncated
+ * normal distribution:
+ * <ul>
+ * <li>Rejection sampling from a normal distribution (ignore samples outside
the truncated range)
+ * <li>Inverse transform sampling
+ * </ul>
+ *
+ * <p>Rejection sampling can be used when the truncated distribution covers a
+ * reasonable proportion of the standard normal distribution and so the
+ * rejection rate is low. The speed of each method will be approximately equal
+ * when:
+ *
+ * <pre>
+ * t1 * n = t2
+ * </pre>
+ *
+ * <p>Where {@code t1} is the speed of the normal distribution sampler, {@code
n} is the
+ * number of samples required to generate a value within the truncated range
and {@code t2} is
+ * the speed of the inverse transform sampler. The crossover point occurs at
approximately:
+ *
+ * <pre>
+ * t1 / t2 = 1 / n
+ * </pre>
+ *
+ * <p>Where {@code 1 / n} is the fraction of the CDF covered by the truncated
normal distribution.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@State(Scope.Benchmark)
+@Fork(value = 1, jvmArgs = {"-server", "-Xms512M", "-Xmx512M"})
+public class NormalSamplerPerformance {
+ /** The value. Must NOT be final to prevent JVM optimisation! */
+ private double value;
+
+ /**
+ * Source of a function to compute a sample from a normal distribution.
+ */
+ @State(Scope.Benchmark)
+ public static class Source {
+ /** The method. */
+ @Param({"normal", "inverse_transform"})
+ private String method;
+ /**
+ * RNG providers.
+ *
+ * <p>Use different speeds.</p>
+ *
+ * @see <a
href="https://commons.apache.org/proper/commons-rng/userguide/rng.html">
+ * Commons RNG user guide</a>
+ */
+ @Param({"XO_RO_SHI_RO_128_PP",
+ "MWC_256",
+ "JDK"})
+ private String randomSourceName;
+
+ /** The generator to supply the next sample value. */
+ private DoubleSupplier gen;
+
+ /**
+ * @return the next value
+ */
+ public double next() {
+ return gen.getAsDouble();
+ }
+
+ /**
+ * Create the sampler for the normal distribution.
+ */
+ @Setup
+ public void setup() {
+ final UniformRandomProvider rng =
RandomSource.valueOf(randomSourceName).create();
+ if ("normal".equals(method)) {
+ gen = ZigguratSampler.NormalizedGaussian.of(rng)::sample;
+ } else if ("inverse_transform".equals(method)) {
+ final NormalDistribution dist = NormalDistribution.of(0, 1);
+ gen = InverseTransformContinuousSampler.of(rng,
dist::inverseCumulativeProbability)::sample;
+ } else {
+ throw new IllegalStateException("Unknown method: " + method);
+ }
+ }
+ }
+
+ /**
+ * Baseline for a JMH method call returning a {@code double}.
+ *
+ * @return the value
+ */
+ @Benchmark
+ public double baseline() {
+ return value;
+ }
+
+ /**
+ * Compute a sample.
+ *
+ * @param source Source of the sample.
+ * @return the value
+ */
+ @Benchmark
+ public double sample(Source source) {
+ return source.next();
+ }
+}