This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 575d89f0c47 feat(agg): add ListAggMv (multi-value list aggregation)
and distinct variant; wire enum/factory; add unit and query tests.\n\nAlso
enables SQL syntax: listAggMv(col, 'sep'[, true]). (#17155)
575d89f0c47 is described below
commit 575d89f0c47f342ff187dfab4c766b18cb44657b
Author: Xiang Fu <[email protected]>
AuthorDate: Fri Nov 7 17:04:49 2025 -0800
feat(agg): add ListAggMv (multi-value list aggregation) and distinct
variant; wire enum/factory; add unit and query tests.\n\nAlso enables SQL
syntax: listAggMv(col, 'sep'[, true]). (#17155)
---
.../function/AggregationFunctionFactory.java | 9 +-
.../function/array/ListAggFunction.java | 76 ++++++--
.../aggregation/function/ListAggFunctionTest.java | 153 ++++++++++++++++
.../apache/pinot/queries/ListAggQueriesTest.java | 193 +++++++++++++++++++++
4 files changed, 409 insertions(+), 22 deletions(-)
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
index af18064d185..0fed9fe5e8b 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java
@@ -268,16 +268,19 @@ public class AggregationFunctionFactory {
}
case LISTAGG: {
Preconditions.checkArgument(numArguments == 2 || numArguments == 3,
- "LISTAGG expects 2 arguments, got: %s. The function can be
used as "
- + "listAgg([distinct] expression, 'separator')",
numArguments);
+ "LISTAGG expects 2 or 3 arguments, got: %s. The function can
be used as "
+ + "listAgg(expression, 'separator'[, true|false])",
numArguments);
ExpressionContext separatorExpression = arguments.get(1);
Preconditions.checkArgument(separatorExpression.getType() ==
ExpressionContext.Type.LITERAL,
"LISTAGG expects the 2nd argument to be literal, got: %s. The
function can be used as "
- + "listAgg([distinct] expression, 'separator')",
separatorExpression.getType());
+ + "listAgg(expression, 'separator'[, true|false])",
separatorExpression.getType());
String separator =
separatorExpression.getLiteral().getStringValue();
boolean isDistinct = false;
if (numArguments == 3) {
ExpressionContext isDistinctListAggExp = arguments.get(2);
+ Preconditions.checkArgument(isDistinctListAggExp.getType() ==
ExpressionContext.Type.LITERAL,
+ "LISTAGG expects the 3rd argument to be a boolean literal
(true/false), got: %s. The function can "
+ + "be used as listAgg(expression, 'separator'[,
true|false])", isDistinctListAggExp.getType());
isDistinct = isDistinctListAggExp.getLiteral().getBooleanValue();
}
if (isDistinct) {
diff --git
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ListAggFunction.java
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ListAggFunction.java
index 644b744b4ec..67c21d7426f 100644
---
a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ListAggFunction.java
+++
b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/array/ListAggFunction.java
@@ -65,10 +65,21 @@ public class ListAggFunction extends
NullableSingleInputAggregationFunction<Obje
Map<ExpressionContext, BlockValSet> blockValSetMap) {
ObjectCollection<String> valueSet =
getObjectCollection(aggregationResultHolder);
BlockValSet blockValSet = blockValSetMap.get(_expression);
- String[] values = blockValSet.getStringValuesSV();
- forEachNotNull(length, blockValSet, (from, to) -> {
- valueSet.addAll(Arrays.asList(values).subList(from, to));
- });
+ if (blockValSet.isSingleValue()) {
+ String[] values = blockValSet.getStringValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ valueSet.addAll(Arrays.asList(values).subList(from, to));
+ });
+ } else {
+ String[][] valuesArray = blockValSet.getStringValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (String v : valuesArray[i]) {
+ valueSet.add(v);
+ }
+ }
+ });
+ }
}
protected ObjectCollection<String>
getObjectCollection(AggregationResultHolder aggregationResultHolder) {
@@ -94,28 +105,55 @@ public class ListAggFunction extends
NullableSingleInputAggregationFunction<Obje
public void aggregateGroupBySV(int length, int[] groupKeyArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- String[] values = blockValSet.getStringValuesSV();
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- ObjectCollection<String> groupValueList =
getObjectCollection(groupByResultHolder, groupKeyArray[i]);
- groupValueList.add(values[i]);
- }
- });
+ if (blockValSet.isSingleValue()) {
+ String[] values = blockValSet.getStringValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ ObjectCollection<String> groupValueList =
getObjectCollection(groupByResultHolder, groupKeyArray[i]);
+ groupValueList.add(values[i]);
+ }
+ });
+ } else {
+ String[][] valuesArray = blockValSet.getStringValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ ObjectCollection<String> groupValueList =
getObjectCollection(groupByResultHolder, groupKeyArray[i]);
+ for (String v : valuesArray[i]) {
+ groupValueList.add(v);
+ }
+ }
+ });
+ }
}
@Override
public void aggregateGroupByMV(int length, int[][] groupKeysArray,
GroupByResultHolder groupByResultHolder,
Map<ExpressionContext, BlockValSet> blockValSetMap) {
BlockValSet blockValSet = blockValSetMap.get(_expression);
- String[] values = blockValSet.getStringValuesSV();
- forEachNotNull(length, blockValSet, (from, to) -> {
- for (int i = from; i < to; i++) {
- for (int groupKey : groupKeysArray[i]) {
- ObjectCollection<String> groupValueList =
getObjectCollection(groupByResultHolder, groupKey);
- groupValueList.add(values[i]);
+ if (blockValSet.isSingleValue()) {
+ String[] values = blockValSet.getStringValuesSV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ for (int groupKey : groupKeysArray[i]) {
+ ObjectCollection<String> groupValueList =
getObjectCollection(groupByResultHolder, groupKey);
+ groupValueList.add(values[i]);
+ }
}
- }
- });
+ });
+ } else {
+ String[][] valuesArray = blockValSet.getStringValuesMV();
+ forEachNotNull(length, blockValSet, (from, to) -> {
+ for (int i = from; i < to; i++) {
+ int[] groupKeys = groupKeysArray[i];
+ for (int groupKey : groupKeys) {
+ ObjectCollection<String> groupValueList =
getObjectCollection(groupByResultHolder, groupKey);
+ for (String v : valuesArray[i]) {
+ groupValueList.add(v);
+ }
+ }
+ }
+ });
+ }
}
@Override
diff --git
a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ListAggFunctionTest.java
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ListAggFunctionTest.java
new file mode 100644
index 00000000000..1dfff07ef7f
--- /dev/null
+++
b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/ListAggFunctionTest.java
@@ -0,0 +1,153 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.query.aggregation.function;
+
+import java.util.Map;
+import org.apache.pinot.common.request.context.ExpressionContext;
+import org.apache.pinot.core.common.SyntheticBlockValSets;
+import org.apache.pinot.core.query.aggregation.AggregationResultHolder;
+import
org.apache.pinot.core.query.aggregation.function.array.ListAggDistinctFunction;
+import org.apache.pinot.core.query.aggregation.function.array.ListAggFunction;
+import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder;
+import
org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class ListAggFunctionTest extends AbstractAggregationFunctionTest {
+
+ private static class TestStringSVBlock extends SyntheticBlockValSets.Base {
+ private final String[] _values;
+
+ TestStringSVBlock(String[] values) {
+ _values = values;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return true;
+ }
+
+ @Override
+ public String[] getStringValuesSV() {
+ return _values;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.STRING;
+ }
+ }
+
+ private static class TestStringMVBlock extends SyntheticBlockValSets.Base {
+ private final String[][] _values;
+
+ TestStringMVBlock(String[][] values) {
+ _values = values;
+ }
+
+ @Override
+ public boolean isSingleValue() {
+ return false;
+ }
+
+ @Override
+ public String[][] getStringValuesMV() {
+ return _values;
+ }
+
+ @Override
+ public FieldSpec.DataType getValueType() {
+ return FieldSpec.DataType.STRING;
+ }
+ }
+
+ @Test
+ public void testListAggAggregate() {
+ ListAggFunction fn = new
ListAggFunction(ExpressionContext.forIdentifier("myField"), ",", false);
+ AggregationResultHolder holder = fn.createAggregationResultHolder();
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"A", "B"}, {"C"}})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"B"}, {"D"}})));
+ String result = fn.extractFinalResult(holder.getResult());
+ assertEquals(result, "A,B,C,B,D");
+ }
+
+ @Test
+ public void testListAggAggregateSV() {
+ ListAggFunction fn = new
ListAggFunction(ExpressionContext.forIdentifier("svField"), "|", false);
+ AggregationResultHolder holder = fn.createAggregationResultHolder();
+ fn.aggregate(3, holder,
+ Map.of(ExpressionContext.forIdentifier("svField"), new
TestStringSVBlock(new String[]{"A", "B", "C"})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("svField"), new
TestStringSVBlock(new String[]{"B", "D"})));
+ String result = fn.extractFinalResult(holder.getResult());
+ assertEquals(result, "A|B|C|B|D");
+ }
+
+ @Test
+ public void testListAggDistinctAggregate() {
+ ListAggDistinctFunction fn =
+ new
ListAggDistinctFunction(ExpressionContext.forIdentifier("myField"), ",", false);
+ AggregationResultHolder holder = fn.createAggregationResultHolder();
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"A", "B"}, {"C"}})));
+ fn.aggregate(2, holder,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"B"}, {"A"}})));
+ String result = fn.extractFinalResult(holder.getResult());
+ assertEquals(result, "A,B,C");
+ }
+
+ @Test
+ public void testListAggDistinctAggregateSV() {
+ ListAggDistinctFunction fn =
+ new
ListAggDistinctFunction(ExpressionContext.forIdentifier("svField"), ",", false);
+ AggregationResultHolder holder = fn.createAggregationResultHolder();
+ fn.aggregate(3, holder,
+ Map.of(ExpressionContext.forIdentifier("svField"), new
TestStringSVBlock(new String[]{"A", "B", "C"})));
+ fn.aggregate(3, holder,
+ Map.of(ExpressionContext.forIdentifier("svField"), new
TestStringSVBlock(new String[]{"B", "A", "D"})));
+ String result = fn.extractFinalResult(holder.getResult());
+ assertEquals(result, "A,B,C,D");
+ }
+
+ @Test
+ public void testGroupByPaths() {
+ ListAggFunction fn = new
ListAggFunction(ExpressionContext.forIdentifier("myField"), ";", false);
+ GroupByResultHolder gb = new ObjectGroupByResultHolder(4, 4);
+ fn.aggregateGroupBySV(2, new int[]{0, 1}, gb,
+ Map.of(ExpressionContext.forIdentifier("myField"), new
TestStringMVBlock(new String[][]{{"X"}, {"Y", "Z"}})));
+ assertEquals(fn.extractFinalResult(gb.getResult(0)), "X");
+ assertEquals(fn.extractFinalResult(gb.getResult(1)), "Y;Z");
+ }
+
+ @Test
+ public void testGroupByMVKeysOnMVColumn() {
+ ListAggFunction fn = new
ListAggFunction(ExpressionContext.forIdentifier("mvField"), ":", false);
+ GroupByResultHolder gb = new ObjectGroupByResultHolder(4, 4);
+ int[][] groupKeysArray = new int[][]{{0, 1}, {1}};
+ fn.aggregateGroupByMV(2, groupKeysArray, gb,
+ Map.of(ExpressionContext.forIdentifier("mvField"), new
TestStringMVBlock(new String[][]{{"A"}, {"B", "C"}})));
+ assertEquals(fn.extractFinalResult(gb.getResult(0)), "A");
+ assertEquals(fn.extractFinalResult(gb.getResult(1)), "A:B:C");
+ }
+}
diff --git
a/pinot-core/src/test/java/org/apache/pinot/queries/ListAggQueriesTest.java
b/pinot-core/src/test/java/org/apache/pinot/queries/ListAggQueriesTest.java
new file mode 100644
index 00000000000..f705d6e7bb2
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/ListAggQueriesTest.java
@@ -0,0 +1,193 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.common.response.broker.ResultTable;
+import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import
org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
+import
org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
+import org.apache.pinot.segment.spi.ImmutableSegment;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
+import org.apache.pinot.spi.data.FieldSpec.DataType;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.ReadMode;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+
+
+public class ListAggQueriesTest extends BaseQueriesTest {
+ private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(),
"ListAggQueriesTest");
+ private static final String RAW_TABLE_NAME = "testTableListAgg";
+ private static final String SEGMENT_NAME = "testSegment";
+
+ private static final int NUM_RECORDS = 200;
+
+ private static final String STR_MV = "strMV";
+ private static final String STR_SV = "strSV";
+ private static final String GROUP_BY_COLUMN = "groupKey";
+
+ private static final Schema SCHEMA = new
Schema.SchemaBuilder().addMultiValueDimension(STR_MV, DataType.STRING)
+ .addSingleValueDimension(STR_SV, DataType.STRING)
+ .addSingleValueDimension(GROUP_BY_COLUMN, DataType.STRING).build();
+
+ private IndexSegment _indexSegment;
+ private List<IndexSegment> _indexSegments;
+
+ @Override
+ protected String getFilter() {
+ return "";
+ }
+
+ @Override
+ protected IndexSegment getIndexSegment() {
+ return _indexSegment;
+ }
+
+ @Override
+ protected List<IndexSegment> getIndexSegments() {
+ return _indexSegments;
+ }
+
+ @BeforeClass
+ public void setUp()
+ throws Exception {
+ FileUtils.deleteDirectory(INDEX_DIR);
+
+ List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+ for (int i = 0; i < NUM_RECORDS; i++) {
+ GenericRow record = new GenericRow();
+ record.putValue(STR_MV, new String[]{"A", (i % 2 == 0) ? "B" : "C"});
+ record.putValue(STR_SV, (i % 2 == 0) ? "X" : "Y");
+ record.putValue(GROUP_BY_COLUMN, String.valueOf(i % 10));
+ records.add(record);
+ }
+
+ SegmentGeneratorConfig conf = new SegmentGeneratorConfig(
+ new
TableConfigBuilder(org.apache.pinot.spi.config.table.TableType.OFFLINE).setTableName(RAW_TABLE_NAME)
+ .build(), SCHEMA);
+ conf.setTableName(RAW_TABLE_NAME);
+ conf.setSegmentName(SEGMENT_NAME);
+ conf.setOutDir(INDEX_DIR.getPath());
+ SegmentIndexCreationDriverImpl driver = new
SegmentIndexCreationDriverImpl();
+ driver.init(conf, new GenericRowRecordReader(records));
+ driver.build();
+
+ ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new
File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+ _indexSegment = immutableSegment;
+ _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+ }
+
+ @Test
+ public void testListAggNonDistinct() {
+ String q = "SELECT listAgg(strMV, ',') FROM testTableListAgg";
+ AggregationOperator op = getOperator(q);
+ AggregationResultsBlock block = op.nextBlock();
+ List<Object> res = block.getResults();
+ assertNotNull(res);
+ // Each row contributes 2 elements; inner-segment result is the
intermediate collection
+ Object partial = res.get(0);
+ assertEquals(((java.util.Collection<?>) partial).size(), 2 * NUM_RECORDS);
+
+ ResultTable table = getBrokerResponse(q).getResultTable();
+ assertEquals(table.getRows().get(0).length, 1);
+ // Each row has 2 MV values; with 2 segments and 2 servers, expect 8 *
NUM_RECORDS
+ assertEquals(((String) table.getRows().get(0)[0]).split(",").length, 8 *
NUM_RECORDS);
+ }
+
+ @Test
+ public void testListAggDistinct() {
+ String q = "SELECT listAgg(strMV, ',', true) FROM testTableListAgg";
+ AggregationOperator op = getOperator(q);
+ AggregationResultsBlock block = op.nextBlock();
+ List<Object> res = block.getResults();
+ assertNotNull(res);
+ // Distinct values are {A,B,C}; inner-segment result is the intermediate
set
+ Object partial = res.get(0);
+ assertEquals(((java.util.Collection<?>) partial).size(), 3);
+
+ // Inter-segment (broker) result is the final string
+ ResultTable table = getBrokerResponse(q).getResultTable();
+ assertEquals(((String) table.getRows().get(0)[0]).split(",").length, 3);
+ }
+
+ @Test
+ public void testListAggExplicitFalseOnMV() {
+ String q = "SELECT listAgg(strMV, ',', false) FROM testTableListAgg";
+ AggregationOperator op = getOperator(q);
+ AggregationResultsBlock block = op.nextBlock();
+ List<Object> res = block.getResults();
+ assertNotNull(res);
+ Object partial = res.get(0);
+ assertEquals(((java.util.Collection<?>) partial).size(), 2 * NUM_RECORDS);
+
+ ResultTable table = getBrokerResponse(q).getResultTable();
+ assertEquals(((String) table.getRows().get(0)[0]).split(",").length, 8 *
NUM_RECORDS);
+ }
+
+ @Test
+ public void testListAggSVNonDistinct() {
+ String q = "SELECT listAgg(strSV, '|') FROM testTableListAgg";
+ AggregationOperator op = getOperator(q);
+ AggregationResultsBlock block = op.nextBlock();
+ List<Object> res = block.getResults();
+ assertNotNull(res);
+ Object partial = res.get(0);
+ assertEquals(((java.util.Collection<?>) partial).size(), NUM_RECORDS);
+
+ ResultTable table = getBrokerResponse(q).getResultTable();
+ assertEquals(((String) table.getRows().get(0)[0]).split("\\|").length, 4 *
NUM_RECORDS);
+ }
+
+ @Test
+ public void testListAggSVDistinct() {
+ String q = "SELECT listAgg(strSV, ',', true) FROM testTableListAgg";
+ AggregationOperator op = getOperator(q);
+ AggregationResultsBlock block = op.nextBlock();
+ List<Object> res = block.getResults();
+ assertNotNull(res);
+ Object partial = res.get(0);
+ // Distinct values are {X,Y}
+ assertEquals(((java.util.Collection<?>) partial).size(), 2);
+
+ ResultTable table = getBrokerResponse(q).getResultTable();
+ assertEquals(((String) table.getRows().get(0)[0]).split(",").length, 2);
+ }
+
+ @AfterClass
+ public void tearDown()
+ throws IOException {
+ _indexSegment.destroy();
+ FileUtils.deleteDirectory(INDEX_DIR);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]