This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 8de6fa9b9a2 Enhance ArrayAgg to support aggregation on multi-value
columns (#17153)
8de6fa9b9a2 is described below
commit 8de6fa9b9a24c4949307ba78d7f3894eaf59fe04
Author: Xiang Fu <[email protected]>
AuthorDate: Sun Nov 9 19:44:36 2025 -0800
Enhance ArrayAgg to support aggregation on multi-value columns (#17153)
---
.../array/ArrayAggDistinctDoubleFunction.java | 25 +-
.../array/ArrayAggDistinctFloatFunction.java | 25 +-
.../array/ArrayAggDistinctIntFunction.java | 25 +-
.../array/ArrayAggDistinctLongFunction.java | 25 +-
.../array/ArrayAggDistinctStringFunction.java | 21 +-
.../function/array/ArrayAggDoubleFunction.java | 25 +-
.../function/array/ArrayAggFloatFunction.java | 25 +-
.../function/array/ArrayAggIntFunction.java | 25 +-
.../function/array/ArrayAggLongFunction.java | 25 +-
.../function/array/ArrayAggStringFunction.java | 21 +-
.../function/array/BaseArrayAggDoubleFunction.java | 58 ++--
.../function/array/BaseArrayAggFloatFunction.java | 58 ++--
.../function/array/BaseArrayAggIntFunction.java | 58 ++--
.../function/array/BaseArrayAggLongFunction.java | 58 ++--
.../function/array/BaseArrayAggStringFunction.java | 58 ++--
.../function/ArrayAggMvFunctionTest.java | 299 +++++++++++++++++++++
.../pinot/queries/ArrayAggMvQueriesTest.java | 171 ++++++++++++
.../pinot/integration/tests/custom/ArrayTest.java | 73 ++++-
.../pinot/segment/spi/AggregationFunctionType.java | 21 +-
19 files changed, 950 insertions(+), 146 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java
index 131a67c62fa..2136d38dd5e 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctDoubleFunction.java
@@ -38,15 +38,26 @@ public class ArrayAggDistinctDoubleFunction extends
BaseArrayAggDoubleFunction<D
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- double[] value = blockValSet.getDoubleValuesSV();
DoubleOpenHashSet valueSet = aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult()
: new DoubleOpenHashSet(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueSet.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ double[] values = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueSet.add(values[i]);
+ }
+ });
+ } else {
+ double[][] valuesArray = blockValSet.getDoubleValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ double[] values = valuesArray[i];
+ for (double v : values) {
+ valueSet.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueSet);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java
index bd65decab01..f3adc68881e 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctFloatFunction.java
@@ -38,15 +38,26 @@ public class ArrayAggDistinctFloatFunction extends
BaseArrayAggFloatFunction<Flo
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- float[] value = blockValSet.getFloatValuesSV();
FloatOpenHashSet valueSet = aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult()
: new FloatOpenHashSet(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueSet.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ float[] values = blockValSet.getFloatValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueSet.add(values[i]);
+ }
+ });
+ } else {
+ float[][] valuesArray = blockValSet.getFloatValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ float[] values = valuesArray[i];
+ for (float v : values) {
+ valueSet.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueSet);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java
index 1e8a4a9f350..deef27763ad 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctIntFunction.java
@@ -40,15 +40,26 @@ public class ArrayAggDistinctIntFunction extends
BaseArrayAggIntFunction<IntSet>
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- int[] value = blockValSet.getIntValuesSV();
IntOpenHashSet valueSet =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult() : new IntOpenHashSet(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueSet.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ int[] values = blockValSet.getIntValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueSet.add(values[i]);
+ }
+ });
+ } else {
+ int[][] valuesArray = blockValSet.getIntValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] values = valuesArray[i];
+ for (int v : values) {
+ valueSet.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueSet);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java
index 78a1b0e95cc..b8e0de6ff2b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctLongFunction.java
@@ -40,15 +40,26 @@ public class ArrayAggDistinctLongFunction extends
BaseArrayAggLongFunction<LongS
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- long[] value = blockValSet.getLongValuesSV();
LongOpenHashSet valueSet =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult() : new LongOpenHashSet(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueSet.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ long[] values = blockValSet.getLongValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueSet.add(values[i]);
+ }
+ });
+ } else {
+ long[][] valuesArray = blockValSet.getLongValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ long[] values = valuesArray[i];
+ for (long v : values) {
+ valueSet.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueSet);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java
index 533cf89ceb2..5a47ee75a67 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDistinctStringFunction.java
@@ -20,7 +20,6 @@ package
org.apache.pinot.core.query.aggregation.function.array;
import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
import it.unimi.dsi.fastutil.objects.ObjectSet;
-import java.util.Arrays;
import java.util.Map;
import org.apache.pinot.common.CustomObject;
import org.apache.pinot.common.request.context.ExpressionContext;
@@ -39,11 +38,27 @@ public class ArrayAggDistinctStringFunction extends
BaseArrayAggStringFunction<O
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- String[] value = blockValSet.getStringValuesSV();
ObjectOpenHashSet<String> valueSet =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult()
: new ObjectOpenHashSet<>(length);
- forEachNotNull(length, blockValSet, (from, to) ->
valueSet.addAll(Arrays.asList(value).subList(from, to)));
+ if (blockValSet.isSingleValue()) {
+ String[] values = blockValSet.getStringValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueSet.add(values[i]);
+ }
+ });
+ } else {
+ String[][] valuesArray = blockValSet.getStringValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ String[] values = valuesArray[i];
+ for (String v : values) {
+ valueSet.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueSet);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java
index a0fbe7945e5..afc5a6b250a 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggDoubleFunction.java
@@ -37,15 +37,26 @@ public class ArrayAggDoubleFunction extends
BaseArrayAggDoubleFunction<DoubleArr
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- double[] value = blockValSet.getDoubleValuesSV();
DoubleArrayList valueArray =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult() : new DoubleArrayList(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueArray.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ double[] values = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueArray.add(values[i]);
+ }
+ });
+ } else {
+ double[][] valuesArray = blockValSet.getDoubleValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ double[] values = valuesArray[i];
+ for (double v : values) {
+ valueArray.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueArray);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java
index 98c114158a1..e48560412e5 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggFloatFunction.java
@@ -37,15 +37,26 @@ public class ArrayAggFloatFunction extends
BaseArrayAggFloatFunction<FloatArrayL
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- float[] value = blockValSet.getFloatValuesSV();
FloatArrayList valueArray =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult() : new FloatArrayList(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueArray.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ float[] values = blockValSet.getFloatValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueArray.add(values[i]);
+ }
+ });
+ } else {
+ float[][] valuesArray = blockValSet.getFloatValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ float[] values = valuesArray[i];
+ for (float v : values) {
+ valueArray.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueArray);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java
index 7f8d7f8e07b..09f23c816f3 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggIntFunction.java
@@ -38,15 +38,26 @@ public class ArrayAggIntFunction extends
BaseArrayAggIntFunction<IntArrayList> {
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- int[] value = blockValSet.getIntValuesSV();
IntArrayList valueArray =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult() : new IntArrayList(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueArray.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ int[] values = blockValSet.getIntValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueArray.add(values[i]);
+ }
+ });
+ } else {
+ int[][] valuesArray = blockValSet.getIntValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] values = valuesArray[i];
+ for (int v : values) {
+ valueArray.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueArray);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java
index 2fab4eb166b..6e89ee94a5f 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggLongFunction.java
@@ -38,15 +38,26 @@ public class ArrayAggLongFunction extends
BaseArrayAggLongFunction<LongArrayList
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- long[] value = blockValSet.getLongValuesSV();
LongArrayList valueArray =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult() : new LongArrayList(length);
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- valueArray.add(value[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ long[] values = blockValSet.getLongValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueArray.add(values[i]);
+ }
+ });
+ } else {
+ long[][] valuesArray = blockValSet.getLongValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ long[] values = valuesArray[i];
+ for (long v : values) {
+ valueArray.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueArray);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java
index 1556890b16c..cfa4cdd4b08 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ArrayAggStringFunction.java
@@ -19,7 +19,6 @@
package org.apache.pinot.core.query.aggregation.function.array;
import it.unimi.dsi.fastutil.objects.ObjectArrayList;
-import java.util.Arrays;
import java.util.Map;
import org.apache.pinot.common.CustomObject;
import org.apache.pinot.common.request.context.ExpressionContext;
@@ -38,11 +37,27 @@ public class ArrayAggStringFunction extends
BaseArrayAggStringFunction<ObjectArr
public void aggregate(int length, AggregationResultHolder
aggregationResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- String[] value = blockValSet.getStringValuesSV();
ObjectArrayList<String> valueArray =
aggregationResultHolder.getResult() != null ?
aggregationResultHolder.getResult()
: new ObjectArrayList<>(length);
- forEachNotNull(length, blockValSet, (from, to) ->
valueArray.addAll(Arrays.asList(value).subList(from, to)));
+ if (blockValSet.isSingleValue()) {
+ String[] values = blockValSet.getStringValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ valueArray.add(values[i]);
+ }
+ });
+ } else {
+ String[][] valuesArray = blockValSet.getStringValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ String[] values = valuesArray[i];
+ for (String v : values) {
+ valueArray.add(v);
+ }
+ }
+ });
+ }
aggregationResultHolder.setValue(valueArray);
}
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java
index 62c683c5808..d4307ed4032 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggDoubleFunction.java
@@ -39,29 +39,55 @@ public abstract class BaseArrayAggDoubleFunction<I extends
DoubleCollection>
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- double[] values = blockValSet.getDoubleValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ double[] values = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
+ }
+ });
+ } else {
+ double[][] valuesArray = blockValSet.getDoubleValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int groupKey = groupKeyArray[i];
+ double[] values = valuesArray[i];
+ for (double v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
+ });
+ }
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- double[] values = blockValSet.getDoubleValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- int[] groupKeys = groupKeysArray[i];
- for (int groupKey : groupKeys) {
- setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ if (blockValSet.isSingleValue()) {
+ double[] values = blockValSet.getDoubleValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ for (int groupKey : groupKeys) {
+ setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ }
+ }
+ });
+ } else {
+ double[][] valuesArray = blockValSet.getDoubleValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ double[] values = valuesArray[i];
+ for (int groupKey : groupKeys) {
+ for (double v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
}
- }
- });
+ });
+ }
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java
index 78191c9349a..47dbf7e2c97 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggFloatFunction.java
@@ -39,29 +39,55 @@ public abstract class BaseArrayAggFloatFunction<I extends
FloatCollection>
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- float[] values = blockValSet.getFloatValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ float[] values = blockValSet.getFloatValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
+ }
+ });
+ } else {
+ float[][] valuesArray = blockValSet.getFloatValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int groupKey = groupKeyArray[i];
+ float[] values = valuesArray[i];
+ for (float v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
+ });
+ }
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- float[] values = blockValSet.getFloatValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- int[] groupKeys = groupKeysArray[i];
- for (int groupKey : groupKeys) {
- setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ if (blockValSet.isSingleValue()) {
+ float[] values = blockValSet.getFloatValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ for (int groupKey : groupKeys) {
+ setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ }
+ }
+ });
+ } else {
+ float[][] valuesArray = blockValSet.getFloatValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ float[] values = valuesArray[i];
+ for (int groupKey : groupKeys) {
+ for (float v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
}
- }
- });
+ });
+ }
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java
index e29e0cc73a6..808e4a5eb9d 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggIntFunction.java
@@ -40,29 +40,55 @@ public abstract class BaseArrayAggIntFunction<I extends
IntCollection>
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- int[] values = blockValSet.getIntValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ int[] values = blockValSet.getIntValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
+ }
+ });
+ } else {
+ int[][] valuesArray = blockValSet.getIntValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int groupKey = groupKeyArray[i];
+ int[] values = valuesArray[i];
+ for (int v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
+ });
+ }
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- int[] values = blockValSet.getIntValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- int[] groupKeys = groupKeysArray[i];
- for (int groupKey : groupKeys) {
- setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ if (blockValSet.isSingleValue()) {
+ int[] values = blockValSet.getIntValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ for (int groupKey : groupKeys) {
+ setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ }
+ }
+ });
+ } else {
+ int[][] valuesArray = blockValSet.getIntValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ int[] values = valuesArray[i];
+ for (int groupKey : groupKeys) {
+ for (int v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
}
- }
- });
+ });
+ }
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java
index 165457c1ac1..2ae70253a7f 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggLongFunction.java
@@ -40,29 +40,55 @@ public abstract class BaseArrayAggLongFunction<I extends
LongCollection>
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- long[] values = blockValSet.getLongValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ long[] values = blockValSet.getLongValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
+ }
+ });
+ } else {
+ long[][] valuesArray = blockValSet.getLongValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int groupKey = groupKeyArray[i];
+ long[] values = valuesArray[i];
+ for (long v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
+ });
+ }
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- long[] values = blockValSet.getLongValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- int[] groupKeys = groupKeysArray[i];
- for (int groupKey : groupKeys) {
- setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ if (blockValSet.isSingleValue()) {
+ long[] values = blockValSet.getLongValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ for (int groupKey : groupKeys) {
+ setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ }
+ }
+ });
+ } else {
+ long[][] valuesArray = blockValSet.getLongValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ long[] values = valuesArray[i];
+ for (int groupKey : groupKeys) {
+ for (long v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
}
- }
- });
+ });
+ }
}
@Override
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java
index 1f4d790415b..fd384c6a55a 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/BaseArrayAggStringFunction.java
@@ -40,29 +40,55 @@ public abstract class BaseArrayAggStringFunction<I extends
ObjectCollection<Stri
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- String[] values = blockValSet.getStringValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ String[] values = blockValSet.getStringValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ setGroupByResult(groupByResultHolder, groupKeyArray[i], values[i]);
+ }
+ });
+ } else {
+ String[][] valuesArray = blockValSet.getStringValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int groupKey = groupKeyArray[i];
+ String[] values = valuesArray[i];
+ for (String v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
+ });
+ }
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- String[] values = blockValSet.getStringValuesSV();
-
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- int[] groupKeys = groupKeysArray[i];
- for (int groupKey : groupKeys) {
- setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ if (blockValSet.isSingleValue()) {
+ String[] values = blockValSet.getStringValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ for (int groupKey : groupKeys) {
+ setGroupByResult(groupByResultHolder, groupKey, values[i]);
+ }
+ }
+ });
+ } else {
+ String[][] valuesArray = blockValSet.getStringValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ String[] values = valuesArray[i];
+ for (int groupKey : groupKeys) {
+ for (String v : values) {
+ setGroupByResult(groupByResultHolder, groupKey, v);
+ }
+ }
}
- }
- });
+ });
+ }
}
@Override
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ArrayAggMvFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ArrayAggMvFunctionTest.java
new file mode 100644
index 00000000000..d6b5ece48f5
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ArrayAggMvFunctionTest.java
@@ -0,0 +1,299 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
+import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet;
+import it.unimi.dsi.fastutil.floats.FloatArrayList;
+import it.unimi.dsi.fastutil.floats.FloatOpenHashSet;
+import it.unimi.dsi.fastutil.ints.IntArrayList;
+import it.unimi.dsi.fastutil.ints.IntOpenHashSet;
+import it.unimi.dsi.fastutil.longs.LongArrayList;
+import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
+import it.unimi.dsi.fastutil.objects.ObjectArrayList;
+import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet;
+import java.util.Map;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.common.ObjectSerDeUtils;
+import org.apache.pinot.core.common.SyntheticBlockValSets;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggDistinctDoubleFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggDistinctFloatFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggDistinctIntFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggDistinctLongFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggDistinctStringFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggDoubleFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggFloatFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggIntFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggLongFunction;
+import
org.apache.pinot.core.query.aggregation.function.array.ArrayAggStringFunction;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class ArrayAggMvFunctionTest extends AbstractAggregationFunctionTest {
+
+ private static class TestDoubleMVBlock extends SyntheticBlockValSets.Base {
+ private final double[][] _values;
+
+ TestDoubleMVBlock(double[][] values) {
+ _values = values;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return false;
+ }
+
+ @Override
+ public double[][] getDoubleValuesMV() {
+ return _values;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.DOUBLE;
+ }
+ }
+
+ private static class TestLongMVBlock extends SyntheticBlockValSets.Base {
+ private final long[][] _values;
+
+ TestLongMVBlock(long[][] values) {
+ _values = values;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return false;
+ }
+
+ @Override
+ public long[][] getLongValuesMV() {
+ return _values;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.LONG;
+ }
+ }
+
+ private static class TestIntMVBlock extends SyntheticBlockValSets.Base {
+ private final int[][] _values;
+
+ TestIntMVBlock(int[][] values) {
+ _values = values;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return false;
+ }
+
+ @Override
+ public int[][] getIntValuesMV() {
+ return _values;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.INT;
+ }
+ }
+
+ private static class TestFloatMVBlock extends SyntheticBlockValSets.Base {
+ private final float[][] _values;
+
+ TestFloatMVBlock(float[][] values) {
+ _values = values;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return false;
+ }
+
+ @Override
+ public float[][] getFloatValuesMV() {
+ return _values;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.FLOAT;
+ }
+ }
+
+ private static class TestStringMVBlock extends SyntheticBlockValSets.Base {
+ private final String[][] _values;
+
+ TestStringMVBlock(String[][] values) {
+ _values = values;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return false;
+ }
+
+ @Override
+ public String[][] getStringValuesMV() {
+ return _values;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.STRING;
+ }
+ }
+
+ @Test
+ public void testDoubleArrayAggMvMultipleBlocks() {
+ ArrayAggDistinctDoubleFunction distinctFn =
+ new
ArrayAggDistinctDoubleFunction(ExpressionContext.forIdentifier("myField"),
false);
+ AggregationResultHolder holder =
distinctFn.createAggregationResultHolder();
+
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestDoubleMVBlock(new double[][]{{1.0, 2.0}, {2.0}})));
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestDoubleMVBlock(new double[][]{{2.0, 3.0}, {3.0}})));
+ DoubleOpenHashSet distinct = holder.getResult();
+ assertEquals(distinct.size(), 3);
+
+ ArrayAggDoubleFunction fn = new
ArrayAggDoubleFunction(ExpressionContext.forIdentifier("myField"), false);
+ holder = fn.createAggregationResultHolder();
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestDoubleMVBlock(new double[][]{{1.0, 2.0}, {2.0}})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestDoubleMVBlock(new double[][]{{2.0, 3.0}, {3.0}})));
+ DoubleArrayList result = holder.getResult();
+ assertEquals(result.size(), 6);
+
+ // round-trip ser/de
+ AggregationFunction.SerializedIntermediateResult ser =
fn.serializeIntermediateResult(result);
+ DoubleArrayList deser = ObjectSerDeUtils.deserialize(ser.getBytes(),
ObjectSerDeUtils.ObjectType.DoubleArrayList);
+ assertEquals(deser.size(), 6);
+ }
+
+ @Test
+ public void testLongArrayAggMvMultipleBlocks() {
+ ArrayAggDistinctLongFunction distinctFn = new ArrayAggDistinctLongFunction(
+ ExpressionContext.forIdentifier("myField"), FieldSpec.DataType.LONG,
false);
+ AggregationResultHolder holder =
distinctFn.createAggregationResultHolder();
+
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestLongMVBlock(new long[][]{{1L, 2L}, {2L}})));
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestLongMVBlock(new long[][]{{2L, 3L}, {3L}})));
+ LongOpenHashSet distinct = holder.getResult();
+ assertEquals(distinct.size(), 3);
+
+ ArrayAggLongFunction fn = new
ArrayAggLongFunction(ExpressionContext.forIdentifier("myField"),
+ FieldSpec.DataType.LONG, false);
+ holder = fn.createAggregationResultHolder();
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestLongMVBlock(new long[][]{{1L, 2L}, {2L}})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestLongMVBlock(new long[][]{{2L, 3L}, {3L}})));
+ LongArrayList result = holder.getResult();
+ assertEquals(result.size(), 6);
+
+ // group-by path sanity
+ GroupByResultHolder gbHolder = new ObjectGroupByResultHolder(4, 4);
+ fn.aggregateGroupBySV(2, new int[]{0, 1}, gbHolder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestLongMVBlock(new long[][]{{5L}, {6L, 7L}})));
+ assertEquals(((LongArrayList) gbHolder.getResult(0)).size(), 1);
+ assertEquals(((LongArrayList) gbHolder.getResult(1)).size(), 2);
+ }
+
+ @Test
+ public void testIntArrayAggMvMultipleBlocks() {
+ ArrayAggDistinctIntFunction distinctFn = new ArrayAggDistinctIntFunction(
+ ExpressionContext.forIdentifier("myField"), FieldSpec.DataType.INT,
false);
+ AggregationResultHolder holder =
distinctFn.createAggregationResultHolder();
+
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestIntMVBlock(new int[][]{{1, 2}, {2}})));
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestIntMVBlock(new int[][]{{2, 3}, {3}})));
+ IntOpenHashSet distinct = holder.getResult();
+ assertEquals(distinct.size(), 3);
+
+ ArrayAggIntFunction fn = new
ArrayAggIntFunction(ExpressionContext.forIdentifier("myField"),
+ FieldSpec.DataType.INT, false);
+ holder = fn.createAggregationResultHolder();
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestIntMVBlock(new int[][]{{1, 2}, {2}})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestIntMVBlock(new int[][]{{2, 3}, {3}})));
+ IntArrayList result = holder.getResult();
+ assertEquals(result.size(), 6);
+ }
+
+ @Test
+ public void testFloatArrayAggMvMultipleBlocks() {
+ ArrayAggDistinctFloatFunction distinctFn =
+ new
ArrayAggDistinctFloatFunction(ExpressionContext.forIdentifier("myField"),
false);
+ AggregationResultHolder holder =
distinctFn.createAggregationResultHolder();
+
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestFloatMVBlock(new float[][]{{1.0f, 2.0f}, {2.0f}})));
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestFloatMVBlock(new float[][]{{2.0f, 3.0f}, {3.0f}})));
+ FloatOpenHashSet distinct = holder.getResult();
+ assertEquals(distinct.size(), 3);
+
+ ArrayAggFloatFunction fn = new
ArrayAggFloatFunction(ExpressionContext.forIdentifier("myField"), false);
+ holder = fn.createAggregationResultHolder();
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestFloatMVBlock(new float[][]{{1.0f, 2.0f}, {2.0f}})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestFloatMVBlock(new float[][]{{2.0f, 3.0f}, {3.0f}})));
+ FloatArrayList result = holder.getResult();
+ assertEquals(result.size(), 6);
+ }
+
+ @Test
+ public void testStringArrayAggMvMultipleBlocks() {
+ ArrayAggDistinctStringFunction distinctFn =
+ new
ArrayAggDistinctStringFunction(ExpressionContext.forIdentifier("myField"),
false);
+ AggregationResultHolder holder =
distinctFn.createAggregationResultHolder();
+
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"A", "B"}, {"B"}})));
+ distinctFn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"B", "C"}, {"C"}})));
+ ObjectOpenHashSet<String> distinct = holder.getResult();
+ assertEquals(distinct.size(), 3);
+
+ ArrayAggStringFunction fn = new
ArrayAggStringFunction(ExpressionContext.forIdentifier("myField"), false);
+ holder = fn.createAggregationResultHolder();
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"A", "B"}, {"B"}})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"B", "C"}, {"C"}})));
+ ObjectArrayList<String> result = holder.getResult();
+ assertEquals(result.size(), 6);
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/ArrayAggMvQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/ArrayAggMvQueriesTest.java
new file mode 100644
index 00000000000..118d27d143c
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/queries/ArrayAggMvQueriesTest.java
@@ -0,0 +1,171 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Set;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.response.broker.ResultTable;
+import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import
org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
+import
org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
+import org.apache.pinot.segment.spi.ImmutableSegment;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.ReadMode;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+
+
+public class ArrayAggMvQueriesTest extends BaseQueriesTest {
+ private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(),
"ArrayAggMvQueriesTest");
+ private static final String RAW_TABLE_NAME = "testTableMv";
+ private static final String SEGMENT_NAME = "testSegment";
+
+ private static final int NUM_RECORDS = 2000;
+
+ private static final String INT_MV = "intMV";
+ private static final String LONG_MV = "longMV";
+ private static final String FLOAT_MV = "floatMV";
+ private static final String DOUBLE_MV = "doubleMV";
+ private static final String STRING_MV = "stringMV";
+ private static final String GROUP_BY_COLUMN = "groupKey";
+
+ private static final Schema SCHEMA = new
Schema.SchemaBuilder().addMultiValueDimension(INT_MV, DataType.INT)
+ .addMultiValueDimension(LONG_MV,
DataType.LONG).addMultiValueDimension(FLOAT_MV, DataType.FLOAT)
+ .addMultiValueDimension(DOUBLE_MV,
DataType.DOUBLE).addMultiValueDimension(STRING_MV, DataType.STRING)
+ .addSingleValueDimension(GROUP_BY_COLUMN, DataType.STRING).build();
+
+ private IndexSegment _indexSegment;
+ private List<IndexSegment> _indexSegments;
+
+ @Override
+ protected String getFilter() {
+ return "";
+ }
+
+ @Override
+ protected IndexSegment getIndexSegment() {
+ return _indexSegment;
+ }
+
+ @Override
+ protected List<IndexSegment> getIndexSegments() {
+ return _indexSegments;
+ }
+
+ @BeforeClass
+ public void setUp()
+ throws Exception {
+ FileUtils.deleteDirectory(INDEX_DIR);
+
+ List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+ for (int i = 0; i < NUM_RECORDS; i++) {
+ GenericRow record = new GenericRow();
+ record.putValue(INT_MV, new Integer[]{i, i + NUM_RECORDS + 1});
+ record.putValue(LONG_MV, new Long[]{(long) i, (long) i + NUM_RECORDS +
1});
+ record.putValue(FLOAT_MV, new Float[]{(float) i, (float) i + NUM_RECORDS
+ 1});
+ record.putValue(DOUBLE_MV, new Double[]{(double) i, (double) i +
NUM_RECORDS + 1});
+ record.putValue(STRING_MV, new String[]{Integer.toString(i),
Integer.toString(i + NUM_RECORDS + 1)});
+ record.putValue(GROUP_BY_COLUMN, String.valueOf(i % 10));
+ records.add(record);
+ }
+
+ SegmentGeneratorConfig conf =
+ new SegmentGeneratorConfig(new
TableConfigBuilder(org.apache.pinot.spi.config.table.TableType.OFFLINE)
+ .setTableName(RAW_TABLE_NAME).build(), SCHEMA);
+ conf.setTableName(RAW_TABLE_NAME);
+ conf.setSegmentName(SEGMENT_NAME);
+ conf.setOutDir(INDEX_DIR.getPath());
+ SegmentIndexCreationDriverImpl driver = new
SegmentIndexCreationDriverImpl();
+ driver.init(conf, new GenericRowRecordReader(records));
+ driver.build();
+
+ ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new
File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+ _indexSegment = immutableSegment;
+ _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+ }
+
+ @Test
+ public void testArrayAggMvNonDistinct() {
+ String query = "SELECT ArrayAgg(intMV, 'INT'), ArrayAgg(longMV, 'LONG'),
ArrayAgg(floatMV, 'FLOAT'), "
+ + "ArrayAgg(doubleMV, 'DOUBLE'), ArrayAgg(stringMV, 'STRING') FROM
testTableMv";
+
+ AggregationOperator aggregationOperator = getOperator(query);
+ AggregationResultsBlock resultsBlock = aggregationOperator.nextBlock();
+ List<Object> aggregationResult = resultsBlock.getResults();
+ assertNotNull(aggregationResult);
+ for (int i = 0; i < 5; i++) {
+ assertEquals(((List<?>) aggregationResult.get(i)).size(), 2 *
NUM_RECORDS);
+ }
+
+ ResultTable resultTable = getBrokerResponse(query).getResultTable();
+ assertEquals(resultTable.getRows().get(0).length, 5);
+ // Final result flattens MV values across both segments; with this setup
it equals 8 × NUM_RECORDS
+ assertEquals(((int[]) resultTable.getRows().get(0)[0]).length, 8 *
NUM_RECORDS);
+ assertEquals(((long[]) resultTable.getRows().get(0)[1]).length, 8 *
NUM_RECORDS);
+ assertEquals(((float[]) resultTable.getRows().get(0)[2]).length, 8 *
NUM_RECORDS);
+ assertEquals(((double[]) resultTable.getRows().get(0)[3]).length, 8 *
NUM_RECORDS);
+ assertEquals(((String[]) resultTable.getRows().get(0)[4]).length, 8 *
NUM_RECORDS);
+ }
+
+ @Test
+ public void testArrayAggMvDistinct() {
+ String query = "SELECT ArrayAgg(intMV, 'INT', true), ArrayAgg(longMV,
'LONG', true), "
+ + "ArrayAgg(floatMV, 'FLOAT', true), ArrayAgg(doubleMV, 'DOUBLE',
true), "
+ + "ArrayAgg(stringMV, 'STRING', true) FROM testTableMv";
+
+ AggregationOperator aggregationOperator = getOperator(query);
+ AggregationResultsBlock resultsBlock = aggregationOperator.nextBlock();
+ List<Object> aggregationResult = resultsBlock.getResults();
+ assertNotNull(aggregationResult);
+ for (int i = 0; i < 5; i++) {
+ assertEquals(((Set<?>) aggregationResult.get(i)).size(), 2 *
NUM_RECORDS);
+ }
+
+ ResultTable resultTable = getBrokerResponse(query).getResultTable();
+ assertEquals(resultTable.getRows().get(0).length, 5);
+ assertEquals(((int[]) resultTable.getRows().get(0)[0]).length, 2 *
NUM_RECORDS);
+ assertEquals(((long[]) resultTable.getRows().get(0)[1]).length, 2 *
NUM_RECORDS);
+ assertEquals(((float[]) resultTable.getRows().get(0)[2]).length, 2 *
NUM_RECORDS);
+ assertEquals(((double[]) resultTable.getRows().get(0)[3]).length, 2 *
NUM_RECORDS);
+ assertEquals(((String[]) resultTable.getRows().get(0)[4]).length, 2 *
NUM_RECORDS);
+ }
+
+ @AfterClass
+ public void tearDown()
+ throws IOException {
+ _indexSegment.destroy();
+ FileUtils.deleteDirectory(INDEX_DIR);
+ }
+}
diff --git
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
index 20dde176be7..6ddc610a087 100644
---
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
+++
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/ArrayTest.java
@@ -85,6 +85,77 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
}
}
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testArrayAggMvQueries(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = String.format("SELECT arrayAgg(%s, 'LONG'), arrayAgg(%s,
'DOUBLE') FROM %s LIMIT %d",
+ LONG_ARRAY_COLUMN, DOUBLE_ARRAY_COLUMN, getTableName(),
getCountStarResult());
+ JsonNode result = postQuery(query).get("resultTable");
+ JsonNode rows = result.get("rows");
+ assertEquals(rows.size(), 1);
+ JsonNode row = rows.get(0);
+ assertEquals(row.size(), 2);
+ // Each row has 4 MV entries, total 1000 rows
+ assertEquals(row.get(0).size(), 4 * getCountStarResult());
+ assertEquals(row.get(1).size(), 4 * getCountStarResult());
+ }
+
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testArrayAggMvDistinctQueries(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = String.format("SELECT arrayAgg(%s, 'LONG', true),
arrayAgg(%s, 'DOUBLE', true) FROM %s LIMIT %d",
+ LONG_ARRAY_COLUMN, DOUBLE_ARRAY_COLUMN, getTableName(),
getCountStarResult());
+ JsonNode result = postQuery(query).get("resultTable");
+ JsonNode rows = result.get("rows");
+ assertEquals(rows.size(), 1);
+ JsonNode row = rows.get(0);
+ assertEquals(row.size(), 2);
+ // Distinct values for both arrays are 4
+ assertEquals(row.get(0).size(), 4);
+ assertEquals(row.get(1).size(), 4);
+ }
+
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testArrayAggMvGroupByQueries(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = String.format(
+ "SELECT arrayAgg(%s, 'LONG'), arrayAgg(%s, 'DOUBLE'), %s FROM %s GROUP
BY %s LIMIT %d",
+ LONG_ARRAY_COLUMN, DOUBLE_ARRAY_COLUMN, GROUP_BY_COLUMN,
getTableName(), GROUP_BY_COLUMN,
+ getCountStarResult());
+ JsonNode result = postQuery(query).get("resultTable");
+ JsonNode rows = result.get("rows");
+ assertEquals(rows.size(), 10);
+ for (int i = 0; i < 10; i++) {
+ JsonNode row = rows.get(i);
+ assertEquals(row.size(), 3);
+ // Each group has 1/10th rows
+ assertEquals(row.get(0).size(), 4 * (getCountStarResult() / 10));
+ assertEquals(row.get(1).size(), 4 * (getCountStarResult() / 10));
+ }
+ }
+
+ @Test(dataProvider = "useBothQueryEngines")
+ public void testArrayAggMvDistinctGroupByQueries(boolean
useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = String.format(
+ "SELECT arrayAgg(%s, 'LONG', true), arrayAgg(%s, 'DOUBLE', true), %s
FROM %s GROUP BY %s LIMIT %d",
+ LONG_ARRAY_COLUMN, DOUBLE_ARRAY_COLUMN, GROUP_BY_COLUMN,
getTableName(), GROUP_BY_COLUMN,
+ getCountStarResult());
+ JsonNode result = postQuery(query).get("resultTable");
+ JsonNode rows = result.get("rows");
+ assertEquals(rows.size(), 10);
+ for (int i = 0; i < 10; i++) {
+ JsonNode row = rows.get(i);
+ assertEquals(row.size(), 3);
+ assertEquals(row.get(0).size(), 4);
+ assertEquals(row.get(1).size(), 4);
+ }
+ }
+
@Test(dataProvider = "useBothQueryEngines")
public void testArrayAggQueries(boolean useMultiStageQueryEngine)
throws Exception {
@@ -946,7 +1017,7 @@ public class ArrayTest extends
CustomDataQueryClusterIntegrationTest {
assertEquals(row.size(), 4);
assertEquals(row.get(0).asInt() % 4 < 2, row.get(1).asBoolean());
assertEquals(row.get(1).asBoolean(), row.get(2).asBoolean());
- assertEquals(row.get(2).asBoolean(), row.get(2).asBoolean());
+ assertEquals(row.get(2).asBoolean(), row.get(3).asBoolean());
}
}
diff --git
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
index 91b8fd9b1b9..c33c28002de 100644
---
a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
+++
b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java
@@ -180,9 +180,7 @@ public enum AggregationFunctionType {
ReturnTypes.ARG1, OperandTypes.VARIADIC, SqlTypeName.OTHER,
SqlTypeName.BIGINT),
// Array aggregate functions
- ARRAYAGG("arrayAgg", ReturnTypes.TO_ARRAY,
- OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER,
SqlTypeFamily.BOOLEAN), i -> i == 2),
- SqlTypeName.OTHER),
+ ARRAYAGG("arrayAgg", new ArrayOfComponentReturnTypeInference(),
OperandTypes.VARIADIC, SqlTypeName.OTHER),
LISTAGG("listAgg", SqlTypeName.OTHER, SqlTypeName.VARCHAR),
SUMARRAYLONG("sumArrayLong", new
ArrayReturnTypeInference(SqlTypeName.BIGINT), OperandTypes.ARRAY,
SqlTypeName.OTHER),
@@ -403,6 +401,23 @@ public enum AggregationFunctionType {
}
}
+ /**
+ * Returns ARRAY of the component type of the first operand when the first
operand is an ARRAY.
+ * Falls back to ARRAY of the operand type if the component type is
unavailable.
+ */
+ private static class ArrayOfComponentReturnTypeInference implements
SqlReturnTypeInference {
+ @Override
+ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
+ RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
+ RelDataType operandType = opBinding.getOperandType(0);
+ RelDataType componentType = operandType.getComponentType();
+ if (componentType != null) {
+ return typeFactory.createArrayType(componentType, -1);
+ }
+ return typeFactory.createArrayType(operandType, -1);
+ }
+ }
+
// Used for aggregation functions that always return BIGINT. The "IfEmpty"
logic ensures that the return type is
// nullable for pure aggregation queries (no group-by) and filtered
aggregation queries. Return values can be null
// if there are no matching rows (even if the operand type is not nullable).
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]