This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git
The following commit(s) were added to refs/heads/master by this push:
new 3cb2c9b RNG-153: Update the sampling algorithm in the UnitBallSampler
3cb2c9b is described below
commit 3cb2c9b718a0878a6fe7572f560f697dffeba758
Author: aherbert <[email protected]>
AuthorDate: Thu Jul 8 13:07:37 2021 +0100
RNG-153: Update the sampling algorithm in the UnitBallSampler
---
.../sampling/shape/UnitBallSamplerBenchmark.java | 21 +++++++------
.../rng/sampling/shape/UnitBallSampler.java | 36 ++++++++++++----------
src/changes/changes.xml | 4 +++
3 files changed, 36 insertions(+), 25 deletions(-)
diff --git
a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java
b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java
index 142d75d..f4523e0 100644
---
a/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java
+++
b/commons-rng-examples/examples-jmh/src/main/java/org/apache/commons/rng/examples/jmh/sampling/shape/UnitBallSamplerBenchmark.java
@@ -18,9 +18,10 @@
package org.apache.commons.rng.examples.jmh.sampling.shape;
import org.apache.commons.rng.UniformRandomProvider;
-import
org.apache.commons.rng.sampling.distribution.AhrensDieterExponentialSampler;
+import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
import
org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
+import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
import org.apache.commons.rng.simple.RandomSource;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
@@ -426,7 +427,7 @@ public class UnitBallSamplerBenchmark {
/** The normal distribution. */
private final NormalizedGaussianSampler normal;
/** The exponential distribution. */
- private final AhrensDieterExponentialSampler exp;
+ private final ContinuousSampler exp;
/**
* @param rng the source of randomness
@@ -436,7 +437,8 @@ public class UnitBallSamplerBenchmark {
normal = new ZigguratNormalizedGaussianSampler(rng);
// Exponential(mean=2) == Chi-squared distribution(degrees
freedom=2)
// thus is the equivalent of the HypersphereDiscardSampler.
- exp = new AhrensDieterExponentialSampler(rng, 2.0);
+ // Here we use mean = 1 and scale the output later.
+ exp = ZigguratSampler.Exponential.of(rng);
}
@Override
@@ -444,8 +446,8 @@ public class UnitBallSamplerBenchmark {
final double x = normal.sample();
final double y = normal.sample();
final double z = normal.sample();
- // Include the exponential sample
- final double sum = exp.sample() + x * x + y * y + z * z;
+ // Include the exponential sample. It has mean 1 so multiply
by 2.
+ final double sum = exp.sample() * 2 + x * x + y * y + z * z;
// Note: Handle the possibility of a zero sum and invalid
inverse
if (sum == 0) {
return sample();
@@ -608,7 +610,7 @@ public class UnitBallSamplerBenchmark {
/** The normal distribution. */
private final NormalizedGaussianSampler normal;
/** The exponential distribution. */
- private final AhrensDieterExponentialSampler exp;
+ private final ContinuousSampler exp;
/**
* @param rng the source of randomness
@@ -620,14 +622,15 @@ public class UnitBallSamplerBenchmark {
normal = new ZigguratNormalizedGaussianSampler(rng);
// Exponential(mean=2) == Chi-squared distribution(degrees
freedom=2)
// thus is the equivalent of the HypersphereDiscardSampler.
- exp = new AhrensDieterExponentialSampler(rng, 2.0);
+ // Here we use mean = 1 and scale the output later.
+ exp = ZigguratSampler.Exponential.of(rng);
}
@Override
public double[] sample() {
final double[] sample = new double[dimension];
- // Include the exponential sample
- double sum = exp.sample();
+ // Include the exponential sample. It has mean 1 so multiply
by 2.
+ double sum = exp.sample() * 2;
for (int i = 0; i < dimension; i++) {
final double x = normal.sample();
sum += x * x;
diff --git
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java
index 7aa8efc..80c836c 100644
---
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java
+++
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/shape/UnitBallSampler.java
@@ -19,8 +19,10 @@ package org.apache.commons.rng.sampling.shape;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.SharedStateObjectSampler;
+import org.apache.commons.rng.sampling.distribution.ContinuousSampler;
import org.apache.commons.rng.sampling.distribution.NormalizedGaussianSampler;
import
org.apache.commons.rng.sampling.distribution.ZigguratNormalizedGaussianSampler;
+import org.apache.commons.rng.sampling.distribution.ZigguratSampler;
/**
* Generate coordinates <a
href="http://mathworld.wolfram.com/BallPointPicking.html">
@@ -113,25 +115,28 @@ public abstract class UnitBallSampler implements
SharedStateObjectSampler<double
* {@link UnitBallSamplerND} for performance.
*/
private static class UnitBallSampler3D extends UnitBallSampler {
- /** The normal distribution. */
+ /** The standard normal distribution. */
private final NormalizedGaussianSampler normal;
+ /** The exponential distribution (mean=1). */
+ private final ContinuousSampler exp;
/**
* @param rng Source of randomness.
*/
UnitBallSampler3D(UniformRandomProvider rng) {
normal = new ZigguratNormalizedGaussianSampler(rng);
+ // Require an Exponential(mean=2).
+ // Here we use mean = 1 and scale the output later.
+ exp = ZigguratSampler.Exponential.of(rng);
}
@Override
public double[] sample() {
- // Discard 2 samples from the coordinate but include in the sum
- final double x0 = normal.sample();
- final double x1 = normal.sample();
final double x = normal.sample();
final double y = normal.sample();
final double z = normal.sample();
- final double sum = x0 * x0 + x1 * x1 + x * x + y * y + z * z;
+ // Include the exponential sample. It has mean 1 so multiply by 2.
+ final double sum = exp.sample() * 2 + x * x + y * y + z * z;
// Note: Handle the possibility of a zero sum and invalid inverse
if (sum == 0) {
return sample();
@@ -147,18 +152,16 @@ public abstract class UnitBallSampler implements
SharedStateObjectSampler<double
}
/**
- * Sample uniformly from a unit n-ball.
- * Take a random point on the (n+1)-dimensional hypersphere and drop two
coordinates.
- * Remember that the (n+1)-hypersphere is the unit sphere of R^(n+2), i.e.
the surface
- * of the (n+2)-dimensional ball.
- * @see <a
href="https://mathoverflow.net/questions/309567/sampling-a-uniformly-distributed-point-inside-a-hypersphere">
- * Sampling a uniformly distributed point INSIDE a hypersphere?</a>
+ * Sample using ball point picking.
+ * @see <a href="https://mathworld.wolfram.com/BallPointPicking.html">Ball
point picking</a>
*/
private static class UnitBallSamplerND extends UnitBallSampler {
/** The dimension. */
private final int dimension;
- /** The normal distribution. */
+ /** The standard normal distribution. */
private final NormalizedGaussianSampler normal;
+ /** The exponential distribution (mean=1). */
+ private final ContinuousSampler exp;
/**
* @param dimension Space dimension.
@@ -167,15 +170,16 @@ public abstract class UnitBallSampler implements
SharedStateObjectSampler<double
UnitBallSamplerND(int dimension, UniformRandomProvider rng) {
this.dimension = dimension;
normal = new ZigguratNormalizedGaussianSampler(rng);
+ // Require an Exponential(mean=2).
+ // Here we use mean = 1 and scale the output later.
+ exp = ZigguratSampler.Exponential.of(rng);
}
@Override
public double[] sample() {
final double[] sample = new double[dimension];
- // Discard 2 samples from the coordinate but include in the sum
- final double x0 = normal.sample();
- final double x1 = normal.sample();
- double sum = x0 * x0 + x1 * x1;
+ // Include the exponential sample. It has mean 1 so multiply by 2.
+ double sum = exp.sample() * 2;
for (int i = 0; i < dimension; i++) {
final double x = normal.sample();
sum += x * x;
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index 356e912..8641760 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -77,6 +77,10 @@ re-run tests that fail, and pass the build if they succeed
within the allotted number of reruns (the test will be marked
as 'flaky' in the report).
">
+ <action dev="aherbert" type="update" issue="153">
+ "UnitBallSampler": Update to use the ZigguratSampler for an
exponential deviate for
+ ball point picking.
+ </action>
<action dev="aherbert" type="update" issue="150">
Update "LargeMeanPoissonSampler" and "GeometricSampler" to use the
ZigguratSampler for
exponential deviates.