This is an automated email from the ASF dual-hosted git repository.

aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-math.git


The following commit(s) were added to refs/heads/master by this push:
     new bab00341a Allow fitting single component data
bab00341a is described below

commit bab00341a9c4072015dde0d94d301bf60cd16f3b
Author: Alex Herbert <aherb...@apache.org>
AuthorDate: Mon Mar 11 21:55:54 2024 +0000

    Allow fitting single component data
---
 ...ariateNormalMixtureExpectationMaximization.java |   6 +-
 ...teNormalMixtureExpectationMaximizationTest.java | 139 ++++++++++++---------
 2 files changed, 84 insertions(+), 61 deletions(-)

diff --git 
a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
 
b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
index 8b51195ab..a3c7397d2 100644
--- 
a/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
+++ 
b/commons-math-legacy/src/main/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximization.java
@@ -294,7 +294,7 @@ public class 
MultivariateNormalMixtureExpectationMaximization {
      * @return Multivariate normal mixture model estimated from the data
      * @throws NumberIsTooLargeException if {@code numComponents} is greater
      * than the number of data rows.
-     * @throws NumberIsTooSmallException if {@code numComponents < 2}.
+     * @throws NumberIsTooSmallException if {@code numComponents < 1}.
      * @throws NotStrictlyPositiveException if data has less than 2 rows
      * @throws DimensionMismatchException if rows of data have different 
numbers
      *             of columns
@@ -306,8 +306,8 @@ public class 
MultivariateNormalMixtureExpectationMaximization {
         if (data.length < 2) {
             throw new NotStrictlyPositiveException(data.length);
         }
-        if (numComponents < 2) {
-            throw new NumberIsTooSmallException(numComponents, 2, true);
+        if (numComponents < 1) {
+            throw new NumberIsTooSmallException(numComponents, 1, true);
         }
         if (numComponents > data.length) {
             throw new NumberIsTooLargeException(numComponents, data.length, 
true);
diff --git 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
index 801cff467..281036456 100644
--- 
a/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
+++ 
b/commons-math-legacy/src/test/java/org/apache/commons/math4/legacy/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
@@ -191,104 +191,127 @@ public class 
MultivariateNormalMixtureExpectationMaximizationTest {
     }
 
     @Test
-    public void testFit() {
-        // Test that the loglikelihood, weights, and models are determined and
-        // fitted correctly
+    public void testFit2Dimensions2Components() {
         final double[][] data = getTestSamples();
-        final double correctLogLikelihood = -4.292431006791994;
-        final double[] correctWeights = new double[] { 0.2962324189652912, 
0.7037675810347089 };
 
-        final double[][] correctMeans = new double[][]{
-            {-1.4213112715121132, 1.6924690505757753},
-            {4.213612224374709, 7.975621325853645}
-        };
+        // Fit using the test samples using Matlab R2023b (Update 6):
+        // GMModel = fitgmdist(X,2);
 
-        final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
-        correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
-            { 1.739356907285747, -0.5867644251487614 },
-            { -0.5867644251487614, 1.0232932029324642 } }
-                );
-        correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
-            { 4.245384898007161, 2.5797798966382155 },
-            { 2.5797798966382155, 3.9200272522448367 } });
+        // Expected results use the component order generated by the CM code 
for convenience
+        // i.e. ComponentProportion from matlab is reversed: [0.703722, 
0.296278]
 
-        final MultivariateNormalDistribution[] correctMVNs = new 
MultivariateNormalDistribution[2];
-        correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], 
correctCovMats[0].getData());
-        correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], 
correctCovMats[1].getData());
+        // NegativeLogLikelihood (CM code use the positive log-likehood 
divided by the number of observations)
+        final double logLikelihood = -4.292430883324220e+02 / data.length;
+        // ComponentProportion
+        final double[] weights = new double[] {0.2962324189652912, 
0.7037675810347089};
+        // mu
+        final double[][] means = new double[][]{
+            {-1.421239458366293, 1.692604555824222},
+            {4.213949861591596, 7.975974466776790}
+        };
+        // Sigma
+        final double[][][] covar = new double[][][] {
+            {{1.739441346307267, -0.586740858187563},
+             {-0.586740858187563, 1.023420964341543}},
+            {{4.243780645051973, 2.578176622652551},
+             {2.578176622652551, 3.918302056479298}}
+        };
 
-        MultivariateNormalMixtureExpectationMaximization fitter
-            = new MultivariateNormalMixtureExpectationMaximization(data);
+        assertFit(data, 2, logLikelihood, weights, means, covar, 1e-3);
+    }
 
-        MixtureMultivariateNormalDistribution initialMix
-            = MultivariateNormalMixtureExpectationMaximization.estimate(data, 
2);
-        fitter.fit(initialMix);
-        MixtureMultivariateNormalDistribution fittedMix = 
fitter.getFittedModel();
-        List<Pair<Double, MultivariateNormalDistribution>> components = 
fittedMix.getComponents();
+    @Test
+    public void testFit1Dimension2Components() {
+        // Use only the first column of the test data
+        final double[][] data = Arrays.stream(getTestSamples())
+            .map(x -> new double[] {x[0]}).toArray(double[][]::new);
+
+        // Fit the first column of test samples using Matlab R2023b (Update 6):
+        // GMModel = fitgmdist(X,2);
 
-        Assert.assertEquals(correctLogLikelihood,
-                            fitter.getLogLikelihood(),
-                            Math.ulp(1d));
+        // NegativeLogLikelihood (CM code use the positive log-likehood 
divided by the number of observations)
+        final double logLikelihood = -2.512197016873482e+02 / data.length;
+        // ComponentProportion
+        final double[] weights = new double[] {0.240510201974078, 
0.759489798025922};
+        // Since data has 1 dimension the means and covariances are single 
values
+        // mu
+        final double[][] means = new double[][]{
+            {-1.736139126623031},
+            {3.899886984922886}
+        };
+        // Sigma
+        final double[][][] covar = new double[][][] {
+            {{1.371327786710623}},
+            {{5.254286022455004}}
+        };
 
