This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new f1da16473b Add DISTINCT_COUNT_OFF_HEAP aggregate function (#15469) f1da16473b is described below commit f1da16473b3f46f2eb1afeb1191e586172e67902 Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Mon Apr 7 12:27:24 2025 -0600 Add DISTINCT_COUNT_OFF_HEAP aggregate function (#15469) --- .../query/NonScanBasedAggregationOperator.java | 5 + .../pinot/core/plan/AggregationPlanNode.java | 8 +- .../function/AggregationFunctionFactory.java | 4 +- .../function/DistinctCountAggregationFunction.java | 2 +- .../DistinctCountMVAggregationFunction.java | 2 +- .../DistinctCountOffHeapAggregationFunction.java | 533 +++++++++++++++++++++ .../pinot/queries/DistinctCountQueriesTest.java | 77 +++ .../pinot/segment/spi/AggregationFunctionType.java | 3 + 8 files changed, 627 insertions(+), 7 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java index fb6e8d02f6..04f388cd46 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/NonScanBasedAggregationOperator.java @@ -38,6 +38,7 @@ import org.apache.pinot.core.operator.blocks.results.AggregationResultsBlock; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.aggregation.function.DistinctCountHLLAggregationFunction; import org.apache.pinot.core.query.aggregation.function.DistinctCountHLLPlusAggregationFunction; +import org.apache.pinot.core.query.aggregation.function.DistinctCountOffHeapAggregationFunction; import org.apache.pinot.core.query.aggregation.function.DistinctCountRawHLLAggregationFunction; import org.apache.pinot.core.query.aggregation.function.DistinctCountRawHLLPlusAggregationFunction; import org.apache.pinot.core.query.aggregation.function.DistinctCountSmartHLLAggregationFunction; @@ -111,6 +112,10 @@ public class NonScanBasedAggregationOperator extends BaseOperator<AggregationRes case DISTINCTAVGMV: result = getDistinctValueSet(Objects.requireNonNull(dataSource.getDictionary())); break; + case DISTINCTCOUNTOFFHEAP: + result = ((DistinctCountOffHeapAggregationFunction) aggregationFunction).extractAggregationResult( + Objects.requireNonNull(dataSource.getDictionary())); + break; case DISTINCTCOUNTHLL: case DISTINCTCOUNTHLLMV: result = getDistinctCountHLLResult(Objects.requireNonNull(dataSource.getDictionary()), diff --git a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java index f5157112a8..cca14f2704 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/AggregationPlanNode.java @@ -49,10 +49,10 @@ import static org.apache.pinot.segment.spi.AggregationFunctionType.*; @SuppressWarnings("rawtypes") public class AggregationPlanNode implements PlanNode { private static final EnumSet<AggregationFunctionType> DICTIONARY_BASED_FUNCTIONS = - EnumSet.of(MIN, MINMV, MAX, MAXMV, MINMAXRANGE, MINMAXRANGEMV, DISTINCTCOUNT, DISTINCTCOUNTMV, DISTINCTCOUNTHLL, - DISTINCTCOUNTHLLMV, DISTINCTCOUNTRAWHLL, DISTINCTCOUNTRAWHLLMV, SEGMENTPARTITIONEDDISTINCTCOUNT, - DISTINCTCOUNTSMARTHLL, DISTINCTSUM, DISTINCTAVG, DISTINCTSUMMV, DISTINCTAVGMV, DISTINCTCOUNTHLLPLUS, - DISTINCTCOUNTHLLPLUSMV, DISTINCTCOUNTRAWHLLPLUS, DISTINCTCOUNTRAWHLLPLUSMV); + EnumSet.of(MIN, MINMV, MAX, MAXMV, MINMAXRANGE, MINMAXRANGEMV, DISTINCTCOUNT, DISTINCTCOUNTMV, DISTINCTSUM, + DISTINCTSUMMV, DISTINCTAVG, DISTINCTAVGMV, DISTINCTCOUNTOFFHEAP, DISTINCTCOUNTHLL, DISTINCTCOUNTHLLMV, + DISTINCTCOUNTRAWHLL, DISTINCTCOUNTRAWHLLMV, DISTINCTCOUNTHLLPLUS, DISTINCTCOUNTHLLPLUSMV, + DISTINCTCOUNTRAWHLLPLUS, DISTINCTCOUNTRAWHLLPLUSMV, SEGMENTPARTITIONEDDISTINCTCOUNT, DISTINCTCOUNTSMARTHLL); // DISTINCTCOUNT excluded because consuming segment metadata contains unknown cardinality when there is no dictionary private static final EnumSet<AggregationFunctionType> METADATA_BASED_FUNCTIONS = diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java index 205f1ae71a..b15896de7c 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java @@ -58,7 +58,7 @@ public class AggregationFunctionFactory { /** * Given the function information, returns a new instance of the corresponding aggregation function. - * <p>NOTE: Underscores in the function name are ignored in V1. + * <p>NOTE: Underscores in the function name are ignored. */ public static AggregationFunction getAggregationFunction(FunctionContext function, boolean nullHandlingEnabled) { try { @@ -360,6 +360,8 @@ public class AggregationFunctionFactory { return new MinMaxRangeAggregationFunction(arguments, nullHandlingEnabled); case DISTINCTCOUNT: return new DistinctCountAggregationFunction(arguments, nullHandlingEnabled); + case DISTINCTCOUNTOFFHEAP: + return new DistinctCountOffHeapAggregationFunction(arguments, nullHandlingEnabled); case DISTINCTCOUNTBITMAP: return new DistinctCountBitmapAggregationFunction(arguments); case SEGMENTPARTITIONEDDISTINCTCOUNT: diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java index 076bc2ccda..7d1b1b5592 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java @@ -30,7 +30,7 @@ import org.apache.pinot.segment.spi.AggregationFunctionType; /** - * Aggregation function to compute the average of distinct values for an SV column + * Aggregation function to compute the count of distinct values for an SV column. */ public class DistinctCountAggregationFunction extends BaseDistinctAggregateAggregationFunction<Integer> { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java index aa1cd6da66..9940ec3080 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java @@ -30,7 +30,7 @@ import org.apache.pinot.segment.spi.AggregationFunctionType; /** - * Aggregation function to compute the average of distinct values for an MV column + * Aggregation function to compute the count of distinct values for an MV column. */ public class DistinctCountMVAggregationFunction extends BaseDistinctAggregateAggregationFunction<Integer> { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountOffHeapAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountOffHeapAggregationFunction.java new file mode 100644 index 0000000000..f0352e79d0 --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountOffHeapAggregationFunction.java @@ -0,0 +1,533 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import com.google.common.base.Preconditions; +import java.util.BitSet; +import java.util.List; +import java.util.Map; +import org.apache.commons.lang3.StringUtils; +import org.apache.pinot.common.CustomObject; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder; +import org.apache.pinot.core.query.aggregation.function.distinct.BaseOffHeapSet; +import org.apache.pinot.core.query.aggregation.function.distinct.OffHeap128BitSet; +import org.apache.pinot.core.query.aggregation.function.distinct.OffHeap32BitSet; +import org.apache.pinot.core.query.aggregation.function.distinct.OffHeap64BitSet; +import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.segment.spi.index.reader.Dictionary; +import org.apache.pinot.spi.data.FieldSpec.DataType; + + +/// Aggregation function to compute the count of distinct values for a column using off-heap memory. +public class DistinctCountOffHeapAggregationFunction + extends NullableSingleInputAggregationFunction<BaseOffHeapSet, Integer> { + // Use empty OffHeap32BitSet as a placeholder for empty result + // NOTE: It is okay to close it (multiple times) since we are never adding values into it + private static final OffHeap32BitSet EMPTY_PLACEHOLDER = new OffHeap32BitSet(0); + + private final int _initialCapacity; + private final int _hashBits; + + public DistinctCountOffHeapAggregationFunction(List<ExpressionContext> arguments, boolean nullHandlingEnabled) { + super(arguments.get(0), nullHandlingEnabled); + if (arguments.size() > 1) { + Parameters parameters = new Parameters(arguments.get(1).getLiteral().getStringValue()); + _initialCapacity = parameters._initialCapacity; + _hashBits = parameters._hashBits; + } else { + _initialCapacity = Parameters.DEFAULT_INITIAL_CAPACITY; + _hashBits = Parameters.DEFAULT_HASH_BITS; + } + } + + @Override + public AggregationFunctionType getType() { + return AggregationFunctionType.DISTINCTCOUNTOFFHEAP; + } + + @Override + public AggregationResultHolder createAggregationResultHolder() { + return new ObjectAggregationResultHolder(); + } + + @Override + public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) { + throw new UnsupportedOperationException( + "DISTINCT_COUNT_OFF_HEAP cannot be applied to group-by queries. Use DISTINCT_COUNT instead."); + } + + @Override + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); + Dictionary dictionary = blockValSet.getDictionary(); + if (dictionary != null) { + // For dictionary-encoded expression, store dictionary ids into the bitmap + if (blockValSet.isSingleValue()) { + int[] dictIds = blockValSet.getDictionaryIdsSV(); + BitSet dictIdBitSet = getDictIdBitSet(aggregationResultHolder, dictionary); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + dictIdBitSet.set(dictIds[i]); + } + }); + } else { + int[][] dictIds = blockValSet.getDictionaryIdsMV(); + BitSet dictIdBitSet = getDictIdBitSet(aggregationResultHolder, dictionary); + for (int i = 0; i < length; i++) { + for (int dictId : dictIds[i]) { + dictIdBitSet.set(dictId); + } + } + } + } else { + // For non-dictionary-encoded expression, add values into the value set + BaseOffHeapSet valueSet = aggregationResultHolder.getResult(); + if (valueSet == null) { + valueSet = createValueSet(blockValSet.getValueType().getStoredType()); + aggregationResultHolder.setValue(valueSet); + } + if (blockValSet.isSingleValue()) { + addToValueSetSV(length, blockValSet, valueSet); + } else { + addToValueSetMV(length, blockValSet, valueSet); + } + } + } + + private static BitSet getDictIdBitSet(AggregationResultHolder aggregationResultHolder, Dictionary dictionary) { + DictIdsWrapper dictIdsWrapper = aggregationResultHolder.getResult(); + if (dictIdsWrapper == null) { + dictIdsWrapper = new DictIdsWrapper(dictionary); + aggregationResultHolder.setValue(dictIdsWrapper); + } + return dictIdsWrapper._bitSet; + } + + private BaseOffHeapSet createValueSet(DataType storedType) { + switch (storedType) { + case INT: + case FLOAT: + return new OffHeap32BitSet(_initialCapacity); + case LONG: + case DOUBLE: + return new OffHeap64BitSet(_initialCapacity); + default: + switch (_hashBits) { + case 32: + return new OffHeap32BitSet(_initialCapacity); + case 64: + return new OffHeap64BitSet(_initialCapacity); + case 128: + return new OffHeap128BitSet(_initialCapacity); + default: + throw new IllegalStateException(); + } + } + } + + private void addToValueSetSV(int length, BlockValSet blockValSet, BaseOffHeapSet valueSet) { + DataType storedType = blockValSet.getValueType().getStoredType(); + switch (storedType) { + case INT: + OffHeap32BitSet intSet = (OffHeap32BitSet) valueSet; + int[] intValues = blockValSet.getIntValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + intSet.add(intValues[i]); + } + }); + break; + case LONG: + OffHeap64BitSet longSet = (OffHeap64BitSet) valueSet; + long[] longValues = blockValSet.getLongValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + longSet.add(longValues[i]); + } + }); + break; + case FLOAT: + OffHeap32BitSet floatSet = (OffHeap32BitSet) valueSet; + float[] floatValues = blockValSet.getFloatValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + floatSet.add(Float.floatToRawIntBits(floatValues[i])); + } + }); + break; + case DOUBLE: + OffHeap64BitSet doubleSet = (OffHeap64BitSet) valueSet; + double[] doubleValues = blockValSet.getDoubleValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + doubleSet.add(Double.doubleToRawLongBits(doubleValues[i])); + } + }); + break; + default: + switch (_hashBits) { + case 32: + OffHeap32BitSet valueSet32 = (OffHeap32BitSet) valueSet; + int[] hashValues32 = blockValSet.get32BitsMurmur3HashValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + valueSet32.add(hashValues32[i]); + } + }); + break; + case 64: + OffHeap64BitSet valueSet64 = (OffHeap64BitSet) valueSet; + long[] hashValues64 = blockValSet.get64BitsMurmur3HashValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + valueSet64.add(hashValues64[i]); + } + }); + break; + case 128: + OffHeap128BitSet valueSet128 = (OffHeap128BitSet) valueSet; + long[][] hashValues128 = blockValSet.get128BitsMurmur3HashValuesSV(); + forEachNotNull(length, blockValSet, (from, to) -> { + for (int i = from; i < to; i++) { + long[] hashValue = hashValues128[i]; + valueSet128.add(hashValue[0], hashValue[1]); + } + }); + break; + default: + throw new IllegalStateException(); + } + break; + } + } + + private void addToValueSetMV(int length, BlockValSet blockValSet, BaseOffHeapSet valueSet) { + DataType storedType = blockValSet.getValueType().getStoredType(); + switch (storedType) { + case INT: + OffHeap32BitSet intSet = (OffHeap32BitSet) valueSet; + int[][] intValues = blockValSet.getIntValuesMV(); + for (int i = 0; i < length; i++) { + for (int intValue : intValues[i]) { + intSet.add(intValue); + } + } + break; + case LONG: + OffHeap64BitSet longSet = (OffHeap64BitSet) valueSet; + long[][] longValues = blockValSet.getLongValuesMV(); + for (int i = 0; i < length; i++) { + for (long longValue : longValues[i]) { + longSet.add(longValue); + } + } + break; + case FLOAT: + OffHeap32BitSet floatSet = (OffHeap32BitSet) valueSet; + float[][] floatValues = blockValSet.getFloatValuesMV(); + for (int i = 0; i < length; i++) { + for (float floatValue : floatValues[i]) { + floatSet.add(Float.floatToRawIntBits(floatValue)); + } + } + break; + case DOUBLE: + OffHeap64BitSet doubleSet = (OffHeap64BitSet) valueSet; + double[][] doubleValues = blockValSet.getDoubleValuesMV(); + for (int i = 0; i < length; i++) { + for (double doubleValue : doubleValues[i]) { + doubleSet.add(Double.doubleToRawLongBits(doubleValue)); + } + } + break; + default: + throw new UnsupportedOperationException( + "DISTINCT_COUNT_OFF_HEAP does not support MV columns of type: " + blockValSet.getValueType() + + ". Use DISTINCT_COUNT instead."); + } + } + + @Override + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + throw new UnsupportedOperationException(); + } + + @Override + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + throw new UnsupportedOperationException(); + } + + @Override + public BaseOffHeapSet extractAggregationResult(AggregationResultHolder aggregationResultHolder) { + Object result = aggregationResultHolder.getResult(); + if (result == null) { + return EMPTY_PLACEHOLDER; + } + if (result instanceof DictIdsWrapper) { + return extractAggregationResult((DictIdsWrapper) result); + } else { + return (BaseOffHeapSet) result; + } + } + + private BaseOffHeapSet extractAggregationResult(DictIdsWrapper dictIdsWrapper) { + BitSet bitSet = dictIdsWrapper._bitSet; + int length = bitSet.cardinality(); + Dictionary dictionary = dictIdsWrapper._dictionary; + DataType storedType = dictionary.getValueType(); + switch (storedType) { + case INT: + OffHeap32BitSet intSet = new OffHeap32BitSet(length); + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + intSet.add(dictionary.getIntValue(i)); + } + return intSet; + case LONG: + OffHeap64BitSet longSet = new OffHeap64BitSet(length); + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + longSet.add(dictionary.getLongValue(i)); + } + return longSet; + case FLOAT: + OffHeap32BitSet floatSet = new OffHeap32BitSet(length); + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + floatSet.add(Float.floatToRawIntBits(dictionary.getFloatValue(i))); + } + return floatSet; + case DOUBLE: + OffHeap64BitSet doubleSet = new OffHeap64BitSet(length); + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + doubleSet.add(Double.doubleToRawLongBits(dictionary.getDoubleValue(i))); + } + return doubleSet; + default: + switch (_hashBits) { + case 32: + OffHeap32BitSet valueSet32 = new OffHeap32BitSet(length); + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + valueSet32.add(dictionary.get32BitsMurmur3HashValue(i)); + } + return valueSet32; + case 64: + OffHeap64BitSet valueSet64 = new OffHeap64BitSet(length); + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + valueSet64.add(dictionary.get64BitsMurmur3HashValue(i)); + } + return valueSet64; + case 128: + OffHeap128BitSet valueSet128 = new OffHeap128BitSet(length); + for (int i = bitSet.nextSetBit(0); i >= 0; i = bitSet.nextSetBit(i + 1)) { + long[] hashValue = dictionary.get128BitsMurmur3HashValue(i); + valueSet128.add(hashValue[0], hashValue[1]); + } + return valueSet128; + default: + throw new IllegalStateException(); + } + } + } + + /// Extracts the value set from the dictionary. + public BaseOffHeapSet extractAggregationResult(Dictionary dictionary) { + int length = dictionary.length(); + DataType storedType = dictionary.getValueType(); + switch (storedType) { + case INT: + OffHeap32BitSet intSet = new OffHeap32BitSet(length); + for (int i = 0; i < length; i++) { + intSet.add(dictionary.getIntValue(i)); + } + return intSet; + case LONG: + OffHeap64BitSet longSet = new OffHeap64BitSet(length); + for (int i = 0; i < length; i++) { + longSet.add(dictionary.getLongValue(i)); + } + return longSet; + case FLOAT: + OffHeap32BitSet floatSet = new OffHeap32BitSet(length); + for (int i = 0; i < length; i++) { + floatSet.add(Float.floatToRawIntBits(dictionary.getFloatValue(i))); + } + return floatSet; + case DOUBLE: + OffHeap64BitSet doubleSet = new OffHeap64BitSet(length); + for (int i = 0; i < length; i++) { + doubleSet.add(Double.doubleToRawLongBits(dictionary.getDoubleValue(i))); + } + return doubleSet; + default: + switch (_hashBits) { + case 32: + OffHeap32BitSet valueSet32 = new OffHeap32BitSet(length); + for (int i = 0; i < length; i++) { + valueSet32.add(dictionary.get32BitsMurmur3HashValue(i)); + } + return valueSet32; + case 64: + OffHeap64BitSet valueSet64 = new OffHeap64BitSet(length); + for (int i = 0; i < length; i++) { + valueSet64.add(dictionary.get64BitsMurmur3HashValue(i)); + } + return valueSet64; + case 128: + OffHeap128BitSet valueSet128 = new OffHeap128BitSet(length); + for (int i = 0; i < length; i++) { + long[] hashValue = dictionary.get128BitsMurmur3HashValue(i); + valueSet128.add(hashValue[0], hashValue[1]); + } + return valueSet128; + default: + throw new IllegalStateException(); + } + } + } + + @Override + public BaseOffHeapSet extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { + throw new UnsupportedOperationException(); + } + + @Override + public BaseOffHeapSet merge(BaseOffHeapSet intermediateResult1, BaseOffHeapSet intermediateResult2) { + assert intermediateResult1 != null && intermediateResult2 != null; + if (intermediateResult1.isEmpty()) { + intermediateResult1.close(); + return intermediateResult2; + } + intermediateResult1.merge(intermediateResult2); + intermediateResult2.close(); + return intermediateResult1; + } + + @Override + public ColumnDataType getIntermediateResultColumnType() { + return ColumnDataType.OBJECT; + } + + @Override + public SerializedIntermediateResult serializeIntermediateResult(BaseOffHeapSet set) { + int type; + if (set instanceof OffHeap32BitSet) { + type = 0; + } else if (set instanceof OffHeap64BitSet) { + type = 1; + } else if (set instanceof OffHeap128BitSet) { + type = 2; + } else { + throw new IllegalStateException(); + } + byte[] bytes = set.serialize(); + set.close(); + return new SerializedIntermediateResult(type, bytes); + } + + @Override + public BaseOffHeapSet deserializeIntermediateResult(CustomObject customObject) { + switch (customObject.getType()) { + case 0: + return OffHeap32BitSet.deserialize(customObject.getBuffer()); + case 1: + return OffHeap64BitSet.deserialize(customObject.getBuffer()); + case 2: + return OffHeap128BitSet.deserialize(customObject.getBuffer()); + default: + throw new IllegalStateException(); + } + } + + @Override + public ColumnDataType getFinalResultColumnType() { + return ColumnDataType.INT; + } + + @Override + public Integer extractFinalResult(BaseOffHeapSet set) { + assert set != null; + int size = set.size(); + set.close(); + return size; + } + + @Override + public Integer mergeFinalResult(Integer finalResult1, Integer finalResult2) { + return finalResult1 + finalResult2; + } + + /// Helper class to wrap the dictionary ids. + /// Different from the BaseDistinctAggregateAggregationFunction.DictIdsWrapper, here we use a pre-allocated BitSet + /// instead of RoaringBitmap for better performance on high cardinality distinct count. + private static final class DictIdsWrapper { + final Dictionary _dictionary; + final BitSet _bitSet; + + DictIdsWrapper(Dictionary dictionary) { + _dictionary = dictionary; + _bitSet = new BitSet(dictionary.length()); + } + } + + /// Helper class to wrap the parameters. + private static class Parameters { + static final char PARAMETER_DELIMITER = ';'; + static final char PARAMETER_KEY_VALUE_SEPARATOR = '='; + + static final String INITIAL_CAPACITY_KEY = "INITIALCAPACITY"; + static final int DEFAULT_INITIAL_CAPACITY = 10_000; + + static final String HASH_BITS_KEY = "HASHBITS"; + static final int DEFAULT_HASH_BITS = 64; + + int _initialCapacity = DEFAULT_INITIAL_CAPACITY; + int _hashBits = DEFAULT_HASH_BITS; + + Parameters(String parametersString) { + StringUtils.deleteWhitespace(parametersString); + String[] keyValuePairs = StringUtils.split(parametersString, PARAMETER_DELIMITER); + for (String keyValuePair : keyValuePairs) { + String[] keyAndValue = StringUtils.split(keyValuePair, PARAMETER_KEY_VALUE_SEPARATOR); + Preconditions.checkArgument(keyAndValue.length == 2, "Invalid parameter: %s", keyValuePair); + String key = keyAndValue[0]; + String value = keyAndValue[1]; + switch (key.toUpperCase()) { + case INITIAL_CAPACITY_KEY: + _initialCapacity = Integer.parseInt(value); + Preconditions.checkArgument(_initialCapacity > 0, "Initial capacity must be > 0, got: %s", + _initialCapacity); + break; + case HASH_BITS_KEY: + _hashBits = Integer.parseInt(value); + Preconditions.checkArgument(_hashBits == 32 || _hashBits == 64 || _hashBits == 128, + "Hash bits must be 32, 64 or 128, got: %s", _hashBits); + break; + default: + throw new IllegalArgumentException("Invalid parameter key: " + key); + } + } + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java index 598e15debf..e3c6b6240f 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java @@ -40,6 +40,9 @@ import org.apache.pinot.core.operator.query.AggregationOperator; import org.apache.pinot.core.operator.query.GroupByOperator; import org.apache.pinot.core.operator.query.NonScanBasedAggregationOperator; import org.apache.pinot.core.query.aggregation.function.DistinctCountSmartHLLAggregationFunction; +import org.apache.pinot.core.query.aggregation.function.distinct.BaseOffHeapSet; +import org.apache.pinot.core.query.aggregation.function.distinct.OffHeap128BitSet; +import org.apache.pinot.core.query.aggregation.function.distinct.OffHeap64BitSet; import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; import org.apache.pinot.core.query.request.context.QueryContext; @@ -269,6 +272,80 @@ public class DistinctCountQueriesTest extends BaseQueriesTest { 4 * NUM_RECORDS, expectedRows); } + @Test + public void testOffHeap() { + // Dictionary based + String query = "SELECT " + + "DISTINCTCOUNTOFFHEAP(intColumn), " + + "DISTINCTCOUNTOFFHEAP(longColumn), " + + "DISTINCTCOUNTOFFHEAP(floatColumn), " + + "DISTINCTCOUNTOFFHEAP(doubleColumn), " + + "DISTINCTCOUNTOFFHEAP(stringColumn), " + + "DISTINCTCOUNTOFFHEAP(bytesColumn) " + + "FROM testTable"; + + // Inner segment + for (Object operator : Arrays.asList(getOperator(query), getOperatorWithFilter(query))) { + assertTrue(operator instanceof NonScanBasedAggregationOperator); + AggregationResultsBlock resultsBlock = ((NonScanBasedAggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(((Operator) operator).getExecutionStatistics(), NUM_RECORDS, + 0, 0, NUM_RECORDS); + List<Object> aggregationResult = resultsBlock.getResults(); + assertNotNull(aggregationResult); + assertEquals(aggregationResult.size(), 6); + for (int i = 0; i < 6; i++) { + assertEquals(((BaseOffHeapSet) aggregationResult.get(i)).size(), _values.size()); + } + } + + // Inter segments + Object[] expectedResults = Collections.nCopies(6, _values.size()).toArray(); + for (BrokerResponseNative brokerResponse : Arrays.asList(getBrokerResponse(query), + getBrokerResponseWithFilter(query))) { + QueriesTestUtils.testInterSegmentsResult(brokerResponse, 4 * NUM_RECORDS, 0, 0, 4 * NUM_RECORDS, expectedResults); + } + + // Regular aggregation + query = query + " WHERE intColumn >= 500"; + + // Inner segment + int expectedResult = 0; + for (Integer value : _values) { + if (value >= 500) { + expectedResult++; + } + } + AggregationOperator aggregationOperator = getOperator(query); + List<Object> aggregationResult = aggregationOperator.nextBlock().getResults(); + assertNotNull(aggregationResult); + assertEquals(aggregationResult.size(), 6); + for (int i = 0; i < 6; i++) { + assertEquals(((BaseOffHeapSet) aggregationResult.get(i)).size(), expectedResult); + } + + // Inter segment + expectedResults = Collections.nCopies(6, expectedResult).toArray(); + QueriesTestUtils.testInterSegmentsResult(getBrokerResponse(query), expectedResults); + + // Change parameters + query = "SELECT DISTINCTCOUNTOFFHEAP(stringColumn, 'initialcapacity=10;hashbits=128') FROM testTable"; + NonScanBasedAggregationOperator nonScanOperator = getOperator(query); + aggregationResult = nonScanOperator.nextBlock().getResults(); + assertNotNull(aggregationResult); + assertEquals(aggregationResult.size(), 1); + assertTrue(aggregationResult.get(0) instanceof OffHeap128BitSet); + assertEquals(((OffHeap128BitSet) aggregationResult.get(0)).size(), _values.size()); + + query = "SELECT DISTINCTCOUNTOFFHEAP(bytesColumn, 'initialcapacity=100') FROM testTable " + + "WHERE intColumn >= 500"; + aggregationOperator = getOperator(query); + aggregationResult = aggregationOperator.nextBlock().getResults(); + assertNotNull(aggregationResult); + assertEquals(aggregationResult.size(), 1); + assertTrue(aggregationResult.get(0) instanceof OffHeap64BitSet); + assertEquals(((OffHeap64BitSet) aggregationResult.get(0)).size(), expectedResult); + } + @Test public void testHLL() { // Dictionary based diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index 623ecaf8ba..cace39ff1a 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -70,6 +70,9 @@ public enum AggregationFunctionType { * (2) count(distinct ...) support multi-argument and will be converted into DISTINCT + COUNT */ DISTINCTCOUNT("distinctCount", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER, SqlTypeName.INTEGER), + DISTINCTCOUNTOFFHEAP("distinctCountOffHeap", ReturnTypes.BIGINT, + OperandTypes.family(List.of(SqlTypeFamily.ANY, SqlTypeFamily.CHARACTER), i -> i == 1), SqlTypeName.OTHER, + SqlTypeName.INTEGER), DISTINCTSUM("distinctSum", ReturnTypes.AGG_SUM, OperandTypes.NUMERIC, SqlTypeName.OTHER, SqlTypeName.DOUBLE), DISTINCTAVG("distinctAvg", ReturnTypes.DOUBLE, OperandTypes.NUMERIC, SqlTypeName.OTHER), DISTINCTCOUNTBITMAP("distinctCountBitmap", ReturnTypes.BIGINT, OperandTypes.ANY, SqlTypeName.OTHER, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org