This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new aa5f318 add mode aggregation function (#7318) aa5f318 is described below commit aa5f318d0708b9b0a3e570706c4236df94d29141 Author: Yash Agarwal <yash.0...@gmail.com> AuthorDate: Thu Aug 19 23:15:29 2021 +0530 add mode aggregation function (#7318) Add support for Mode Function. Mode accepts an additional parameter to reduce multiple modes to a single value: MIN/MAX/AVG --- .../function/AggregationFunctionTypeTest.java | 1 + .../apache/pinot/core/common/ObjectSerDeUtils.java | 152 +++- .../function/AggregationFunctionFactory.java | 2 + .../function/ModeAggregationFunction.java | 691 +++++++++++++++ .../pinot/core/common/ObjectSerDeUtilsTest.java | 68 ++ .../function/AggregationFunctionFactoryTest.java | 7 + .../org/apache/pinot/queries/ModeQueriesTest.java | 949 +++++++++++++++++++++ .../pinot/segment/spi/AggregationFunctionType.java | 1 + 8 files changed, 1869 insertions(+), 2 deletions(-) diff --git a/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java b/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java index 70a25b5..e325b0a 100644 --- a/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/common/function/AggregationFunctionTypeTest.java @@ -32,6 +32,7 @@ public class AggregationFunctionTypeTest { Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MaX"), AggregationFunctionType.MAX); Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("SuM"), AggregationFunctionType.SUM); Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("AvG"), AggregationFunctionType.AVG); + Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MoDe"), AggregationFunctionType.MODE); Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("MiNmAxRaNgE"), AggregationFunctionType.MINMAXRANGE); Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("DiStInCtCoUnT"), AggregationFunctionType.DISTINCTCOUNT); Assert.assertEquals(AggregationFunctionType.getAggregationFunctionType("DiStInCtCoUnThLl"), AggregationFunctionType.DISTINCTCOUNTHLL); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java index 7417daa..123b03b 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java @@ -22,16 +22,24 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog; import com.google.common.primitives.Longs; import com.tdunning.math.stats.MergingDigest; import com.tdunning.math.stats.TDigest; +import it.unimi.dsi.fastutil.doubles.Double2LongMap; +import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import it.unimi.dsi.fastutil.doubles.DoubleIterator; import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; import it.unimi.dsi.fastutil.doubles.DoubleSet; +import it.unimi.dsi.fastutil.floats.Float2LongMap; +import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap; import it.unimi.dsi.fastutil.floats.FloatIterator; import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; import it.unimi.dsi.fastutil.floats.FloatSet; +import it.unimi.dsi.fastutil.ints.Int2LongMap; +import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap; import it.unimi.dsi.fastutil.ints.IntIterator; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.longs.Long2LongMap; +import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap; import it.unimi.dsi.fastutil.longs.LongIterator; import it.unimi.dsi.fastutil.longs.LongOpenHashSet; import it.unimi.dsi.fastutil.longs.LongSet; @@ -97,7 +105,11 @@ public class ObjectSerDeUtils { BytesSet(19), IdSet(20), List(21), - BigDecimal(22); + BigDecimal(22), + Int2LongMap(23), + Long2LongMap(24), + Float2LongMap(25), + Double2LongMap(26); private final int _value; ObjectType(int value) { @@ -127,6 +139,14 @@ public class ObjectSerDeUtils { return ObjectType.HyperLogLog; } else if (value instanceof QuantileDigest) { return ObjectType.QuantileDigest; + } else if (value instanceof Int2LongMap) { + return ObjectType.Int2LongMap; + } else if (value instanceof Long2LongMap) { + return ObjectType.Long2LongMap; + } else if (value instanceof Float2LongMap) { + return ObjectType.Float2LongMap; + } else if (value instanceof Double2LongMap) { + return ObjectType.Double2LongMap; } else if (value instanceof Map) { return ObjectType.Map; } else if (value instanceof IntSet) { @@ -874,6 +894,130 @@ public class ObjectSerDeUtils { } }; + public static final ObjectSerDe<Int2LongMap> INT_2_LONG_MAP_SER_DE = new ObjectSerDe<Int2LongMap>() { + + @Override + public byte[] serialize(Int2LongMap map) { + int size = map.size(); + byte[] bytes = new byte[Integer.BYTES + size * (Integer.BYTES + Long.BYTES)]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.putInt(size); + for (Int2LongMap.Entry entry : map.int2LongEntrySet()) { + byteBuffer.putInt(entry.getIntKey()); + byteBuffer.putLong(entry.getLongValue()); + } + return bytes; + } + + @Override + public Int2LongOpenHashMap deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public Int2LongOpenHashMap deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + Int2LongOpenHashMap map = new Int2LongOpenHashMap(size); + for (int i = 0; i < size; i++) { + map.put(byteBuffer.getInt(), byteBuffer.getLong()); + } + return map; + } + }; + + public static final ObjectSerDe<Long2LongMap> LONG_2_LONG_MAP_SER_DE = new ObjectSerDe<Long2LongMap>() { + + @Override + public byte[] serialize(Long2LongMap map) { + int size = map.size(); + byte[] bytes = new byte[Integer.BYTES + size * (Long.BYTES + Long.BYTES)]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.putInt(size); + for (Long2LongMap.Entry entry : map.long2LongEntrySet()) { + byteBuffer.putLong(entry.getLongKey()); + byteBuffer.putLong(entry.getLongValue()); + } + return bytes; + } + + @Override + public Long2LongOpenHashMap deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public Long2LongOpenHashMap deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + Long2LongOpenHashMap map = new Long2LongOpenHashMap(size); + for (int i = 0; i < size; i++) { + map.put(byteBuffer.getLong(), byteBuffer.getLong()); + } + return map; + } + }; + + public static final ObjectSerDe<Float2LongMap> FLOAT_2_LONG_MAP_SER_DE = new ObjectSerDe<Float2LongMap>() { + + @Override + public byte[] serialize(Float2LongMap map) { + int size = map.size(); + byte[] bytes = new byte[Integer.BYTES + size * (Float.BYTES + Long.BYTES)]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.putInt(size); + for (Float2LongMap.Entry entry : map.float2LongEntrySet()) { + byteBuffer.putFloat(entry.getFloatKey()); + byteBuffer.putLong(entry.getLongValue()); + } + return bytes; + } + + @Override + public Float2LongOpenHashMap deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public Float2LongOpenHashMap deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + Float2LongOpenHashMap map = new Float2LongOpenHashMap(size); + for (int i = 0; i < size; i++) { + map.put(byteBuffer.getFloat(), byteBuffer.getLong()); + } + return map; + } + }; + + public static final ObjectSerDe<Double2LongMap> DOUBLE_2_LONG_MAP_SER_DE = new ObjectSerDe<Double2LongMap>() { + + @Override + public byte[] serialize(Double2LongMap map) { + int size = map.size(); + byte[] bytes = new byte[Integer.BYTES + size * (Double.BYTES + Long.BYTES)]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.putInt(size); + for (Double2LongMap.Entry entry : map.double2LongEntrySet()) { + byteBuffer.putDouble(entry.getDoubleKey()); + byteBuffer.putLong(entry.getLongValue()); + } + return bytes; + } + + @Override + public Double2LongOpenHashMap deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public Double2LongOpenHashMap deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + Double2LongOpenHashMap map = new Double2LongOpenHashMap(size); + for (int i = 0; i < size; i++) { + map.put(byteBuffer.getDouble(), byteBuffer.getLong()); + } + return map; + } + }; + // NOTE: DO NOT change the order, it has to be the same order as the ObjectType //@formatter:off private static final ObjectSerDe[] SER_DES = { @@ -899,7 +1043,11 @@ public class ObjectSerDeUtils { BYTES_SET_SER_DE, ID_SET_SER_DE, LIST_SER_DE, - BIGDECIMAL_SER_DE + BIGDECIMAL_SER_DE, + INT_2_LONG_MAP_SER_DE, + LONG_2_LONG_MAP_SER_DE, + FLOAT_2_LONG_MAP_SER_DE, + DOUBLE_2_LONG_MAP_SER_DE }; //@formatter:on diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java index bf7c5aa..ccd45fc 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactory.java @@ -122,6 +122,8 @@ public class AggregationFunctionFactory { return new SumPrecisionAggregationFunction(arguments); case AVG: return new AvgAggregationFunction(firstArgument); + case MODE: + return new ModeAggregationFunction(arguments); case MINMAXRANGE: return new MinMaxRangeAggregationFunction(firstArgument); case DISTINCTCOUNT: diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java new file mode 100644 index 0000000..b67152b --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/ModeAggregationFunction.java @@ -0,0 +1,691 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.aggregation.function; + +import com.google.common.base.Preconditions; +import it.unimi.dsi.fastutil.doubles.Double2LongMap; +import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap; +import it.unimi.dsi.fastutil.floats.Float2LongMap; +import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2IntMap; +import it.unimi.dsi.fastutil.ints.Int2IntMaps; +import it.unimi.dsi.fastutil.ints.Int2IntOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2LongMap; +import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap; +import it.unimi.dsi.fastutil.longs.Long2LongMap; +import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap; +import it.unimi.dsi.fastutil.objects.ObjectIterator; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.core.common.BlockValSet; +import org.apache.pinot.core.query.aggregation.AggregationResultHolder; +import org.apache.pinot.core.query.aggregation.ObjectAggregationResultHolder; +import org.apache.pinot.core.query.aggregation.groupby.GroupByResultHolder; +import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder; +import org.apache.pinot.segment.spi.AggregationFunctionType; +import org.apache.pinot.segment.spi.index.reader.Dictionary; +import org.apache.pinot.spi.data.FieldSpec.DataType; + + +/** + * This function is used for Mode calculations. + * <p>The function can be used as MODE(expression, multiModeReducerType) + * <p>Following arguments are supported: + * <ul> + * <li>Expression: expression that contains the column to be calculated mode on, can be any Numeric column</li> + * <li>MultiModeReducerType (optional): the reducer to use in case of multiple modes present in data</li> + * </ul> + */ +@SuppressWarnings({"rawtypes", "unchecked"}) +public class ModeAggregationFunction extends BaseSingleInputAggregationFunction<Map<? extends Number, Long>, Double> { + + private static final double DEFAULT_FINAL_RESULT = Double.NEGATIVE_INFINITY; + + private final MultiModeReducerType _multiModeReducerType; + + public ModeAggregationFunction(List<ExpressionContext> arguments) { + super(arguments.get(0)); + + int numArguments = arguments.size(); + Preconditions.checkArgument(numArguments <= 2, "Mode expects at most 2 arguments, got: %s", numArguments); + if (numArguments > 1) { + _multiModeReducerType = MultiModeReducerType.valueOf(arguments.get(1).getLiteral()); + } else { + _multiModeReducerType = MultiModeReducerType.MIN; + } + } + + /** + * Helper method to create a value map for the given value type. + */ + private static Map<? extends Number, Long> getValueMap(DataType valueType) { + switch (valueType) { + case INT: + return new Int2LongOpenHashMap(); + case LONG: + return new Long2LongOpenHashMap(); + case FLOAT: + return new Float2LongOpenHashMap(); + case DOUBLE: + return new Double2LongOpenHashMap(); + default: + throw new IllegalStateException("Illegal data type for MODE aggregation function: " + valueType); + } + } + + /** + * Returns the value map from the result holder or creates a new one if it does not exist. + */ + private static Map<? extends Number, Long> getValueMap(AggregationResultHolder aggregationResultHolder, + DataType valueType) { + Map<? extends Number, Long> valueMap = aggregationResultHolder.getResult(); + if (valueMap == null) { + valueMap = getValueMap(valueType); + aggregationResultHolder.setValue(valueMap); + } + return valueMap; + } + + /** + * Helper method to set INT value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int groupKey, int value) { + Int2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey); + if (valueMap == null) { + valueMap = new Int2LongOpenHashMap(); + groupByResultHolder.setValueForKey(groupKey, valueMap); + } + valueMap.merge(value, 1, Long::sum); + } + + /** + * Helper method to set LONG value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int groupKey, long value) { + Long2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey); + if (valueMap == null) { + valueMap = new Long2LongOpenHashMap(); + groupByResultHolder.setValueForKey(groupKey, valueMap); + } + valueMap.merge(value, 1, Long::sum); + } + + /** + * Helper method to set FLOAT value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int groupKey, float value) { + Float2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey); + if (valueMap == null) { + valueMap = new Float2LongOpenHashMap(); + groupByResultHolder.setValueForKey(groupKey, valueMap); + } + valueMap.merge(value, 1, Long::sum); + } + + /** + * Helper method to set DOUBLE value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int groupKey, double value) { + Double2LongOpenHashMap valueMap = groupByResultHolder.getResult(groupKey); + if (valueMap == null) { + valueMap = new Double2LongOpenHashMap(); + groupByResultHolder.setValueForKey(groupKey, valueMap); + } + valueMap.merge(value, 1, Long::sum); + } + + /** + * Returns the dictionary id count map from the result holder or creates a new one if it does not exist. + */ + protected static Int2IntOpenHashMap getDictIdCountMap(AggregationResultHolder aggregationResultHolder, + Dictionary dictionary) { + ModeAggregationFunction.DictIdsWrapper dictIdsWrapper = aggregationResultHolder.getResult(); + if (dictIdsWrapper == null) { + dictIdsWrapper = new ModeAggregationFunction.DictIdsWrapper(dictionary); + aggregationResultHolder.setValue(dictIdsWrapper); + } + return dictIdsWrapper._dictIdCountMap; + } + + /** + * Returns the dictionary id count map for the given group key or creates a new one if it does not exist. + */ + protected static Int2IntOpenHashMap getDictIdCountMap(GroupByResultHolder groupByResultHolder, int groupKey, + Dictionary dictionary) { + ModeAggregationFunction.DictIdsWrapper dictIdsWrapper = groupByResultHolder.getResult(groupKey); + if (dictIdsWrapper == null) { + dictIdsWrapper = new ModeAggregationFunction.DictIdsWrapper(dictionary); + groupByResultHolder.setValueForKey(groupKey, dictIdsWrapper); + } + return dictIdsWrapper._dictIdCountMap; + } + + /** + * Helper method to read dictionary and convert dictionary ids to values for dictionary-encoded expression. + */ + private static Map<? extends Number, Long> convertToValueMap(DictIdsWrapper dictIdsWrapper) { + Dictionary dictionary = dictIdsWrapper._dictionary; + Int2IntOpenHashMap dictIdCountMap = dictIdsWrapper._dictIdCountMap; + int numValues = dictIdCountMap.size(); + ObjectIterator<Int2IntMap.Entry> iterator = Int2IntMaps.fastIterator(dictIdCountMap); + DataType storedType = dictionary.getValueType(); + switch (storedType) { + case INT: + Int2LongOpenHashMap intValueMap = new Int2LongOpenHashMap(numValues); + while (iterator.hasNext()) { + Int2IntMap.Entry next = iterator.next(); + intValueMap.put(dictionary.getIntValue(next.getIntKey()), next.getIntValue()); + } + return intValueMap; + case LONG: + Long2LongOpenHashMap longValueMap = new Long2LongOpenHashMap(numValues); + while (iterator.hasNext()) { + Int2IntMap.Entry next = iterator.next(); + longValueMap.put(dictionary.getLongValue(next.getIntKey()), next.getIntValue()); + } + return longValueMap; + case FLOAT: + Float2LongOpenHashMap floatValueMap = new Float2LongOpenHashMap(numValues); + while (iterator.hasNext()) { + Int2IntMap.Entry next = iterator.next(); + floatValueMap.put(dictionary.getFloatValue(next.getIntKey()), next.getIntValue()); + } + return floatValueMap; + case DOUBLE: + Double2LongOpenHashMap doubleValueMap = new Double2LongOpenHashMap(numValues); + while (iterator.hasNext()) { + Int2IntMap.Entry next = iterator.next(); + doubleValueMap.put(dictionary.getDoubleValue(next.getIntKey()), next.getIntValue()); + } + return doubleValueMap; + default: + throw new IllegalStateException("Illegal data type for MODE aggregation function: " + storedType); + } + } + + /** + * Helper method to extract segment level intermediate result from the inner segment result. + */ + private static Map<? extends Number, Long> extractIntermediateResult(@Nullable Object result) { + if (result == null) { + // NOTE: Return an empty Int2LongOpenHashMap for empty result. + return new Int2LongOpenHashMap(); + } + + if (result instanceof DictIdsWrapper) { + // For dictionary-encoded expression, convert dictionary ids to values + return convertToValueMap((DictIdsWrapper) result); + } + assert result instanceof Map; + // For non-dictionary-encoded expression, directly return the value set + return (Map) result; + } + + @Override + public AggregationFunctionType getType() { + return AggregationFunctionType.MODE; + } + + @Override + public AggregationResultHolder createAggregationResultHolder() { + return new ObjectAggregationResultHolder(); + } + + @Override + public GroupByResultHolder createGroupByResultHolder(int initialCapacity, int maxCapacity) { + return new ObjectGroupByResultHolder(initialCapacity, maxCapacity); + } + + @Override + public void aggregate(int length, AggregationResultHolder aggregationResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); + + // For dictionary-encoded expression, store dictionary ids into the dictId map + Dictionary dictionary = blockValSet.getDictionary(); + if (dictionary != null) { + int[] dictIds = blockValSet.getDictionaryIdsSV(); + Int2IntOpenHashMap dictIdValueMap = getDictIdCountMap(aggregationResultHolder, dictionary); + for (int i = 0; i < length; i++) { + dictIdValueMap.merge(dictIds[i], 1, Integer::sum); + } + return; + } + + // For non-dictionary-encoded expression, store values into the value map + DataType storedType = blockValSet.getValueType().getStoredType(); + Map<? extends Number, Long> valueMap = getValueMap(aggregationResultHolder, storedType); + switch (storedType) { + case INT: + Int2LongOpenHashMap intMap = (Int2LongOpenHashMap) valueMap; + int[] intValues = blockValSet.getIntValuesSV(); + for (int i = 0; i < length; i++) { + intMap.merge(intValues[i], 1, Long::sum); + } + break; + case LONG: + Long2LongOpenHashMap longMap = (Long2LongOpenHashMap) valueMap; + long[] longValues = blockValSet.getLongValuesSV(); + for (int i = 0; i < length; i++) { + longMap.merge(longValues[i], 1, Long::sum); + } + break; + case FLOAT: + Float2LongOpenHashMap floatMap = (Float2LongOpenHashMap) valueMap; + float[] floatValues = blockValSet.getFloatValuesSV(); + for (int i = 0; i < length; i++) { + floatMap.merge(floatValues[i], 1, Long::sum); + } + break; + case DOUBLE: + Double2LongOpenHashMap doubleMap = (Double2LongOpenHashMap) valueMap; + double[] doubleValues = blockValSet.getDoubleValuesSV(); + for (int i = 0; i < length; i++) { + doubleMap.merge(doubleValues[i], 1, Long::sum); + } + break; + default: + throw new IllegalStateException("Illegal data type for MODE aggregation function: " + storedType); + } + } + + @Override + public void aggregateGroupBySV(int length, int[] groupKeyArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); + + // For dictionary-encoded expression, store dictionary ids into the dictId map + Dictionary dictionary = blockValSet.getDictionary(); + if (dictionary != null) { + int[] dictIds = blockValSet.getDictionaryIdsSV(); + for (int i = 0; i < length; i++) { + getDictIdCountMap(groupByResultHolder, groupKeyArray[i], dictionary).merge(dictIds[i], 1, Integer::sum); + } + return; + } + + // For non-dictionary-encoded expression, store values into the value map + DataType storedType = blockValSet.getValueType().getStoredType(); + switch (storedType) { + case INT: + int[] intValues = blockValSet.getIntValuesSV(); + for (int i = 0; i < length; i++) { + setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], intValues[i]); + } + break; + case LONG: + long[] longValues = blockValSet.getLongValuesSV(); + for (int i = 0; i < length; i++) { + setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], longValues[i]); + } + break; + case FLOAT: + float[] floatValues = blockValSet.getFloatValuesSV(); + for (int i = 0; i < length; i++) { + setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], floatValues[i]); + } + break; + case DOUBLE: + double[] doubleValues = blockValSet.getDoubleValuesSV(); + for (int i = 0; i < length; i++) { + setValueForGroupKeys(groupByResultHolder, groupKeyArray[i], doubleValues[i]); + } + break; + default: + throw new IllegalStateException("Illegal data type for MODE aggregation function: " + storedType); + } + } + + @Override + public void aggregateGroupByMV(int length, int[][] groupKeysArray, GroupByResultHolder groupByResultHolder, + Map<ExpressionContext, BlockValSet> blockValSetMap) { + BlockValSet blockValSet = blockValSetMap.get(_expression); + + // For dictionary-encoded expression, store dictionary ids into the dictId map + Dictionary dictionary = blockValSet.getDictionary(); + if (dictionary != null) { + int[] dictIds = blockValSet.getDictionaryIdsSV(); + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + getDictIdCountMap(groupByResultHolder, groupKey, dictionary).merge(dictIds[i], 1, Integer::sum); + } + } + return; + } + + // For non-dictionary-encoded expression, store values into the value map + DataType storedType = blockValSet.getValueType().getStoredType(); + switch (storedType) { + case INT: + int[] intValues = blockValSet.getIntValuesSV(); + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + setValueForGroupKeys(groupByResultHolder, groupKey, intValues[i]); + } + } + break; + case LONG: + long[] longValues = blockValSet.getLongValuesSV(); + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + setValueForGroupKeys(groupByResultHolder, groupKey, longValues[i]); + } + } + break; + case FLOAT: + float[] floatValues = blockValSet.getFloatValuesSV(); + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + setValueForGroupKeys(groupByResultHolder, groupKey, floatValues[i]); + } + } + break; + case DOUBLE: + double[] doubleValues = blockValSet.getDoubleValuesSV(); + for (int i = 0; i < length; i++) { + for (int groupKey : groupKeysArray[i]) { + setValueForGroupKeys(groupByResultHolder, groupKey, doubleValues[i]); + } + } + break; + default: + throw new IllegalStateException("Illegal data type for MODE aggregation function: " + storedType); + } + } + + @Override + public Map<? extends Number, Long> extractAggregationResult(AggregationResultHolder aggregationResultHolder) { + return extractIntermediateResult(aggregationResultHolder.getResult()); + } + + @Override + public Map<? extends Number, Long> extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { + return extractIntermediateResult(groupByResultHolder.getResult(groupKey)); + } + + @Override + public Map<? extends Number, Long> merge(Map<? extends Number, Long> intermediateResult1, + Map<? extends Number, Long> intermediateResult2) { + if (intermediateResult1.isEmpty()) { + return intermediateResult2; + } + if (intermediateResult2.isEmpty()) { + return intermediateResult1; + } + if (intermediateResult1 instanceof Int2LongOpenHashMap && intermediateResult2 instanceof Int2LongOpenHashMap) { + ((Int2LongOpenHashMap) intermediateResult2).int2LongEntrySet().fastForEach( + e -> ((Int2LongOpenHashMap) intermediateResult1).merge(e.getIntKey(), e.getLongValue(), Long::sum)); + } else if (intermediateResult1 instanceof Long2LongOpenHashMap + && intermediateResult2 instanceof Long2LongOpenHashMap) { + ((Long2LongOpenHashMap) intermediateResult2).long2LongEntrySet().fastForEach( + e -> ((Long2LongOpenHashMap) intermediateResult1).merge(e.getLongKey(), e.getLongValue(), Long::sum)); + } else if (intermediateResult1 instanceof Float2LongOpenHashMap + && intermediateResult2 instanceof Float2LongOpenHashMap) { + ((Float2LongOpenHashMap) intermediateResult2).float2LongEntrySet().fastForEach( + e -> ((Float2LongOpenHashMap) intermediateResult1).merge(e.getFloatKey(), e.getLongValue(), Long::sum)); + } else if (intermediateResult1 instanceof Double2LongOpenHashMap + && intermediateResult2 instanceof Double2LongOpenHashMap) { + ((Double2LongOpenHashMap) intermediateResult2).double2LongEntrySet().fastForEach( + e -> ((Double2LongOpenHashMap) intermediateResult1).merge(e.getDoubleKey(), e.getLongValue(), Long::sum)); + } else { + throw new IllegalStateException( + "Illegal data type for Intermediate Result of MODE aggregation function: " + intermediateResult1.getClass() + .getSimpleName() + ", " + intermediateResult2.getClass().getSimpleName()); + } + return intermediateResult1; + } + + @Override + public boolean isIntermediateResultComparable() { + return false; + } + + @Override + public ColumnDataType getIntermediateResultColumnType() { + return ColumnDataType.OBJECT; + } + + @Override + public ColumnDataType getFinalResultColumnType() { + return ColumnDataType.DOUBLE; + } + + @Override + public Double extractFinalResult(Map<? extends Number, Long> intermediateResult) { + if (intermediateResult.isEmpty()) { + return DEFAULT_FINAL_RESULT; + } else if (intermediateResult instanceof Int2LongOpenHashMap) { + return extractFinalResult((Int2LongOpenHashMap) intermediateResult); + } else if (intermediateResult instanceof Long2LongOpenHashMap) { + return extractFinalResult((Long2LongOpenHashMap) intermediateResult); + } else if (intermediateResult instanceof Float2LongOpenHashMap) { + return extractFinalResult((Float2LongOpenHashMap) intermediateResult); + } else if (intermediateResult instanceof Double2LongOpenHashMap) { + return extractFinalResult((Double2LongOpenHashMap) intermediateResult); + } else { + throw new IllegalStateException( + "Illegal data type for Intermediate Result of MODE aggregation function: " + intermediateResult.getClass() + .getSimpleName()); + } + } + + public double extractFinalResult(Int2LongOpenHashMap intermediateResult) { + ObjectIterator<Int2LongMap.Entry> iterator = intermediateResult.int2LongEntrySet().fastIterator(); + Int2LongMap.Entry first = iterator.next(); + long maxFrequency = first.getLongValue(); + switch (_multiModeReducerType) { + case MIN: + int min = first.getIntKey(); + while (iterator.hasNext()) { + Int2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency && min > next.getIntKey())) { + maxFrequency = next.getLongValue(); + min = next.getIntKey(); + } + } + return min; + case MAX: + int max = first.getIntKey(); + while (iterator.hasNext()) { + Int2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency && max < next.getIntKey())) { + maxFrequency = next.getLongValue(); + max = next.getIntKey(); + } + } + return max; + case AVG: + double sum = first.getIntKey(); + int count = 1; + while (iterator.hasNext()) { + Int2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency)) { + maxFrequency = next.getLongValue(); + sum = next.getIntKey(); + count = 1; + } else if (next.getLongValue() == maxFrequency) { + sum += next.getIntKey(); + count += 1; + } + } + return sum / count; + default: + throw new IllegalStateException("Illegal reducer type for MODE aggregation function: " + _multiModeReducerType); + } + } + + public double extractFinalResult(Long2LongOpenHashMap intermediateResult) { + ObjectIterator<Long2LongMap.Entry> iterator = intermediateResult.long2LongEntrySet().fastIterator(); + Long2LongMap.Entry first = iterator.next(); + long maxFrequency = first.getLongValue(); + switch (_multiModeReducerType) { + case MIN: + long min = first.getLongKey(); + while (iterator.hasNext()) { + Long2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency + && min > next.getLongKey())) { + maxFrequency = next.getLongValue(); + min = next.getLongKey(); + } + } + return min; + case MAX: + long max = first.getLongKey(); + while (iterator.hasNext()) { + Long2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency + && max < next.getLongKey())) { + maxFrequency = next.getLongValue(); + max = next.getLongKey(); + } + } + return max; + case AVG: + double sum = first.getLongKey(); + int count = 1; + while (iterator.hasNext()) { + Long2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency)) { + maxFrequency = next.getLongValue(); + sum = next.getLongKey(); + count = 1; + } else if (next.getLongValue() == maxFrequency) { + sum += next.getLongKey(); + count += 1; + } + } + return sum / count; + default: + throw new IllegalStateException("Illegal reducer type for MODE aggregation function: " + _multiModeReducerType); + } + } + + public double extractFinalResult(Float2LongOpenHashMap intermediateResult) { + ObjectIterator<Float2LongMap.Entry> iterator = intermediateResult.float2LongEntrySet().fastIterator(); + Float2LongMap.Entry first = iterator.next(); + long maxFrequency = first.getLongValue(); + switch (_multiModeReducerType) { + case MIN: + float min = first.getFloatKey(); + while (iterator.hasNext()) { + Float2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency + && min > next.getFloatKey())) { + maxFrequency = next.getLongValue(); + min = next.getFloatKey(); + } + } + return min; + case MAX: + float max = first.getFloatKey(); + while (iterator.hasNext()) { + Float2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency + && max < next.getFloatKey())) { + maxFrequency = next.getLongValue(); + max = next.getFloatKey(); + } + } + return max; + case AVG: + double sum = first.getFloatKey(); + int count = 1; + while (iterator.hasNext()) { + Float2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency)) { + maxFrequency = next.getLongValue(); + sum = next.getFloatKey(); + count = 1; + } else if (next.getLongValue() == maxFrequency) { + sum += next.getFloatKey(); + count += 1; + } + } + return sum / count; + default: + throw new IllegalStateException("Illegal reducer type for MODE aggregation function: " + _multiModeReducerType); + } + } + + public Double extractFinalResult(Double2LongOpenHashMap intermediateResult) { + ObjectIterator<Double2LongMap.Entry> iterator = intermediateResult.double2LongEntrySet().fastIterator(); + Double2LongMap.Entry first = iterator.next(); + long maxFrequency = first.getLongValue(); + switch (_multiModeReducerType) { + case MIN: + double min = first.getDoubleKey(); + while (iterator.hasNext()) { + Double2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency + && min > next.getDoubleKey())) { + maxFrequency = next.getLongValue(); + min = next.getDoubleKey(); + } + } + return min; + case MAX: + double max = first.getDoubleKey(); + while (iterator.hasNext()) { + Double2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency) || (next.getLongValue() == maxFrequency + && max < next.getDoubleKey())) { + maxFrequency = next.getLongValue(); + max = next.getDoubleKey(); + } + } + return max; + case AVG: + double sum = first.getDoubleKey(); + int count = 1; + while (iterator.hasNext()) { + Double2LongMap.Entry next = iterator.next(); + if ((next.getLongValue() > maxFrequency)) { + maxFrequency = next.getLongValue(); + sum = next.getDoubleKey(); + count = 1; + } else if (next.getLongValue() == maxFrequency) { + sum += next.getDoubleKey(); + count += 1; + } + } + return sum / count; + default: + throw new IllegalStateException("Illegal reducer type for MODE aggregation function: " + _multiModeReducerType); + } + } + + private enum MultiModeReducerType { + MIN, MAX, AVG + } + + private static final class DictIdsWrapper { + + final Dictionary _dictionary; + final Int2IntOpenHashMap _dictIdCountMap; + + private DictIdsWrapper(Dictionary dictionary) { + _dictionary = dictionary; + _dictIdCountMap = new Int2IntOpenHashMap(); + } + } +} diff --git a/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java b/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java index b019698..8e4c6df 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/common/ObjectSerDeUtilsTest.java @@ -20,9 +20,13 @@ package org.apache.pinot.core.common; import com.clearspring.analytics.stream.cardinality.HyperLogLog; import com.tdunning.math.stats.TDigest; +import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap; import java.util.HashMap; import java.util.Map; import java.util.Random; @@ -202,4 +206,68 @@ public class ObjectSerDeUtilsTest { } } } + + @Test + public void testInt2LongMap() { + for (int i = 0; i < NUM_ITERATIONS; i++) { + int size = RANDOM.nextInt(100); + Int2LongOpenHashMap expected = new Int2LongOpenHashMap(size); + for (int j = 0; j < size; j++) { + expected.put(RANDOM.nextInt(20), RANDOM.nextLong()); + } + + byte[] bytes = ObjectSerDeUtils.serialize(expected); + Int2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes, ObjectSerDeUtils.ObjectType.Int2LongMap); + + assertEquals(actual, expected, ERROR_MESSAGE); + } + } + + @Test + public void testLong2LongMap() { + for (int i = 0; i < NUM_ITERATIONS; i++) { + int size = RANDOM.nextInt(100); + Long2LongOpenHashMap expected = new Long2LongOpenHashMap(size); + for (int j = 0; j < size; j++) { + expected.put(RANDOM.nextLong(), RANDOM.nextLong()); + } + + byte[] bytes = ObjectSerDeUtils.serialize(expected); + Long2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes, ObjectSerDeUtils.ObjectType.Long2LongMap); + + assertEquals(actual, expected, ERROR_MESSAGE); + } + } + + @Test + public void testFloat2LongMap() { + for (int i = 0; i < NUM_ITERATIONS; i++) { + int size = RANDOM.nextInt(100); + Float2LongOpenHashMap expected = new Float2LongOpenHashMap(size); + for (int j = 0; j < size; j++) { + expected.put(RANDOM.nextFloat(), RANDOM.nextLong()); + } + + byte[] bytes = ObjectSerDeUtils.serialize(expected); + Float2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes, ObjectSerDeUtils.ObjectType.Float2LongMap); + + assertEquals(actual, expected, ERROR_MESSAGE); + } + } + + @Test + public void testDouble2LongMap() { + for (int i = 0; i < NUM_ITERATIONS; i++) { + int size = RANDOM.nextInt(100); + Double2LongOpenHashMap expected = new Double2LongOpenHashMap(size); + for (int j = 0; j < size; j++) { + expected.put(RANDOM.nextDouble(), RANDOM.nextLong()); + } + + byte[] bytes = ObjectSerDeUtils.serialize(expected); + Double2LongOpenHashMap actual = ObjectSerDeUtils.deserialize(bytes, ObjectSerDeUtils.ObjectType.Double2LongMap); + + assertEquals(actual, expected, ERROR_MESSAGE); + } + } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java index bb847b1..b855806 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/aggregation/function/AggregationFunctionFactoryTest.java @@ -80,6 +80,13 @@ public class AggregationFunctionFactoryTest { assertEquals(aggregationFunction.getColumnName(), "avg_column"); assertEquals(aggregationFunction.getResultColumnName(), function.toString()); + function = getFunction("MoDe"); + aggregationFunction = AggregationFunctionFactory.getAggregationFunction(function, DUMMY_QUERY_CONTEXT); + assertTrue(aggregationFunction instanceof ModeAggregationFunction); + assertEquals(aggregationFunction.getType(), AggregationFunctionType.MODE); + assertEquals(aggregationFunction.getColumnName(), "mode_column"); + assertEquals(aggregationFunction.getResultColumnName(), function.toString()); + function = getFunction("MiNmAxRaNgE"); aggregationFunction = AggregationFunctionFactory.getAggregationFunction(function, DUMMY_QUERY_CONTEXT); assertTrue(aggregationFunction instanceof MinMaxRangeAggregationFunction); diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/ModeQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/ModeQueriesTest.java new file mode 100644 index 0000000..381127a --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/queries/ModeQueriesTest.java @@ -0,0 +1,949 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.queries; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import it.unimi.dsi.fastutil.doubles.Double2LongOpenHashMap; +import it.unimi.dsi.fastutil.floats.Float2LongOpenHashMap; +import it.unimi.dsi.fastutil.ints.Int2LongOpenHashMap; +import it.unimi.dsi.fastutil.longs.Long2LongOpenHashMap; +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; +import java.util.stream.Collectors; +import org.apache.commons.io.FileUtils; +import org.apache.pinot.common.response.broker.AggregationResult; +import org.apache.pinot.common.response.broker.BrokerResponseNative; +import org.apache.pinot.common.response.broker.GroupByResult; +import org.apache.pinot.common.utils.HashUtil; +import org.apache.pinot.core.common.Operator; +import org.apache.pinot.core.operator.blocks.IntermediateResultsBlock; +import org.apache.pinot.core.operator.query.AggregationGroupByOperator; +import org.apache.pinot.core.operator.query.AggregationOperator; +import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult; +import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; +import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader; +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl; +import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader; +import org.apache.pinot.segment.spi.ImmutableSegment; +import org.apache.pinot.segment.spi.IndexSegment; +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig; +import org.apache.pinot.spi.config.table.TableConfig; +import org.apache.pinot.spi.config.table.TableType; +import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.apache.pinot.spi.data.Schema; +import org.apache.pinot.spi.data.readers.GenericRow; +import org.apache.pinot.spi.utils.ReadMode; +import org.apache.pinot.spi.utils.builder.TableConfigBuilder; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + + +/** + * Queries test for MODE queries. + */ +@SuppressWarnings("rawtypes") +public class ModeQueriesTest extends BaseQueriesTest { + private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "ModeQueriesTest"); + private static final String RAW_TABLE_NAME = "testTable"; + private static final String SEGMENT_NAME = "testSegment"; + private static final Random RANDOM = new Random(); + + private static final int NUM_RECORDS = 2000; + private static final int MAX_VALUE = 1000; + + private static final String INT_COLUMN = "intColumn"; + private static final String INT_MV_COLUMN = "intMvColumn"; + private static final String LONG_COLUMN = "longColumn"; + private static final String FLOAT_COLUMN = "floatColumn"; + private static final String DOUBLE_COLUMN = "doubleColumn"; + private static final String INT_NO_DICT_COLUMN = "intNoDictColumn"; + private static final String LONG_NO_DICT_COLUMN = "longNoDictColumn"; + private static final String FLOAT_NO_DICT_COLUMN = "floatNoDictColumn"; + private static final String DOUBLE_NO_DICT_COLUMN = "doubleNoDictColumn"; + private static final Schema SCHEMA = new Schema.SchemaBuilder().addSingleValueDimension(INT_COLUMN, DataType.INT) + .addMultiValueDimension(INT_MV_COLUMN, DataType.INT).addSingleValueDimension(INT_NO_DICT_COLUMN, DataType.INT) + .addSingleValueDimension(LONG_COLUMN, DataType.LONG).addSingleValueDimension(LONG_NO_DICT_COLUMN, DataType.LONG) + .addSingleValueDimension(FLOAT_COLUMN, DataType.FLOAT) + .addSingleValueDimension(FLOAT_NO_DICT_COLUMN, DataType.FLOAT) + .addSingleValueDimension(DOUBLE_COLUMN, DataType.DOUBLE) + .addSingleValueDimension(DOUBLE_NO_DICT_COLUMN, DataType.DOUBLE).build(); + private static final TableConfig TABLE_CONFIG = new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME) + .setNoDictionaryColumns( + Lists.newArrayList(INT_NO_DICT_COLUMN, LONG_NO_DICT_COLUMN, FLOAT_NO_DICT_COLUMN, DOUBLE_NO_DICT_COLUMN)) + .build(); + private static final double DELTA = 0.00001; + + private HashMap<Integer, Long> _values; + private Double _expectedResultMin; + private Double _expectedResultMax; + private Double _expectedResultAvg; + private IndexSegment _indexSegment; + private List<IndexSegment> _indexSegments; + + @Override + protected String getFilter() { + // NOTE: Use a match all filter to switch between DictionaryBasedAggregationOperator and AggregationOperator + return " WHERE intColumn >= 0"; + } + + @Override + protected IndexSegment getIndexSegment() { + return _indexSegment; + } + + @Override + protected List<IndexSegment> getIndexSegments() { + return _indexSegments; + } + + @BeforeClass + public void setUp() + throws Exception { + FileUtils.deleteDirectory(INDEX_DIR); + + List<GenericRow> records = new ArrayList<>(NUM_RECORDS); + int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE); + _values = new HashMap<>(hashMapCapacity); + for (int i = 0; i < NUM_RECORDS; i++) { + int value = RANDOM.nextInt(MAX_VALUE); + GenericRow record = new GenericRow(); + _values.merge(value, 1L, Long::sum); + record.putValue(INT_COLUMN, value); + record.putValue(INT_MV_COLUMN, new Integer[]{value, value}); + record.putValue(INT_NO_DICT_COLUMN, value); + record.putValue(LONG_COLUMN, (long) value); + record.putValue(LONG_NO_DICT_COLUMN, (long) value); + record.putValue(FLOAT_COLUMN, (float) value); + record.putValue(FLOAT_NO_DICT_COLUMN, (float) value); + record.putValue(DOUBLE_COLUMN, (double) value); + record.putValue(DOUBLE_NO_DICT_COLUMN, (double) value); + records.add(record); + } + _expectedResultMin = _values.keySet().stream() + .filter(key -> Objects.equals(_values.get(key), _values.values().stream().max(Long::compareTo).get())) + .mapToDouble(Integer::doubleValue).min().orElse(Double.NEGATIVE_INFINITY); + _expectedResultMax = _values.keySet().stream() + .filter(key -> Objects.equals(_values.get(key), _values.values().stream().max(Long::compareTo).get())) + .mapToDouble(Integer::doubleValue).max().orElse(Double.NEGATIVE_INFINITY); + _expectedResultAvg = _values.keySet().stream() + .filter(key -> Objects.equals(_values.get(key), _values.values().stream().max(Long::compareTo).get())) + .mapToDouble(Integer::doubleValue).average().orElse(Double.NEGATIVE_INFINITY); + + SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA); + segmentGeneratorConfig.setTableName(RAW_TABLE_NAME); + segmentGeneratorConfig.setSegmentName(SEGMENT_NAME); + segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath()); + + SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl(); + driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records)); + driver.build(); + + ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap); + _indexSegment = immutableSegment; + _indexSegments = Arrays.asList(immutableSegment, immutableSegment); + } + + @Test + public void testAggregationOnly() { + String query = "SELECT MODE(intColumn), MODE(longColumn), MODE(floatColumn), MODE(doubleColumn) FROM testTable"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlock = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultsWithoutFilter = resultsBlock.getAggregationResult(); + + operator = getOperatorForPqlQueryWithFilter(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultWithFilter = resultsBlockWithFilter.getAggregationResult(); + + assertNotNull(aggregationResultsWithoutFilter); + assertNotNull(aggregationResultWithFilter); + assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter); + assertTrue(Maps.difference((Int2LongOpenHashMap) aggregationResultsWithoutFilter.get(0), _values).areEqual()); + assertTrue(Maps.difference((Long2LongOpenHashMap) aggregationResultsWithoutFilter.get(1), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().longValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Float2LongOpenHashMap) aggregationResultsWithoutFilter.get(2), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().floatValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Double2LongOpenHashMap) aggregationResultsWithoutFilter.get(3), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().doubleValue(), Map.Entry::getValue))) + .areEqual()); + + // Inter segments (expect 4 * inner segment result) + double[] expectedResults = new double[4]; + for (int i = 0; i < 4; i++) { + expectedResults[i] = _expectedResultMin; + } + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + + brokerResponse = getBrokerResponseForPqlQueryWithFilter(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + } + + @Test + public void testAggregationOnlyNoDictionary() { + String query = + "SELECT MODE(intNoDictColumn), MODE(longNoDictColumn), MODE(floatNoDictColumn), MODE(doubleNoDictColumn) FROM testTable"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlock = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultsWithoutFilter = resultsBlock.getAggregationResult(); + + operator = getOperatorForPqlQueryWithFilter(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultWithFilter = resultsBlockWithFilter.getAggregationResult(); + + assertNotNull(aggregationResultsWithoutFilter); + assertNotNull(aggregationResultWithFilter); + assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter); + assertTrue(Maps.difference((Int2LongOpenHashMap) aggregationResultsWithoutFilter.get(0), _values).areEqual()); + assertTrue(Maps.difference((Long2LongOpenHashMap) aggregationResultsWithoutFilter.get(1), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().longValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Float2LongOpenHashMap) aggregationResultsWithoutFilter.get(2), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().floatValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Double2LongOpenHashMap) aggregationResultsWithoutFilter.get(3), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().doubleValue(), Map.Entry::getValue))) + .areEqual()); + + // Inter segments (expect 4 * inner segment result) + double[] expectedResults = new double[4]; + for (int i = 0; i < 4; i++) { + expectedResults[i] = _expectedResultMin; + } + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + + brokerResponse = getBrokerResponseForPqlQueryWithFilter(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + } + + @Test + public void testAggregationOnlyWithMultiModeReducerOptionMIN() { + String query = + "SELECT MODE(intColumn, 'MIN'), MODE(longColumn, 'MIN'), MODE(floatColumn, 'MIN'), MODE(doubleColumn, 'MIN') FROM testTable"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlock = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultsWithoutFilter = resultsBlock.getAggregationResult(); + + operator = getOperatorForPqlQueryWithFilter(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultWithFilter = resultsBlockWithFilter.getAggregationResult(); + + assertNotNull(aggregationResultsWithoutFilter); + assertNotNull(aggregationResultWithFilter); + assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter); + assertTrue(Maps.difference((Int2LongOpenHashMap) aggregationResultsWithoutFilter.get(0), _values).areEqual()); + assertTrue(Maps.difference((Long2LongOpenHashMap) aggregationResultsWithoutFilter.get(1), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().longValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Float2LongOpenHashMap) aggregationResultsWithoutFilter.get(2), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().floatValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Double2LongOpenHashMap) aggregationResultsWithoutFilter.get(3), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().doubleValue(), Map.Entry::getValue))) + .areEqual()); + + // Inter segments (expect 4 * inner segment result) + double[] expectedResults = new double[4]; + for (int i = 0; i < 4; i++) { + expectedResults[i] = _expectedResultMin; + } + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + + brokerResponse = getBrokerResponseForPqlQueryWithFilter(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + } + + @Test + public void testAggregationOnlyWithMultiModeReducerOptionMAX() { + String query = + "SELECT MODE(intColumn, 'MAX'), MODE(longColumn, 'MAX'), MODE(floatColumn, 'MAX'), MODE(doubleColumn, 'MAX') FROM testTable"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlock = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultsWithoutFilter = resultsBlock.getAggregationResult(); + + operator = getOperatorForPqlQueryWithFilter(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultWithFilter = resultsBlockWithFilter.getAggregationResult(); + + assertNotNull(aggregationResultsWithoutFilter); + assertNotNull(aggregationResultWithFilter); + assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter); + assertTrue(Maps.difference((Int2LongOpenHashMap) aggregationResultsWithoutFilter.get(0), _values).areEqual()); + assertTrue(Maps.difference((Long2LongOpenHashMap) aggregationResultsWithoutFilter.get(1), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().longValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Float2LongOpenHashMap) aggregationResultsWithoutFilter.get(2), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().floatValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Double2LongOpenHashMap) aggregationResultsWithoutFilter.get(3), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().doubleValue(), Map.Entry::getValue))) + .areEqual()); + + // Inter segments (expect 4 * inner segment result) + double[] expectedResults = new double[4]; + for (int i = 0; i < 4; i++) { + expectedResults[i] = _expectedResultMax; + } + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + + brokerResponse = getBrokerResponseForPqlQueryWithFilter(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + } + + @Test + public void testAggregationOnlyWithMultiModeReducerOptionAVG() { + String query = + "SELECT MODE(intColumn, 'AVG'), MODE(longColumn, 'AVG'), MODE(floatColumn, 'AVG'), MODE(doubleColumn, 'AVG') FROM testTable"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlock = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultsWithoutFilter = resultsBlock.getAggregationResult(); + + operator = getOperatorForPqlQueryWithFilter(query); + assertTrue(operator instanceof AggregationOperator); + IntermediateResultsBlock resultsBlockWithFilter = ((AggregationOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + List<Object> aggregationResultWithFilter = resultsBlockWithFilter.getAggregationResult(); + + assertNotNull(aggregationResultsWithoutFilter); + assertNotNull(aggregationResultWithFilter); + assertEquals(aggregationResultsWithoutFilter, aggregationResultWithFilter); + assertTrue(Maps.difference((Int2LongOpenHashMap) aggregationResultsWithoutFilter.get(0), _values).areEqual()); + assertTrue(Maps.difference((Long2LongOpenHashMap) aggregationResultsWithoutFilter.get(1), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().longValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Float2LongOpenHashMap) aggregationResultsWithoutFilter.get(2), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().floatValue(), Map.Entry::getValue))) + .areEqual()); + assertTrue(Maps.difference((Double2LongOpenHashMap) aggregationResultsWithoutFilter.get(3), + _values.entrySet().stream().collect(Collectors.toMap(e -> e.getKey().doubleValue(), Map.Entry::getValue))) + .areEqual()); + + // Inter segments (expect 4 * inner segment result) + double[] expectedResults = new double[4]; + for (int i = 0; i < 4; i++) { + expectedResults[i] = _expectedResultAvg; + } + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + + brokerResponse = getBrokerResponseForPqlQueryWithFilter(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + aggregationResults = brokerResponse.getAggregationResults(); + Assert.assertEquals(aggregationResults.size(), expectedResults.length); + for (int i = 0; i < expectedResults.length; i++) { + AggregationResult aggregationResult = aggregationResults.get(i); + double expectedAggregationResult = expectedResults[i]; + Serializable value = aggregationResult.getValue(); + Assert.assertEquals(Double.parseDouble(value.toString()), expectedAggregationResult, DELTA); + } + } + + @Test + public void testAggregationGroupBySv() { + String query = + "SELECT MODE(intColumn), MODE(longColumn), MODE(floatColumn), MODE(doubleColumn) FROM testTable GROUP BY intColumn"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + Integer key = (Integer) groupKey._keys[0]; + assertTrue(_values.containsKey(key)); + assertTrue( + Maps.difference((Int2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + Collections.singletonMap(key, _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Long2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId), + Collections.singletonMap(key.longValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Float2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId), + Collections.singletonMap(key.floatValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Double2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId), + Collections.singletonMap(key.doubleValue(), _values.get(key))).areEqual()); + } + assertEquals(numGroups, _values.size()); + + // Inter segments (expect 4 * inner segment result) + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + // size of this array will be equal to number of aggregation functions since + // we return each aggregation function separately + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + int numAggregationColumns = aggregationResults.size(); + Assert.assertEquals(numAggregationColumns, 4); + for (AggregationResult aggregationResult : aggregationResults) { + Assert.assertNull(aggregationResult.getValue()); + List<GroupByResult> groupByResults = aggregationResult.getGroupByResult(); + numGroups = groupByResults.size(); + for (int i = 0; i < numGroups; i++) { + GroupByResult groupByResult = groupByResults.get(i); + List<String> group = groupByResult.getGroup(); + assertEquals(group.size(), 1); + assertTrue(_values.containsKey(Integer.parseInt(group.get(0)))); + assertEquals(Double.parseDouble(groupByResult.getValue().toString()), Double.parseDouble(group.get(0)), DELTA); + } + } + } + + @Test + public void testAggregationGroupByMv() { + String query = + "SELECT MODE(intColumn), MODE(longColumn), MODE(floatColumn), MODE(doubleColumn) FROM testTable GROUP BY intMvColumn"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 5 * NUM_RECORDS, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + Integer key = (Integer) groupKey._keys[0]; + assertTrue(_values.containsKey(key)); + assertTrue( + Maps.difference((Int2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + Collections.singletonMap(key, _values.get(key) * 2)).areEqual()); + assertTrue( + Maps.difference((Long2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId), + Collections.singletonMap(key.longValue(), _values.get(key) * 2)).areEqual()); + assertTrue( + Maps.difference((Float2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId), + Collections.singletonMap(key.floatValue(), _values.get(key) * 2)).areEqual()); + assertTrue( + Maps.difference((Double2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId), + Collections.singletonMap(key.doubleValue(), _values.get(key) * 2)).areEqual()); + } + assertEquals(numGroups, _values.size()); + + // Inter segments (expect 4 * inner segment result) + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 5 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + // size of this array will be equal to number of aggregation functions since + // we return each aggregation function separately + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + int numAggregationColumns = aggregationResults.size(); + Assert.assertEquals(numAggregationColumns, 4); + for (AggregationResult aggregationResult : aggregationResults) { + Assert.assertNull(aggregationResult.getValue()); + List<GroupByResult> groupByResults = aggregationResult.getGroupByResult(); + numGroups = groupByResults.size(); + for (int i = 0; i < numGroups; i++) { + GroupByResult groupByResult = groupByResults.get(i); + List<String> group = groupByResult.getGroup(); + assertEquals(group.size(), 1); + assertTrue(_values.containsKey(Integer.parseInt(group.get(0)))); + assertEquals(Double.parseDouble(groupByResult.getValue().toString()), Double.parseDouble(group.get(0)), DELTA); + } + } + } + + @Test + public void testAggregationGroupBySvNoDictionary() { + String query = + "SELECT MODE(intNoDictColumn), MODE(longNoDictColumn), MODE(floatNoDictColumn), MODE(doubleNoDictColumn) FROM testTable GROUP BY intNoDictColumn"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + Integer key = (Integer) groupKey._keys[0]; + assertTrue(_values.containsKey(key)); + assertTrue( + Maps.difference((Int2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + Collections.singletonMap(key, _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Long2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId), + Collections.singletonMap(key.longValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Float2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId), + Collections.singletonMap(key.floatValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Double2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId), + Collections.singletonMap(key.doubleValue(), _values.get(key))).areEqual()); + } + assertEquals(numGroups, _values.size()); + + // Inter segments (expect 4 * inner segment result) + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + // size of this array will be equal to number of aggregation functions since + // we return each aggregation function separately + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + int numAggregationColumns = aggregationResults.size(); + Assert.assertEquals(numAggregationColumns, 4); + for (AggregationResult aggregationResult : aggregationResults) { + Assert.assertNull(aggregationResult.getValue()); + List<GroupByResult> groupByResults = aggregationResult.getGroupByResult(); + numGroups = groupByResults.size(); + for (int i = 0; i < numGroups; i++) { + GroupByResult groupByResult = groupByResults.get(i); + List<String> group = groupByResult.getGroup(); + assertEquals(group.size(), 1); + assertTrue(_values.containsKey(Integer.parseInt(group.get(0)))); + assertEquals(Double.parseDouble(groupByResult.getValue().toString()), Double.parseDouble(group.get(0)), DELTA); + } + } + } + + @Test + public void testAggregationGroupByMvNoDictionary() { + String query = + "SELECT MODE(intNoDictColumn), MODE(longNoDictColumn), MODE(floatNoDictColumn), MODE(doubleNoDictColumn) FROM testTable GROUP BY intMvColumn"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 5 * NUM_RECORDS, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + Integer key = (Integer) groupKey._keys[0]; + assertTrue(_values.containsKey(key)); + assertTrue( + Maps.difference((Int2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + Collections.singletonMap(key, _values.get(key) * 2)).areEqual()); + assertTrue( + Maps.difference((Long2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId), + Collections.singletonMap(key.longValue(), _values.get(key) * 2)).areEqual()); + assertTrue( + Maps.difference((Float2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId), + Collections.singletonMap(key.floatValue(), _values.get(key) * 2)).areEqual()); + assertTrue( + Maps.difference((Double2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId), + Collections.singletonMap(key.doubleValue(), _values.get(key) * 2)).areEqual()); + } + assertEquals(numGroups, _values.size()); + + // Inter segments (expect 4 * inner segment result) + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 5 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + // size of this array will be equal to number of aggregation functions since + // we return each aggregation function separately + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + int numAggregationColumns = aggregationResults.size(); + Assert.assertEquals(numAggregationColumns, 4); + for (AggregationResult aggregationResult : aggregationResults) { + Assert.assertNull(aggregationResult.getValue()); + List<GroupByResult> groupByResults = aggregationResult.getGroupByResult(); + numGroups = groupByResults.size(); + for (int i = 0; i < numGroups; i++) { + GroupByResult groupByResult = groupByResults.get(i); + List<String> group = groupByResult.getGroup(); + assertEquals(group.size(), 1); + assertTrue(_values.containsKey(Integer.parseInt(group.get(0)))); + assertEquals(Double.parseDouble(groupByResult.getValue().toString()), Double.parseDouble(group.get(0)), DELTA); + } + } + } + + @Test + public void testAggregationGroupBySvWithMultiModeReducerOptionMIN() { + String query = + "SELECT MODE(intColumn, 'MIN'), MODE(longColumn, 'MIN'), MODE(floatColumn, 'MIN'), MODE(doubleColumn, 'MIN') FROM testTable GROUP BY intColumn"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + Integer key = (Integer) groupKey._keys[0]; + assertTrue(_values.containsKey(key)); + assertTrue( + Maps.difference((Int2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + Collections.singletonMap(key, _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Long2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId), + Collections.singletonMap(key.longValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Float2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId), + Collections.singletonMap(key.floatValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Double2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId), + Collections.singletonMap(key.doubleValue(), _values.get(key))).areEqual()); + } + assertEquals(numGroups, _values.size()); + + // Inter segments (expect 4 * inner segment result) + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + // size of this array will be equal to number of aggregation functions since + // we return each aggregation function separately + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + int numAggregationColumns = aggregationResults.size(); + Assert.assertEquals(numAggregationColumns, 4); + for (AggregationResult aggregationResult : aggregationResults) { + Assert.assertNull(aggregationResult.getValue()); + List<GroupByResult> groupByResults = aggregationResult.getGroupByResult(); + numGroups = groupByResults.size(); + for (int i = 0; i < numGroups; i++) { + GroupByResult groupByResult = groupByResults.get(i); + List<String> group = groupByResult.getGroup(); + assertEquals(group.size(), 1); + assertTrue(_values.containsKey(Integer.parseInt(group.get(0)))); + assertEquals(Double.parseDouble(groupByResult.getValue().toString()), Double.parseDouble(group.get(0)), DELTA); + } + } + } + + @Test + public void testAggregationGroupBySvWithMultiModeReducerOptionMAX() { + String query = + "SELECT MODE(intColumn, 'MAX'), MODE(longColumn, 'MAX'), MODE(floatColumn, 'MAX'), MODE(doubleColumn, 'MAX') FROM testTable GROUP BY intColumn"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + Integer key = (Integer) groupKey._keys[0]; + assertTrue(_values.containsKey(key)); + assertTrue( + Maps.difference((Int2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + Collections.singletonMap(key, _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Long2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId), + Collections.singletonMap(key.longValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Float2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId), + Collections.singletonMap(key.floatValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Double2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId), + Collections.singletonMap(key.doubleValue(), _values.get(key))).areEqual()); + } + assertEquals(numGroups, _values.size()); + + // Inter segments (expect 4 * inner segment result) + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + // size of this array will be equal to number of aggregation functions since + // we return each aggregation function separately + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + int numAggregationColumns = aggregationResults.size(); + Assert.assertEquals(numAggregationColumns, 4); + for (AggregationResult aggregationResult : aggregationResults) { + Assert.assertNull(aggregationResult.getValue()); + List<GroupByResult> groupByResults = aggregationResult.getGroupByResult(); + numGroups = groupByResults.size(); + for (int i = 0; i < numGroups; i++) { + GroupByResult groupByResult = groupByResults.get(i); + List<String> group = groupByResult.getGroup(); + assertEquals(group.size(), 1); + assertTrue(_values.containsKey(Integer.parseInt(group.get(0)))); + assertEquals(Double.parseDouble(groupByResult.getValue().toString()), Double.parseDouble(group.get(0)), DELTA); + } + } + } + + @Test + public void testAggregationGroupBySvWithMultiModeReducerOptionAVG() { + String query = + "SELECT MODE(intColumn, 'AVG'), MODE(longColumn, 'AVG'), MODE(floatColumn, 'AVG'), MODE(doubleColumn, 'AVG') FROM testTable GROUP BY intColumn"; + + // Inner segment + Operator operator = getOperatorForPqlQuery(query); + assertTrue(operator instanceof AggregationGroupByOperator); + IntermediateResultsBlock resultsBlock = ((AggregationGroupByOperator) operator).nextBlock(); + QueriesTestUtils.testInnerSegmentExecutionStatistics(operator.getExecutionStatistics(), NUM_RECORDS, 0, + 4 * NUM_RECORDS, NUM_RECORDS); + AggregationGroupByResult aggregationGroupByResult = resultsBlock.getAggregationGroupByResult(); + assertNotNull(aggregationGroupByResult); + int numGroups = 0; + Iterator<GroupKeyGenerator.GroupKey> groupKeyIterator = aggregationGroupByResult.getGroupKeyIterator(); + while (groupKeyIterator.hasNext()) { + numGroups++; + GroupKeyGenerator.GroupKey groupKey = groupKeyIterator.next(); + Integer key = (Integer) groupKey._keys[0]; + assertTrue(_values.containsKey(key)); + assertTrue( + Maps.difference((Int2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(0, groupKey._groupId), + Collections.singletonMap(key, _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Long2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(1, groupKey._groupId), + Collections.singletonMap(key.longValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Float2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(2, groupKey._groupId), + Collections.singletonMap(key.floatValue(), _values.get(key))).areEqual()); + assertTrue( + Maps.difference((Double2LongOpenHashMap) aggregationGroupByResult.getResultForGroupId(3, groupKey._groupId), + Collections.singletonMap(key.doubleValue(), _values.get(key))).areEqual()); + } + assertEquals(numGroups, _values.size()); + + // Inter segments (expect 4 * inner segment result) + BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); + Assert.assertEquals(brokerResponse.getNumDocsScanned(), 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getNumEntriesScannedInFilter(), 0); + Assert.assertEquals(brokerResponse.getNumEntriesScannedPostFilter(), 4 * 4 * NUM_RECORDS); + Assert.assertEquals(brokerResponse.getTotalDocs(), 4 * NUM_RECORDS); + // size of this array will be equal to number of aggregation functions since + // we return each aggregation function separately + List<AggregationResult> aggregationResults = brokerResponse.getAggregationResults(); + int numAggregationColumns = aggregationResults.size(); + Assert.assertEquals(numAggregationColumns, 4); + for (AggregationResult aggregationResult : aggregationResults) { + Assert.assertNull(aggregationResult.getValue()); + List<GroupByResult> groupByResults = aggregationResult.getGroupByResult(); + numGroups = groupByResults.size(); + for (int i = 0; i < numGroups; i++) { + GroupByResult groupByResult = groupByResults.get(i); + List<String> group = groupByResult.getGroup(); + assertEquals(group.size(), 1); + assertTrue(_values.containsKey(Integer.parseInt(group.get(0)))); + assertEquals(Double.parseDouble(groupByResult.getValue().toString()), Double.parseDouble(group.get(0)), DELTA); + } + } + } + + @AfterClass + public void tearDown() + throws IOException { + _indexSegment.destroy(); + FileUtils.deleteDirectory(INDEX_DIR); + } +} diff --git a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java index 1683787..9d10e8a 100644 --- a/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java +++ b/pinot-segment-spi/src/main/java/org/apache/pinot/segment/spi/AggregationFunctionType.java @@ -32,6 +32,7 @@ public enum AggregationFunctionType { SUM("sum"), SUMPRECISION("sumPrecision"), AVG("avg"), + MODE("mode"), MINMAXRANGE("minMaxRange"), DISTINCTCOUNT("distinctCount"), DISTINCTCOUNTBITMAP("distinctCountBitmap"), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org