This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new 800220b860 VarianceAggregationFunction NULL support. (#11365) 800220b860 is described below commit 800220b860a31243b179554c9d87732f9456e372 Author: Shen Yu <s...@startree.ai> AuthorDate: Tue Aug 22 17:55:46 2023 +0000 VarianceAggregationFunction NULL support. (#11365) --- .../function/AggregationFunctionFactory.java | 8 +- .../function/VarianceAggregationFunction.java | 84 +++++++++++--- .../queries/NullHandlingEnabledQueriesTest.java | 122 +++++++++++++++++++++ 3 files changed, 196 insertions(+), 18 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java index a79f01a0f1..0f03ee8723 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java @@ -317,13 +317,13 @@ public class AggregationFunctionFactory { case BOOLOR: return new BooleanOrAggregationFunction(firstArgument, nullHandlingEnabled); case VARPOP: - return new VarianceAggregationFunction(firstArgument, false, false); + return new VarianceAggregationFunction(firstArgument, false, false, nullHandlingEnabled); case VARSAMP: - return new VarianceAggregationFunction(firstArgument, true, false); + return new VarianceAggregationFunction(firstArgument, true, false, nullHandlingEnabled); case STDDEVPOP: - return new VarianceAggregationFunction(firstArgument, false, true); + return new VarianceAggregationFunction(firstArgument, false, true, nullHandlingEnabled); case STDDEVSAMP: - return new VarianceAggregationFunction(firstArgument, true, true); + return new VarianceAggregationFunction(firstArgument, true, true, nullHandlingEnabled); case SKEWNESS: return new FourthMomentAggregationFunction(firstArgument, FourthMomentAggregationFunction.Type.SKEWNESS); case KURTOSIS: diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java index c86269b7ce..5498731442 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java @@ -29,6 +29,7 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder import org.apache.pinot.core.query.aggregation.utils.StatisticalAggregationFunctionUtils; import org.apache.pinot.segment.local.customobject.VarianceTuple; import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.roaringbitmap.RoaringBitmap; /** @@ -41,13 +42,15 @@ import org.apache.pinot.segment.spi.AggregationFunctionType; public class VarianceAggregationFunction extends BaseSingleInputAggregationFunction<VarianceTuple, Double> { private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY; protected final boolean _isSample; - protected final boolean _isStdDev; + protected final boolean _nullHandlingEnabled; - public VarianceAggregationFunction(ExpressionContext expression, boolean isSample, boolean isStdDev) { + public VarianceAggregationFunction(ExpressionContext expression, boolean isSample, boolean isStdDev, + boolean nullHandlingEnabled) { super(expression); _isSample = isSample; _isStdDev = isStdDev; + _nullHandlingEnabled = nullHandlingEnabled; } @Override @@ -72,18 +75,38 @@ public class VarianceAggregationFunction extends BaseSingleInputAggregationFunct public void aggregate(int length, AggregationResultHolder aggregationResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[] values = StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression); + RoaringBitmap nullBitmap = null; + if (_nullHandlingEnabled) { + nullBitmap = blockValSetMap.get(_expression).getNullBitmap(); + } long count = 0; double sum = 0.0; double variance = 0.0; - for (int i = 0; i < length; i++) { - count++; - sum += values[i]; - if (count > 1) { - variance = computeIntermediateVariance(count, sum, variance, values[i]); + if (nullBitmap != null && !nullBitmap.isEmpty()) { + for (int i = 0; i < length; i++) { + if (!nullBitmap.contains(i)) { + count++; + sum += values[i]; + if (count > 1) { + variance = computeIntermediateVariance(count, sum, variance, values[i]); + } + } + } + } else { + for (int i = 0; i < length; i++) { + count++; + sum += values[i]; + if (count > 1) { + variance = computeIntermediateVariance(count, sum, variance, values[i]); + } } } - setAggregationResult(aggregationResultHolder, length, sum, variance); + + if (_nullHandlingEnabled && count == 0) { + return; + } + setAggregationResult(aggregationResultHolder, count, sum, variance); } private double computeIntermediateVariance(long count, double sum, double m2, double value) { @@ -116,8 +139,20 @@ public class VarianceAggregationFunction extends BaseSingleInputAggregationFunct public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[] values = StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression); - for (int i = 0; i < length; i++) { - setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i], 0.0); + RoaringBitmap nullBitmap = null; + if (_nullHandlingEnabled) { + nullBitmap = blockValSetMap.get(_expression).getNullBitmap(); + } + if (nullBitmap != null && !nullBitmap.isEmpty()) { + for (int i = 0; i < length; i++) { + if (!nullBitmap.contains(i)) { + setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i], 0.0); + } + } + } else { + for (int i = 0; i < length; i++) { + setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i], 0.0); + } } } @@ -125,9 +160,23 @@ public class VarianceAggregationFunction extends BaseSingleInputAggregationFunct public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, Map<ExpressionContext, BlockValSet> blockValSetMap) { double[] values = StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression); - for (int i = 0; i < length; i++) { - for (int groupKey : groupKeysArray[i]) { - setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0); + RoaringBitmap nullBitmap = null; + if (_nullHandlingEnabled) { + nullBitmap = blockValSetMap.get(_expression).getNullBitmap(); + } + if (nullBitmap != null && !nullBitmap.isEmpty()) { + for (int i = 0; i < length; i++) { + if (!nullBitmap.contains(i)) { + for (int groupKey : groupKeysArray[i]) { + setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0); + } + } + } + } else { + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0); + } } } } @@ -136,7 +185,7 @@ public class VarianceAggregationFunction extends BaseSingleInputAggregationFunct public VarianceTuple extractAggregationResult(AggregationResultHolder aggregationResultHolder) { VarianceTuple varianceTuple = aggregationResultHolder.getResult(); if (varianceTuple == null) { - return new VarianceTuple(0L, 0.0, 0.0); + return _nullHandlingEnabled ? null : new VarianceTuple(0L, 0.0, 0.0); } else { return varianceTuple; } @@ -149,6 +198,13 @@ public class VarianceAggregationFunction extends BaseSingleInputAggregationFunct @Override public VarianceTuple merge(VarianceTuple intermediateResult1, VarianceTuple intermediateResult2) { + if (_nullHandlingEnabled) { + if (intermediateResult1 == null) { + return intermediateResult2; + } else if (intermediateResult2 == null) { + return intermediateResult1; + } + } intermediateResult1.apply(intermediateResult2); return intermediateResult1; } diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java index b53bfa17fa..9830f35f7e 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java @@ -1400,4 +1400,126 @@ public class NullHandlingEnabledQueriesTest extends BaseQueriesTest { assertEquals(rows.size(), NUM_OF_SEGMENT_COPIES); assertArrayEquals(rows.get(0), new Object[]{null}); } + + @Test(dataProvider = "NumberTypes") + public void testStddevPop(FieldSpec.DataType dataType) + throws Exception { + initializeRows(); + insertRow(null); + insertRow(1); + insertRow(2); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + Schema schema = new Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType).build(); + setUpSegments(tableConfig, schema); + String query = String.format("SELECT STDDEV_POP(%s) FROM testTable", COLUMN1); + + BrokerResponseNative brokerResponse = getBrokerResponse(query, QUERY_OPTIONS); + + ResultTable resultTable = brokerResponse.getResultTable(); + List<Object[]> rows = resultTable.getRows(); + assertEquals(rows.size(), 1); + assertEquals(rows.get(0)[0], 0.5); + } + + @Test(dataProvider = "NumberTypes") + public void testGroupByStddevPop(FieldSpec.DataType dataType) + throws Exception { + initializeRows(); + insertRowWithTwoColumns(null, "key"); + insertRowWithTwoColumns(1, "key"); + insertRowWithTwoColumns(2, "key"); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + Schema schema = new Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType) + .addSingleValueDimension(COLUMN2, FieldSpec.DataType.STRING).build(); + setUpSegments(tableConfig, schema); + String query = String.format("SELECT STDDEV_POP(%s), %s FROM testTable GROUP BY %s", COLUMN1, COLUMN2, COLUMN2); + + BrokerResponseNative brokerResponse = getBrokerResponse(query, QUERY_OPTIONS); + + ResultTable resultTable = brokerResponse.getResultTable(); + List<Object[]> rows = resultTable.getRows(); + assertEquals(rows.size(), 1); + assertArrayEquals(rows.get(0), new Object[]{0.5, "key"}); + } + + @Test(dataProvider = "NumberTypes") + public void testGroupByMvStddevPop(FieldSpec.DataType dataType) + throws Exception { + initializeRows(); + insertRowWithTwoColumns(null, new String[]{"key1", "key2"}); + insertRowWithTwoColumns(1, new String[]{"key1", "key2"}); + insertRowWithTwoColumns(2, new String[]{"key1"}); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + Schema schema = new Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType) + .addMultiValueDimension(COLUMN2, FieldSpec.DataType.STRING).build(); + setUpSegments(tableConfig, schema); + String query = + String.format("SELECT STDDEV_POP(%s), %s FROM testTable GROUP BY %s ORDER BY %s", COLUMN1, COLUMN2, COLUMN2, + COLUMN2); + + BrokerResponseNative brokerResponse = getBrokerResponse(query, QUERY_OPTIONS); + + ResultTable resultTable = brokerResponse.getResultTable(); + List<Object[]> rows = resultTable.getRows(); + assertEquals(rows.size(), 2); + assertArrayEquals(rows.get(0), new Object[]{0.5, "key1"}); + assertArrayEquals(rows.get(1), new Object[]{0.0, "key2"}); + } + + @Test + public void testAllNullGroupByStddevPopReturnsNull() + throws Exception { + initializeRows(); + insertRowWithTwoColumns(null, "key1"); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + Schema schema = new Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, FieldSpec.DataType.INT) + .addSingleValueDimension(COLUMN2, FieldSpec.DataType.STRING).build(); + setUpSegments(tableConfig, schema); + String query = + String.format("SELECT STDDEV_POP(%s), %s FROM testTable GROUP BY %s ORDER BY %s", COLUMN1, COLUMN2, COLUMN2, + COLUMN2); + + BrokerResponseNative brokerResponse = getBrokerResponse(query, QUERY_OPTIONS); + + ResultTable resultTable = brokerResponse.getResultTable(); + List<Object[]> rows = resultTable.getRows(); + assertEquals(rows.size(), 1); + assertEquals(rows.get(0)[0], null); + } + + @Test + public void testAllNullStddevPopReturnsNull() + throws Exception { + initializeRows(); + insertRow(null); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + Schema schema = new Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, FieldSpec.DataType.DOUBLE).build(); + setUpSegments(tableConfig, schema); + String query = String.format("SELECT STDDEV_POP(%s) FROM testTable", COLUMN1); + + BrokerResponseNative brokerResponse = getBrokerResponse(query, QUERY_OPTIONS); + + ResultTable resultTable = brokerResponse.getResultTable(); + List<Object[]> rows = resultTable.getRows(); + assertEquals(rows.size(), 1); + assertEquals(rows.get(0)[0], null); + } + + @Test + public void testNoMatchingRowNullHandlingDisabledStddevPopReturnsNull() + throws Exception { + initializeRows(); + insertRow(1); + TableConfig tableConfig = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); + Schema schema = new Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, FieldSpec.DataType.DOUBLE).build(); + setUpSegments(tableConfig, schema); + String query = String.format("SELECT STDDEV_POP(%s) FROM testTable WHERE %s != 1", COLUMN1, COLUMN1); + + BrokerResponseNative brokerResponse = getBrokerResponse(query); + + ResultTable resultTable = brokerResponse.getResultTable(); + List<Object[]> rows = resultTable.getRows(); + assertEquals(rows.size(), 1); + assertEquals(rows.get(0)[0], Double.NEGATIVE_INFINITY); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org