This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git


The following commit(s) were added to refs/heads/master by this push:
     new 9040f25  STATISTICS-55: TruncatedNormalDistribution sampler
9040f25 is described below

commit 9040f25a27726ff039bf251d12e2e597d2012ca8
Author: Alex Herbert <aherb...@apache.org>
AuthorDate: Mon Sep 26 23:48:45 2022 +0100

    STATISTICS-55: TruncatedNormalDistribution sampler
    
    Add a sampler using rejection sampling of a standard normal deviate.
    
    This is used when the truncation covers a set fraction of the CDF of the
    parent normal distribution.
    
    A JMH benchmark indicates that this can be conservatively set at 20%.
---
 .../distribution/TruncatedNormalDistribution.java  |  60 +++++++++
 .../jmh/distribution/NormalSamplerPerformance.java | 146 +++++++++++++++++++++
 2 files changed, 206 insertions(+)

diff --git 
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
 
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
index 8eb439c..a8e3bec 100644
--- 
a/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
+++ 
b/commons-statistics-distribution/src/main/java/org/apache/commons/statistics/distribution/TruncatedNormalDistribution.java
@@ -17,9 +17,12 @@
 
 package org.apache.commons.statistics.distribution;
 
+import java.util.function.DoubleSupplier;
 import org.apache.commons.numbers.gamma.Erf;
 import org.apache.commons.numbers.gamma.ErfDifference;
 import org.apache.commons.numbers.gamma.Erfcx;
+import org.apache.commons.rng.UniformRandomProvider;
+import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
 
 /**
  * Implementation of the truncated normal distribution.
@@ -57,6 +60,16 @@ public final class TruncatedNormalDistribution extends 
AbstractContinuousDistrib
     /** Normalisation constant sqrt(2 pi) / 2 = sqrt(pi / 2). */
     private static final double ROOT_PI_2 = 
1.253314137315500251207882642405522626;
 
+    /**
+     * The threshold to switch to a rejection sampler. When the truncated
+     * distribution covers more than this fraction of the CDF then rejection
+     * sampling will be more efficient than inverse CDF sampling. Performance
+     * benchmarks indicate that a normalized Gaussian sampler is up to 10 times
+     * faster than inverse transform sampling using a fast random generator. 
See
+     * STATISTICS-55.
+     */
+    private static final double REJECTION_THRESHOLD = 0.2;
+
     /** Parent normal distribution. */
     private final NormalDistribution parentNormal;
     /** Lower bound of this distribution. */
@@ -217,6 +230,53 @@ public final class TruncatedNormalDistribution extends 
AbstractContinuousDistrib
         return clipToRange(x);
     }
 
