This is an automated email from the ASF dual-hosted git repository. aherbert pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
commit 3b13c1ed144556362b9f49a9f650d4fdafbcfb32 Author: Alex Herbert <[email protected]> AuthorDate: Tue Jan 2 10:43:35 2024 +0000 Refactor DoubleStatistics API to be consistent with IntStatistics --- .../statistics/descriptive/DoubleStatistics.java | 31 +++--- .../descriptive/DoubleStatisticsTest.java | 108 ++++++++++----------- .../statistics/descriptive/UserGuideTest.java | 38 ++++---- 3 files changed, 88 insertions(+), 89 deletions(-) diff --git a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/DoubleStatistics.java b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/DoubleStatistics.java index 3bdd073..eee279f 100644 --- a/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/DoubleStatistics.java +++ b/commons-statistics-descriptive/src/main/java/org/apache/commons/statistics/descriptive/DoubleStatistics.java @@ -19,7 +19,6 @@ package org.apache.commons.statistics.descriptive; import java.util.Objects; import java.util.Set; import java.util.function.DoubleConsumer; -import java.util.function.DoubleSupplier; import java.util.function.Function; /** @@ -381,7 +380,7 @@ public final class DoubleStatistics implements DoubleConsumer { * @param statistic Statistic. * @return {@code true} if supported * @throws NullPointerException if the {@code statistic} is {@code null} - * @see #get(Statistic) + * @see #getAsDouble(Statistic) */ public boolean isSupported(Statistic statistic) { // Check for the appropriate underlying implementation @@ -414,16 +413,16 @@ public final class DoubleStatistics implements DoubleConsumer { } /** - * Gets the value of the specified {@code statistic}. + * Gets the value of the specified {@code statistic} as a {@code double}. * * @param statistic Statistic. * @return the value * @throws IllegalArgumentException if the {@code statistic} is not supported * @see #isSupported(Statistic) - * @see #getSupplier(Statistic) + * @see #getResult(Statistic) */ - public double get(Statistic statistic) { - return getSupplier(statistic).getAsDouble(); + public double getAsDouble(Statistic statistic) { + return getResult(statistic).getAsDouble(); } /** @@ -441,15 +440,15 @@ public final class DoubleStatistics implements DoubleConsumer { * @return the supplier * @throws IllegalArgumentException if the {@code statistic} is not supported * @see #isSupported(Statistic) - * @see #get(Statistic) + * @see #getAsDouble(Statistic) */ - public DoubleSupplier getSupplier(Statistic statistic) { + public StatisticResult getResult(Statistic statistic) { // Locate the implementation. // Statistics that wrap an underlying implementation are created in methods. // The return argument should be a method reference and not an instance // of DoubleStatistic. This ensures the statistic implementation cannot // be updated with new values by casting the result and calling accept(double). - DoubleSupplier stat = null; + StatisticResult stat = null; switch (statistic) { case GEOMETRIC_MEAN: stat = getGeometricMean(); @@ -503,7 +502,7 @@ public final class DoubleStatistics implements DoubleConsumer { * * @return a geometric mean supplier (or null if unsupported) */ - private DoubleSupplier getGeometricMean() { + private StatisticResult getGeometricMean() { if (sumOfLogs != null) { // Return a function that has access to the count and sumOfLogs return () -> GeometricMean.computeGeometricMean(count, sumOfLogs); @@ -516,7 +515,7 @@ public final class DoubleStatistics implements DoubleConsumer { * * @return a kurtosis supplier (or null if unsupported) */ - private DoubleSupplier getKurtosis() { + private StatisticResult getKurtosis() { if (moment instanceof SumOfFourthDeviations) { return new Kurtosis((SumOfFourthDeviations) moment) .setBiased(config.isBiased())::getAsDouble; @@ -529,7 +528,7 @@ public final class DoubleStatistics implements DoubleConsumer { * * @return a mean supplier (or null if unsupported) */ - private DoubleSupplier getMean() { + private StatisticResult getMean() { if (moment != null) { // Special case where wrapping with a Mean is not required return moment::getFirstMoment; @@ -542,7 +541,7 @@ public final class DoubleStatistics implements DoubleConsumer { * * @return a skewness supplier (or null if unsupported) */ - private DoubleSupplier getSkewness() { + private StatisticResult getSkewness() { if (moment instanceof SumOfCubedDeviations) { return new Skewness((SumOfCubedDeviations) moment) .setBiased(config.isBiased())::getAsDouble; @@ -555,7 +554,7 @@ public final class DoubleStatistics implements DoubleConsumer { * * @return a standard deviation supplier (or null if unsupported) */ - private DoubleSupplier getStandardDeviation() { + private StatisticResult getStandardDeviation() { if (moment instanceof SumOfSquaredDeviations) { return new StandardDeviation((SumOfSquaredDeviations) moment) .setBiased(config.isBiased())::getAsDouble; @@ -568,7 +567,7 @@ public final class DoubleStatistics implements DoubleConsumer { * * @return a variance supplier (or null if unsupported) */ - private DoubleSupplier getVariance() { + private StatisticResult getVariance() { if (moment instanceof SumOfSquaredDeviations) { return new Variance((SumOfSquaredDeviations) moment) .setBiased(config.isBiased())::getAsDouble; @@ -627,7 +626,7 @@ public final class DoubleStatistics implements DoubleConsumer { * @param v Value. * @return {@code this} instance * @throws NullPointerException if the value is null - * @see #getSupplier(Statistic) + * @see #getResult(Statistic) */ public DoubleStatistics setConfiguration(StatisticsConfiguration v) { config = Objects.requireNonNull(v); diff --git a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/DoubleStatisticsTest.java b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/DoubleStatisticsTest.java index 096aa4e..69fc8fb 100644 --- a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/DoubleStatisticsTest.java +++ b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/DoubleStatisticsTest.java @@ -26,12 +26,12 @@ import java.util.List; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; import java.util.function.BiConsumer; -import java.util.function.DoubleSupplier; import java.util.function.Function; import java.util.function.Supplier; import java.util.function.ToDoubleFunction; import java.util.stream.IntStream; import java.util.stream.Stream; +import org.apache.commons.statistics.distribution.DoubleTolerances; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeAll; @@ -45,6 +45,10 @@ import org.junit.jupiter.params.provider.MethodSource; * * <p>This class verifies that the statistics computed using the summary * class are an exact match to the statistics computed individually. + * + * <p>For simplicity some tests use only the {@code double} result of the statistic. + * The full {@link StatisticResult} interface is asserted in the test of the array or + * stream of input data. */ class DoubleStatisticsTest { /** Empty statistic array. */ @@ -271,6 +275,14 @@ class DoubleStatisticsTest { assertStatistics(stats, data, builder::build, ExpectedResult::getArray); } + /** + * Assert the computed statistics match the expected result. + * + * @param stats Statistics that are computed. + * @param data Test data. + * @param constructor Constructor to create the {@link IntStatistics}. + * @param expected Function to obtain the expected result. + */ private static void assertStatistics(EnumSet<Statistic> stats, TestData data, Function<double[], DoubleStatistics> constructor, ToDoubleFunction<ExpectedResult> expected) { @@ -285,35 +297,48 @@ class DoubleStatisticsTest { stats.forEach(s -> computed.addAll(coComputed.get(s))); // Test if the statistics are correctly identified as supported - EnumSet.allOf(Statistic.class).forEach(s -> - Assertions.assertEquals(computed.contains(s), statistics.isSupported(s), - () -> stats + " isSupported -> " + s.toString())); - - // Test the values - computed.forEach(s -> - Assertions.assertEquals(expected.applyAsDouble(expectedResult.get(s).get(id)), statistics.get(s), - () -> stats + " value -> " + s.toString())); + EnumSet.allOf(Statistic.class).forEach(s -> { + final boolean isSupported = computed.contains(s); + Assertions.assertEquals(isSupported, statistics.isSupported(s), + () -> stats + " isSupported -> " + s.toString()); + if (isSupported) { + final double doubleResult = expected.applyAsDouble(expectedResult.get(s).get(id)); + // Test individual values + Assertions.assertEquals(doubleResult, statistics.getAsDouble(s), + () -> stats + " getAsDouble -> " + s.toString()); + // Test the values from the result + final StatisticResult result = () -> doubleResult; + TestHelper.assertEquals(result, statistics.getResult(s), + DoubleTolerances.equals(), + () -> stats + " getResult -> " + s.toString()); + } else { + Assertions.assertThrows(IllegalArgumentException.class, () -> statistics.getAsDouble(s), + () -> stats + " getAsDouble -> " + s.toString()); + Assertions.assertThrows(IllegalArgumentException.class, () -> statistics.getResult(s), + () -> stats + " getResult -> " + s.toString()); + } + }); } /** * Add all the {@code values} to an aggregator of the {@code statistics}. * - * <p>This method verifies that the {@link DoubleStatistics#get(Statistic)} and - * {@link DoubleStatistics#getSupplier(Statistic)} methods return the same + * <p>This method verifies that the {@link DoubleStatistics#getAsDouble(Statistic)} and + * {@link DoubleStatistics#getResult(Statistic)} methods return the same * result as values are added. * - * @param statistic Statistics. + * @param statistics Statistics. * @param values Values. * @return the statistics */ private static DoubleStatistics acceptAll(Statistic[] statistics, double[] values) { final DoubleStatistics stats = DoubleStatistics.of(statistics); - final DoubleSupplier[] f = getSuppliers(statistics, stats); + final StatisticResult[] f = getResults(statistics, stats); for (final double x : values) { stats.accept(x); for (int i = 0; i < statistics.length; i++) { final Statistic s = statistics[i]; - Assertions.assertEquals(stats.get(s), f[i].getAsDouble(), + Assertions.assertEquals(stats.getAsDouble(s), f[i].getAsDouble(), () -> "Supplier(" + s + ") after value " + x); } } @@ -327,10 +352,10 @@ class DoubleStatisticsTest { * @param stats Statistic aggregator. * @return the suppliers */ - private static DoubleSupplier[] getSuppliers(Statistic[] statistics, final DoubleStatistics stats) { - final DoubleSupplier[] f = new DoubleSupplier[statistics.length]; + private static StatisticResult[] getResults(Statistic[] statistics, final DoubleStatistics stats) { + final StatisticResult[] f = new StatisticResult[statistics.length]; for (int i = 0; i < statistics.length; i++) { - final DoubleSupplier supplier = stats.getSupplier(statistics[i]); + final StatisticResult supplier = stats.getResult(statistics[i]); Assertions.assertFalse(supplier instanceof DoubleStatistic, () -> "DoubleStatistic instance: " + supplier.getClass().getSimpleName()); f[i] = supplier; @@ -341,8 +366,8 @@ class DoubleStatisticsTest { /** * Combine the two statistic aggregators. * - * <p>This method verifies that the {@link DoubleStatistics#get(Statistic)} and - * {@link DoubleStatistics#getSupplier(Statistic)} methods return the same + * <p>This method verifies that the {@link DoubleStatistics#getAsDouble(Statistic)} and + * {@link DoubleStatistics#getResult(Statistic)} methods return the same * result after the {@link DoubleStatistics#combine(DoubleStatistics)}. * * @param statistics Statistics to compute. @@ -352,11 +377,11 @@ class DoubleStatisticsTest { */ private static DoubleStatistics combine(Statistic[] statistics, DoubleStatistics s1, DoubleStatistics s2) { - final DoubleSupplier[] f = getSuppliers(statistics, s1); + final StatisticResult[] f = getResults(statistics, s1); s1.combine(s2); for (int i = 0; i < statistics.length; i++) { final Statistic s = statistics[i]; - Assertions.assertEquals(s1.get(s), f[i].getAsDouble(), + Assertions.assertEquals(s1.getAsDouble(s), f[i].getAsDouble(), () -> "Supplier(" + s + ") after combine"); } return s1; @@ -392,31 +417,6 @@ class DoubleStatisticsTest { Assertions.assertThrows(NullPointerException.class, () -> s.isSupported(null)); } - @ParameterizedTest - @MethodSource - void testNotSupported(Statistic stat) { - DoubleStatistics statistics = DoubleStatistics.of(stat); - for (final Statistic s : Statistic.values()) { - Assertions.assertEquals(s == stat, statistics.isSupported(s), - () -> stat + " isSupported -> " + s.toString()); - if (s == stat) { - Assertions.assertDoesNotThrow(() -> statistics.get(s), - () -> stat + " get -> " + s.toString()); - Assertions.assertNotNull(statistics.getSupplier(s), - () -> stat + " getSupplier -> " + s.toString()); - } else { - Assertions.assertThrows(IllegalArgumentException.class, () -> statistics.get(s), - () -> stat + " get -> " + s.toString()); - Assertions.assertThrows(IllegalArgumentException.class, () -> statistics.getSupplier(s), - () -> stat + " getSupplier -> " + s.toString()); - } - } - } - - static Statistic[] testNotSupported() { - return new Statistic[] {Statistic.MIN, Statistic.PRODUCT}; - } - @ParameterizedTest @MethodSource void testIncompatibleCombineThrows(EnumSet<Statistic> stat1, EnumSet<Statistic> stat2) { @@ -425,13 +425,13 @@ class DoubleStatisticsTest { DoubleStatistics statistics = DoubleStatistics.of(stat1, v1); DoubleStatistics other = DoubleStatistics.of(stat2, v2); // Store values - final double[] values = stat1.stream().mapToDouble(statistics::get).toArray(); + final double[] values = stat1.stream().mapToDouble(statistics::getAsDouble).toArray(); Assertions.assertThrows(IllegalArgumentException.class, () -> statistics.combine(other), () -> stat1 + " " + stat2); // Values should be unchanged final int[] i = {0}; stat1.stream().forEach( - s -> Assertions.assertEquals(values[i[0]++], statistics.get(s), () -> s + " changed")); + s -> Assertions.assertEquals(values[i[0]++], statistics.getAsDouble(s), () -> s + " changed")); } static Stream<Arguments> testIncompatibleCombineThrows() { @@ -459,9 +459,9 @@ class DoubleStatisticsTest { statistics2.combine(other2); // The stats should be the same for (final Statistic s : stat1) { - final double expected = statistics1.get(s); + final double expected = statistics1.getAsDouble(s); assertFinite(expected, s); - Assertions.assertEquals(expected, statistics2.get(s), () -> s.toString()); + Assertions.assertEquals(expected, statistics2.getAsDouble(s), () -> s.toString()); } } @@ -487,7 +487,7 @@ class DoubleStatisticsTest { final DoubleStatistics statistics1 = DoubleStatistics.builder(stat).build(values); StatisticsConfiguration c = StatisticsConfiguration.withDefaults(); - DoubleSupplier s = null; + StatisticResult s = null; // Note the circular loop to check setting back to the start option for (int index = 0; index <= options.length; index++) { final int i = index % options.length; @@ -495,9 +495,9 @@ class DoubleStatisticsTest { c = c.withBiased(value); Assertions.assertSame(statistics1, statistics1.setConfiguration(c)); - Assertions.assertEquals(results[i], statistics1.get(stat), + Assertions.assertEquals(results[i], statistics1.getAsDouble(stat), () -> options[i] + " get: " + BaseDoubleStatisticTest.format(values)); - final DoubleSupplier s1 = statistics1.getSupplier(stat); + final StatisticResult s1 = statistics1.getResult(stat); Assertions.assertEquals(results[i], s1.getAsDouble(), () -> options[i] + " supplier: " + BaseDoubleStatisticTest.format(values)); @@ -512,7 +512,7 @@ class DoubleStatisticsTest { // Set through the builder final DoubleStatistics statistics2 = DoubleStatistics.builder(stat) .setConfiguration(c).build(values); - Assertions.assertEquals(results[i], statistics2.get(stat), + Assertions.assertEquals(results[i], statistics2.getAsDouble(stat), () -> options[i] + " get via builder: " + BaseDoubleStatisticTest.format(values)); } } diff --git a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/UserGuideTest.java b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/UserGuideTest.java index daeba31..745b35c 100644 --- a/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/UserGuideTest.java +++ b/commons-statistics-descriptive/src/test/java/org/apache/commons/statistics/descriptive/UserGuideTest.java @@ -55,15 +55,15 @@ class UserGuideTest { DoubleStatistics stats = DoubleStatistics.builder( Statistic.MIN, Statistic.MAX, Statistic.VARIANCE) .build(data); - Assertions.assertEquals(1, stats.get(Statistic.MIN)); - Assertions.assertEquals(8, stats.get(Statistic.MAX)); + Assertions.assertEquals(1, stats.getAsDouble(Statistic.MIN)); + Assertions.assertEquals(8, stats.getAsDouble(Statistic.MAX)); // Python numpy 1.24.4 // np.var(np.arange(1, 9), ddof=1) // np.std(np.arange(1, 9), ddof=1) - Assertions.assertEquals(6.0, stats.get(Statistic.VARIANCE), 1e-10); + Assertions.assertEquals(6.0, stats.getAsDouble(Statistic.VARIANCE), 1e-10); // Get other statistics supported by the underlying computations - Assertions.assertEquals(2.449489742783178, stats.get(Statistic.STANDARD_DEVIATION), 1e-10); - Assertions.assertEquals(4.5, stats.get(Statistic.MEAN), 1e-10); + Assertions.assertEquals(2.449489742783178, stats.getAsDouble(Statistic.STANDARD_DEVIATION), 1e-10); + Assertions.assertEquals(4.5, stats.getAsDouble(Statistic.MEAN), 1e-10); } @Test @@ -78,12 +78,12 @@ class UserGuideTest { .map(builder::build) .reduce(DoubleStatistics::combine) .get(); - Assertions.assertEquals(1, stats.get(Statistic.MIN)); - Assertions.assertEquals(8, stats.get(Statistic.MAX)); - Assertions.assertEquals(6.0, stats.get(Statistic.VARIANCE), 1e-10); + Assertions.assertEquals(1, stats.getAsDouble(Statistic.MIN)); + Assertions.assertEquals(8, stats.getAsDouble(Statistic.MAX)); + Assertions.assertEquals(6.0, stats.getAsDouble(Statistic.VARIANCE), 1e-10); // Get other statistics supported by the underlying computations - Assertions.assertEquals(2.449489742783178, stats.get(Statistic.STANDARD_DEVIATION), 1e-10); - Assertions.assertEquals(4.5, stats.get(Statistic.MEAN), 1e-10); + Assertions.assertEquals(2.449489742783178, stats.getAsDouble(Statistic.STANDARD_DEVIATION), 1e-10); + Assertions.assertEquals(4.5, stats.getAsDouble(Statistic.MEAN), 1e-10); } @Test @@ -97,12 +97,12 @@ class UserGuideTest { Collector<double[], DoubleStatistics, DoubleStatistics> collector = Collector.of(builder::build, (s, d) -> s.combine(builder.build(d)), DoubleStatistics::combine); DoubleStatistics stats = Arrays.stream(data).collect(collector); - Assertions.assertEquals(1, stats.get(Statistic.MIN)); - Assertions.assertEquals(8, stats.get(Statistic.MAX)); - Assertions.assertEquals(6.0, stats.get(Statistic.VARIANCE), 1e-10); + Assertions.assertEquals(1, stats.getAsDouble(Statistic.MIN)); + Assertions.assertEquals(8, stats.getAsDouble(Statistic.MAX)); + Assertions.assertEquals(6.0, stats.getAsDouble(Statistic.VARIANCE), 1e-10); // Get other statistics supported by the underlying computations - Assertions.assertEquals(2.449489742783178, stats.get(Statistic.STANDARD_DEVIATION), 1e-10); - Assertions.assertEquals(4.5, stats.get(Statistic.MEAN), 1e-10); + Assertions.assertEquals(2.449489742783178, stats.getAsDouble(Statistic.STANDARD_DEVIATION), 1e-10); + Assertions.assertEquals(4.5, stats.getAsDouble(Statistic.MEAN), 1e-10); } @Test @@ -119,15 +119,15 @@ class UserGuideTest { DoubleStatistics stats = DoubleStatistics.of( EnumSet.of(Statistic.MIN, Statistic.MAX), 1, 1, 2, 3, 5, 8, 13); - Assertions.assertEquals(1, stats.get(Statistic.MIN)); - Assertions.assertEquals(13, stats.get(Statistic.MAX)); + Assertions.assertEquals(1, stats.getAsDouble(Statistic.MIN)); + Assertions.assertEquals(13, stats.getAsDouble(Statistic.MAX)); } @Test void testDoubleStatistics6() { DoubleStatistics stats = DoubleStatistics.of(Statistic.MEAN, Statistic.MAX); - DoubleSupplier mean = stats.getSupplier(Statistic.MEAN); - DoubleSupplier max = stats.getSupplier(Statistic.MAX); + DoubleSupplier mean = stats.getResult(Statistic.MEAN); + DoubleSupplier max = stats.getResult(Statistic.MAX); IntStream.rangeClosed(1, 5).forEach(x -> { stats.accept(x); Assertions.assertEquals((x + 1.0) / 2, mean.getAsDouble(), "mean");
