MATH-1158.

Sampler functionality defined in "EnumeratedDistribution".
Method "createSampler" overridden in "EnumeratedRealDistribution".


Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/a5035d0e
Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/a5035d0e
Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/a5035d0e

Branch: refs/heads/feature-MATH-1158
Commit: a5035d0e1cde068320984d789473e1140adefdc0
Parents: a6eda3d
Author: Gilles <er...@apache.org>
Authored: Fri Mar 11 04:48:18 2016 +0100
Committer: Gilles <er...@apache.org>
Committed: Fri Mar 11 04:48:18 2016 +0100

----------------------------------------------------------------------
 .../distribution/EnumeratedDistribution.java    | 116 +++++++++++++++++++
 .../EnumeratedRealDistribution.java             |  20 ++++
 .../EnumeratedRealDistributionTest.java         |   6 +-
 3 files changed, 140 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/commons-math/blob/a5035d0e/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
 
b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
index 8e1149f..40af6e4 100644
--- 
a/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
+++ 
b/src/main/java/org/apache/commons/math4/distribution/EnumeratedDistribution.java
@@ -31,6 +31,7 @@ import 
org.apache.commons.math4.exception.NullArgumentException;
 import org.apache.commons.math4.exception.util.LocalizedFormats;
 import org.apache.commons.math4.random.RandomGenerator;
 import org.apache.commons.math4.random.Well19937c;
+import org.apache.commons.math4.rng.UniformRandomProvider;
 import org.apache.commons.math4.util.MathArrays;
 import org.apache.commons.math4.util.Pair;
 