-        int i = 0;
-        for (Pair<Double, MultivariateNormalDistribution> component : 
components) {
-            final double weight = component.getFirst();
-            final MultivariateNormalDistribution mvn = component.getSecond();
-            final double[] mean = mvn.getMeans();
-            final RealMatrix covMat = mvn.getCovariances();
-            Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
-            Assert.assertArrayEquals(correctMeans[i], mean, 0.0);
-            Assert.assertEquals(correctCovMats[i], covMat);
-            i++;
-        }
+        assertFit(data, 2, logLikelihood, weights, means, covar, 0.05);
     }
 
     @Test
-    public void testFit1() {
-        // Test that the fit can be performed on data with a single dimension
+    public void testFit1Dimension1Component() {
         // Use only the first column of the test data
         final double[][] data = Arrays.stream(getTestSamples())
             .map(x -> new double[] {x[0]}).toArray(double[][]::new);
 
         // Fit the first column of test samples using Matlab R2023b (Update 6):
-        // GMModel = fitgmdist(X,2);
+        // GMModel = fitgmdist(X,1);
 
         // NegativeLogLikelihood (CM code use the positive log-likehood 
divided by the number of observations)
-        final double correctLogLikelihood = -2.512197016873482e+02 / 
data.length;
+        final double logLikelihood = -2.576329329354790e+02 / data.length;
         // ComponentProportion
-        final double[] correctWeights = new double[] {0.240510201974078, 
0.759489798025922};
+        final double[] weights = new double[] {1.0};
         // Since data has 1 dimension the means and covariances are single 
values
         // mu
-        final double[] correctMeans = new double[] {-1.736139126623031, 
3.899886984922886};
+        final double[][] means = new double[][]{
+            {2.544365206503801},
+        };
         // Sigma
-        final double[] correctCov = new double[] {1.371327786710623, 
5.254286022455004};
+        final double[][][] covar = new double[][][] {
+            {{10.122711799089901}},
+        };
 
+        assertFit(data, 1, logLikelihood, weights, means, covar, 1e-3);
+    }
+
+    private static void assertFit(double[][] data, int numComponents,
+            double logLikelihood, double[] weights,
+            double[][] means, double[][][] covar, double relError) {
         MultivariateNormalMixtureExpectationMaximization fitter
             = new MultivariateNormalMixtureExpectationMaximization(data);
 
         MixtureMultivariateNormalDistribution initialMix
-            = MultivariateNormalMixtureExpectationMaximization.estimate(data, 
2);
+            = MultivariateNormalMixtureExpectationMaximization.estimate(data, 
numComponents);
         fitter.fit(initialMix);
         MixtureMultivariateNormalDistribution fittedMix = 
fitter.getFittedModel();
         List<Pair<Double, MultivariateNormalDistribution>> components = 
fittedMix.getComponents();
 
-        final double relError = 0.05;
-        Assert.assertEquals(correctLogLikelihood,
-                            fitter.getLogLikelihood(),
-                            Math.abs(correctLogLikelihood) * relError);
+        Assert.assertEquals(logLikelihood,
+            fitter.getLogLikelihood(),
+            Math.abs(logLikelihood) * relError);
 
         int i = 0;
         for (Pair<Double, MultivariateNormalDistribution> component : 
components) {
             final double weight = component.getFirst();
             final MultivariateNormalDistribution mvn = component.getSecond();
-            final double[] mean = mvn.getMeans();
-            final RealMatrix covMat = mvn.getCovariances();
-            Assert.assertEquals(correctWeights[i], weight, correctWeights[i] * 
relError);
-            Assert.assertEquals(correctMeans[i], mean[0], 
Math.abs(correctMeans[i]) * relError);
-            Assert.assertEquals(correctCov[i], covMat.getEntry(0, 0), 
correctCov[i] * relError);
+            Assert.assertEquals(weights[i], weight, weights[i] * relError);
+            assertArrayEquals(means[i], mvn.getMeans(), relError);
+            final double[][] c = mvn.getCovariances().getData();
+            Assert.assertEquals(covar[i].length, c.length);
+            for (int j = 0; j < covar[i].length; j++) {
+                assertArrayEquals(covar[i][j], c[j], relError);
+            }
             i++;
         }
     }
 
+    private static void assertArrayEquals(double[] e, double[] a, double 
relError) {
+        Assert.assertEquals("length", e.length, a.length);
+        for (int i = 0; i < e.length; i++) {
+            Assert.assertEquals(e[i], a[i], Math.abs(e[i]) * relError);
+        }
+    }
+
     private double[][] getTestSamples() {
         // generated using R Mixtools rmvnorm with mean vectors [-1.5, 2] and
         // [4, 8.2]

Reply via email to