This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-rng.git
The following commit(s) were added to refs/heads/master by this push: new 6edeb10 RNG-146: Prevent infinite standard deviation 6edeb10 is described below commit 6edeb102c5310895480781e14c14facaf12ed864 Author: aherbert <aherb...@apache.org> AuthorDate: Fri Jul 9 11:29:17 2021 +0100 RNG-146: Prevent infinite standard deviation --- .../rng/sampling/distribution/GaussianSampler.java | 22 +++++++-- .../sampling/distribution/GaussianSamplerTest.java | 54 +++++++++++++++++++++- src/changes/changes.xml | 3 ++ 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java index 5540018..38e0537 100644 --- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java +++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/GaussianSampler.java @@ -22,6 +22,14 @@ import org.apache.commons.rng.UniformRandomProvider; * Sampling from a Gaussian distribution with given mean and * standard deviation. * + * <h2>Note</h2> + * + * <p>The mean and standard deviation are validated to ensure they are finite. This prevents + * generation of NaN samples by avoiding invalid arithmetic (inf * 0 or inf - inf). + * However use of an extremely large standard deviation and/or mean may result in samples that are + * infinite; that is the parameters are not validated to prevent truncation of the output + * distribution. + * * @since 1.1 */ public class GaussianSampler implements SharedStateContinuousSampler { @@ -36,14 +44,19 @@ public class GaussianSampler implements SharedStateContinuousSampler { * @param normalized Generator of N(0,1) Gaussian distributed random numbers. * @param mean Mean of the Gaussian distribution. * @param standardDeviation Standard deviation of the Gaussian distribution. - * @throws IllegalArgumentException if {@code standardDeviation <= 0} + * @throws IllegalArgumentException if {@code standardDeviation <= 0} or is infinite; + * or {@code mean} is infinite */ public GaussianSampler(NormalizedGaussianSampler normalized, double mean, double standardDeviation) { - if (standardDeviation <= 0) { + if (!(standardDeviation > 0 && standardDeviation < Double.POSITIVE_INFINITY)) { throw new IllegalArgumentException( - "standard deviation is not strictly positive: " + standardDeviation); + "standard deviation is not strictly positive and finite: " + standardDeviation); + } + // To be replaced by JDK 1.8 Double.isFinite. This will detect NaN values. + if (!(Math.abs(mean) <= Double.MAX_VALUE)) { + throw new IllegalArgumentException("mean is not finite: " + mean); } this.normalized = normalized; this.mean = mean; @@ -102,7 +115,8 @@ public class GaussianSampler implements SharedStateContinuousSampler { * @param mean Mean of the Gaussian distribution. * @param standardDeviation Standard deviation of the Gaussian distribution. * @return the sampler - * @throws IllegalArgumentException if {@code standardDeviation <= 0} + * @throws IllegalArgumentException if {@code standardDeviation <= 0} or is infinite; + * or {@code mean} is infinite * @see #withUniformRandomProvider(UniformRandomProvider) * @since 1.3 */ diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java index 9bfabba..2ad15f8 100644 --- a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/GaussianSamplerTest.java @@ -28,7 +28,7 @@ import org.junit.Test; */ public class GaussianSamplerTest { /** - * Test the constructor with a bad standard deviation. + * Test the constructor with a zero standard deviation. */ @Test(expected = IllegalArgumentException.class) public void testConstructorThrowsWithZeroStandardDeviation() { @@ -41,6 +41,58 @@ public class GaussianSamplerTest { } /** + * Test the constructor with an infinite standard deviation. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithInfiniteStandardDeviation() { + final RestorableUniformRandomProvider rng = + RandomSource.SPLIT_MIX_64.create(0L); + final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng); + final double mean = 1; + final double standardDeviation = Double.POSITIVE_INFINITY; + GaussianSampler.of(gauss, mean, standardDeviation); + } + + /** + * Test the constructor with a NaN standard deviation. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithNaNStandardDeviation() { + final RestorableUniformRandomProvider rng = + RandomSource.SPLIT_MIX_64.create(0L); + final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng); + final double mean = 1; + final double standardDeviation = Double.NaN; + GaussianSampler.of(gauss, mean, standardDeviation); + } + + /** + * Test the constructor with an infinite mean. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithInfiniteMean() { + final RestorableUniformRandomProvider rng = + RandomSource.SPLIT_MIX_64.create(0L); + final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng); + final double mean = Double.POSITIVE_INFINITY; + final double standardDeviation = 1; + GaussianSampler.of(gauss, mean, standardDeviation); + } + + /** + * Test the constructor with a NaN mean. + */ + @Test(expected = IllegalArgumentException.class) + public void testConstructorThrowsWithNaNMean() { + final RestorableUniformRandomProvider rng = + RandomSource.SPLIT_MIX_64.create(0L); + final NormalizedGaussianSampler gauss = new ZigguratNormalizedGaussianSampler(rng); + final double mean = Double.NaN; + final double standardDeviation = 1; + GaussianSampler.of(gauss, mean, standardDeviation); + } + + /** * Test the SharedStateSampler implementation. */ @Test diff --git a/src/changes/changes.xml b/src/changes/changes.xml index e211d01..c71270f 100644 --- a/src/changes/changes.xml +++ b/src/changes/changes.xml @@ -77,6 +77,9 @@ re-run tests that fail, and pass the build if they succeed within the allotted number of reruns (the test will be marked as 'flaky' in the report). "> + <action dev="aherbert" type="fix" issue="146"> + "GaussianSampler": Prevent infinite mean and standard deviation. + </action> <action dev="aherbert" type="update" issue="154"> Update Gaussian samplers to avoid infinity in the tails of the distribution. Applies to: ZigguratNormalisedGaussianSampler; BoxMullerNormalizedGaussianSampler; and