@@ -59,6 +60,7 @@ public class EnumeratedDistribution<T> implements 
Serializable {
     /**
      * RNG instance used to generate samples from the distribution.
      */
+    @Deprecated
     protected final RandomGenerator random;
 
     /**
@@ -113,6 +115,7 @@ public class EnumeratedDistribution<T> implements 
Serializable {
      * @throws NotANumberException if any of the probabilities are NaN.
      * @throws MathArithmeticException all of the probabilities are 0.
      */
+    @Deprecated
     public EnumeratedDistribution(final RandomGenerator rng, final 
List<Pair<T, Double>> pmf)
         throws NotPositiveException, MathArithmeticException, 
NotFiniteNumberException, NotANumberException {
         random = rng;
@@ -151,6 +154,7 @@ public class EnumeratedDistribution<T> implements 
Serializable {
      *
      * @param seed the new seed
      */
+    @Deprecated
     public void reseedRandomGenerator(long seed) {
         random.setSeed(seed);
     }
@@ -205,6 +209,7 @@ public class EnumeratedDistribution<T> implements 
Serializable {
      *
      * @return a random value.
      */
+    @Deprecated
     public T sample() {
         final double randomValue = random.nextDouble();
 
@@ -233,6 +238,7 @@ public class EnumeratedDistribution<T> implements 
Serializable {
      * @throws NotStrictlyPositiveException if {@code sampleSize} is not
      * positive.
      */
+    @Deprecated
     public Object[] sample(int sampleSize) throws NotStrictlyPositiveException 
{
         if (sampleSize <= 0) {
             throw new 
NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
@@ -262,6 +268,7 @@ public class EnumeratedDistribution<T> implements 
Serializable {
      * @throws NotStrictlyPositiveException if {@code sampleSize} is not 
positive.
      * @throws NullArgumentException if {@code array} is null
      */
+    @Deprecated
     public T[] sample(int sampleSize, final T[] array) throws 
NotStrictlyPositiveException {
         if (sampleSize <= 0) {
             throw new 
NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
@@ -288,4 +295,113 @@ public class EnumeratedDistribution<T> implements 
Serializable {
 
     }
 
+    /**
+     * Creates a {@link Sampler}.
+     *
+     * @param rng Random number generator.
+     */
+    public Sampler createSampler(final UniformRandomProvider rng) {
+        return new Sampler(rng);
+    }
+
+    /**
+     * Sampler functionality.
+     */
+    public class Sampler {
+        /** RNG. */
+        private final UniformRandomProvider random;
+
+        /**
+         * @param rng Random number generator.
+         */
+        Sampler(UniformRandomProvider rng) {
+            random = rng;
+        }
+
+        /**
+         * Generates a random value sampled from this distribution.
+         *
+         * @return a random value.
+         */
+        public T sample() {
+            final double randomValue = random.nextDouble();
+
+            int index = Arrays.binarySearch(cumulativeProbabilities, 
randomValue);
+            if (index < 0) {
+                index = -index - 1;
+            }
+
+            if (index >= 0 &&
+                index < probabilities.length &&
+                randomValue < cumulativeProbabilities[index]) {
+                return singletons.get(index);
+            }
+
+            // This should never happen, but it ensures we will return a 
correct
+            // object in case there is some floating point inequality problem
+            // wrt the cumulative probabilities.
+            return singletons.get(singletons.size() - 1);
+        }
+
+        /**
+         * Generates a random sample from the distribution.
+         *
+         * @param sampleSize the number of random values to generate.
+         * @return an array representing the random sample.
+         * @throws NotStrictlyPositiveException if {@code sampleSize} is not
+         * positive.
+         */
+        public Object[] sample(int sampleSize) throws 
NotStrictlyPositiveException {
+            if (sampleSize <= 0) {
+                throw new 
NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
+                                                       sampleSize);
+            }
+
+            final Object[] out = new Object[sampleSize];
+
+            for (int i = 0; i < sampleSize; i++) {
+                out[i] = sample();
+            }
+
+            return out;
+        }
+
+        /**
+         * Generates a random sample from the distribution.
+         * <p>
+         * If the requested samples fit in the specified array, it is returned
+         * therein. Otherwise, a new array is allocated with the runtime type 
of
+         * the specified array and the size of this collection.
+         *
+         * @param sampleSize the number of random values to generate.
+         * @param array the array to populate.
+         * @return an array representing the random sample.
+         * @throws NotStrictlyPositiveException if {@code sampleSize} is not 
positive.
+         * @throws NullArgumentException if {@code array} is null
+         */
+        public T[] sample(int sampleSize, final T[] array) throws 
NotStrictlyPositiveException {
+            if (sampleSize <= 0) {
+                throw new 
NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
+            }
+
+            if (array == null) {
+                throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
+            }
+
+            T[] out;
+            if (array.length < sampleSize) {
+                @SuppressWarnings("unchecked") // safe as both are of type T
+                final T[] unchecked = (T[]) 
Array.newInstance(array.getClass().getComponentType(), sampleSize);
+                out = unchecked;
+            } else {
+                out = array;
+            }
+
+            for (int i = 0; i < sampleSize; i++) {
+                out[i] = sample();
+            }
+
+            return out;
+        }
+    }
 }

http://git-wip-us.apache.org/repos/asf/commons-math/blob/a5035d0e/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java
 
b/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java
index 688b3fd..9e03b2b 100644
--- 
a/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java
+++ 
b/src/main/java/org/apache/commons/math4/distribution/EnumeratedRealDistribution.java
@@ -30,6 +30,7 @@ import 
org.apache.commons.math4.exception.NotPositiveException;
 import org.apache.commons.math4.exception.OutOfRangeException;
 import org.apache.commons.math4.random.RandomGenerator;
 import org.apache.commons.math4.random.Well19937c;
+import org.apache.commons.math4.rng.UniformRandomProvider;
 import org.apache.commons.math4.util.Pair;
 
 /**
@@ -93,6 +94,7 @@ public class EnumeratedRealDistribution extends 
AbstractRealDistribution {
      * @throws NotANumberException if any of the probabilities are NaN.
      * @throws MathArithmeticException all of the probabilities are 0.
      */
+    @Deprecated
     public EnumeratedRealDistribution(final RandomGenerator rng,
                                     final double[] singletons, final double[] 
probabilities)
         throws DimensionMismatchException, NotPositiveException, 
MathArithmeticException,
@@ -111,6 +113,7 @@ public class EnumeratedRealDistribution extends 
AbstractRealDistribution {
      * @param data input dataset
      * @since 3.6
      */
+    @Deprecated
     public EnumeratedRealDistribution(final RandomGenerator rng, final 
double[] data) {
         super(rng);
         final Map<Double, Integer> dataMap = new HashMap<Double, Integer>();
@@ -319,7 +322,24 @@ public class EnumeratedRealDistribution extends 
AbstractRealDistribution {
      * {@inheritDoc}
      */
     @Override
+    @Deprecated
     public double sample() {
         return innerDistribution.sample();
     }
+
+    /** {@inheritDoc} */
+    @Override
+    public RealDistribution.Sampler createSampler(final UniformRandomProvider 
rng) {
+        return new RealDistribution.Sampler() {
+            /** Delegate. */
+            private final EnumeratedDistribution<Double>.Sampler inner =
+                innerDistribution.createSampler(rng);
+
+            /** {@inheritDoc} */
+            @Override
+            public double sample() {
+                return inner.sample();
+            }
+        };
+    }
 }

http://git-wip-us.apache.org/repos/asf/commons-math/blob/a5035d0e/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java
----------------------------------------------------------------------
diff --git 
a/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java
 
b/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java
index f1cf652..0300e5e 100644
--- 
a/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java
+++ 
b/src/test/java/org/apache/commons/math4/distribution/EnumeratedRealDistributionTest.java
@@ -30,6 +30,7 @@ import 
org.apache.commons.math4.exception.NotFiniteNumberException;
 import org.apache.commons.math4.exception.NotPositiveException;
 import org.apache.commons.math4.util.FastMath;
 import org.apache.commons.math4.util.Pair;
+import org.apache.commons.math4.rng.RandomSource;
 import org.junit.Assert;
 import org.junit.Test;
 
@@ -175,8 +176,9 @@ public class EnumeratedRealDistributionTest {
     @Test
     public void testSample() {
         final int n = 1000000;
-        testDistribution.reseedRandomGenerator(-334759360); // fixed seed
-        final double[] samples = testDistribution.sample(n);
+        final RealDistribution.Sampler sampler =
+            
testDistribution.createSampler(RandomSource.create(RandomSource.WELL_1024_A, 
-123456789));
+        final double[] samples = AbstractRealDistribution.sample(n, sampler);
         Assert.assertEquals(n, samples.length);
         double sum = 0;
         double sumOfSquares = 0;

Reply via email to