Author: psteitz Date: Sun Mar 17 04:28:04 2013 New Revision: 1457372 URL: http://svn.apache.org/r1457372 Log: Made EmpiricalDistribution smoothing kernel pluggable. To enable subclasses to provide the enclosing distribution's underlying RandomGenerator to distribution constructors, two more changes were required: * In EmpiricalDistribution, the RandomDataGenerator field (randomData) was changed from private to protected. * The private getRan() method in RandomDataGenerator returning the underlying RandomGenerator was renamed getRandomGenerator and made public. JIRA: MATH-671
Modified: commons/proper/math/trunk/src/changes/changes.xml commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/EmpiricalDistribution.java commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/RandomDataGenerator.java commons/proper/math/trunk/src/test/java/org/apache/commons/math3/random/EmpiricalDistributionTest.java Modified: commons/proper/math/trunk/src/changes/changes.xml URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/changes/changes.xml?rev=1457372&r1=1457371&r2=1457372&view=diff ============================================================================== --- commons/proper/math/trunk/src/changes/changes.xml (original) +++ commons/proper/math/trunk/src/changes/changes.xml Sun Mar 17 04:28:04 2013 @@ -55,6 +55,9 @@ This is a minor release: It combines bug Changes to existing features were made in a backwards-compatible way such as to allow drop-in replacement of the v3.1[.1] JAR file. "> + <action dev="psteitz" type="update" issue="MATH-671"> + Made EmpiricalDisribution smoothing kernel pluggable. + </action> <action dev="psteitz" type="add" issue="MATH-946" due-to="Jared Becksfort"> Added array-scaling methods to MathArrays. </action> Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/EmpiricalDistribution.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/EmpiricalDistribution.java?rev=1457372&r1=1457371&r2=1457372&view=diff ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/EmpiricalDistribution.java (original) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/EmpiricalDistribution.java Sun Mar 17 04:28:04 2013 @@ -110,6 +110,9 @@ public class EmpiricalDistribution exten /** Serializable version identifier */ private static final long serialVersionUID = 5729073523949762654L; + /** RandomDataGenerator instance to use in repeated calls to getNext() */ + protected final RandomDataGenerator randomData; + /** List of SummaryStatistics objects characterizing the bins */ private final List<SummaryStatistics> binStats; @@ -134,9 +137,6 @@ public class EmpiricalDistribution exten /** upper bounds of subintervals in (0,1) "belonging" to the bins */ private double[] upperBounds = null; - /** RandomDataGenerator instance to use in repeated calls to getNext() */ - private final RandomDataGenerator randomData; - /** * Creates a new EmpiricalDistribution with the default bin count. */ @@ -487,8 +487,7 @@ public class EmpiricalDistribution exten SummaryStatistics stats = binStats.get(i); if (stats.getN() > 0) { if (stats.getStandardDeviation() > 0) { // more than one obs - return randomData.nextGaussian(stats.getMean(), - stats.getStandardDeviation()); + return getKernel(stats).sample(); } else { return stats.getMean(); // only one obs in bin } @@ -842,9 +841,10 @@ public class EmpiricalDistribution exten * @param bStats summary statistics for the bin * @return within-bin kernel parameterized by bStats */ - private RealDistribution getKernel(SummaryStatistics bStats) { - // For now, hard-code Gaussian (only kernel supported) - return new NormalDistribution( - bStats.getMean(), bStats.getStandardDeviation()); + protected RealDistribution getKernel(SummaryStatistics bStats) { + // Default to Gaussian + return new NormalDistribution(randomData.getRandomGenerator(), + bStats.getMean(), bStats.getStandardDeviation(), + NormalDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); } } Modified: commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/RandomDataGenerator.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/RandomDataGenerator.java?rev=1457372&r1=1457371&r2=1457372&view=diff ============================================================================== --- commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/RandomDataGenerator.java (original) +++ commons/proper/math/trunk/src/main/java/org/apache/commons/math3/random/RandomDataGenerator.java Sun Mar 17 04:28:04 2013 @@ -163,7 +163,7 @@ public class RandomDataGenerator impleme } // Get a random number generator - RandomGenerator ran = getRan(); + RandomGenerator ran = getRandomGenerator(); // Initialize output buffer StringBuilder outBuffer = new StringBuilder(); @@ -202,7 +202,7 @@ public class RandomDataGenerator impleme if (max <= 0) { // the range is too wide to fit in a positive int (larger than 2^31); as it covers // more than half the integer range, we use directly a simple rejection method - final RandomGenerator rng = getRan(); + final RandomGenerator rng = getRandomGenerator(); while (true) { final int r = rng.nextInt(); if (r >= lower && r <= upper) { @@ -211,7 +211,7 @@ public class RandomDataGenerator impleme } } else { // we can shift the range and generate directly a positive int - return lower + getRan().nextInt(max); + return lower + getRandomGenerator().nextInt(max); } } @@ -225,7 +225,7 @@ public class RandomDataGenerator impleme if (max <= 0) { // the range is too wide to fit in a positive long (larger than 2^63); as it covers // more than half the long range, we use directly a simple rejection method - final RandomGenerator rng = getRan(); + final RandomGenerator rng = getRandomGenerator(); while (true) { final long r = rng.nextLong(); if (r >= lower && r <= upper) { @@ -234,10 +234,10 @@ public class RandomDataGenerator impleme } } else if (max < Integer.MAX_VALUE){ // we can shift the range and generate directly a positive int - return lower + getRan().nextInt((int) max); + return lower + getRandomGenerator().nextInt((int) max); } else { // we can shift the range and generate directly a positive long - return lower + nextLong(getRan(), max); + return lower + nextLong(getRandomGenerator(), max); } } @@ -433,7 +433,7 @@ public class RandomDataGenerator impleme * @throws NotStrictlyPositiveException if {@code len <= 0} */ public long nextPoisson(double mean) throws NotStrictlyPositiveException { - return new PoissonDistribution(getRan(), mean, + return new PoissonDistribution(getRandomGenerator(), mean, PoissonDistribution.DEFAULT_EPSILON, PoissonDistribution.DEFAULT_MAX_ITERATIONS).sample(); } @@ -443,7 +443,7 @@ public class RandomDataGenerator impleme if (sigma <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.STANDARD_DEVIATION, sigma); } - return sigma * getRan().nextGaussian() + mu; + return sigma * getRandomGenerator().nextGaussian() + mu; } /** @@ -458,7 +458,7 @@ public class RandomDataGenerator impleme * </p> */ public double nextExponential(double mean) throws NotStrictlyPositiveException { - return new ExponentialDistribution(getRan(), mean, + return new ExponentialDistribution(getRandomGenerator(), mean, ExponentialDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -485,7 +485,7 @@ public class RandomDataGenerator impleme * {@code scale <= 0}. */ public double nextGamma(double shape, double scale) throws NotStrictlyPositiveException { - return new GammaDistribution(getRan(),shape, scale, + return new GammaDistribution(getRandomGenerator(),shape, scale, GammaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -502,7 +502,7 @@ public class RandomDataGenerator impleme * @throws NotPositiveException if {@code numberOfSuccesses < 0}. */ public int nextHypergeometric(int populationSize, int numberOfSuccesses, int sampleSize) throws NotPositiveException, NotStrictlyPositiveException, NumberIsTooLargeException { - return new HypergeometricDistribution(getRan(),populationSize, + return new HypergeometricDistribution(getRandomGenerator(),populationSize, numberOfSuccesses, sampleSize).sample(); } @@ -517,7 +517,7 @@ public class RandomDataGenerator impleme * range {@code [0, 1]}. */ public int nextPascal(int r, double p) throws NotStrictlyPositiveException, OutOfRangeException { - return new PascalDistribution(getRan(), r, p).sample(); + return new PascalDistribution(getRandomGenerator(), r, p).sample(); } /** @@ -528,7 +528,7 @@ public class RandomDataGenerator impleme * @throws NotStrictlyPositiveException if {@code df <= 0} */ public double nextT(double df) throws NotStrictlyPositiveException { - return new TDistribution(getRan(), df, + return new TDistribution(getRandomGenerator(), df, TDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -542,7 +542,7 @@ public class RandomDataGenerator impleme * {@code scale <= 0}. */ public double nextWeibull(double shape, double scale) throws NotStrictlyPositiveException { - return new WeibullDistribution(getRan(), shape, scale, + return new WeibullDistribution(getRandomGenerator(), shape, scale, WeibullDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -556,7 +556,7 @@ public class RandomDataGenerator impleme * or {@code exponent <= 0}. */ public int nextZipf(int numberOfElements, double exponent) throws NotStrictlyPositiveException { - return new ZipfDistribution(getRan(), numberOfElements, exponent).sample(); + return new ZipfDistribution(getRandomGenerator(), numberOfElements, exponent).sample(); } /** @@ -567,7 +567,7 @@ public class RandomDataGenerator impleme * @return random value sampled from the beta(alpha, beta) distribution */ public double nextBeta(double alpha, double beta) { - return new BetaDistribution(getRan(), alpha, beta, + return new BetaDistribution(getRandomGenerator(), alpha, beta, BetaDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -579,7 +579,7 @@ public class RandomDataGenerator impleme * @return random value sampled from the Binomial(numberOfTrials, probabilityOfSuccess) distribution */ public int nextBinomial(int numberOfTrials, double probabilityOfSuccess) { - return new BinomialDistribution(getRan(), numberOfTrials, probabilityOfSuccess).sample(); + return new BinomialDistribution(getRandomGenerator(), numberOfTrials, probabilityOfSuccess).sample(); } /** @@ -590,7 +590,7 @@ public class RandomDataGenerator impleme * @return random value sampled from the Cauchy(median, scale) distribution */ public double nextCauchy(double median, double scale) { - return new CauchyDistribution(getRan(), median, scale, + return new CauchyDistribution(getRandomGenerator(), median, scale, CauchyDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -601,7 +601,7 @@ public class RandomDataGenerator impleme * @return random value sampled from the ChiSquare(df) distribution */ public double nextChiSquare(double df) { - return new ChiSquaredDistribution(getRan(), df, + return new ChiSquaredDistribution(getRandomGenerator(), df, ChiSquaredDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -615,7 +615,7 @@ public class RandomDataGenerator impleme * {@code numeratorDf <= 0} or {@code denominatorDf <= 0}. */ public double nextF(double numeratorDf, double denominatorDf) throws NotStrictlyPositiveException { - return new FDistribution(getRan(), numeratorDf, denominatorDf, + return new FDistribution(getRandomGenerator(), numeratorDf, denominatorDf, FDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY).sample(); } @@ -671,7 +671,7 @@ public class RandomDataGenerator impleme throw new NotANumberException(); } - final RandomGenerator generator = getRan(); + final RandomGenerator generator = getRandomGenerator(); // ensure nextDouble() isn't 0.0 double u = generator.nextDouble(); @@ -758,7 +758,7 @@ public class RandomDataGenerator impleme * @param seed the seed value to use */ public void reSeed(long seed) { - getRan().setSeed(seed); + getRandomGenerator().setSeed(seed); } /** @@ -789,7 +789,7 @@ public class RandomDataGenerator impleme * {@code System.currentTimeMillis() + System.identityHashCode(this))}. */ public void reSeed() { - getRan().setSeed(System.currentTimeMillis() + System.identityHashCode(this)); + getRandomGenerator().setSeed(System.currentTimeMillis() + System.identityHashCode(this)); } /** @@ -823,7 +823,7 @@ public class RandomDataGenerator impleme * * @return the Random used to generate random data */ - private RandomGenerator getRan() { + public RandomGenerator getRandomGenerator() { if (rand == null) { initRan(); } Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math3/random/EmpiricalDistributionTest.java URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/random/EmpiricalDistributionTest.java?rev=1457372&r1=1457371&r2=1457372&view=diff ============================================================================== --- commons/proper/math/trunk/src/test/java/org/apache/commons/math3/random/EmpiricalDistributionTest.java (original) +++ commons/proper/math/trunk/src/test/java/org/apache/commons/math3/random/EmpiricalDistributionTest.java Sun Mar 17 04:28:04 2013 @@ -22,15 +22,19 @@ import java.io.IOException; import java.io.InputStreamReader; import java.net.URL; import java.util.ArrayList; +import java.util.Arrays; import org.apache.commons.math3.TestUtils; import org.apache.commons.math3.analysis.UnivariateFunction; import org.apache.commons.math3.analysis.integration.BaseAbstractUnivariateIntegrator; import org.apache.commons.math3.analysis.integration.IterativeLegendreGaussIntegrator; +import org.apache.commons.math3.distribution.AbstractRealDistribution; import org.apache.commons.math3.distribution.NormalDistribution; import org.apache.commons.math3.distribution.RealDistribution; import org.apache.commons.math3.distribution.RealDistributionAbstractTest; +import org.apache.commons.math3.distribution.UniformRealDistribution; import org.apache.commons.math3.exception.NullArgumentException; +import org.apache.commons.math3.exception.OutOfRangeException; import org.apache.commons.math3.stat.descriptive.SummaryStatistics; import org.junit.Assert; import org.junit.Before; @@ -428,4 +432,152 @@ public final class EmpiricalDistribution return new NormalDistribution((upper + lower + 1) / 2d, 3.0276503540974917); } } + + @Test + public void testKernelOverrideConstant() { + final EmpiricalDistribution dist = new ConstantKernelEmpiricalDistribution(5); + final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d}; + dist.load(data); + // Bin masses concentrated on 2, 5, 8, 11, 14 <- effectively discrete uniform distribution over these + double[] values = {2d, 5d, 8d, 11d, 14d}; + for (int i = 0; i < 20; i++) { + Assert.assertTrue(Arrays.binarySearch(values, dist.sample()) >= 0); + } + final double tol = 10E-12; + Assert.assertEquals(0.0, dist.cumulativeProbability(1), tol); + Assert.assertEquals(0.2, dist.cumulativeProbability(2), tol); + Assert.assertEquals(0.6, dist.cumulativeProbability(10), tol); + Assert.assertEquals(0.8, dist.cumulativeProbability(12), tol); + Assert.assertEquals(0.8, dist.cumulativeProbability(13), tol); + Assert.assertEquals(1.0, dist.cumulativeProbability(15), tol); + + Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.1), tol); + Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.2), tol); + Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.3), tol); + Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.4), tol); + Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.5), tol); + Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.6), tol); + } + + @Test + public void testKernelOverrideUniform() { + final EmpiricalDistribution dist = new UniformKernelEmpiricalDistribution(5); + final double[] data = {1d,2d,3d, 4d,5d,6d, 7d,8d,9d, 10d,11d,12d, 13d,14d,15d}; + dist.load(data); + // Kernels are uniform distributions on [1,3], [4,6], [7,9], [10,12], [13,15] + final double bounds[] = {3d, 6d, 9d, 12d}; + final double tol = 10E-12; + for (int i = 0; i < 20; i++) { + final double v = dist.sample(); + // Make sure v is not in the excluded range between bins - that is (bounds[i], bounds[i] + 1) + for (int j = 0; j < bounds.length; j++) { + Assert.assertFalse(v > bounds[j] + tol && v < bounds[j] + 1 - tol); + } + } + Assert.assertEquals(0.0, dist.cumulativeProbability(1), tol); + Assert.assertEquals(0.1, dist.cumulativeProbability(2), tol); + Assert.assertEquals(0.6, dist.cumulativeProbability(10), tol); + Assert.assertEquals(0.8, dist.cumulativeProbability(12), tol); + Assert.assertEquals(0.8, dist.cumulativeProbability(13), tol); + Assert.assertEquals(1.0, dist.cumulativeProbability(15), tol); + + Assert.assertEquals(2.0, dist.inverseCumulativeProbability(0.1), tol); + Assert.assertEquals(3.0, dist.inverseCumulativeProbability(0.2), tol); + Assert.assertEquals(5.0, dist.inverseCumulativeProbability(0.3), tol); + Assert.assertEquals(6.0, dist.inverseCumulativeProbability(0.4), tol); + Assert.assertEquals(8.0, dist.inverseCumulativeProbability(0.5), tol); + Assert.assertEquals(9.0, dist.inverseCumulativeProbability(0.6), tol); + } + + + /** + * Empirical distribution using a constant smoothing kernel. + */ + private class ConstantKernelEmpiricalDistribution extends EmpiricalDistribution { + private static final long serialVersionUID = 1L; + public ConstantKernelEmpiricalDistribution(int i) { + super(i); + } + // Use constant distribution equal to bin mean within bin + protected RealDistribution getKernel(SummaryStatistics bStats) { + return new ConstantDistribution(bStats.getMean()); + } + } + + /** + * Empirical distribution using a uniform smoothing kernel. + */ + private class UniformKernelEmpiricalDistribution extends EmpiricalDistribution { + public UniformKernelEmpiricalDistribution(int i) { + super(i); + } + protected RealDistribution getKernel(SummaryStatistics bStats) { + return new UniformRealDistribution(randomData.getRandomGenerator(), bStats.getMin(), bStats.getMax(), + UniformRealDistribution.DEFAULT_INVERSE_ABSOLUTE_ACCURACY); + } + } + + /** + * Distribution that takes just one value. + */ + private class ConstantDistribution extends AbstractRealDistribution { + private static final long serialVersionUID = 1L; + + /** Singleton value in the sample space */ + private final double c; + + public ConstantDistribution(double c) { + this.c = c; + } + + public double density(double x) { + return 0; + } + + public double cumulativeProbability(double x) { + return x < c ? 0 : 1; + } + + @Override + public double inverseCumulativeProbability(double p) { + if (p < 0.0 || p > 1.0) { + throw new OutOfRangeException(p, 0, 1); + } + return c; + } + + public double getNumericalMean() { + return c; + } + + public double getNumericalVariance() { + return 0; + } + + public double getSupportLowerBound() { + return c; + } + + public double getSupportUpperBound() { + return c; + } + + public boolean isSupportLowerBoundInclusive() { + return false; + } + + public boolean isSupportUpperBoundInclusive() { + return true; + } + + public boolean isSupportConnected() { + return true; + } + + @Override + public double sample() { + return c; + } + + } }