This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
The following commit(s) were added to refs/heads/master by this push: new 4187c2d Enforce symmetry in the combine method for the moments 4187c2d is described below commit 4187c2d8aeaa3708d853a6bf7d36d95cbe1bdbc8 Author: aherbert <aherb...@apache.org> AuthorDate: Tue Oct 3 13:00:41 2023 +0100 Enforce symmetry in the combine method for the moments Do not use and maintain the FirstMoment accept variables in the combine method. These are not required for higher order combine methods. --- .../statistics/descriptive/FirstMoment.java | 63 +++++++++++++++------- .../descriptive/SumOfSquaredDeviations.java | 6 ++- .../commons/statistics/descriptive/MeanTest.java | 2 +- .../statistics/descriptive/VarianceTest.java | 2 +- 4 files changed, 50 insertions(+), 23 deletions(-) diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java index ab2b08e..dbea047 100644 --- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java +++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/FirstMoment.java @@ -68,6 +68,7 @@ class FirstMoment implements DoubleConsumer { /** * Deviation of most recently added value from the previous first moment. * Retained to prevent repeated computation in higher order moments. + * Note: This value is not used in the {@link #combine(FirstMoment)} method. */ protected double dev; @@ -75,6 +76,7 @@ class FirstMoment implements DoubleConsumer { * Deviation of most recently added value from the previous first moment, * normalized by current sample size. Retained to prevent repeated * computation in higher order moments. + * Note: This value is not used in the {@link #combine(FirstMoment)} method. */ protected double nDev; @@ -183,26 +185,49 @@ class FirstMoment implements DoubleConsumer { * @return {@code this} instance after combining {@code other}. */ FirstMoment combine(FirstMoment other) { - if (n == 0) { - n = other.n; - nonFiniteValue = other.nonFiniteValue; - dev = other.dev; - nDev = other.nDev; - m1 = other.m1; - } else if (other.n != 0) { - n += other.n; - nonFiniteValue += other.nonFiniteValue; - dev = other.m1 * 0.5 - m1 * 0.5; - // In contrast to the accept method, here nDev can be close to MAX_VALUE - // if the weight (other.n / n) approaches 1. So we cannot yet rescale nDev and - // instead have to combine it with the scaled-down value of m1. - nDev = dev * ((double) other.n / n); - m1 = m1 * 0.5 + nDev; - // Scale up the terms. - m1 *= 2; - dev *= 2; - nDev *= 2; + nonFiniteValue += other.nonFiniteValue; + final double mu1 = this.m1; + final double mu2 = other.m1; + final long n1 = n; + final long n2 = other.n; + n = n1 + n2; + // Adjust the mean with the weighted difference: + // m1 = m1 + (m2 - m1) * n2 / (n1 + n2) + // The difference between means can be 2 * MAX_VALUE so the computation optionally + // scales by a factor of 2. Avoiding scaling if possible preserves sub-normals. + if (n1 == n2) { + // Optimisation for equal sizes: m1 = (m1 + m2) / 2 + // Use scaling for a large sum + final double sum = mu1 + mu2; + m1 = Double.isFinite(sum) ? + sum * 0.5 : + mu1 * 0.5 + mu2 * 0.5; + } else { + // Use scaling for a large difference + if (Double.isFinite(mu2 - mu1)) { + m1 = combine(mu1, mu2, n1, n2); + } else { + m1 = 2 * combine(mu1 * 0.5, mu2 * 0.5, n1, n2); + } } return this; } + + /** + * Combine the moments. This method is used to enforce symmetry. It assumes that + * the two sizes are not identical, and at least one size is non-zero. + * + * @param m1 Moment 1. + * @param m2 Moment 2. + * @param n1 Size of sample 1. + * @param n2 Size of sample 2. + * @return the combined first moment + */ + private static double combine(double m1, double m2, long n1, long n2) { + // Note: If either size is zero the weighted difference is zero and + // the other moment is unchanged. + return n2 < n1 ? + m1 + (m2 - m1) * ((double) n2 / (n1 + n2)) : + m2 + (m1 - m2) * ((double) n1 / (n1 + n2)); + } } diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java index 8ce1604..2fe9a97 100644 --- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java +++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/SumOfSquaredDeviations.java @@ -157,9 +157,11 @@ class SumOfSquaredDeviations extends FirstMoment { } else if (m != 0) { // "Updating one-pass algorithm" // See: Chan et al (1983) Equation 1.5b (modified for the mean) - final double diffOfMean = other.getFirstMoment() - m1; + final double diffOfMean = other.m1 - m1; final double sqDiffOfMean = diffOfMean * diffOfMean; - sumSquaredDev += other.sumSquaredDev + sqDiffOfMean * (((double) n * m) / ((double) n + m)); + // Enforce symmetry + sumSquaredDev = (sumSquaredDev + other.sumSquaredDev) + + sqDiffOfMean * (((double) n * m) / ((double) n + m)); } super.combine(other); return this; diff --git a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/MeanTest.java b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/MeanTest.java index effdfc6..61daa51 100644 --- a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/MeanTest.java +++ b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/MeanTest.java @@ -149,7 +149,7 @@ final class MeanTest { Mean mean1b = Mean.create(); Arrays.stream(array1).forEach(mean1b); mean2.combine(mean1b); - TestHelper.assertEquals(expected, mean2.getAsDouble(), ULP_COMBINE, () -> "combine"); + Assertions.assertEquals(mean1.getAsDouble(), mean2.getAsDouble(), () -> "combine reversed"); Assertions.assertEquals(mean1BeforeCombine, mean1b.getAsDouble()); } diff --git a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/VarianceTest.java b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/VarianceTest.java index 1eb7376..d0b24e2 100644 --- a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/VarianceTest.java +++ b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/VarianceTest.java @@ -149,7 +149,7 @@ final class VarianceTest { Variance var1b = Variance.create(); Arrays.stream(array1).forEach(var1b); var2.combine(var1b); - TestHelper.assertEquals(expected, var2.getAsDouble(), ULP_COMBINE_ACCEPT, () -> "combine"); + Assertions.assertEquals(var1.getAsDouble(), var2.getAsDouble(), () -> "combine reversed"); Assertions.assertEquals(var1BeforeCombine, var1b.getAsDouble()); }