Repository: commons-math Updated Branches: refs/heads/feature-MATH-1158 880b04814 -> 3066a8085
MATH-1351 New sampling API for multivariate distributions (similar to changes performed for MATH-1158). Unit test file renamed in accordance to the class being tested. One failing test "@Ignore"d (see comments on the bug-tracking system). Project: http://git-wip-us.apache.org/repos/asf/commons-math/repo Commit: http://git-wip-us.apache.org/repos/asf/commons-math/commit/3066a808 Tree: http://git-wip-us.apache.org/repos/asf/commons-math/tree/3066a808 Diff: http://git-wip-us.apache.org/repos/asf/commons-math/diff/3066a808 Branch: refs/heads/feature-MATH-1158 Commit: 3066a8085f86b743da14a161427c403a7038e8b0 Parents: 880b048 Author: Gilles <er...@apache.org> Authored: Mon Mar 28 13:45:42 2016 +0200 Committer: Gilles <er...@apache.org> Committed: Mon Mar 28 13:45:42 2016 +0200 ---------------------------------------------------------------------- .../AbstractMultivariateRealDistribution.java | 44 ++- .../MixtureMultivariateNormalDistribution.java | 60 ++-- .../MixtureMultivariateRealDistribution.java | 124 ++++---- .../MultivariateNormalDistribution.java | 73 ++--- .../MultivariateRealDistribution.java | 37 ++- ...xtureMultivariateNormalDistributionTest.java | 268 +++++++++++++++++ .../MultivariateNormalDistributionTest.java | 6 +- ...riateNormalMixtureModelDistributionTest.java | 300 ------------------- 8 files changed, 413 insertions(+), 499 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/main/java/org/apache/commons/math4/distribution/AbstractMultivariateRealDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/AbstractMultivariateRealDistribution.java b/src/main/java/org/apache/commons/math4/distribution/AbstractMultivariateRealDistribution.java index 93e4b7b..1c4adef 100644 --- a/src/main/java/org/apache/commons/math4/distribution/AbstractMultivariateRealDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/AbstractMultivariateRealDistribution.java @@ -18,7 +18,7 @@ package org.apache.commons.math4.distribution; import org.apache.commons.math4.exception.NotStrictlyPositiveException; import org.apache.commons.math4.exception.util.LocalizedFormats; -import org.apache.commons.math4.random.RandomGenerator; +import org.apache.commons.math4.rng.UniformRandomProvider; /** * Base class for multivariate probability distributions. @@ -27,48 +27,46 @@ import org.apache.commons.math4.random.RandomGenerator; */ public abstract class AbstractMultivariateRealDistribution implements MultivariateRealDistribution { - /** RNG instance used to generate samples from the distribution. */ - protected final RandomGenerator random; /** The number of dimensions or columns in the multivariate distribution. */ private final int dimension; /** - * @param rng Random number generator. * @param n Number of dimensions. */ - protected AbstractMultivariateRealDistribution(RandomGenerator rng, - int n) { - random = rng; + protected AbstractMultivariateRealDistribution(int n) { dimension = n; } /** {@inheritDoc} */ @Override - public void reseedRandomGenerator(long seed) { - random.setSeed(seed); - } - - /** {@inheritDoc} */ - @Override public int getDimension() { return dimension; } /** {@inheritDoc} */ @Override - public abstract double[] sample(); + public abstract Sampler createSampler(UniformRandomProvider rng); - /** {@inheritDoc} */ - @Override - public double[][] sample(final int sampleSize) { - if (sampleSize <= 0) { + /** + * Utility function for creating {@code n} vectors generated by the + * given {@code sampler}. + * + * @param n Number of samples. + * @param sampler Sampler. + * @return an array of size {@code n} whose elements are random vectors + * sampled from this distribution. + */ + public static double[][] sample(int n, + MultivariateRealDistribution.Sampler sampler) { + if (n <= 0) { throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, - sampleSize); + n); } - final double[][] out = new double[sampleSize][dimension]; - for (int i = 0; i < sampleSize; i++) { - out[i] = sample(); + + final double[][] samples = new double[n][]; + for (int i = 0; i < n; i++) { + samples[i] = sampler.sample(); } - return out; + return samples; } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistribution.java b/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistribution.java index d7cd4cd..e24a2ac 100644 --- a/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistribution.java @@ -21,7 +21,6 @@ import java.util.List; import org.apache.commons.math4.exception.DimensionMismatchException; import org.apache.commons.math4.exception.NotPositiveException; -import org.apache.commons.math4.random.RandomGenerator; import org.apache.commons.math4.util.Pair; /** @@ -33,63 +32,42 @@ import org.apache.commons.math4.util.Pair; */ public class MixtureMultivariateNormalDistribution extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> { - - /** - * Creates a multivariate normal mixture distribution. - * <p> - * <b>Note:</b> this constructor will implicitly create an instance of - * {@link org.apache.commons.math4.random.Well19937c Well19937c} as random - * generator to be used for sampling only (see {@link #sample()} and - * {@link #sample(int)}). In case no sampling is needed for the created - * distribution, it is advised to pass {@code null} as random generator via - * the appropriate constructors to avoid the additional initialisation - * overhead. - * - * @param weights Weights of each component. - * @param means Mean vector for each component. - * @param covariances Covariance matrix for each component. - */ - public MixtureMultivariateNormalDistribution(double[] weights, - double[][] means, - double[][][] covariances) { - super(createComponents(weights, means, covariances)); - } - /** * Creates a mixture model from a list of distributions and their * associated weights. - * <p> - * <b>Note:</b> this constructor will implicitly create an instance of - * {@link org.apache.commons.math4.random.Well19937c Well19937c} as random - * generator to be used for sampling only (see {@link #sample()} and - * {@link #sample(int)}). In case no sampling is needed for the created - * distribution, it is advised to pass {@code null} as random generator via - * the appropriate constructors to avoid the additional initialisation - * overhead. * - * @param components List of (weight, distribution) pairs from which to sample. + * @param components Distributions from which to sample. + * @throws NotPositiveException if any of the weights is negative. + * @throws DimensionMismatchException if not all components have the same + * number of variables. */ - public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) { + public MixtureMultivariateNormalDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) + throws NotPositiveException, + DimensionMismatchException { super(components); } /** - * Creates a mixture model from a list of distributions and their - * associated weights. + * Creates a multivariate normal mixture distribution. * - * @param rng Random number generator. - * @param components Distributions from which to sample. + * @param weights Weights of each component. + * @param means Mean vector for each component. + * @param covariances Covariance matrix for each component. * @throws NotPositiveException if any of the weights is negative. * @throws DimensionMismatchException if not all components have the same * number of variables. */ - public MixtureMultivariateNormalDistribution(RandomGenerator rng, - List<Pair<Double, MultivariateNormalDistribution>> components) - throws NotPositiveException, DimensionMismatchException { - super(rng, components); + public MixtureMultivariateNormalDistribution(double[] weights, + double[][] means, + double[][][] covariances) + throws NotPositiveException, + DimensionMismatchException { + this(createComponents(weights, means, covariances)); } /** + * Creates components of the mixture model. + * * @param weights Weights of each component. * @param means Mean vector for each component. * @param covariances Covariance matrix for each component. http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateRealDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateRealDistribution.java b/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateRealDistribution.java index ce8c7d9..4caee3f 100644 --- a/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateRealDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/MixtureMultivariateRealDistribution.java @@ -23,8 +23,7 @@ import org.apache.commons.math4.exception.DimensionMismatchException; import org.apache.commons.math4.exception.MathArithmeticException; import org.apache.commons.math4.exception.NotPositiveException; import org.apache.commons.math4.exception.util.LocalizedFormats; -import org.apache.commons.math4.random.RandomGenerator; -import org.apache.commons.math4.random.Well19937c; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.util.Pair; /** @@ -45,33 +44,14 @@ public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistr /** * Creates a mixture model from a list of distributions and their * associated weights. - * <p> - * <b>Note:</b> this constructor will implicitly create an instance of - * {@link Well19937c} as random generator to be used for sampling only (see - * {@link #sample()} and {@link #sample(int)}). In case no sampling is - * needed for the created distribution, it is advised to pass {@code null} - * as random generator via the appropriate constructors to avoid the - * additional initialisation overhead. * - * @param components List of (weight, distribution) pairs from which to sample. - */ - public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) { - this(new Well19937c(), components); - } - - /** - * Creates a mixture model from a list of distributions and their - * associated weights. - * - * @param rng Random number generator. * @param components Distributions from which to sample. * @throws NotPositiveException if any of the weights is negative. * @throws DimensionMismatchException if not all components have the same * number of variables. */ - public MixtureMultivariateRealDistribution(RandomGenerator rng, - List<Pair<Double, T>> components) { - super(rng, components.get(0).getSecond().getDimension()); + public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) { + super(components.get(0).getSecond().getDimension()); final int numComp = components.size(); final int dim = getDimension(); @@ -112,61 +92,75 @@ public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistr return p; } - /** {@inheritDoc} */ - @Override - public double[] sample() { - // Sampled values. - double[] vals = null; - - // Determine which component to sample from. - final double randomValue = random.nextDouble(); - double sum = 0; + /** + * Gets the distributions that make up the mixture model. + * + * @return the component distributions and associated weights. + */ + public List<Pair<Double, T>> getComponents() { + final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length); for (int i = 0; i < weight.length; i++) { - sum += weight[i]; - if (randomValue <= sum) { - // pick model i - vals = distribution.get(i).sample(); - break; - } - } - - if (vals == null) { - // This should never happen, but it ensures we won't return a null in - // case the loop above has some floating point inequality problem on - // the final iteration. - vals = distribution.get(weight.length - 1).sample(); + list.add(new Pair<Double, T>(weight[i], distribution.get(i))); } - return vals; + return list; } /** {@inheritDoc} */ @Override - public void reseedRandomGenerator(long seed) { - // Seed needs to be propagated to underlying components - // in order to maintain consistency between runs. - super.reseedRandomGenerator(seed); - - for (int i = 0; i < distribution.size(); i++) { - // Make each component's seed different in order to avoid - // using the same sequence of random numbers. - distribution.get(i).reseedRandomGenerator(i + 1 + seed); - } + public MultivariateRealDistribution.Sampler createSampler(UniformRandomProvider rng) { + return new MixtureSampler(rng); } /** - * Gets the distributions that make up the mixture model. - * - * @return the component distributions and associated weights. + * Sampler. */ - public List<Pair<Double, T>> getComponents() { - final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length); - - for (int i = 0; i < weight.length; i++) { - list.add(new Pair<Double, T>(weight[i], distribution.get(i))); + private class MixtureSampler implements MultivariateRealDistribution.Sampler { + /** RNG */ + private final UniformRandomProvider rng; + /** Sampler for each of the distribution in the mixture. */ + private final MultivariateRealDistribution.Sampler[] samplers; + + /** + * @param generator RNG. + */ + MixtureSampler(UniformRandomProvider generator) { + rng = generator; + + samplers = new MultivariateRealDistribution.Sampler[weight.length]; + for (int i = 0; i < weight.length; i++) { + samplers[i] = distribution.get(i).createSampler(rng); + } } - return list; + /** {@inheritDoc} */ + @Override + public double[] sample() { + // Sampled values. + double[] vals = null; + + // Determine which component to sample from. + final double randomValue = rng.nextDouble(); + double sum = 0; + + for (int i = 0; i < weight.length; i++) { + sum += weight[i]; + if (randomValue <= sum) { + // pick model i + vals = samplers[i].sample(); + break; + } + } + + if (vals == null) { + // This should never happen, but it ensures we won't return a null in + // case the loop above has some floating point inequality problem on + // the final iteration. + vals = samplers[weight.length - 1].sample(); + } + + return vals; + } } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/main/java/org/apache/commons/math4/distribution/MultivariateNormalDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/MultivariateNormalDistribution.java b/src/main/java/org/apache/commons/math4/distribution/MultivariateNormalDistribution.java index 212fb2a..da270ad 100644 --- a/src/main/java/org/apache/commons/math4/distribution/MultivariateNormalDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/MultivariateNormalDistribution.java @@ -22,8 +22,7 @@ import org.apache.commons.math4.linear.EigenDecomposition; import org.apache.commons.math4.linear.NonPositiveDefiniteMatrixException; import org.apache.commons.math4.linear.RealMatrix; import org.apache.commons.math4.linear.SingularMatrixException; -import org.apache.commons.math4.random.RandomGenerator; -import org.apache.commons.math4.random.Well19937c; +import org.apache.commons.math4.rng.UniformRandomProvider; import org.apache.commons.math4.util.FastMath; import org.apache.commons.math4.util.MathArrays; @@ -53,44 +52,12 @@ public class MultivariateNormalDistribution /** * Creates a multivariate normal distribution with the given mean vector and * covariance matrix. - * <br/> - * The number of dimensions is equal to the length of the mean vector - * and to the number of rows and columns of the covariance matrix. - * It is frequently written as "p" in formulae. * <p> - * <b>Note:</b> this constructor will implicitly create an instance of - * {@link Well19937c} as random generator to be used for sampling only (see - * {@link #sample()} and {@link #sample(int)}). In case no sampling is - * needed for the created distribution, it is advised to pass {@code null} - * as random generator via the appropriate constructors to avoid the - * additional initialisation overhead. - * - * @param means Vector of means. - * @param covariances Covariance matrix. - * @throws DimensionMismatchException if the arrays length are - * inconsistent. - * @throws SingularMatrixException if the eigenvalue decomposition cannot - * be performed on the provided covariance matrix. - * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is - * negative. - */ - public MultivariateNormalDistribution(final double[] means, - final double[][] covariances) - throws SingularMatrixException, - DimensionMismatchException, - NonPositiveDefiniteMatrixException { - this(new Well19937c(), means, covariances); - } - - /** - * Creates a multivariate normal distribution with the given mean vector and - * covariance matrix. - * <br/> * The number of dimensions is equal to the length of the mean vector * and to the number of rows and columns of the covariance matrix. * It is frequently written as "p" in formulae. + * </p> * - * @param rng Random Number Generator. * @param means Vector of means. * @param covariances Covariance matrix. * @throws DimensionMismatchException if the arrays length are @@ -100,13 +67,12 @@ public class MultivariateNormalDistribution * @throws NonPositiveDefiniteMatrixException if any of the eigenvalues is * negative. */ - public MultivariateNormalDistribution(RandomGenerator rng, - final double[] means, + public MultivariateNormalDistribution(final double[] means, final double[][] covariances) throws SingularMatrixException, DimensionMismatchException, NonPositiveDefiniteMatrixException { - super(rng, means.length); + super(means.length); final int dim = means.length; @@ -210,21 +176,30 @@ public class MultivariateNormalDistribution /** {@inheritDoc} */ @Override - public double[] sample() { - final int dim = getDimension(); - final double[] normalVals = new double[dim]; + public MultivariateRealDistribution.Sampler createSampler(final UniformRandomProvider rng) { + return new MultivariateRealDistribution.Sampler() { + /** Normal distribution. */ + private final RealDistribution.Sampler gauss = new NormalDistribution().createSampler(rng); - for (int i = 0; i < dim; i++) { - normalVals[i] = random.nextGaussian(); - } + /** {@inheritDoc} */ + @Override + public double[] sample() { + final int dim = getDimension(); + final double[] normalVals = new double[dim]; - final double[] vals = samplingMatrix.operate(normalVals); + for (int i = 0; i < dim; i++) { + normalVals[i] = gauss.sample(); + } - for (int i = 0; i < dim; i++) { - vals[i] += means[i]; - } + final double[] vals = samplingMatrix.operate(normalVals); + + for (int i = 0; i < dim; i++) { + vals[i] += means[i]; + } - return vals; + return vals; + } + }; } /** http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/main/java/org/apache/commons/math4/distribution/MultivariateRealDistribution.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/commons/math4/distribution/MultivariateRealDistribution.java b/src/main/java/org/apache/commons/math4/distribution/MultivariateRealDistribution.java index d734d96..eaaf35e 100644 --- a/src/main/java/org/apache/commons/math4/distribution/MultivariateRealDistribution.java +++ b/src/main/java/org/apache/commons/math4/distribution/MultivariateRealDistribution.java @@ -16,7 +16,7 @@ */ package org.apache.commons.math4.distribution; -import org.apache.commons.math4.exception.NotStrictlyPositiveException; +import org.apache.commons.math4.rng.UniformRandomProvider; /** * Base interface for multivariate distributions on the reals. @@ -42,13 +42,6 @@ public interface MultivariateRealDistribution { double density(double[] x); /** - * Reseeds the random generator used to generate samples. - * - * @param seed Seed with which to initialize the random number generator. - */ - void reseedRandomGenerator(long seed); - - /** * Gets the number of random variables of the distribution. * It is the size of the array returned by the {@link #sample() sample} * method. @@ -58,21 +51,27 @@ public interface MultivariateRealDistribution { int getDimension(); /** - * Generates a random value vector sampled from this distribution. + * Creates a sampler. + * + * @param rng Generator of uniformly distributed numbers. + * @return a sampler that produces random numbers according this + * distribution. * - * @return a random value vector. + * @since 4.0 */ - double[] sample(); + Sampler createSampler(UniformRandomProvider rng); /** - * Generates a list of a random value vectors from the distribution. - * - * @param sampleSize the number of random vectors to generate. - * @return an array representing the random samples. - * @throws org.apache.commons.math4.exception.NotStrictlyPositiveException - * if {@code sampleSize} is not positive. + * Sampling functionality. * - * @see #sample() + * @since 4.0 */ - double[][] sample(int sampleSize) throws NotStrictlyPositiveException; + interface Sampler { + /** + * Generates a random value vector sampled from this distribution. + * + * @return a random value vector. + */ + double[] sample(); + } } http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/test/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistributionTest.java new file mode 100644 index 0000000..c4d3a8f --- /dev/null +++ b/src/test/java/org/apache/commons/math4/distribution/MixtureMultivariateNormalDistributionTest.java @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.commons.math4.distribution; + +import java.util.List; +import java.util.ArrayList; + +import org.apache.commons.math4.distribution.MixtureMultivariateRealDistribution; +import org.apache.commons.math4.distribution.MultivariateNormalDistribution; +import org.apache.commons.math4.exception.MathArithmeticException; +import org.apache.commons.math4.exception.NotPositiveException; +import org.apache.commons.math4.rng.RandomSource; +import org.apache.commons.math4.util.Pair; +import org.junit.Assert; +import org.junit.Test; +import org.junit.Ignore; + +/** + * Test case {@link MixtureMultivariateNormalDistribution}. + */ +public class MixtureMultivariateNormalDistributionTest { + + @Test + public void testNonUnitWeightSum() { + final double[] weights = { 1, 2 }; + final double[][] means = { { -1.5, 2.0 }, + { 4.0, 8.2 } }; + final double[][][] covariances = { { { 2.0, -1.1 }, + { -1.1, 2.0 } }, + { { 3.5, 1.5 }, + { 1.5, 3.5 } } }; + final MixtureMultivariateNormalDistribution d + = new MixtureMultivariateNormalDistribution(weights, means, covariances); + + final List<Pair<Double, MultivariateNormalDistribution>> comp = d.getComponents(); + + Assert.assertEquals(1d / 3, comp.get(0).getFirst().doubleValue(), Math.ulp(1d)); + Assert.assertEquals(2d / 3, comp.get(1).getFirst().doubleValue(), Math.ulp(1d)); + } + + @Test(expected=MathArithmeticException.class) + public void testWeightSumOverFlow() { + final double[] weights = { 0.5 * Double.MAX_VALUE, 0.51 * Double.MAX_VALUE }; + final double[][] means = { { -1.5, 2.0 }, + { 4.0, 8.2 } }; + final double[][][] covariances = { { { 2.0, -1.1 }, + { -1.1, 2.0 } }, + { { 3.5, 1.5 }, + { 1.5, 3.5 } } }; + new MixtureMultivariateNormalDistribution(weights, means, covariances); + } + + @Test(expected=NotPositiveException.class) + public void testPreconditionPositiveWeights() { + final double[] negativeWeights = { -0.5, 1.5 }; + final double[][] means = { { -1.5, 2.0 }, + { 4.0, 8.2 } }; + final double[][][] covariances = { { { 2.0, -1.1 }, + { -1.1, 2.0 } }, + { { 3.5, 1.5 }, + { 1.5, 3.5 } } }; + new MixtureMultivariateNormalDistribution(negativeWeights, means, covariances); + } + + /** + * Test the accuracy of the density calculation. + */ + @Test + public void testDensities() { + final double[] weights = { 0.3, 0.7 }; + final double[][] means = { { -1.5, 2.0 }, + { 4.0, 8.2 } }; + final double[][][] covariances = { { { 2.0, -1.1 }, + { -1.1, 2.0 } }, + { { 3.5, 1.5 }, + { 1.5, 3.5 } } }; + final MixtureMultivariateNormalDistribution d + = new MixtureMultivariateNormalDistribution(weights, means, covariances); + + // Test vectors + final double[][] testValues = { { -1.5, 2 }, + { 4, 8.2 }, + { 1.5, -2 }, + { 0, 0 } }; + + // Densities that we should get back. + // Calculated by assigning weights to multivariate normal distribution + // and summing + // values from dmvnorm function in R 2.15 CRAN package Mixtools v0.4. + // Like: .3*dmvnorm(val,mu1,sigma1)+.7*dmvnorm(val,mu2,sigma2) + final double[] correctDensities = { 0.02862037278930575, + 0.03523044847314091, + 0.000416241365629767, + 0.009932042831700297 }; + + for (int i = 0; i < testValues.length; i++) { + Assert.assertEquals(correctDensities[i], d.density(testValues[i]), Math.ulp(1d)); + } + } + + /** + * Test the accuracy of sampling from the distribution. + */ + @Ignore@Test + public void testSampling() { + final double[] weights = { 0.3, 0.7 }; + final double[][] means = { { -1.5, 2.0 }, + { 4.0, 8.2 } }; + final double[][][] covariances = { { { 2.0, -1.1 }, + { -1.1, 2.0 } }, + { { 3.5, 1.5 }, + { 1.5, 3.5 } } }; + final MixtureMultivariateNormalDistribution d = + new MixtureMultivariateNormalDistribution(weights, means, covariances); + final MultivariateRealDistribution.Sampler sampler = + d.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 50)); + + final double[][] correctSamples = getCorrectSamples(); + final int n = correctSamples.length; + final double[][] samples = AbstractMultivariateRealDistribution.sample(n, sampler); + + for (int i = 0; i < n; i++) { + for (int j = 0; j < samples[i].length; j++) { + Assert.assertEquals("sample[" + j + "]", + correctSamples[i][j], samples[i][j], 1e-16); + } + } + } + + /** + * Values used in {@link #testSampling()}. + */ + private double[][] getCorrectSamples() { + // These were sampled from the MultivariateNormalMixtureModelDistribution class + // with seed 50. + // + // They were then fit to a MVN mixture model in R using mixtools. + // + // The optimal parameters were: + // - component weights: {0.3595186, 0.6404814} + // - mean vectors: {-1.645879, 1.989797}, {3.474328, 7.782232} + // - covariance matrices: + // { 1.397738 -1.167732 + // -1.167732 1.801782 } + // and + // { 3.934593 2.354787 + // 2.354787 4.428024 } + // + // It is considered fairly close to the actual test parameters, + // considering that the sample size is only 100. + return new double[][] { + { 6.259990922080121, 11.972954175355897 }, + { -2.5296544304801847, 1.0031292519854365 }, + { 0.49037886081440396, 0.9758251727325711 }, + { 5.022970993312015, 9.289348879616787 }, + { -1.686183146603914, 2.007244382745706 }, + { -1.4729253946002685, 2.762166644212484 }, + { 4.329788143963888, 11.514016497132253 }, + { 3.008674596114442, 4.960246550446107 }, + { 3.342379304090846, 5.937630105198625 }, + { 2.6993068328674754, 7.42190871572571 }, + { -2.446569340219571, 1.9687117791378763 }, + { 1.922417883170056, 4.917616702617099 }, + { -1.1969741543898518, 2.4576126277884387 }, + { 2.4216948702967196, 8.227710158117134 }, + { 6.701424725804463, 9.098666475042428 }, + { 2.9890253545698964, 9.643807939324331 }, + { 0.7162632354907799, 8.978811120287553 }, + { -2.7548699149775877, 4.1354812280794215 }, + { 8.304528180745018, 11.602319388898287 }, + { -2.7633253389165926, 2.786173883989795 }, + { 1.3322228389460813, 5.447481218602913 }, + { -1.8120096092851508, 1.605624499560037 }, + { 3.6546253437206504, 8.195304526564376 }, + { -2.312349539658588, 1.868941220444169 }, + { -1.882322136356522, 2.033795570464242 }, + { 4.562770714939441, 7.414967958885031 }, + { 4.731882017875329, 8.890676665580747 }, + { 3.492186010427425, 8.9005225241848 }, + { -1.619700190174894, 3.314060142479045 }, + { 3.5466090064003315, 7.75182101001913 }, + { 5.455682472787392, 8.143119287755635 }, + { -2.3859602945473197, 1.8826732217294837 }, + { 3.9095306088680015, 9.258129209626317 }, + { 7.443020189508173, 7.837840713329312 }, + { 2.136004873917428, 6.917636475958297 }, + { -1.7203379410395119, 2.3212878757611524 }, + { 4.618991257611526, 12.095065976419436 }, + { -0.4837044029854387, 0.8255970441255125 }, + { -4.438938966557163, 4.948666297280241 }, + { -0.4539625134045906, 4.700922454655341 }, + { 2.1285488271265356, 8.457941480487563 }, + { 3.4873561871454393, 11.99809827845933 }, + { 4.723049431412658, 7.813095742563365 }, + { 1.1245583037967455, 5.20587873556688 }, + { 1.3411933634409197, 6.069796875785409 }, + { 4.585119332463686, 7.967669543767418 }, + { 1.3076522817963823, -0.647431033653445 }, + { -1.4449446442803178, 1.9400424267464862 }, + { -2.069794456383682, 3.5824162107496544 }, + { -0.15959481421417276, 1.5466782303315405 }, + { -2.0823081278810136, 3.0914366458581437 }, + { 3.521944615248141, 10.276112932926408 }, + { 1.0164326704884257, 4.342329556442856 }, + { 5.3718868590295275, 8.374761158360922 }, + { 0.3673656866959396, 8.75168581694866 }, + { -2.250268955954753, 1.4610850300996527 }, + { -2.312739727403522, 1.5921126297576362 }, + { 3.138993360831055, 6.7338392374947365 }, + { 2.6978650950790115, 7.941857288979095 }, + { 4.387985088655384, 8.253499976968 }, + { -1.8928961721456705, 0.23631082388724223 }, + { 4.43509029544109, 8.565290285488782 }, + { 4.904728034106502, 5.79936660133754 }, + { -1.7640371853739507, 2.7343727594167433 }, + { 2.4553674733053463, 7.875871017408807 }, + { -2.6478965122565006, 4.465127753193949 }, + { 3.493873671142299, 10.443093773532448 }, + { 1.1321916197409103, 7.127108479263268 }, + { -1.7335075535240392, 2.550629648463023 }, + { -0.9772679734368084, 4.377196298969238 }, + { 3.6388366973980357, 6.947299283206256 }, + { 0.27043799318823325, 6.587978599614367 }, + { 5.356782352010253, 7.388957912116327 }, + { -0.09187745751354681, 0.23612399246659743 }, + { 2.903203580353435, 3.8076727621794415 }, + { 5.297014824937293, 8.650985262326508 }, + { 4.934508602170976, 9.164571423190052 }, + { -1.0004911869654256, 4.797064194444461 }, + { 6.782491700298046, 11.852373338280497 }, + { 2.8983678524536014, 8.303837362117521 }, + { 4.805003269830865, 6.790462904325329 }, + { -0.8815799740744226, 1.3015810062131394 }, + { 5.115138859802104, 6.376895810201089 }, + { 4.301239328205988, 8.60546337560793 }, + { 3.276423626317666, 9.889429652591947 }, + { -4.001924973153122, 4.3353864592328515 }, + { 3.9571892554119517, 4.500569057308562 }, + { 4.783067027436208, 7.451125480601317 }, + { 4.79065438272821, 9.614122776979698 }, + { 2.677655270279617, 6.8875223698210135 }, + { -1.3714746289327362, 2.3992153193382437 }, + { 3.240136859745249, 7.748339397522042 }, + { 5.107885374416291, 8.508324480583724 }, + { -1.5830830226666048, 0.9139127045208315 }, + { -1.1596156791652918, -0.04502759384531929 }, + { -0.4670021307952068, 3.6193633227841624 }, + { -0.7026065228267798, 0.4811423031997131 }, + { -2.719979836732917, 2.5165041618080104 }, + { 1.0336754331123372, -0.34966029029320644 }, + { 4.743217291882213, 5.750060115251131 } + }; + } +} http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalDistributionTest.java index 41d526c..3e6d9ff 100644 --- a/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalDistributionTest.java +++ b/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalDistributionTest.java @@ -20,6 +20,7 @@ package org.apache.commons.math4.distribution; import org.apache.commons.math4.distribution.MultivariateNormalDistribution; import org.apache.commons.math4.distribution.NormalDistribution; import org.apache.commons.math4.linear.RealMatrix; +import org.apache.commons.math4.rng.RandomSource; import org.apache.commons.math4.stat.correlation.Covariance; import java.util.Random; @@ -75,11 +76,12 @@ public class MultivariateNormalDistributionTest { final double[][] sigma = { { 2, -1.1 }, { -1.1, 2 } }; final MultivariateNormalDistribution d = new MultivariateNormalDistribution(mu, sigma); - d.reseedRandomGenerator(50); + final MultivariateRealDistribution.Sampler sampler = + d.createSampler(RandomSource.create(RandomSource.WELL_19937_C, 50)); final int n = 500000; + final double[][] samples = AbstractMultivariateRealDistribution.sample(n, sampler); - final double[][] samples = d.sample(n); final int dim = d.getDimension(); final double[] sampleMeans = new double[dim]; http://git-wip-us.apache.org/repos/asf/commons-math/blob/3066a808/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalMixtureModelDistributionTest.java ---------------------------------------------------------------------- diff --git a/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalMixtureModelDistributionTest.java b/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalMixtureModelDistributionTest.java deleted file mode 100644 index 8bed770..0000000 --- a/src/test/java/org/apache/commons/math4/distribution/MultivariateNormalMixtureModelDistributionTest.java +++ /dev/null @@ -1,300 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.commons.math4.distribution; - -import java.util.List; -import java.util.ArrayList; - -import org.apache.commons.math4.distribution.MixtureMultivariateRealDistribution; -import org.apache.commons.math4.distribution.MultivariateNormalDistribution; -import org.apache.commons.math4.exception.MathArithmeticException; -import org.apache.commons.math4.exception.NotPositiveException; -import org.apache.commons.math4.util.Pair; -import org.junit.Assert; -import org.junit.Test; - -/** - * Test that demonstrates the use of {@link MixtureMultivariateRealDistribution} - * in order to create a mixture model composed of {@link MultivariateNormalDistribution - * normal distributions}. - */ -public class MultivariateNormalMixtureModelDistributionTest { - - @Test - public void testNonUnitWeightSum() { - final double[] weights = { 1, 2 }; - final double[][] means = { { -1.5, 2.0 }, - { 4.0, 8.2 } }; - final double[][][] covariances = { { { 2.0, -1.1 }, - { -1.1, 2.0 } }, - { { 3.5, 1.5 }, - { 1.5, 3.5 } } }; - final MultivariateNormalMixtureModelDistribution d - = create(weights, means, covariances); - - final List<Pair<Double, MultivariateNormalDistribution>> comp = d.getComponents(); - - Assert.assertEquals(1d / 3, comp.get(0).getFirst().doubleValue(), Math.ulp(1d)); - Assert.assertEquals(2d / 3, comp.get(1).getFirst().doubleValue(), Math.ulp(1d)); - } - - @Test(expected=MathArithmeticException.class) - public void testWeightSumOverFlow() { - final double[] weights = { 0.5 * Double.MAX_VALUE, 0.51 * Double.MAX_VALUE }; - final double[][] means = { { -1.5, 2.0 }, - { 4.0, 8.2 } }; - final double[][][] covariances = { { { 2.0, -1.1 }, - { -1.1, 2.0 } }, - { { 3.5, 1.5 }, - { 1.5, 3.5 } } }; - create(weights, means, covariances); - } - - @Test(expected=NotPositiveException.class) - public void testPreconditionPositiveWeights() { - final double[] negativeWeights = { -0.5, 1.5 }; - final double[][] means = { { -1.5, 2.0 }, - { 4.0, 8.2 } }; - final double[][][] covariances = { { { 2.0, -1.1 }, - { -1.1, 2.0 } }, - { { 3.5, 1.5 }, - { 1.5, 3.5 } } }; - create(negativeWeights, means, covariances); - } - - /** - * Test the accuracy of the density calculation. - */ - @Test - public void testDensities() { - final double[] weights = { 0.3, 0.7 }; - final double[][] means = { { -1.5, 2.0 }, - { 4.0, 8.2 } }; - final double[][][] covariances = { { { 2.0, -1.1 }, - { -1.1, 2.0 } }, - { { 3.5, 1.5 }, - { 1.5, 3.5 } } }; - final MultivariateNormalMixtureModelDistribution d - = create(weights, means, covariances); - - // Test vectors - final double[][] testValues = { { -1.5, 2 }, - { 4, 8.2 }, - { 1.5, -2 }, - { 0, 0 } }; - - // Densities that we should get back. - // Calculated by assigning weights to multivariate normal distribution - // and summing - // values from dmvnorm function in R 2.15 CRAN package Mixtools v0.4. - // Like: .3*dmvnorm(val,mu1,sigma1)+.7*dmvnorm(val,mu2,sigma2) - final double[] correctDensities = { 0.02862037278930575, - 0.03523044847314091, - 0.000416241365629767, - 0.009932042831700297 }; - - for (int i = 0; i < testValues.length; i++) { - Assert.assertEquals(correctDensities[i], d.density(testValues[i]), Math.ulp(1d)); - } - } - - /** - * Test the accuracy of sampling from the distribution. - */ - @Test - public void testSampling() { - final double[] weights = { 0.3, 0.7 }; - final double[][] means = { { -1.5, 2.0 }, - { 4.0, 8.2 } }; - final double[][][] covariances = { { { 2.0, -1.1 }, - { -1.1, 2.0 } }, - { { 3.5, 1.5 }, - { 1.5, 3.5 } } }; - final MultivariateNormalMixtureModelDistribution d - = create(weights, means, covariances); - d.reseedRandomGenerator(50); - - final double[][] correctSamples = getCorrectSamples(); - final int n = correctSamples.length; - final double[][] samples = d.sample(n); - - for (int i = 0; i < n; i++) { - for (int j = 0; j < samples[i].length; j++) { - Assert.assertEquals(correctSamples[i][j], samples[i][j], 1e-16); - } - } - } - - /** - * Creates a mixture of Gaussian distributions. - * - * @param weights Weights. - * @param means Means. - * @param covariances Covariances. - * @return the mixture distribution. - */ - private MultivariateNormalMixtureModelDistribution create(double[] weights, - double[][] means, - double[][][] covariances) { - final List<Pair<Double, MultivariateNormalDistribution>> mvns - = new ArrayList<Pair<Double, MultivariateNormalDistribution>>(); - - for (int i = 0; i < weights.length; i++) { - final MultivariateNormalDistribution dist - = new MultivariateNormalDistribution(means[i], covariances[i]); - mvns.add(new Pair<Double, MultivariateNormalDistribution>(weights[i], dist)); - } - - return new MultivariateNormalMixtureModelDistribution(mvns); - } - - /** - * Values used in {@link #testSampling()}. - */ - private double[][] getCorrectSamples() { - // These were sampled from the MultivariateNormalMixtureModelDistribution class - // with seed 50. - // - // They were then fit to a MVN mixture model in R using mixtools. - // - // The optimal parameters were: - // - component weights: {0.3595186, 0.6404814} - // - mean vectors: {-1.645879, 1.989797}, {3.474328, 7.782232} - // - covariance matrices: - // { 1.397738 -1.167732 - // -1.167732 1.801782 } - // and - // { 3.934593 2.354787 - // 2.354787 4.428024 } - // - // It is considered fairly close to the actual test parameters, - // considering that the sample size is only 100. - return new double[][] { - { 6.259990922080121, 11.972954175355897 }, - { -2.5296544304801847, 1.0031292519854365 }, - { 0.49037886081440396, 0.9758251727325711 }, - { 5.022970993312015, 9.289348879616787 }, - { -1.686183146603914, 2.007244382745706 }, - { -1.4729253946002685, 2.762166644212484 }, - { 4.329788143963888, 11.514016497132253 }, - { 3.008674596114442, 4.960246550446107 }, - { 3.342379304090846, 5.937630105198625 }, - { 2.6993068328674754, 7.42190871572571 }, - { -2.446569340219571, 1.9687117791378763 }, - { 1.922417883170056, 4.917616702617099 }, - { -1.1969741543898518, 2.4576126277884387 }, - { 2.4216948702967196, 8.227710158117134 }, - { 6.701424725804463, 9.098666475042428 }, - { 2.9890253545698964, 9.643807939324331 }, - { 0.7162632354907799, 8.978811120287553 }, - { -2.7548699149775877, 4.1354812280794215 }, - { 8.304528180745018, 11.602319388898287 }, - { -2.7633253389165926, 2.786173883989795 }, - { 1.3322228389460813, 5.447481218602913 }, - { -1.8120096092851508, 1.605624499560037 }, - { 3.6546253437206504, 8.195304526564376 }, - { -2.312349539658588, 1.868941220444169 }, - { -1.882322136356522, 2.033795570464242 }, - { 4.562770714939441, 7.414967958885031 }, - { 4.731882017875329, 8.890676665580747 }, - { 3.492186010427425, 8.9005225241848 }, - { -1.619700190174894, 3.314060142479045 }, - { 3.5466090064003315, 7.75182101001913 }, - { 5.455682472787392, 8.143119287755635 }, - { -2.3859602945473197, 1.8826732217294837 }, - { 3.9095306088680015, 9.258129209626317 }, - { 7.443020189508173, 7.837840713329312 }, - { 2.136004873917428, 6.917636475958297 }, - { -1.7203379410395119, 2.3212878757611524 }, - { 4.618991257611526, 12.095065976419436 }, - { -0.4837044029854387, 0.8255970441255125 }, - { -4.438938966557163, 4.948666297280241 }, - { -0.4539625134045906, 4.700922454655341 }, - { 2.1285488271265356, 8.457941480487563 }, - { 3.4873561871454393, 11.99809827845933 }, - { 4.723049431412658, 7.813095742563365 }, - { 1.1245583037967455, 5.20587873556688 }, - { 1.3411933634409197, 6.069796875785409 }, - { 4.585119332463686, 7.967669543767418 }, - { 1.3076522817963823, -0.647431033653445 }, - { -1.4449446442803178, 1.9400424267464862 }, - { -2.069794456383682, 3.5824162107496544 }, - { -0.15959481421417276, 1.5466782303315405 }, - { -2.0823081278810136, 3.0914366458581437 }, - { 3.521944615248141, 10.276112932926408 }, - { 1.0164326704884257, 4.342329556442856 }, - { 5.3718868590295275, 8.374761158360922 }, - { 0.3673656866959396, 8.75168581694866 }, - { -2.250268955954753, 1.4610850300996527 }, - { -2.312739727403522, 1.5921126297576362 }, - { 3.138993360831055, 6.7338392374947365 }, - { 2.6978650950790115, 7.941857288979095 }, - { 4.387985088655384, 8.253499976968 }, - { -1.8928961721456705, 0.23631082388724223 }, - { 4.43509029544109, 8.565290285488782 }, - { 4.904728034106502, 5.79936660133754 }, - { -1.7640371853739507, 2.7343727594167433 }, - { 2.4553674733053463, 7.875871017408807 }, - { -2.6478965122565006, 4.465127753193949 }, - { 3.493873671142299, 10.443093773532448 }, - { 1.1321916197409103, 7.127108479263268 }, - { -1.7335075535240392, 2.550629648463023 }, - { -0.9772679734368084, 4.377196298969238 }, - { 3.6388366973980357, 6.947299283206256 }, - { 0.27043799318823325, 6.587978599614367 }, - { 5.356782352010253, 7.388957912116327 }, - { -0.09187745751354681, 0.23612399246659743 }, - { 2.903203580353435, 3.8076727621794415 }, - { 5.297014824937293, 8.650985262326508 }, - { 4.934508602170976, 9.164571423190052 }, - { -1.0004911869654256, 4.797064194444461 }, - { 6.782491700298046, 11.852373338280497 }, - { 2.8983678524536014, 8.303837362117521 }, - { 4.805003269830865, 6.790462904325329 }, - { -0.8815799740744226, 1.3015810062131394 }, - { 5.115138859802104, 6.376895810201089 }, - { 4.301239328205988, 8.60546337560793 }, - { 3.276423626317666, 9.889429652591947 }, - { -4.001924973153122, 4.3353864592328515 }, - { 3.9571892554119517, 4.500569057308562 }, - { 4.783067027436208, 7.451125480601317 }, - { 4.79065438272821, 9.614122776979698 }, - { 2.677655270279617, 6.8875223698210135 }, - { -1.3714746289327362, 2.3992153193382437 }, - { 3.240136859745249, 7.748339397522042 }, - { 5.107885374416291, 8.508324480583724 }, - { -1.5830830226666048, 0.9139127045208315 }, - { -1.1596156791652918, -0.04502759384531929 }, - { -0.4670021307952068, 3.6193633227841624 }, - { -0.7026065228267798, 0.4811423031997131 }, - { -2.719979836732917, 2.5165041618080104 }, - { 1.0336754331123372, -0.34966029029320644 }, - { 4.743217291882213, 5.750060115251131 } - }; - } -} - -/** - * Class that implements a mixture of Gaussian ditributions. - */ -class MultivariateNormalMixtureModelDistribution - extends MixtureMultivariateRealDistribution<MultivariateNormalDistribution> { - - public MultivariateNormalMixtureModelDistribution(List<Pair<Double, MultivariateNormalDistribution>> components) { - super(components); - } -}