This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 800220b860 VarianceAggregationFunction NULL support. (#11365)
800220b860 is described below

commit 800220b860a31243b179554c9d87732f9456e372
Author: Shen Yu <s...@startree.ai>
AuthorDate: Tue Aug 22 17:55:46 2023 +0000

    VarianceAggregationFunction NULL support. (#11365)
---
 .../function/AggregationFunctionFactory.java       |   8 +-
 .../function/VarianceAggregationFunction.java      |  84 +++++++++++---
 .../queries/NullHandlingEnabledQueriesTest.java    | 122 +++++++++++++++++++++
 3 files changed, 196 insertions(+), 18 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index a79f01a0f1..0f03ee8723 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -317,13 +317,13 @@ public class AggregationFunctionFactory {
           case BOOLOR:
             return new BooleanOrAggregationFunction(firstArgument, 
nullHandlingEnabled);
           case VARPOP:
-            return new VarianceAggregationFunction(firstArgument, false, 
false);
+            return new VarianceAggregationFunction(firstArgument, false, 
false, nullHandlingEnabled);
           case VARSAMP:
-            return new VarianceAggregationFunction(firstArgument, true, false);
+            return new VarianceAggregationFunction(firstArgument, true, false, 
nullHandlingEnabled);
           case STDDEVPOP:
-            return new VarianceAggregationFunction(firstArgument, false, true);
+            return new VarianceAggregationFunction(firstArgument, false, true, 
nullHandlingEnabled);
           case STDDEVSAMP:
-            return new VarianceAggregationFunction(firstArgument, true, true);
+            return new VarianceAggregationFunction(firstArgument, true, true, 
nullHandlingEnabled);
           case SKEWNESS:
             return new FourthMomentAggregationFunction(firstArgument, 
FourthMomentAggregationFunction.Type.SKEWNESS);
           case KURTOSIS:
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
index c86269b7ce..5498731442 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/VarianceAggregationFunction.java
@@ -29,6 +29,7 @@ import 
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder
 import 
org.apache.pinot.core.query.aggregation.utils.StatisticalAggregationFunctionUtils;
 import org.apache.pinot.segment.local.customobject.VarianceTuple;
 import org.apache.pinot.segment.spi.AggregationFunctionType;
+import org.roaringbitmap.RoaringBitmap;
 
 
 /**
@@ -41,13 +42,15 @@ import org.apache.pinot.segment.spi.AggregationFunctionType;
 public class VarianceAggregationFunction extends 
BaseSingleInputAggregationFunction<VarianceTuple, Double> {
   private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY;
   protected final boolean _isSample;
-
   protected final boolean _isStdDev;
+  protected final boolean _nullHandlingEnabled;
 
-  public VarianceAggregationFunction(ExpressionContext expression, boolean 
isSample, boolean isStdDev) {
+  public VarianceAggregationFunction(ExpressionContext expression, boolean 
isSample, boolean isStdDev,
+      boolean nullHandlingEnabled) {
     super(expression);
     _isSample = isSample;
     _isStdDev = isStdDev;
+    _nullHandlingEnabled = nullHandlingEnabled;
   }
 
   @Override
@@ -72,18 +75,38 @@ public class VarianceAggregationFunction extends 
BaseSingleInputAggregationFunct
   public void aggregate(int length, AggregationResultHolder 
aggregationResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
     double[] values = 
StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression);
+    RoaringBitmap nullBitmap = null;
+    if (_nullHandlingEnabled) {
+      nullBitmap = blockValSetMap.get(_expression).getNullBitmap();
+    }
 
     long count = 0;
     double sum = 0.0;
     double variance = 0.0;
-    for (int i = 0; i < length; i++) {
-      count++;
-      sum += values[i];
-      if (count > 1) {
-        variance = computeIntermediateVariance(count, sum, variance, 
values[i]);
+    if (nullBitmap != null && !nullBitmap.isEmpty()) {
+      for (int i = 0; i < length; i++) {
+        if (!nullBitmap.contains(i)) {
+          count++;
+          sum += values[i];
+          if (count > 1) {
+            variance = computeIntermediateVariance(count, sum, variance, 
values[i]);
+          }
+        }
+      }
+    } else {
+      for (int i = 0; i < length; i++) {
+        count++;
+        sum += values[i];
+        if (count > 1) {
+          variance = computeIntermediateVariance(count, sum, variance, 
values[i]);
+        }
       }
     }
-    setAggregationResult(aggregationResultHolder, length, sum, variance);
+
+    if (_nullHandlingEnabled && count == 0) {
+      return;
+    }
+    setAggregationResult(aggregationResultHolder, count, sum, variance);
   }
 
   private double computeIntermediateVariance(long count, double sum, double 
m2, double value) {
@@ -116,8 +139,20 @@ public class VarianceAggregationFunction extends 
BaseSingleInputAggregationFunct
   public void aggregateGroupBySV(int length, int[] groupKeyArray, 
GroupByResultHolder groupByResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
     double[] values = 
StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression);
-    for (int i = 0; i < length; i++) {
-      setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i], 
0.0);
+    RoaringBitmap nullBitmap = null;
+    if (_nullHandlingEnabled) {
+      nullBitmap = blockValSetMap.get(_expression).getNullBitmap();
+    }
+    if (nullBitmap != null && !nullBitmap.isEmpty()) {
+      for (int i = 0; i < length; i++) {
+        if (!nullBitmap.contains(i)) {
+          setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, 
values[i], 0.0);
+        }
+      }
+    } else {
+      for (int i = 0; i < length; i++) {
+        setGroupByResult(groupKeyArray[i], groupByResultHolder, 1L, values[i], 
0.0);
+      }
     }
   }
 
@@ -125,9 +160,23 @@ public class VarianceAggregationFunction extends 
BaseSingleInputAggregationFunct
   public void aggregateGroupByMV(int length, int[][] groupKeysArray, 
GroupByResultHolder groupByResultHolder,
       Map<ExpressionContext, BlockValSet> blockValSetMap) {
     double[] values = 
StatisticalAggregationFunctionUtils.getValSet(blockValSetMap, _expression);
-    for (int i = 0; i < length; i++) {
-      for (int groupKey : groupKeysArray[i]) {
-        setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0);
+    RoaringBitmap nullBitmap = null;
+    if (_nullHandlingEnabled) {
+      nullBitmap = blockValSetMap.get(_expression).getNullBitmap();
+    }
+    if (nullBitmap != null && !nullBitmap.isEmpty()) {
+      for (int i = 0; i < length; i++) {
+        if (!nullBitmap.contains(i)) {
+          for (int groupKey : groupKeysArray[i]) {
+            setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 
0.0);
+          }
+        }
+      }
+    } else {
+      for (int i = 0; i < length; i++) {
+        for (int groupKey : groupKeysArray[i]) {
+          setGroupByResult(groupKey, groupByResultHolder, 1L, values[i], 0.0);
+        }
       }
     }
   }
@@ -136,7 +185,7 @@ public class VarianceAggregationFunction extends 
BaseSingleInputAggregationFunct
   public VarianceTuple extractAggregationResult(AggregationResultHolder 
aggregationResultHolder) {
     VarianceTuple varianceTuple = aggregationResultHolder.getResult();
     if (varianceTuple == null) {
-      return new VarianceTuple(0L, 0.0, 0.0);
+      return _nullHandlingEnabled ? null : new VarianceTuple(0L, 0.0, 0.0);
     } else {
       return varianceTuple;
     }
@@ -149,6 +198,13 @@ public class VarianceAggregationFunction extends 
BaseSingleInputAggregationFunct
 
   @Override
   public VarianceTuple merge(VarianceTuple intermediateResult1, VarianceTuple 
intermediateResult2) {
+    if (_nullHandlingEnabled) {
+      if (intermediateResult1 == null) {
+        return intermediateResult2;
+      } else if (intermediateResult2 == null) {
+        return intermediateResult1;
+      }
+    }
     intermediateResult1.apply(intermediateResult2);
     return intermediateResult1;
   }
diff --git 
a/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
 
b/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
index b53bfa17fa..9830f35f7e 100644
--- 
a/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
+++ 
b/pinot-core/src/test/java/org/apache/pinot/queries/NullHandlingEnabledQueriesTest.java
@@ -1400,4 +1400,126 @@ public class NullHandlingEnabledQueriesTest extends 
BaseQueriesTest {
     assertEquals(rows.size(), NUM_OF_SEGMENT_COPIES);
     assertArrayEquals(rows.get(0), new Object[]{null});
   }
+
+  @Test(dataProvider = "NumberTypes")
+  public void testStddevPop(FieldSpec.DataType dataType)
+      throws Exception {
+    initializeRows();
+    insertRow(null);
+    insertRow(1);
+    insertRow(2);
+    TableConfig tableConfig = new 
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+    Schema schema = new 
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType).build();
+    setUpSegments(tableConfig, schema);
+    String query = String.format("SELECT STDDEV_POP(%s) FROM testTable", 
COLUMN1);
+
+    BrokerResponseNative brokerResponse = getBrokerResponse(query, 
QUERY_OPTIONS);
+
+    ResultTable resultTable = brokerResponse.getResultTable();
+    List<Object[]> rows = resultTable.getRows();
+    assertEquals(rows.size(), 1);
+    assertEquals(rows.get(0)[0], 0.5);
+  }
+
+  @Test(dataProvider = "NumberTypes")
+  public void testGroupByStddevPop(FieldSpec.DataType dataType)
+      throws Exception {
+    initializeRows();
+    insertRowWithTwoColumns(null, "key");
+    insertRowWithTwoColumns(1, "key");
+    insertRowWithTwoColumns(2, "key");
+    TableConfig tableConfig = new 
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+    Schema schema = new 
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType)
+        .addSingleValueDimension(COLUMN2, FieldSpec.DataType.STRING).build();
+    setUpSegments(tableConfig, schema);
+    String query = String.format("SELECT STDDEV_POP(%s), %s FROM testTable 
GROUP BY %s", COLUMN1, COLUMN2, COLUMN2);
+
+    BrokerResponseNative brokerResponse = getBrokerResponse(query, 
QUERY_OPTIONS);
+
+    ResultTable resultTable = brokerResponse.getResultTable();
+    List<Object[]> rows = resultTable.getRows();
+    assertEquals(rows.size(), 1);
+    assertArrayEquals(rows.get(0), new Object[]{0.5, "key"});
+  }
+
+  @Test(dataProvider = "NumberTypes")
+  public void testGroupByMvStddevPop(FieldSpec.DataType dataType)
+      throws Exception {
+    initializeRows();
+    insertRowWithTwoColumns(null, new String[]{"key1", "key2"});
+    insertRowWithTwoColumns(1, new String[]{"key1", "key2"});
+    insertRowWithTwoColumns(2, new String[]{"key1"});
+    TableConfig tableConfig = new 
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+    Schema schema = new 
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, dataType)
+        .addMultiValueDimension(COLUMN2, FieldSpec.DataType.STRING).build();
+    setUpSegments(tableConfig, schema);
+    String query =
+        String.format("SELECT STDDEV_POP(%s), %s FROM testTable GROUP BY %s 
ORDER BY %s", COLUMN1, COLUMN2, COLUMN2,
+            COLUMN2);
+
+    BrokerResponseNative brokerResponse = getBrokerResponse(query, 
QUERY_OPTIONS);
+
+    ResultTable resultTable = brokerResponse.getResultTable();
+    List<Object[]> rows = resultTable.getRows();
+    assertEquals(rows.size(), 2);
+    assertArrayEquals(rows.get(0), new Object[]{0.5, "key1"});
+    assertArrayEquals(rows.get(1), new Object[]{0.0, "key2"});
+  }
+
+  @Test
+  public void testAllNullGroupByStddevPopReturnsNull()
+      throws Exception {
+    initializeRows();
+    insertRowWithTwoColumns(null, "key1");
+    TableConfig tableConfig = new 
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+    Schema schema = new 
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, FieldSpec.DataType.INT)
+        .addSingleValueDimension(COLUMN2, FieldSpec.DataType.STRING).build();
+    setUpSegments(tableConfig, schema);
+    String query =
+        String.format("SELECT STDDEV_POP(%s), %s FROM testTable GROUP BY %s 
ORDER BY %s", COLUMN1, COLUMN2, COLUMN2,
+            COLUMN2);
+
+    BrokerResponseNative brokerResponse = getBrokerResponse(query, 
QUERY_OPTIONS);
+
+    ResultTable resultTable = brokerResponse.getResultTable();
+    List<Object[]> rows = resultTable.getRows();
+    assertEquals(rows.size(), 1);
+    assertEquals(rows.get(0)[0], null);
+  }
+
+  @Test
+  public void testAllNullStddevPopReturnsNull()
+      throws Exception {
+    initializeRows();
+    insertRow(null);
+    TableConfig tableConfig = new 
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+    Schema schema = new 
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, 
FieldSpec.DataType.DOUBLE).build();
+    setUpSegments(tableConfig, schema);
+    String query = String.format("SELECT STDDEV_POP(%s) FROM testTable", 
COLUMN1);
+
+    BrokerResponseNative brokerResponse = getBrokerResponse(query, 
QUERY_OPTIONS);
+
+    ResultTable resultTable = brokerResponse.getResultTable();
+    List<Object[]> rows = resultTable.getRows();
+    assertEquals(rows.size(), 1);
+    assertEquals(rows.get(0)[0], null);
+  }
+
+  @Test
+  public void testNoMatchingRowNullHandlingDisabledStddevPopReturnsNull()
+      throws Exception {
+    initializeRows();
+    insertRow(1);
+    TableConfig tableConfig = new 
TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+    Schema schema = new 
Schema.SchemaBuilder().addSingleValueDimension(COLUMN1, 
FieldSpec.DataType.DOUBLE).build();
+    setUpSegments(tableConfig, schema);
+    String query = String.format("SELECT STDDEV_POP(%s) FROM testTable WHERE 
%s != 1", COLUMN1, COLUMN1);
+
+    BrokerResponseNative brokerResponse = getBrokerResponse(query);
+
+    ResultTable resultTable = brokerResponse.getResultTable();
+    List<Object[]> rows = resultTable.getRows();
+    assertEquals(rows.size(), 1);
+    assertEquals(rows.get(0)[0], Double.NEGATIVE_INFINITY);
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to