RNG-30: Unit tests for discrete distributions.
Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/438b67b3 Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/438b67b3 Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/438b67b3 Branch: refs/heads/master Commit: 438b67b3258427e8c5a20d17b5fad0ee4629f0e8 Parents: 74bbdd2 Author: Gilles <er...@apache.org> Authored: Sat Nov 12 16:51:01 2016 +0100 Committer: Gilles <er...@apache.org> Committed: Sat Nov 12 16:51:01 2016 +0100 ---------------------------------------------------------------------- .../DiscreteSamplerParametricTest.java | 161 +++++++++++++++++ .../distribution/DiscreteSamplerTestData.java | 60 +++++++ .../distribution/DiscreteSamplersList.java | 180 +++++++++++++++++++ 3 files changed, 401 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-rng/blob/438b67b3/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java new file mode 100644 index 0000000..d96fcb1 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerParametricTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.util.Arrays; +import java.util.List; +import java.util.ArrayList; +import java.util.concurrent.Callable; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.ObjectInputStream; +import java.io.ByteArrayOutputStream; +import java.io.ByteArrayInputStream; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.Assume; +import org.junit.Ignore; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import org.apache.commons.math3.distribution.ChiSquaredDistribution; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.RandomProviderState; +import org.apache.commons.rng.RestorableUniformRandomProvider; +import org.apache.commons.rng.core.RandomProviderDefaultState; +import org.apache.commons.rng.sampling.DiscreteSampler; + +/** + * Tests for samplers. + */ +@RunWith(value=Parameterized.class) +public class DiscreteSamplerParametricTest { + /** Sampler under test. */ + private final DiscreteSamplerTestData sampler; + + /** + * Initializes generator instance. + * + * @param rng RNG to be tested. + */ + public DiscreteSamplerParametricTest(DiscreteSamplerTestData data) { + sampler = data; + } + + @Parameters(name = "{index}: data={0}") + public static Iterable<DiscreteSamplerTestData[]> getList() { + return DiscreteSamplersList.list(); + } + + @Test + public void testSampling() { + final int sampleSize = 10000; + + final double[] prob = sampler.getProbabilities(); + final int len = prob.length; + final long[] expected = new long[len]; + for (int i = 0; i < len; i++) { + expected[i] = (long) (prob[i] * sampleSize); + } + check(sampleSize, + sampler.getSampler(), + sampler.getPoints(), + expected); + } + + /** + * Performs a chi-square test of homogeneity of the observed + * distribution with the expected distribution. + * + * @param sampler Sampler. + * @param sampleSize Number of random values to generate. + * @param points Outcomes. + * @param expected Expected counts of the given outcomes. + */ + private void check(long sampleSize, + DiscreteSampler sampler, + int[] points, + long[] expected) { + final int numTests = 50; + + // Run the tests. + int numFailures = 0; + + final int numBins = points.length; + final long[] observed = new long[numBins]; + + // For storing chi2 larger than the critical value. + final List<Double> failedStat = new ArrayList<Double>(); + try { + for (int i = 0; i < numTests; i++) { + Arrays.fill(observed, 0); + SAMPLE: for (long j = 0; j < sampleSize; j++) { + final int value = sampler.sample(); + + for (int k = 0; k < numBins; k++) { + if (value == points[k]) { + ++observed[k]; + continue SAMPLE; + } + } + } + + // Statistics check. XXX + final double chi2stat = chiSquareStat(expected, observed); + if (chi2stat < 0.001) { + failedStat.add(chi2stat); + ++numFailures; + } + } + } catch (Exception e) { + // Should never happen. + throw new RuntimeException("Unexpected", e); + } + + if ((double) numFailures / (double) numTests > 0.02) { + Assert.fail(sampler + ": Too many failures for sample size = " + sampleSize + + " (" + numFailures + " out of " + numTests + " tests failed, " + + "chi2=" + Arrays.toString(failedStat.toArray(new Double[0]))); + } + } + + /** + * @param expected Counts. + * @param observed Counts. + * @return the chi-square statistics. + */ + private static double chiSquareStat(long[] expected, + long[] observed) { + final int numBins = expected.length; + double chi2 = 0; + for (int i = 0; i < numBins; i++) { + final long diff = observed[i] - expected[i]; + chi2 += (diff / (double) expected[i]) * diff; + // System.out.println("bin[" + i + "]" + + // " obs=" + observed[i] + + // " exp=" + expected[i]); + } + + final int dof = numBins - 1; + final ChiSquaredDistribution dist = new ChiSquaredDistribution(null, dof, 1e-8); + + return 1 - dist.cumulativeProbability(chi2); + } +} http://git-wip-us.apache.org/repos/asf/commons-rng/blob/438b67b3/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java new file mode 100644 index 0000000..fd52f29 --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplerTestData.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.util.Arrays; + +import org.apache.commons.rng.sampling.DiscreteSampler; + +/** + * Data store for {@link InverseMethodDiscreteParametricTest}. + */ +class DiscreteSamplerTestData { + private final DiscreteSampler sampler; + private final int[] points; + private final double[] probabilities; + + public DiscreteSamplerTestData(DiscreteSampler sampler, + int[] points, + double[] probabilities) { + this.sampler = sampler; + this.points = points.clone(); + this.probabilities = probabilities.clone(); + } + + public DiscreteSampler getSampler() { + return sampler; + } + + public int[] getPoints() { + return points.clone(); + } + + public double[] getProbabilities() { + return probabilities.clone(); + } + + @Override + public String toString() { + final int len = points.length; + final String[] p = new String[len]; + for (int i = 0; i < len; i++) { + p[i] = "p(" + points[i] + ")=" + probabilities[i]; + } + return sampler.toString() + ": " + Arrays.toString(p); + } +} http://git-wip-us.apache.org/repos/asf/commons-rng/blob/438b67b3/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java ---------------------------------------------------------------------- diff --git a/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java new file mode 100644 index 0000000..7a7425f --- /dev/null +++ b/commons-rng-sampling/src/test/java/org/apache/commons/rng/sampling/distribution/DiscreteSamplersList.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.rng.sampling.distribution; + +import java.util.Arrays; +import java.util.List; +import java.util.ArrayList; +import java.util.Collections; + +import org.apache.commons.math3.util.MathArrays; + +import org.apache.commons.rng.UniformRandomProvider; +import org.apache.commons.rng.sampling.DiscreteSampler; +import org.apache.commons.rng.simple.RandomSource; + +/** + * List of samplers. + */ +public class DiscreteSamplersList { + /** List of all RNGs implemented in the library. */ + private static final List<DiscreteSamplerTestData[]> LIST = + new ArrayList<DiscreteSamplerTestData[]>(); + + static { + try { + // List of distributions to test. + + // Binomial ("inverse method"). + final int trialsBinomial = 20; + final double probSuccessBinomial = 0.67; + add(LIST, new org.apache.commons.math3.distribution.BinomialDistribution(trialsBinomial, probSuccessBinomial), + MathArrays.sequence(8, 9, 1), + RandomSource.create(RandomSource.KISS)); + + // Geometric ("inverse method"). + final double probSuccessGeometric = 0.21; + add(LIST, new org.apache.commons.math3.distribution.GeometricDistribution(probSuccessGeometric), + MathArrays.sequence(10, 0, 1), + RandomSource.create(RandomSource.ISAAC)); + + // Hypergeometric ("inverse method"). + final int popSizeHyper = 34; + final int numSuccessesHyper = 11; + final int sampleSizeHyper = 12; + add(LIST, new org.apache.commons.math3.distribution.HypergeometricDistribution(popSizeHyper, numSuccessesHyper, sampleSizeHyper), + MathArrays.sequence(10, 0, 1), + RandomSource.create(RandomSource.MT)); + + // Pascal ("inverse method"). + final int numSuccessesPascal = 6; + final double probSuccessPascal = 0.2; + add(LIST, new org.apache.commons.math3.distribution.PascalDistribution(numSuccessesPascal, probSuccessPascal), + MathArrays.sequence(18, 1, 1), + RandomSource.create(RandomSource.TWO_CMRES)); + + // Uniform ("inverse method"). + final int loUniform = -3; + final int hiUniform = 4; + add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, hiUniform), + MathArrays.sequence(10, -4, 1), + RandomSource.create(RandomSource.SPLIT_MIX_64)); + // Uniform. + add(LIST, new org.apache.commons.math3.distribution.UniformIntegerDistribution(loUniform, hiUniform), + MathArrays.sequence(10, -4, 1), + new DiscreteUniformSampler(RandomSource.create(RandomSource.MT_64), loUniform, hiUniform)); + + // Zipf ("inverse method"). + final int numElementsZipf = 5; + final double exponentZipf = 2.345; + add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentZipf), + MathArrays.sequence(5, 0, 1), + RandomSource.create(RandomSource.XOR_SHIFT_1024_S)); + // Zipf. + add(LIST, new org.apache.commons.math3.distribution.ZipfDistribution(numElementsZipf, exponentZipf), + MathArrays.sequence(7, 0, 1), + new RejectionInversionZipfSampler(RandomSource.create(RandomSource.WELL_19937_C), numElementsZipf, exponentZipf)); + + // Poisson ("inverse method"). + final double meanPoisson = 3.21; + add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson), + MathArrays.sequence(10, 0, 1), + RandomSource.create(RandomSource.MWC_256)); + // Poisson. + add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(meanPoisson), + MathArrays.sequence(10, 0, 1), + new PoissonSampler(RandomSource.create(RandomSource.KISS), meanPoisson)); + // Poisson (mean > 40). + final double largeMeanPoisson = 543.21; + add(LIST, new org.apache.commons.math3.distribution.PoissonDistribution(largeMeanPoisson), + MathArrays.sequence(100, (int) (largeMeanPoisson - 50), 1), + new PoissonSampler(RandomSource.create(RandomSource.SPLIT_MIX_64), largeMeanPoisson)); + } catch (Exception e) { + System.err.println("Unexpected exception while creating the list of samplers: " + e); + e.printStackTrace(System.err); + throw new RuntimeException(e); + } + } + + /** + * Class contains only static methods. + */ + private DiscreteSamplersList() {} + + /** + * @param list List of data (one the "parameters" tested by the Junit parametric test). + * @param dist Distribution to which the samples are supposed to conform. + * @param points Outcomes selection. + * @param rng Generator of uniformly distributed sequences. + */ + private static void add(List<DiscreteSamplerTestData[]> list, + final org.apache.commons.math3.distribution.IntegerDistribution dist, + int[] points, + UniformRandomProvider rng) { + final DiscreteSampler inverseMethodSampler = + new InverseMethodDiscreteSampler(rng, + new DiscreteInverseCumulativeProbabilityFunction() { + @Override + public int inverseCumulativeProbability(double p) { + return dist.inverseCumulativeProbability(p); + } + }); + list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(inverseMethodSampler, + points, + getProbabilities(dist, points)) }); + } + + /** + * @param list List of data (one the "parameters" tested by the Junit parametric test). + * @param dist Distribution to which the samples are supposed to conform. + * @param points Outcomes selection. + * @param sampler Sampler. + */ + private static void add(List<DiscreteSamplerTestData[]> list, + final org.apache.commons.math3.distribution.IntegerDistribution dist, + int[] points, + final DiscreteSampler sampler) { + list.add(new DiscreteSamplerTestData[] { new DiscreteSamplerTestData(sampler, + points, + getProbabilities(dist, points)) }); + } + + /** + * Subclasses that are "parametric" tests can forward the call to + * the "@Parameters"-annotated method to this method. + * + * @return the list of all generators. + */ + public static Iterable<DiscreteSamplerTestData[]> list() { + return Collections.unmodifiableList(LIST); + } + + /** + * @param dist Distribution. + * @param points Points. + * @return the probabilities of the given points according to the distribution. + */ + private static double[] getProbabilities(org.apache.commons.math3.distribution.IntegerDistribution dist, + int[] points) { + final int len = points.length; + final double[] prob = new double[len]; + for (int i = 0; i < len; i++) { + prob[i] = dist.probability(points[i]); + } + return prob; + } +}