MATH-1158. Sampler functionality defined in "EnumeratedDistribution". Method "createSampler" overridden in "EnumeratedRealDistribution".
Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/a5035d0e Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/a5035d0e Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/a5035d0e Branch: refs/heads/feature-MATH-1158 Commit: a5035d0e1cde068320984d789473e1140adefdc0 Parents: a6eda3d Author: Gilles <er...@apache.org> Authored: Fri Mar 11 04:48:18 2016 +0100 Committer: Gilles <er...@apache.org> Committed: Fri Mar 11 04:48:18 2016 +0100 ---------------------------------------------------------------------- .../distribution/EnumeratedDistribution.java | 116 +++++++++++++++++++ .../EnumeratedRealDistribution.java | 20 ++++ .../EnumeratedRealDistributionTest.java | 6 +- 3 files changed, 140 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/a5035d0e/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java index 8e1149f..40af6e4 100644 --- a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java @@ -31,6 +31,7 @@ import org.apache.commons.math4.exception.NullArgumentException; import org.apache.commons.math4.exception.util.LocalizedFormats; import org.apache.commons.math4.random.RandomGenerator; import org.apache.commons.math4.random.Well19937c; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.util.MathArrays; import org.apache.commons.math4.util.Pair; @@ -59,6 +60,7 @@ public class EnumeratedDistribution<T> implements Serializable { /** * RNG instance used to generate samples from the distribution. */ + @Deprecated protected final RandomGenerator random; /** @@ -113,6 +115,7 @@ public class EnumeratedDistribution<T> implements Serializable { * @throws NotANumberException if any of the probabilities are NaN. * @throws MathArithmeticException all of the probabilities are 0. */ + @Deprecated public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf) throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException { random = rng; @@ -151,6 +154,7 @@ public class EnumeratedDistribution<T> implements Serializable { * * @param seed the new seed */ + @Deprecated public void reseedRandomGenerator(long seed) { random.setSeed(seed); } @@ -205,6 +209,7 @@ public class EnumeratedDistribution<T> implements Serializable { * * @return a random value. */ + @Deprecated public T sample() { final double randomValue = random.nextDouble(); @@ -233,6 +238,7 @@ public class EnumeratedDistribution<T> implements Serializable { * @throws NotStrictlyPositiveException if {@code sampleSize} is not * positive. */ + @Deprecated public Object[] sample(int sampleSize) throws NotStrictlyPositiveException { if (sampleSize <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, @@ -262,6 +268,7 @@ public class EnumeratedDistribution<T> implements Serializable { * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive. * @throws NullArgumentException if {@code array} is null */ + @Deprecated public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException { if (sampleSize <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); @@ -288,4 +295,113 @@ public class EnumeratedDistribution<T> implements Serializable { } + /** + * Creates a {@link Sampler}. + * + * @param rng Random number generator. + */ + public Sampler createSampler(final UniformRandomProvider rng) { + return new Sampler(rng); + } + + /** + * Sampler functionality. + */ + public class Sampler { + /** RNG. */ + private final UniformRandomProvider random; + + /** + * @param rng Random number generator. + */ + Sampler(UniformRandomProvider rng) { + random = rng; + } + + /** + * Generates a random value sampled from this distribution. + * + * @return a random value. + */ + public T sample() { + final double randomValue = random.nextDouble(); + + int index = Arrays.binarySearch(cumulativeProbabilities, randomValue); + if (index < 0) { + index = -index - 1; + } + + if (index >= 0 && + index < probabilities.length && + randomValue < cumulativeProbabilities[index]) { + return singletons.get(index); + } + + // This should never happen, but it ensures we will return a correct + // object in case there is some floating point inequality problem + // wrt the cumulative probabilities. + return singletons.get(singletons.size() - 1); + } + + /** + * Generates a random sample from the distribution. + * + * @param sampleSize the number of random values to generate. + * @return an array representing the random sample. + * @throws NotStrictlyPositiveException if {@code sampleSize} is not + * positive. + */ + public Object[] sample(int sampleSize) throws NotStrictlyPositiveException { + if (sampleSize <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, + sampleSize); + } + + final Object[] out = new Object[sampleSize]; + + for (int i = 0; i < sampleSize; i++) { + out[i] = sample(); + } + + return out; + } + + /** + * Generates a random sample from the distribution. + * <p> + * If the requested samples fit in the specified array, it is returned + * therein. Otherwise, a new array is allocated with the runtime type of + * the specified array and the size of this collection. + * + * @param sampleSize the number of random values to generate. + * @param array the array to populate. + * @return an array representing the random sample. + * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive. + * @throws NullArgumentException if {@code array} is null + */ + public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException { + if (sampleSize <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); + } + + if (array == null) { + throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY); + } + + T[] out; + if (array.length < sampleSize) { + @SuppressWarnings("unchecked") // safe as both are of type T + final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize); + out = unchecked; + } else { + out = array; + } + + for (int i = 0; i < sampleSize; i++) { + out[i] = sample(); + } + + return out; + } + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/a5035d0e/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java b/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java index 688b3fd..9e03b2b 100644 --- a/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java @@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotPositiveException; import org.apache.commons.math4.exception.OutOfRangeException; import org.apache.commons.math4.random.RandomGenerator; import org.apache.commons.math4.random.Well19937c; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.util.Pair; /** @@ -93,6 +94,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution { * @throws NotANumberException if any of the probabilities are NaN. * @throws MathArithmeticException all of the probabilities are 0. */ + @Deprecated public EnumeratedRealDistribution(final RandomGenerator rng, final double[] singletons, final double[] probabilities) throws DimensionMismatchException, NotPositiveException, MathArithmeticException, @@ -111,6 +113,7 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution { * @param data input dataset * @since 3.6 */ + @Deprecated public EnumeratedRealDistribution(final RandomGenerator rng, final double[] data) { super(rng); final Map<Double, Integer> dataMap = new HashMap<Double, Integer>(); @@ -319,7 +322,24 @@ public class EnumeratedRealDistribution extends AbstractRealDistribution { * {@inheritDoc} */ @Override + @Deprecated public double sample() { return innerDistribution.sample(); } + + /** {@inheritDoc} */ + @Override + public RealDistribution.Sampler createSampler(final UniformRandomProvider rng) { + return new RealDistribution.Sampler() { + /** Delegate. */ + private final EnumeratedDistribution<Double>.Sampler inner = + innerDistribution.createSampler(rng); + + /** {@inheritDoc} */ + @Override + public double sample() { + return inner.sample(); + } + }; + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/a5035d0e/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java index f1cf652..0300e5e 100644 --- a/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java +++ b/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java @@ -30,6 +30,7 @@ import org.apache.commons.math4.exception.NotFiniteNumberException; import org.apache.commons.math4.exception.NotPositiveException; import org.apache.commons.math4.util.FastMath; import org.apache.commons.math4.util.Pair; +import org.apache.commons.math4.rng.RandomSource; import org.junit.Assert; import org.junit.Test; @@ -175,8 +176,9 @@ public class EnumeratedRealDistributionTest { @Test public void testSample() { final int n = 1000000; - testDistribution.reseedRandomGenerator(-334759360); // fixed seed - final double[] samples = testDistribution.sample(n); + final RealDistribution.Sampler sampler = + testDistribution.createSampler(RandomSource.create(RandomSource.WELL_1024_A, -123456789)); + final double[] samples = AbstractRealDistribution.sample(n, sampler); Assert.assertEquals(n, samples.length); double sum = 0; double sumOfSquares = 0;