+    /** {@inheritDoc} */
+    @Override
+    public Sampler createSampler(UniformRandomProvider rng) {
+        // If the truncation covers a reasonable amount of the normal 
distribution
+        // then a rejection sampler can be used.
+        double threshold = REJECTION_THRESHOLD;
+        // If the truncation is entirely in the upper or lower half then 
adjust the
+        // threshold as twice the samples can be used
+        if (lower >= 0 || upper <= 0) {
+            threshold *= 0.5;
+        }
+
+        if (cdfDelta > threshold) {
+            // Create the rejection sampler
+            final ZigguratSampler.NormalizedGaussian sampler = 
ZigguratSampler.NormalizedGaussian.of(rng);
+            DoubleSupplier gen;
+            // Use mirroring if possible
+            if (lower >= 0) {
+                // Return the upper-half of the Gaussian
+                gen = () -> Math.abs(sampler.sample());
+            } else if (upper <= 0) {
+                // Return the lower-half of the Gaussian
+                gen = () -> -Math.abs(sampler.sample());
+            } else {
+                // Return the full range of the Gaussian
+                gen = sampler::sample;
+            }
+            // Map the bounds to a standard normal distribution
+            final double u = parentNormal.getMean();
+            final double s = parentNormal.getStandardDeviation();
+            final double a = (lower - u) / s;
+            final double b = (upper - u) / s;
+            // Sample in [a, b] using rejection
+            return () -> {
+                double x = gen.getAsDouble();
+                while (x < a || x > b) {
+                    x = gen.getAsDouble();
+                }
+                // Avoid floating-point error when mapping back
+                return clipToRange(u + x * s);
+            };
+        }
+
+        // Default to an inverse CDF sampler
+        return super.createSampler(rng);
+    }
+
     /**
      * {@inheritDoc}
      *
diff --git 
a/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java
 
b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java
new file mode 100644
index 0000000..b525265
--- /dev/null
+++ 
b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/distribution/NormalSamplerPerformance.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.statistics.examples.jmh.distribution;
+
+import java.util.concurrent.TimeUnit;
+import java.util.function.DoubleSupplier;
+import org.apache.commons.rng.UniformRandomProvider;
+import 
org.apache.commons.rng.sampling.distribution.InverseTransformContinuousSampler;
+import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
+import org.apache.commons.rng.simple.RandomSource;
+import org.apache.commons.statistics.distribution.NormalDistribution;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+/**
+ * Executes a benchmark of the sampling from a normal distribution.
+ *
+ * <p>This benchmark is used to determine what sampler to use in a truncated
+ * normal distribution:
+ * <ul>
+ *  <li>Rejection sampling from a normal distribution (ignore samples outside 
the truncated range)
+ *  <li>Inverse transform sampling
+ * </ul>
+ *
+ * <p>Rejection sampling can be used when the truncated distribution covers a
+ * reasonable proportion of the standard normal distribution and so the
+ * rejection rate is low. The speed of each method will be approximately equal
+ * when:
+ *
+ * <pre>
+ * t1 * n = t2
+ * </pre>
+ *
+ * <p>Where {@code t1} is the speed of the normal distribution sampler, {@code 
n} is the
+ * number of samples required to generate a value within the truncated range 
and {@code t2} is
+ * the speed of the inverse transform sampler. The crossover point occurs at 
approximately:
+ *
+ * <pre>
+ * t1 / t2 = 1 / n
+ * </pre>
+ *
+ * <p>Where {@code 1 / n} is the fraction of the CDF covered by the truncated 
normal distribution.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@State(Scope.Benchmark)
+@Fork(value = 1, jvmArgs = {"-server", "-Xms512M", "-Xmx512M"})
+public class NormalSamplerPerformance {
+    /** The value. Must NOT be final to prevent JVM optimisation! */
+    private double value;
+
+    /**
+     * Source of a function to compute a sample from a normal distribution.
+     */
+    @State(Scope.Benchmark)
+    public static class Source {
+        /** The method. */
+        @Param({"normal", "inverse_transform"})
+        private String method;
+        /**
+         * RNG providers.
+         *
+         * <p>Use different speeds.</p>
+         *
+         * @see <a 
href="https://commons.apache.org/proper/commons-rng/userguide/rng.html";>
+         *      Commons RNG user guide</a>
+         */
+        @Param({"XO_RO_SHI_RO_128_PP",
+                "MWC_256",
+                "JDK"})
+        private String randomSourceName;
+
+        /** The generator to supply the next sample value. */
+        private DoubleSupplier gen;
+
+        /**
+         * @return the next value
+         */
+        public double next() {
+            return gen.getAsDouble();
+        }
+
+        /**
+         * Create the sampler for the normal distribution.
+         */
+        @Setup
+        public void setup() {
+            final UniformRandomProvider rng = 
RandomSource.valueOf(randomSourceName).create();
+            if ("normal".equals(method)) {
+                gen = ZigguratSampler.NormalizedGaussian.of(rng)::sample;
+            } else if ("inverse_transform".equals(method)) {
+                final NormalDistribution dist = NormalDistribution.of(0, 1);
+                gen = InverseTransformContinuousSampler.of(rng, 
dist::inverseCumulativeProbability)::sample;
+            } else {
+                throw new IllegalStateException("Unknown method: " + method);
+            }
+        }
+    }
+
+    /**
+     * Baseline for a JMH method call returning a {@code double}.
+     *
+     * @return the value
+     */
+    @Benchmark
+    public double baseline() {
+        return value;
+    }
+
+    /**
+     * Compute a sample.
+     *
+     * @param source Source of the sample.
+     * @return the value
+     */
+    @Benchmark
+    public double sample(Source source) {
+        return source.next();
+    }
+}

Reply via email to