This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-rng.git


The following commit(s) were added to refs/heads/master by this push:
     new 6edeb10  RNG-146: Prevent infinite standard deviation
6edeb10 is described below

commit 6edeb102c5310895480781e14c14facaf12ed864
Author: aherbert <aherb...@apache.org>
AuthorDate: Fri Jul 9 11:29:17 2021 +0100

    RNG-146: Prevent infinite standard deviation
---
 .../rng/sampling/distribution/GaussianSampler.java | 22 +++++++--
 .../sampling/distribution/GaussianSamplerTest.java | 54 +++++++++++++++++++++-
 src/changes/changes.xml                            |  3 ++
 3 files changed, 74 insertions(+), 5 deletions(-)

diff --git 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
index 5540018..38e0537 100644
--- 
a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
+++ 
b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java
@@ -22,6 +22,14 @@ import org.apache.commons.rng.UniformRandomProvider;
  * Sampling from a Gaussian distribution with given mean and
  * standard deviation.
  *
+ * <h2>Note</h2>
+ *
+ * <p>The mean and standard deviation are validated to ensure they are finite. 
This prevents
+ * generation of NaN samples by avoiding invalid arithmetic (inf * 0 or inf - 
inf).
+ * However use of an extremely large standard deviation and/or mean may result 
in samples that are
+ * infinite; that is the parameters are not validated to prevent truncation of 
the output
+ * distribution.
+ *
  * @since 1.1
  */
 public class GaussianSampler implements SharedStateContinuousSampler {
@@ -36,14 +44,19 @@ public class GaussianSampler implements 
SharedStateContinuousSampler {
      * @param normalized Generator of N(0,1) Gaussian distributed random 
numbers.
      * @param mean Mean of the Gaussian distribution.
      * @param standardDeviation Standard deviation of the Gaussian 
distribution.
-     * @throws IllegalArgumentException if {@code standardDeviation <= 0}
+     * @throws IllegalArgumentException if {@code standardDeviation <= 0} or 
is infinite;
+     * or {@code mean} is infinite
      */
     public GaussianSampler(NormalizedGaussianSampler normalized,
                            double mean,
                            double standardDeviation) {
-        if (standardDeviation <= 0) {
+        if (!(standardDeviation > 0 && standardDeviation < 
Double.POSITIVE_INFINITY)) {
             throw new IllegalArgumentException(
-                "standard deviation is not strictly positive: " + 
standardDeviation);
+                "standard deviation is not strictly positive and finite: " + 
standardDeviation);
+        }
+        // To be replaced by JDK 1.8 Double.isFinite. This will detect NaN 
values.
+        if (!(Math.abs(mean) <= Double.MAX_VALUE)) {
+            throw new IllegalArgumentException("mean is not finite: " + mean);
         }
         this.normalized = normalized;
         this.mean = mean;
@@ -102,7 +115,8 @@ public class GaussianSampler implements 
SharedStateContinuousSampler {
      * @param mean Mean of the Gaussian distribution.
      * @param standardDeviation Standard deviation of the Gaussian 
distribution.
      * @return the sampler
-     * @throws IllegalArgumentException if {@code standardDeviation <= 0}
+     * @throws IllegalArgumentException if {@code standardDeviation <= 0} or 
is infinite;
+     * or {@code mean} is infinite
      * @see #withUniformRandomProvider(UniformRandomProvider)
      * @since 1.3
      */
diff --git 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
index 9bfabba..2ad15f8 100644
--- 
a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
+++ 
b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java
@@ -28,7 +28,7 @@ import org.junit.Test;
  */
 public class GaussianSamplerTest {
     /**
-     * Test the constructor with a bad standard deviation.
+     * Test the constructor with a zero standard deviation.
      */
     @Test(expected = IllegalArgumentException.class)
     public void testConstructorThrowsWithZeroStandardDeviation() {
@@ -41,6 +41,58 @@ public class GaussianSamplerTest {
     }
 
     /**
+     * Test the constructor with an infinite standard deviation.
+     */
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithInfiniteStandardDeviation() {
+        final RestorableUniformRandomProvider rng =
+            RandomSource.SPLIT_MIX_64.create(0L);
+        final NormalizedGaussianSampler gauss = new 
ZigguratNormalizedGaussianSampler(rng);
+        final double mean = 1;
+        final double standardDeviation = Double.POSITIVE_INFINITY;
+        GaussianSampler.of(gauss, mean, standardDeviation);
+    }
+
+    /**
+     * Test the constructor with a NaN standard deviation.
+     */
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithNaNStandardDeviation() {
+        final RestorableUniformRandomProvider rng =
+            RandomSource.SPLIT_MIX_64.create(0L);
+        final NormalizedGaussianSampler gauss = new 
ZigguratNormalizedGaussianSampler(rng);
+        final double mean = 1;
+        final double standardDeviation = Double.NaN;
+        GaussianSampler.of(gauss, mean, standardDeviation);
+    }
+
+    /**
+     * Test the constructor with an infinite mean.
+     */
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithInfiniteMean() {
+        final RestorableUniformRandomProvider rng =
+            RandomSource.SPLIT_MIX_64.create(0L);
+        final NormalizedGaussianSampler gauss = new 
ZigguratNormalizedGaussianSampler(rng);
+        final double mean = Double.POSITIVE_INFINITY;
+        final double standardDeviation = 1;
+        GaussianSampler.of(gauss, mean, standardDeviation);
+    }
+
+    /**
+     * Test the constructor with a NaN mean.
+     */
+    @Test(expected = IllegalArgumentException.class)
+    public void testConstructorThrowsWithNaNMean() {
+        final RestorableUniformRandomProvider rng =
+            RandomSource.SPLIT_MIX_64.create(0L);
+        final NormalizedGaussianSampler gauss = new 
ZigguratNormalizedGaussianSampler(rng);
+        final double mean = Double.NaN;
+        final double standardDeviation = 1;
+        GaussianSampler.of(gauss, mean, standardDeviation);
+    }
+
+    /**
      * Test the SharedStateSampler implementation.
      */
     @Test
diff --git a/src/changes/changes.xml b/src/changes/changes.xml
index e211d01..c71270f 100644
--- a/src/changes/changes.xml
+++ b/src/changes/changes.xml
@@ -77,6 +77,9 @@ re-run tests that fail, and pass the build if they succeed
 within the allotted number of reruns (the test will be marked
 as 'flaky' in the report).
 ">
+      <action dev="aherbert" type="fix" issue="146">
+        "GaussianSampler": Prevent infinite mean and standard deviation.
+      </action>
       <action dev="aherbert" type="update" issue="154">
         Update Gaussian samplers to avoid infinity in the tails of the 
distribution. Applies
         to: ZigguratNormalisedGaussianSampler; 
BoxMullerNormalizedGaussianSampler; and

Reply via email to