This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
The following commit(s) were added to refs/heads/master by this push: new 3cb2c9b RNG-153: Update the sampling algorithm in the UnitBallSampler 3cb2c9b is described below commit 3cb2c9b718a0878a6fe7572f560f697dffeba758 Author: aherbert <aherb...@apache.org> AuthorDate: Thu Jul 8 13:07:37 2021 +0100 RNG-153: Update the sampling algorithm in the UnitBallSampler --- .../sampling/shape/UnitBallSamplerBenchmark.java | 21 +++++++------ .../rng/sampling/shape/UnitBallSampler.java | 36 ++++++++++++---------- src/changes/changes.xml | 4 +++ 3 files changed, 36 insertions(+), 25 deletions(-) diff --git a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java index 142d75d..f4523e0 100644 --- a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java +++ b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java @@ -18,9 +18,10 @@ package org.apache.commons.rng.examples.jmh.sampling.shape; import org.apache.commons.rng.UniformRandomProvider; -import org.apache.commons.rng.sampling.distribution.AhrensDieterExponentialSampler; +import org.apache.commons.rng.sampling.distribution.ContinuousSampler; import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler; import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler; +import org.apache.commons.rng.sampling.distribution.ZigguratSampler; import org.apache.commons.rng.simple.RandomSource; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; @@ -426,7 +427,7 @@ public class UnitBallSamplerBenchmark { /** The normal distribution. */ private final NormalizedGaussianSampler normal; /** The exponential distribution. */ - private final AhrensDieterExponentialSampler exp; + private final ContinuousSampler exp; /** * @param rng the source of randomness @@ -436,7 +437,8 @@ public class UnitBallSamplerBenchmark { normal = new ZigguratNormalizedGaussianSampler(rng); // Exponential(mean=2) == Chi-squared distribution(degrees freedom=2) // thus is the equivalent of the HypersphereDiscardSampler. - exp = new AhrensDieterExponentialSampler(rng, 2.0); + // Here we use mean = 1 and scale the output later. + exp = ZigguratSampler.Exponential.of(rng); } @Override @@ -444,8 +446,8 @@ public class UnitBallSamplerBenchmark { final double x = normal.sample(); final double y = normal.sample(); final double z = normal.sample(); - // Include the exponential sample - final double sum = exp.sample() + x * x + y * y + z * z; + // Include the exponential sample. It has mean 1 so multiply by 2. + final double sum = exp.sample() * 2 + x * x + y * y + z * z; // Note: Handle the possibility of a zero sum and invalid inverse if (sum == 0) { return sample(); @@ -608,7 +610,7 @@ public class UnitBallSamplerBenchmark { /** The normal distribution. */ private final NormalizedGaussianSampler normal; /** The exponential distribution. */ - private final AhrensDieterExponentialSampler exp; + private final ContinuousSampler exp; /** * @param rng the source of randomness @@ -620,14 +622,15 @@ public class UnitBallSamplerBenchmark { normal = new ZigguratNormalizedGaussianSampler(rng); // Exponential(mean=2) == Chi-squared distribution(degrees freedom=2) // thus is the equivalent of the HypersphereDiscardSampler. - exp = new AhrensDieterExponentialSampler(rng, 2.0); + // Here we use mean = 1 and scale the output later. + exp = ZigguratSampler.Exponential.of(rng); } @Override public double[] sample() { final double[] sample = new double[dimension]; - // Include the exponential sample - double sum = exp.sample(); + // Include the exponential sample. It has mean 1 so multiply by 2. + double sum = exp.sample() * 2; for (int i = 0; i < dimension; i++) { final double x = normal.sample(); sum += x * x; diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java index 7aa8efc..80c836c 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java @@ -19,8 +19,10 @@ package org.apache.commons.rng.sampling.shape; import org.apache.commons.rng.UniformRandomProvider; import org.apache.commons.rng.sampling.SharedStateObjectSampler; +import org.apache.commons.rng.sampling.distribution.ContinuousSampler; import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler; import org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler; +import org.apache.commons.rng.sampling.distribution.ZigguratSampler; /** * Generate coordinates <a href="http://mathworld.wolfram.com/BallPointPicking.html"> @@ -113,25 +115,28 @@ public abstract class UnitBallSampler implements SharedStateObjectSampler<double * {@link UnitBallSamplerND} for performance. */ private static class UnitBallSampler3D extends UnitBallSampler { - /** The normal distribution. */ + /** The standard normal distribution. */ private final NormalizedGaussianSampler normal; + /** The exponential distribution (mean=1). */ + private final ContinuousSampler exp; /** * @param rng Source of randomness. */ UnitBallSampler3D(UniformRandomProvider rng) { normal = new ZigguratNormalizedGaussianSampler(rng); + // Require an Exponential(mean=2). + // Here we use mean = 1 and scale the output later. + exp = ZigguratSampler.Exponential.of(rng); } @Override public double[] sample() { - // Discard 2 samples from the coordinate but include in the sum - final double x0 = normal.sample(); - final double x1 = normal.sample(); final double x = normal.sample(); final double y = normal.sample(); final double z = normal.sample(); - final double sum = x0 * x0 + x1 * x1 + x * x + y * y + z * z; + // Include the exponential sample. It has mean 1 so multiply by 2. + final double sum = exp.sample() * 2 + x * x + y * y + z * z; // Note: Handle the possibility of a zero sum and invalid inverse if (sum == 0) { return sample(); @@ -147,18 +152,16 @@ public abstract class UnitBallSampler implements SharedStateObjectSampler<double } /** - * Sample uniformly from a unit n-ball. - * Take a random point on the (n+1)-dimensional hypersphere and drop two coordinates. - * Remember that the (n+1)-hypersphere is the unit sphere of R^(n+2), i.e. the surface - * of the (n+2)-dimensional ball. - * @see <a href="https://mathoverflow.net/questions/309567/sampling-a-uniformly-distributed-point-inside-a-hypersphere"> - * Sampling a uniformly distributed point INSIDE a hypersphere?</a> + * Sample using ball point picking. + * @see <a href="https://mathworld.wolfram.com/BallPointPicking.html">Ball point picking</a> */ private static class UnitBallSamplerND extends UnitBallSampler { /** The dimension. */ private final int dimension; - /** The normal distribution. */ + /** The standard normal distribution. */ private final NormalizedGaussianSampler normal; + /** The exponential distribution (mean=1). */ + private final ContinuousSampler exp; /** * @param dimension Space dimension. @@ -167,15 +170,16 @@ public abstract class UnitBallSampler implements SharedStateObjectSampler<double UnitBallSamplerND(int dimension, UniformRandomProvider rng) { this.dimension = dimension; normal = new ZigguratNormalizedGaussianSampler(rng); + // Require an Exponential(mean=2). + // Here we use mean = 1 and scale the output later. + exp = ZigguratSampler.Exponential.of(rng); } @Override public double[] sample() { final double[] sample = new double[dimension]; - // Discard 2 samples from the coordinate but include in the sum - final double x0 = normal.sample(); - final double x1 = normal.sample(); - double sum = x0 * x0 + x1 * x1; + // Include the exponential sample. It has mean 1 so multiply by 2. + double sum = exp.sample() * 2; for (int i = 0; i < dimension; i++) { final double x = normal.sample(); sum += x * x; diff --git a/src/changes/changes.xml b/src/changes/changes.xml index 356e912..8641760 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -77,6 +77,10 @@ re-run tests that fail, and pass the build if they succeed within the allotted number of reruns (the test will be marked as 'flaky' in the report). "> + <action dev="aherbert" type="update" issue="153"> + "UnitBallSampler": Update to use the ZigguratSampler for an exponential deviate for + ball point picking. + </action> <action dev="aherbert" type="update" issue="150"> Update "LargeMeanPoissonSampler" and "GeometricSampler" to use the ZigguratSampler for exponential deviates.