This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push: new c223dfc Support for exact distinct count for non int data types (#5872) c223dfc is described below commit c223dfcfbda1ebbfcab5588e44fdfee4edec450f Author: Kishore Gopalakrishna <g.kish...@gmail.com> AuthorDate: Thu Aug 20 13:07:05 2020 -0700 Support for exact distinct count for non int data types (#5872) ## Description Currently in `DistinctCount`, we use `IntOpenHashSet` to store distinct ids even for non int types. While this is efficient, the accuracy drops as the cardinality increase. This PR sets up the right `HashSet` based on column data type. ## Upgrade Notes Brokers should be upgraded before servers in order to keep backward-compatible ## Release Notes With this change, the `DistinctCount` aggregation function will always return the exact distinct count regardless of the column data type. It might bring performance overhead for data types other than `INT`. For use cases that is performance sensitive and not require the exact distinct count, use `DistinctCountBitmap` which has the same behavior as the current `DistinctCount` and better performance. Provide a new boolean Helix cluster config `enable.distinct.count.bitmap.override` to auto-rewrite `DistinctCount` to `DistinctCountBitmap` on broker. --- .../broker/broker/helix/HelixBrokerStarter.java | 10 +- .../requesthandler/BaseBrokerRequestHandler.java | 57 ++++- .../apache/pinot/common/utils/CommonConstants.java | 3 + .../apache/pinot/core/common/ObjectSerDeUtils.java | 230 ++++++++++++++++++- .../query/DictionaryBasedAggregationOperator.java | 32 ++- .../function/DistinctCountAggregationFunction.java | 249 ++++++++++++++++----- .../DistinctCountMVAggregationFunction.java | 82 ++++--- .../pinot/queries/DistinctCountQueriesTest.java | 56 +++-- 8 files changed, 594 insertions(+), 125 deletions(-) diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/HelixBrokerStarter.java b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/HelixBrokerStarter.java index 9ba8207..c0fdd8f 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/HelixBrokerStarter.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/broker/helix/HelixBrokerStarter.java @@ -192,9 +192,12 @@ public class HelixBrokerStarter implements ServiceStartable { new HelixConfigScopeBuilder(HelixConfigScope.ConfigScopeProperty.CLUSTER).forCluster(_clusterName).build(); Map<String, String> configMap = _helixAdmin.getConfig(helixConfigScope, Arrays .asList(Helix.ENABLE_CASE_INSENSITIVE_KEY, Helix.DEPRECATED_ENABLE_CASE_INSENSITIVE_KEY, - Broker.CONFIG_OF_ENABLE_QUERY_LIMIT_OVERRIDE, Helix.DEFAULT_HYPERLOGLOG_LOG2M_KEY)); + Broker.CONFIG_OF_ENABLE_QUERY_LIMIT_OVERRIDE, Helix.DEFAULT_HYPERLOGLOG_LOG2M_KEY, + Helix.ENABLE_DISTINCT_COUNT_BITMAP_OVERRIDE_KEY)); + boolean caseInsensitive = Boolean.parseBoolean(configMap.get(Helix.ENABLE_CASE_INSENSITIVE_KEY)) || Boolean .parseBoolean(configMap.get(Helix.DEPRECATED_ENABLE_CASE_INSENSITIVE_KEY)); + String log2mStr = configMap.get(Helix.DEFAULT_HYPERLOGLOG_LOG2M_KEY); if (log2mStr != null) { try { @@ -204,10 +207,15 @@ public class HelixBrokerStarter implements ServiceStartable { Helix.DEFAULT_HYPERLOGLOG_LOG2M_KEY, log2mStr, CommonConstants.Helix.DEFAULT_HYPERLOGLOG_LOG2M); } } + if (Boolean.parseBoolean(configMap.get(Broker.CONFIG_OF_ENABLE_QUERY_LIMIT_OVERRIDE))) { _brokerConf.setProperty(Broker.CONFIG_OF_ENABLE_QUERY_LIMIT_OVERRIDE, true); } + if (Boolean.parseBoolean(configMap.get(Helix.ENABLE_DISTINCT_COUNT_BITMAP_OVERRIDE_KEY))) { + _brokerConf.setProperty(Helix.ENABLE_DISTINCT_COUNT_BITMAP_OVERRIDE_KEY, true); + } + LOGGER.info("Setting up broker request handler"); // Set up metric registry and broker metrics _metricsRegistry = new MetricsRegistry(); diff --git a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java index b357449..792534d 100644 --- a/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java +++ b/pinot-broker/src/main/java/org/apache/pinot/broker/requesthandler/BaseBrokerRequestHandler.java @@ -114,8 +114,9 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler { private final RateLimiter _numDroppedLogRateLimiter; private final AtomicInteger _numDroppedLog; - private final boolean _enableQueryLimitOverride; private final int _defaultHllLog2m; + private final boolean _enableQueryLimitOverride; + private final boolean _enableDistinctCountBitmapOverride; public BaseBrokerRequestHandler(PinotConfiguration config, RoutingManager routingManager, AccessControlFactory accessControlFactory, QueryQuotaManager queryQuotaManager, TableCache tableCache, @@ -130,6 +131,8 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler { _defaultHllLog2m = _config.getProperty(CommonConstants.Helix.DEFAULT_HYPERLOGLOG_LOG2M_KEY, CommonConstants.Helix.DEFAULT_HYPERLOGLOG_LOG2M); _enableQueryLimitOverride = _config.getProperty(Broker.CONFIG_OF_ENABLE_QUERY_LIMIT_OVERRIDE, false); + _enableDistinctCountBitmapOverride = + _config.getProperty(CommonConstants.Helix.ENABLE_DISTINCT_COUNT_BITMAP_OVERRIDE_KEY, false); _brokerId = config.getProperty(Broker.CONFIG_OF_BROKER_ID, getDefaultBrokerId()); _brokerTimeoutMs = config.getProperty(Broker.CONFIG_OF_BROKER_TIMEOUT_MS, Broker.DEFAULT_BROKER_TIMEOUT_MS); @@ -205,6 +208,9 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler { if (_enableQueryLimitOverride) { handleQueryLimitOverride(brokerRequest, _queryResponseLimit); } + if (_enableDistinctCountBitmapOverride) { + handleDistinctCountBitmapOverride(brokerRequest); + } String tableName = brokerRequest.getQuerySource().getTableName(); String rawTableName = TableNameBuilder.extractRawTableName(tableName); requestStatistics.setTableName(rawTableName); @@ -569,6 +575,55 @@ public abstract class BaseBrokerRequestHandler implements BrokerRequestHandler { } /** + * Helper method to rewrite 'DistinctCount' with 'DistinctCountBitmap' for the given broker request. + */ + private static void handleDistinctCountBitmapOverride(BrokerRequest brokerRequest) { + List<AggregationInfo> aggregationsInfo = brokerRequest.getAggregationsInfo(); + if (aggregationsInfo != null) { + for (AggregationInfo aggregationInfo : aggregationsInfo) { + if (StringUtils.remove(aggregationInfo.getAggregationType(), '_') + .equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())) { + aggregationInfo.setAggregationType(AggregationFunctionType.DISTINCTCOUNTBITMAP.name()); + } + } + } + PinotQuery pinotQuery = brokerRequest.getPinotQuery(); + if (pinotQuery != null) { + for (Expression expression : pinotQuery.getSelectList()) { + handleDistinctCountBitmapOverride(expression); + } + List<Expression> orderByExpressions = pinotQuery.getOrderByList(); + if (orderByExpressions != null) { + for (Expression expression : orderByExpressions) { + handleDistinctCountBitmapOverride(expression); + } + } + Expression havingExpression = pinotQuery.getHavingExpression(); + if (havingExpression != null) { + handleDistinctCountBitmapOverride(havingExpression); + } + } + } + + /** + * Helper method to rewrite 'DistinctCount' with 'DistinctCountBitmap' for the given expression. + */ + private static void handleDistinctCountBitmapOverride(Expression expression) { + Function function = expression.getFunctionCall(); + if (function == null) { + return; + } + if (StringUtils.remove(function.getOperator(), '_') + .equalsIgnoreCase(AggregationFunctionType.DISTINCTCOUNT.name())) { + function.setOperator(AggregationFunctionType.DISTINCTCOUNTBITMAP.name()); + } else { + for (Expression operand : function.getOperands()) { + handleDistinctCountBitmapOverride(operand); + } + } + } + + /** * Check if a SQL parsed BrokerRequest is a literal only query. * @param brokerRequest * @return true if this query selects only Literals diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/CommonConstants.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/CommonConstants.java index 85113e7..e116849 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/CommonConstants.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/CommonConstants.java @@ -57,6 +57,9 @@ public class CommonConstants { public static final String DEFAULT_HYPERLOGLOG_LOG2M_KEY = "default.hyperloglog.log2m"; public static final int DEFAULT_HYPERLOGLOG_LOG2M = 8; + // Whether to rewrite DistinctCount to DistinctCountBitmap + public static final String ENABLE_DISTINCT_COUNT_BITMAP_OVERRIDE_KEY = "enable.distinct.count.bitmap.override"; + // More information on why these numbers are set can be found in the following doc: // https://cwiki.apache.org/confluence/display/PINOT/Controller+Separation+between+Helix+and+Pinot public static final int NUMBER_OF_PARTITIONS_IN_LEAD_CONTROLLER_RESOURCE = 24; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java index 9c87921..52d4c67 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/common/ObjectSerDeUtils.java @@ -23,9 +23,20 @@ import com.google.common.primitives.Longs; import com.tdunning.math.stats.MergingDigest; import com.tdunning.math.stats.TDigest; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.doubles.DoubleIterator; +import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; +import it.unimi.dsi.fastutil.doubles.DoubleSet; +import it.unimi.dsi.fastutil.floats.FloatIterator; +import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; +import it.unimi.dsi.fastutil.floats.FloatSet; import it.unimi.dsi.fastutil.ints.IntIterator; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; import it.unimi.dsi.fastutil.ints.IntSet; +import it.unimi.dsi.fastutil.longs.LongIterator; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.longs.LongSet; +import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; +import it.unimi.dsi.fastutil.objects.ObjectSet; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -33,6 +44,7 @@ import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; import org.apache.datasketches.memory.Memory; import org.apache.datasketches.theta.Sketch; import org.apache.pinot.common.utils.StringUtil; @@ -41,6 +53,8 @@ import org.apache.pinot.core.query.aggregation.function.customobject.AvgPair; import org.apache.pinot.core.query.aggregation.function.customobject.DistinctTable; import org.apache.pinot.core.query.aggregation.function.customobject.MinMaxRangePair; import org.apache.pinot.core.query.aggregation.function.customobject.QuantileDigest; +import org.apache.pinot.spi.utils.ByteArray; +import org.apache.pinot.spi.utils.StringUtils; import org.locationtech.jts.geom.Geometry; import org.roaringbitmap.RoaringBitmap; @@ -48,6 +62,7 @@ import org.roaringbitmap.RoaringBitmap; /** * The {@code ObjectSerDeUtils} class provides the utility methods to serialize/de-serialize objects. */ +@SuppressWarnings({"rawtypes", "unchecked"}) public class ObjectSerDeUtils { private ObjectSerDeUtils() { } @@ -68,7 +83,12 @@ public class ObjectSerDeUtils { DistinctTable(11), DataSketch(12), Geometry(13), - RoaringBitmap(14); + RoaringBitmap(14), + LongSet(15), + FloatSet(16), + DoubleSet(17), + StringSet(18), + BytesSet(19); private final int _value; @@ -111,6 +131,19 @@ public class ObjectSerDeUtils { return ObjectType.Geometry; } else if (value instanceof RoaringBitmap) { return ObjectType.RoaringBitmap; + } else if (value instanceof LongSet) { + return ObjectType.LongSet; + } else if (value instanceof FloatSet) { + return ObjectType.FloatSet; + } else if (value instanceof DoubleSet) { + return ObjectType.DoubleSet; + } else if (value instanceof ObjectSet) { + ObjectSet objectSet = (ObjectSet) value; + if (objectSet.isEmpty() || objectSet.iterator().next() instanceof String) { + return ObjectType.StringSet; + } else { + return ObjectType.BytesSet; + } } else { throw new IllegalArgumentException("Unsupported type of value: " + value.getClass().getSimpleName()); } @@ -390,14 +423,14 @@ public class ObjectSerDeUtils { } @Override - public Map<Object, Object> deserialize(byte[] bytes) { + public HashMap<Object, Object> deserialize(byte[] bytes) { return deserialize(ByteBuffer.wrap(bytes)); } @Override - public Map<Object, Object> deserialize(ByteBuffer byteBuffer) { + public HashMap<Object, Object> deserialize(ByteBuffer byteBuffer) { int size = byteBuffer.getInt(); - Map<Object, Object> map = new HashMap<>(size); + HashMap<Object, Object> map = new HashMap<>(size); if (size == 0) { return map; } @@ -437,14 +470,14 @@ public class ObjectSerDeUtils { } @Override - public IntSet deserialize(byte[] bytes) { + public IntOpenHashSet deserialize(byte[] bytes) { return deserialize(ByteBuffer.wrap(bytes)); } @Override - public IntSet deserialize(ByteBuffer byteBuffer) { + public IntOpenHashSet deserialize(ByteBuffer byteBuffer) { int size = byteBuffer.getInt(); - IntSet intSet = new IntOpenHashSet(size); + IntOpenHashSet intSet = new IntOpenHashSet(size); for (int i = 0; i < size; i++) { intSet.add(byteBuffer.getInt()); } @@ -452,6 +485,179 @@ public class ObjectSerDeUtils { } }; + public static final ObjectSerDe<LongSet> LONG_SET_SER_DE = new ObjectSerDe<LongSet>() { + + @Override + public byte[] serialize(LongSet longSet) { + int size = longSet.size(); + byte[] bytes = new byte[Integer.BYTES + size * Long.BYTES]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.putInt(size); + LongIterator iterator = longSet.iterator(); + while (iterator.hasNext()) { + byteBuffer.putLong(iterator.nextLong()); + } + return bytes; + } + + @Override + public LongOpenHashSet deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public LongOpenHashSet deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + LongOpenHashSet longSet = new LongOpenHashSet(size); + for (int i = 0; i < size; i++) { + longSet.add(byteBuffer.getLong()); + } + return longSet; + } + }; + + public static final ObjectSerDe<FloatSet> FLOAT_SET_SER_DE = new ObjectSerDe<FloatSet>() { + + @Override + public byte[] serialize(FloatSet floatSet) { + int size = floatSet.size(); + byte[] bytes = new byte[Integer.BYTES + size * Float.BYTES]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.putInt(size); + FloatIterator iterator = floatSet.iterator(); + while (iterator.hasNext()) { + byteBuffer.putFloat(iterator.nextFloat()); + } + return bytes; + } + + @Override + public FloatOpenHashSet deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public FloatOpenHashSet deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + FloatOpenHashSet floatSet = new FloatOpenHashSet(size); + for (int i = 0; i < size; i++) { + floatSet.add(byteBuffer.getFloat()); + } + return floatSet; + } + }; + + public static final ObjectSerDe<DoubleSet> DOUBLE_SET_SER_DE = new ObjectSerDe<DoubleSet>() { + + @Override + public byte[] serialize(DoubleSet doubleSet) { + int size = doubleSet.size(); + byte[] bytes = new byte[Integer.BYTES + size * Double.BYTES]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + byteBuffer.putInt(size); + DoubleIterator iterator = doubleSet.iterator(); + while (iterator.hasNext()) { + byteBuffer.putDouble(iterator.nextDouble()); + } + return bytes; + } + + @Override + public DoubleOpenHashSet deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public DoubleOpenHashSet deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + DoubleOpenHashSet doubleSet = new DoubleOpenHashSet(size); + for (int i = 0; i < size; i++) { + doubleSet.add(byteBuffer.getDouble()); + } + return doubleSet; + } + }; + + public static final ObjectSerDe<Set<String>> STRING_SET_SER_DE = new ObjectSerDe<Set<String>>() { + + @Override + public byte[] serialize(Set<String> stringSet) { + int size = stringSet.size(); + // NOTE: No need to close the ByteArrayOutputStream. + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream); + try { + dataOutputStream.writeInt(size); + for (String value : stringSet) { + byte[] bytes = StringUtils.encodeUtf8(value); + dataOutputStream.writeInt(bytes.length); + dataOutputStream.write(bytes); + } + } catch (IOException e) { + throw new RuntimeException("Caught exception while serializing Set<String>", e); + } + return byteArrayOutputStream.toByteArray(); + } + + @Override + public ObjectOpenHashSet<String> deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public ObjectOpenHashSet<String> deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + ObjectOpenHashSet<String> stringSet = new ObjectOpenHashSet<>(size); + for (int i = 0; i < size; i++) { + int length = byteBuffer.getInt(); + byte[] bytes = new byte[length]; + byteBuffer.get(bytes); + stringSet.add(StringUtils.decodeUtf8(bytes)); + } + return stringSet; + } + }; + + public static final ObjectSerDe<Set<ByteArray>> BYTES_SET_SER_DE = new ObjectSerDe<Set<ByteArray>>() { + + @Override + public byte[] serialize(Set<ByteArray> bytesSet) { + int size = bytesSet.size(); + // NOTE: No need to close the ByteArrayOutputStream. + ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream); + try { + dataOutputStream.writeInt(size); + for (ByteArray value : bytesSet) { + byte[] bytes = value.getBytes(); + dataOutputStream.writeInt(bytes.length); + dataOutputStream.write(bytes); + } + } catch (IOException e) { + throw new RuntimeException("Caught exception while serializing Set<ByteArray>", e); + } + return byteArrayOutputStream.toByteArray(); + } + + @Override + public ObjectOpenHashSet<ByteArray> deserialize(byte[] bytes) { + return deserialize(ByteBuffer.wrap(bytes)); + } + + @Override + public ObjectOpenHashSet<ByteArray> deserialize(ByteBuffer byteBuffer) { + int size = byteBuffer.getInt(); + ObjectOpenHashSet<ByteArray> bytesSet = new ObjectOpenHashSet<>(size); + for (int i = 0; i < size; i++) { + int length = byteBuffer.getInt(); + byte[] bytes = new byte[length]; + byteBuffer.get(bytes); + bytesSet.add(new ByteArray(bytes)); + } + return bytesSet; + } + }; + public static final ObjectSerDe<TDigest> TDIGEST_SER_DE = new ObjectSerDe<TDigest>() { @Override @@ -553,7 +759,12 @@ public class ObjectSerDeUtils { DISTINCT_TABLE_SER_DE, DATA_SKETCH_SER_DE, GEOMETRY_SER_DE, - ROARING_BITMAP_SER_DE + ROARING_BITMAP_SER_DE, + LONG_SET_SER_DE, + FLOAT_SET_SER_DE, + DOUBLE_SET_SER_DE, + STRING_SET_SER_DE, + BYTES_SET_SER_DE }; //@formatter:on @@ -565,7 +776,6 @@ public class ObjectSerDeUtils { return serialize(value, objectType._value); } - @SuppressWarnings("unchecked") public static byte[] serialize(Object value, int objectTypeValue) { return SER_DES[objectTypeValue].serialize(value); } @@ -574,7 +784,6 @@ public class ObjectSerDeUtils { return deserialize(bytes, objectType._value); } - @SuppressWarnings("unchecked") public static <T> T deserialize(byte[] bytes, int objectTypeValue) { return (T) SER_DES[objectTypeValue].deserialize(bytes); } @@ -583,7 +792,6 @@ public class ObjectSerDeUtils { return deserialize(byteBuffer, objectType._value); } - @SuppressWarnings("unchecked") public static <T> T deserialize(ByteBuffer byteBuffer, int objectTypeValue) { return (T) SER_DES[objectTypeValue].deserialize(byteBuffer); } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java index 7fa6798..8dacb84 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/DictionaryBasedAggregationOperator.java @@ -18,9 +18,12 @@ */ package org.apache.pinot.core.operator.query; +import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; +import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.pinot.core.operator.BaseOperator; @@ -30,6 +33,7 @@ import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.aggregation.function.customobject.MinMaxRangePair; import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.segment.index.readers.Dictionary; +import org.apache.pinot.spi.utils.ByteArray; /** @@ -77,42 +81,52 @@ public class DictionaryBasedAggregationOperator extends BaseOperator<Intermediat .add(new MinMaxRangePair(dictionary.getDoubleValue(0), dictionary.getDoubleValue(dictionarySize - 1))); break; case DISTINCTCOUNT: - IntOpenHashSet set = new IntOpenHashSet(dictionarySize); switch (dictionary.getValueType()) { case INT: + IntOpenHashSet intSet = new IntOpenHashSet(dictionarySize); for (int dictId = 0; dictId < dictionarySize; dictId++) { - set.add(dictionary.getIntValue(dictId)); + intSet.add(dictionary.getIntValue(dictId)); } + aggregationResults.add(intSet); break; case LONG: + LongOpenHashSet longSet = new LongOpenHashSet(dictionarySize); for (int dictId = 0; dictId < dictionarySize; dictId++) { - set.add(Long.hashCode(dictionary.getLongValue(dictId))); + longSet.add(dictionary.getLongValue(dictId)); } + aggregationResults.add(longSet); break; case FLOAT: + FloatOpenHashSet floatSet = new FloatOpenHashSet(dictionarySize); for (int dictId = 0; dictId < dictionarySize; dictId++) { - set.add(Float.hashCode(dictionary.getFloatValue(dictId))); + floatSet.add(dictionary.getFloatValue(dictId)); } + aggregationResults.add(floatSet); break; case DOUBLE: + DoubleOpenHashSet doubleSet = new DoubleOpenHashSet(dictionarySize); for (int dictId = 0; dictId < dictionarySize; dictId++) { - set.add(Double.hashCode(dictionary.getDoubleValue(dictId))); + doubleSet.add(dictionary.getDoubleValue(dictId)); } + aggregationResults.add(doubleSet); break; case STRING: + ObjectOpenHashSet<String> stringSet = new ObjectOpenHashSet<>(dictionarySize); for (int dictId = 0; dictId < dictionarySize; dictId++) { - set.add(dictionary.getStringValue(dictId).hashCode()); + stringSet.add(dictionary.getStringValue(dictId)); } + aggregationResults.add(stringSet); break; case BYTES: + ObjectOpenHashSet<ByteArray> bytesSet = new ObjectOpenHashSet<>(dictionarySize); for (int dictId = 0; dictId < dictionarySize; dictId++) { - set.add(Arrays.hashCode(dictionary.getBytesValue(dictId))); + bytesSet.add(new ByteArray(dictionary.getBytesValue(dictId))); } + aggregationResults.add(bytesSet); break; default: throw new IllegalStateException(); } - aggregationResults.add(set); break; case SEGMENTPARTITIONEDDISTINCTCOUNT: aggregationResults.add((long) dictionarySize); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java index e8e7e97..368e587 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountAggregationFunction.java @@ -18,9 +18,13 @@ */ package org.apache.pinot.core.query.aggregation.function; +import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; +import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; -import java.util.Arrays; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; import java.util.Map; +import java.util.Set; import org.apache.pinot.common.function.AggregationFunctionType; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.core.common.BlockValSet; @@ -31,11 +35,13 @@ import org.apache.pinot.core.query.aggregation.groupby.ObjectGroupByResultHolder import org.apache.pinot.core.query.request.context.ExpressionContext; import org.apache.pinot.core.segment.index.readers.Dictionary; import org.apache.pinot.spi.data.FieldSpec.DataType; +import org.apache.pinot.spi.utils.ByteArray; import org.roaringbitmap.PeekableIntIterator; import org.roaringbitmap.RoaringBitmap; -public class DistinctCountAggregationFunction extends BaseSingleInputAggregationFunction<IntOpenHashSet, Integer> { +@SuppressWarnings({"rawtypes", "unchecked"}) +public class DistinctCountAggregationFunction extends BaseSingleInputAggregationFunction<Set, Integer> { public DistinctCountAggregationFunction(ExpressionContext expression) { super(expression); @@ -69,44 +75,51 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation return; } - // For non-dictionary-encoded expression, store hash code of the values into the value set - IntOpenHashSet valueSet = getValueSet(aggregationResultHolder); + // For non-dictionary-encoded expression, store values into the value set DataType valueType = blockValSet.getValueType(); + Set valueSet = getValueSet(aggregationResultHolder, valueType); switch (valueType) { case INT: + IntOpenHashSet intSet = (IntOpenHashSet) valueSet; int[] intValues = blockValSet.getIntValuesSV(); for (int i = 0; i < length; i++) { - valueSet.add(intValues[i]); + intSet.add(intValues[i]); } break; case LONG: + LongOpenHashSet longSet = (LongOpenHashSet) valueSet; long[] longValues = blockValSet.getLongValuesSV(); for (int i = 0; i < length; i++) { - valueSet.add(Long.hashCode(longValues[i])); + longSet.add(longValues[i]); } break; case FLOAT: + FloatOpenHashSet floatSet = (FloatOpenHashSet) valueSet; float[] floatValues = blockValSet.getFloatValuesSV(); for (int i = 0; i < length; i++) { - valueSet.add(Float.hashCode(floatValues[i])); + floatSet.add(floatValues[i]); } break; case DOUBLE: + DoubleOpenHashSet doubleSet = (DoubleOpenHashSet) valueSet; double[] doubleValues = blockValSet.getDoubleValuesSV(); for (int i = 0; i < length; i++) { - valueSet.add(Double.hashCode(doubleValues[i])); + doubleSet.add(doubleValues[i]); } break; case STRING: + ObjectOpenHashSet<String> stringSet = (ObjectOpenHashSet<String>) valueSet; String[] stringValues = blockValSet.getStringValuesSV(); + //noinspection ManualArrayToCollectionCopy for (int i = 0; i < length; i++) { - valueSet.add(stringValues[i].hashCode()); + stringSet.add(stringValues[i]); } break; case BYTES: + ObjectOpenHashSet<ByteArray> bytesSet = (ObjectOpenHashSet<ByteArray>) valueSet; byte[][] bytesValues = blockValSet.getBytesValuesSV(); for (int i = 0; i < length; i++) { - valueSet.add(Arrays.hashCode(bytesValues[i])); + bytesSet.add(new ByteArray(bytesValues[i])); } break; default: @@ -129,43 +142,46 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation return; } - // For non-dictionary-encoded expression, store hash code of the values into the value set + // For non-dictionary-encoded expression, store values into the value set DataType valueType = blockValSet.getValueType(); switch (valueType) { case INT: int[] intValues = blockValSet.getIntValuesSV(); for (int i = 0; i < length; i++) { - getValueSet(groupByResultHolder, groupKeyArray[i]).add(intValues[i]); + ((IntOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.INT)).add(intValues[i]); } break; case LONG: long[] longValues = blockValSet.getLongValuesSV(); for (int i = 0; i < length; i++) { - getValueSet(groupByResultHolder, groupKeyArray[i]).add(Long.hashCode(longValues[i])); + ((LongOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.LONG)).add(longValues[i]); } break; case FLOAT: float[] floatValues = blockValSet.getFloatValuesSV(); for (int i = 0; i < length; i++) { - getValueSet(groupByResultHolder, groupKeyArray[i]).add(Float.hashCode(floatValues[i])); + ((FloatOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.FLOAT)).add(floatValues[i]); } break; case DOUBLE: double[] doubleValues = blockValSet.getDoubleValuesSV(); for (int i = 0; i < length; i++) { - getValueSet(groupByResultHolder, groupKeyArray[i]).add(Double.hashCode(doubleValues[i])); + ((DoubleOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.DOUBLE)) + .add(doubleValues[i]); } break; case STRING: String[] stringValues = blockValSet.getStringValuesSV(); for (int i = 0; i < length; i++) { - getValueSet(groupByResultHolder, groupKeyArray[i]).add(stringValues[i].hashCode()); + ((ObjectOpenHashSet<String>) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.STRING)) + .add(stringValues[i]); } break; case BYTES: byte[][] bytesValues = blockValSet.getBytesValuesSV(); for (int i = 0; i < length; i++) { - getValueSet(groupByResultHolder, groupKeyArray[i]).add(Arrays.hashCode(bytesValues[i])); + ((ObjectOpenHashSet<ByteArray>) getValueSet(groupByResultHolder, groupKeyArray[i], DataType.BYTES)) + .add(new ByteArray(bytesValues[i])); } break; default: @@ -188,7 +204,7 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation return; } - // For non-dictionary-encoded expression, store hash code of the values into the value set + // For non-dictionary-encoded expression, store values into the value set DataType valueType = blockValSet.getValueType(); switch (valueType) { case INT: @@ -200,31 +216,31 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation case LONG: long[] longValues = blockValSet.getLongValuesSV(); for (int i = 0; i < length; i++) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], Long.hashCode(longValues[i])); + setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], longValues[i]); } break; case FLOAT: float[] floatValues = blockValSet.getFloatValuesSV(); for (int i = 0; i < length; i++) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], Float.hashCode(floatValues[i])); + setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], floatValues[i]); } break; case DOUBLE: double[] doubleValues = blockValSet.getDoubleValuesSV(); for (int i = 0; i < length; i++) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], Double.hashCode(doubleValues[i])); + setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], doubleValues[i]); } break; case STRING: String[] stringValues = blockValSet.getStringValuesSV(); for (int i = 0; i < length; i++) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], stringValues[i].hashCode()); + setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], stringValues[i]); } break; case BYTES: byte[][] bytesValues = blockValSet.getBytesValuesSV(); for (int i = 0; i < length; i++) { - setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], Arrays.hashCode(bytesValues[i])); + setValueForGroupKeys(groupByResultHolder, groupKeysArray[i], new ByteArray(bytesValues[i])); } break; default: @@ -233,41 +249,88 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } @Override - public IntOpenHashSet extractAggregationResult(AggregationResultHolder aggregationResultHolder) { + public Set extractAggregationResult(AggregationResultHolder aggregationResultHolder) { Object result = aggregationResultHolder.getResult(); if (result == null) { + // Use empty IntOpenHashSet as a place holder for empty result return new IntOpenHashSet(); } if (result instanceof DictIdsWrapper) { - // For dictionary-encoded expression, convert dictionary ids to hash code of the values + // For dictionary-encoded expression, convert dictionary ids to values return convertToValueSet((DictIdsWrapper) result); } else { // For non-dictionary-encoded expression, directly return the value set - return (IntOpenHashSet) result; + return (Set) result; } } @Override - public IntOpenHashSet extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { + public Set extractGroupByResult(GroupByResultHolder groupByResultHolder, int groupKey) { Object result = groupByResultHolder.getResult(groupKey); if (result == null) { + // NOTE: Return an empty IntOpenHashSet for empty result. return new IntOpenHashSet(); } if (result instanceof DictIdsWrapper) { - // For dictionary-encoded expression, convert dictionary ids to hash code of the values + // For dictionary-encoded expression, convert dictionary ids to values return convertToValueSet((DictIdsWrapper) result); } else { // For non-dictionary-encoded expression, directly return the value set - return (IntOpenHashSet) result; + return (Set) result; } } @Override - public IntOpenHashSet merge(IntOpenHashSet intermediateResult1, IntOpenHashSet intermediateResult2) { - intermediateResult1.addAll(intermediateResult2); - return intermediateResult1; + public Set merge(Set intermediateResult1, Set intermediateResult2) { + if (intermediateResult1.isEmpty()) { + return intermediateResult2; + } + if (intermediateResult2.isEmpty()) { + return intermediateResult1; + } + if (intermediateResult1.getClass() == intermediateResult2.getClass()) { + // Both results are of the same type, directly merge + intermediateResult1.addAll(intermediateResult2); + return intermediateResult1; + } else { + // TODO: Remove this part after releasing 0.5.0 + // The results are not of the same type. This can happen when servers are getting upgraded, and some servers are + // still running the old code and store hash codes in the set. For backward-compatibility, we convert the values + // into hash codes and insert them into the hash code set. + IntOpenHashSet hashCodeSet; + Set valueSet; + if (intermediateResult1 instanceof IntOpenHashSet) { + hashCodeSet = (IntOpenHashSet) intermediateResult1; + valueSet = intermediateResult2; + } else { + hashCodeSet = (IntOpenHashSet) intermediateResult2; + valueSet = intermediateResult1; + } + if (valueSet instanceof LongOpenHashSet) { + LongOpenHashSet longSet = (LongOpenHashSet) valueSet; + for (long value : longSet) { + hashCodeSet.add(Long.hashCode(value)); + } + } else if (valueSet instanceof FloatOpenHashSet) { + FloatOpenHashSet floatSet = (FloatOpenHashSet) valueSet; + for (float value : floatSet) { + hashCodeSet.add(Float.hashCode(value)); + } + } else if (valueSet instanceof DoubleOpenHashSet) { + DoubleOpenHashSet doubleSet = (DoubleOpenHashSet) valueSet; + for (double value : doubleSet) { + hashCodeSet.add(Double.hashCode(value)); + } + } else { + // STRING and BYTES + for (Object value : valueSet) { + hashCodeSet.add(value.hashCode()); + } + } + return hashCodeSet; + } } @Override @@ -286,7 +349,7 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } @Override - public Integer extractFinalResult(IntOpenHashSet intermediateResult) { + public Integer extractFinalResult(Set intermediateResult) { return intermediateResult.size(); } @@ -306,16 +369,37 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation /** * Returns the value set from the result holder or creates a new one if it does not exist. */ - protected static IntOpenHashSet getValueSet(AggregationResultHolder aggregationResultHolder) { - IntOpenHashSet valueSet = aggregationResultHolder.getResult(); + protected static Set getValueSet(AggregationResultHolder aggregationResultHolder, DataType valueType) { + Set valueSet = aggregationResultHolder.getResult(); if (valueSet == null) { - valueSet = new IntOpenHashSet(); + valueSet = getValueSet(valueType); aggregationResultHolder.setValue(valueSet); } return valueSet; } /** + * Helper method to create a value set for the given value type. + */ + private static Set getValueSet(DataType valueType) { + switch (valueType) { + case INT: + return new IntOpenHashSet(); + case LONG: + return new LongOpenHashSet(); + case FLOAT: + return new FloatOpenHashSet(); + case DOUBLE: + return new DoubleOpenHashSet(); + case STRING: + case BYTES: + return new ObjectOpenHashSet(); + default: + throw new IllegalStateException("Illegal data type for DISTINCT_COUNT aggregation function: " + valueType); + } + } + + /** * Returns the dictionary id bitmap for the given group key or creates a new one if it does not exist. */ protected static RoaringBitmap getDictIdBitmap(GroupByResultHolder groupByResultHolder, int groupKey, @@ -331,10 +415,10 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation /** * Returns the value set for the given group key or creates a new one if it does not exist. */ - protected static IntOpenHashSet getValueSet(GroupByResultHolder groupByResultHolder, int groupKey) { - IntOpenHashSet valueSet = groupByResultHolder.getResult(groupKey); + protected static Set getValueSet(GroupByResultHolder groupByResultHolder, int groupKey, DataType valueType) { + Set valueSet = groupByResultHolder.getResult(groupKey); if (valueSet == null) { - valueSet = new IntOpenHashSet(); + valueSet = getValueSet(valueType); groupByResultHolder.setValueForKey(groupKey, valueSet); } return valueSet; @@ -351,59 +435,108 @@ public class DistinctCountAggregationFunction extends BaseSingleInputAggregation } /** - * Helper method to set value for the given group keys into the result holder. + * Helper method to set INT value for the given group keys into the result holder. */ private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, int value) { for (int groupKey : groupKeys) { - getValueSet(groupByResultHolder, groupKey).add(value); + ((IntOpenHashSet) getValueSet(groupByResultHolder, groupKey, DataType.INT)).add(value); } } /** - * Helper method to read dictionary and convert dictionary ids to hash code of the values for dictionary-encoded - * expression. + * Helper method to set LONG value for the given group keys into the result holder. */ - private static IntOpenHashSet convertToValueSet(DictIdsWrapper dictIdsWrapper) { + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, long value) { + for (int groupKey : groupKeys) { + ((LongOpenHashSet) getValueSet(groupByResultHolder, groupKey, DataType.LONG)).add(value); + } + } + + /** + * Helper method to set FLOAT value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, float value) { + for (int groupKey : groupKeys) { + ((FloatOpenHashSet) getValueSet(groupByResultHolder, groupKey, DataType.FLOAT)).add(value); + } + } + + /** + * Helper method to set DOUBLE value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, double value) { + for (int groupKey : groupKeys) { + ((DoubleOpenHashSet) getValueSet(groupByResultHolder, groupKey, DataType.DOUBLE)).add(value); + } + } + + /** + * Helper method to set STRING value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, String value) { + for (int groupKey : groupKeys) { + ((ObjectOpenHashSet<String>) getValueSet(groupByResultHolder, groupKey, DataType.STRING)).add(value); + } + } + + /** + * Helper method to set BYTES value for the given group keys into the result holder. + */ + private static void setValueForGroupKeys(GroupByResultHolder groupByResultHolder, int[] groupKeys, ByteArray value) { + for (int groupKey : groupKeys) { + ((ObjectOpenHashSet<ByteArray>) getValueSet(groupByResultHolder, groupKey, DataType.BYTES)).add(value); + } + } + + /** + * Helper method to read dictionary and convert dictionary ids to values for dictionary-encoded expression. + */ + private static Set convertToValueSet(DictIdsWrapper dictIdsWrapper) { Dictionary dictionary = dictIdsWrapper._dictionary; RoaringBitmap dictIdBitmap = dictIdsWrapper._dictIdBitmap; - IntOpenHashSet valueSet = new IntOpenHashSet(dictIdBitmap.getCardinality()); + int numValues = dictIdBitmap.getCardinality(); PeekableIntIterator iterator = dictIdBitmap.getIntIterator(); DataType valueType = dictionary.getValueType(); switch (valueType) { case INT: + IntOpenHashSet intSet = new IntOpenHashSet(numValues); while (iterator.hasNext()) { - valueSet.add(dictionary.getIntValue(iterator.next())); + intSet.add(dictionary.getIntValue(iterator.next())); } - break; + return intSet; case LONG: + LongOpenHashSet longSet = new LongOpenHashSet(numValues); while (iterator.hasNext()) { - valueSet.add(Long.hashCode(dictionary.getLongValue(iterator.next()))); + longSet.add(dictionary.getLongValue(iterator.next())); } - break; + return longSet; case FLOAT: + FloatOpenHashSet floatSet = new FloatOpenHashSet(numValues); while (iterator.hasNext()) { - valueSet.add(Float.hashCode(dictionary.getFloatValue(iterator.next()))); + floatSet.add(dictionary.getFloatValue(iterator.next())); } - break; + return floatSet; case DOUBLE: + DoubleOpenHashSet doubleSet = new DoubleOpenHashSet(numValues); while (iterator.hasNext()) { - valueSet.add(Double.hashCode(dictionary.getDoubleValue(iterator.next()))); + doubleSet.add(dictionary.getDoubleValue(iterator.next())); } - break; + return doubleSet; case STRING: + ObjectOpenHashSet<String> stringSet = new ObjectOpenHashSet<>(numValues); while (iterator.hasNext()) { - valueSet.add(dictionary.getStringValue(iterator.next()).hashCode()); + stringSet.add(dictionary.getStringValue(iterator.next())); } - break; + return stringSet; case BYTES: + ObjectOpenHashSet<ByteArray> bytesSet = new ObjectOpenHashSet<>(numValues); while (iterator.hasNext()) { - valueSet.add(Arrays.hashCode(dictionary.getBytesValue(iterator.next()))); + bytesSet.add(new ByteArray(dictionary.getBytesValue(iterator.next()))); } - break; + return bytesSet; default: throw new IllegalStateException("Illegal data type for DISTINCT_COUNT aggregation function: " + valueType); } - return valueSet; } private static final class DictIdsWrapper { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java index fb6b2e3..cfd1ea1 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/aggregation/function/DistinctCountMVAggregationFunction.java @@ -18,8 +18,13 @@ */ package org.apache.pinot.core.query.aggregation.function; +import it.unimi.dsi.fastutil.doubles.DoubleOpenHashSet; +import it.unimi.dsi.fastutil.floats.FloatOpenHashSet; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.longs.LongOpenHashSet; +import it.unimi.dsi.fastutil.objects.ObjectOpenHashSet; import java.util.Map; +import java.util.Set; import org.apache.pinot.common.function.AggregationFunctionType; import org.apache.pinot.core.common.BlockValSet; import org.apache.pinot.core.query.aggregation.AggregationResultHolder; @@ -30,6 +35,7 @@ import org.apache.pinot.spi.data.FieldSpec; import org.roaringbitmap.RoaringBitmap; +@SuppressWarnings({"rawtypes", "unchecked"}) public class DistinctCountMVAggregationFunction extends DistinctCountAggregationFunction { public DistinctCountMVAggregationFunction(ExpressionContext expression) { @@ -57,46 +63,53 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation return; } - // For non-dictionary-encoded expression, store hash code of the values into the value set - IntOpenHashSet valueSet = getValueSet(aggregationResultHolder); + // For non-dictionary-encoded expression, store values into the value set FieldSpec.DataType valueType = blockValSet.getValueType(); + Set valueSet = getValueSet(aggregationResultHolder, valueType); switch (valueType) { case INT: + IntOpenHashSet intSet = (IntOpenHashSet) valueSet; int[][] intValues = blockValSet.getIntValuesMV(); for (int i = 0; i < length; i++) { for (int value : intValues[i]) { - valueSet.add(value); + intSet.add(value); } } break; case LONG: + LongOpenHashSet longSet = (LongOpenHashSet) valueSet; long[][] longValues = blockValSet.getLongValuesMV(); for (int i = 0; i < length; i++) { for (long value : longValues[i]) { - valueSet.add(Long.hashCode(value)); + longSet.add(value); } } break; case FLOAT: + FloatOpenHashSet floatSet = (FloatOpenHashSet) valueSet; float[][] floatValues = blockValSet.getFloatValuesMV(); for (int i = 0; i < length; i++) { for (float value : floatValues[i]) { - valueSet.add(Float.hashCode(value)); + floatSet.add(value); } } case DOUBLE: + DoubleOpenHashSet doubleSet = (DoubleOpenHashSet) valueSet; double[][] doubleValues = blockValSet.getDoubleValuesMV(); for (int i = 0; i < length; i++) { for (double value : doubleValues[i]) { - valueSet.add(Double.hashCode(value)); + doubleSet.add(value); } } break; case STRING: + ObjectOpenHashSet<String> stringSet = (ObjectOpenHashSet<String>) valueSet; String[][] stringValues = blockValSet.getStringValuesMV(); for (int i = 0; i < length; i++) { + //noinspection ManualArrayToCollectionCopy for (String value : stringValues[i]) { - valueSet.add(value.hashCode()); + //noinspection UseBulkOperation + stringSet.add(value); } } break; @@ -120,51 +133,58 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation return; } - // For non-dictionary-encoded expression, store hash code of the values into the value set + // For non-dictionary-encoded expression, store values into the value set FieldSpec.DataType valueType = blockValSet.getValueType(); switch (valueType) { case INT: int[][] intValues = blockValSet.getIntValuesMV(); for (int i = 0; i < length; i++) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKeyArray[i]); + IntOpenHashSet intSet = + (IntOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], FieldSpec.DataType.INT); for (int value : intValues[i]) { - valueSet.add(value); + intSet.add(value); } } break; case LONG: long[][] longValues = blockValSet.getLongValuesMV(); for (int i = 0; i < length; i++) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKeyArray[i]); + LongOpenHashSet longSet = + (LongOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], FieldSpec.DataType.LONG); for (long value : longValues[i]) { - valueSet.add(Long.hashCode(value)); + longSet.add(value); } } break; case FLOAT: float[][] floatValues = blockValSet.getFloatValuesMV(); for (int i = 0; i < length; i++) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKeyArray[i]); + FloatOpenHashSet floatSet = + (FloatOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], FieldSpec.DataType.FLOAT); for (float value : floatValues[i]) { - valueSet.add(Float.hashCode(value)); + floatSet.add(value); } } break; case DOUBLE: double[][] doubleValues = blockValSet.getDoubleValuesMV(); for (int i = 0; i < length; i++) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKeyArray[i]); + DoubleOpenHashSet doubleSet = + (DoubleOpenHashSet) getValueSet(groupByResultHolder, groupKeyArray[i], FieldSpec.DataType.DOUBLE); for (double value : doubleValues[i]) { - valueSet.add(Double.hashCode(value)); + doubleSet.add(value); } } break; case STRING: String[][] stringValues = blockValSet.getStringValuesMV(); for (int i = 0; i < length; i++) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKeyArray[i]); + ObjectOpenHashSet<String> stringSet = + (ObjectOpenHashSet<String>) getValueSet(groupByResultHolder, groupKeyArray[i], FieldSpec.DataType.STRING); + //noinspection ManualArrayToCollectionCopy for (String value : stringValues[i]) { - valueSet.add(value.hashCode()); + //noinspection UseBulkOperation + stringSet.add(value); } } break; @@ -197,9 +217,9 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation int[][] intValues = blockValSet.getIntValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKey); + IntOpenHashSet intSet = (IntOpenHashSet) getValueSet(groupByResultHolder, groupKey, FieldSpec.DataType.INT); for (int value : intValues[i]) { - valueSet.add(value); + intSet.add(value); } } } @@ -208,9 +228,10 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation long[][] longValues = blockValSet.getLongValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKey); + LongOpenHashSet longSet = + (LongOpenHashSet) getValueSet(groupByResultHolder, groupKey, FieldSpec.DataType.LONG); for (long value : longValues[i]) { - valueSet.add(Long.hashCode(value)); + longSet.add(value); } } } @@ -219,9 +240,10 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation float[][] floatValues = blockValSet.getFloatValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKey); + FloatOpenHashSet floatSet = + (FloatOpenHashSet) getValueSet(groupByResultHolder, groupKey, FieldSpec.DataType.FLOAT); for (float value : floatValues[i]) { - valueSet.add(Float.hashCode(value)); + floatSet.add(value); } } } @@ -230,9 +252,10 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation double[][] doubleValues = blockValSet.getDoubleValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKey); + DoubleOpenHashSet doubleSet = + (DoubleOpenHashSet) getValueSet(groupByResultHolder, groupKey, FieldSpec.DataType.DOUBLE); for (double value : doubleValues[i]) { - valueSet.add(Double.hashCode(value)); + doubleSet.add(value); } } } @@ -241,9 +264,12 @@ public class DistinctCountMVAggregationFunction extends DistinctCountAggregation String[][] stringValues = blockValSet.getStringValuesMV(); for (int i = 0; i < length; i++) { for (int groupKey : groupKeysArray[i]) { - IntOpenHashSet valueSet = getValueSet(groupByResultHolder, groupKey); + ObjectOpenHashSet<String> stringSet = + (ObjectOpenHashSet<String>) getValueSet(groupByResultHolder, groupKey, FieldSpec.DataType.STRING); + //noinspection ManualArrayToCollectionCopy for (String value : stringValues[i]) { - valueSet.add(value.hashCode()); + //noinspection UseBulkOperation + stringSet.add(value); } } } diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java index 0a9c105..3bd5d82 100644 --- a/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/queries/DistinctCountQueriesTest.java @@ -91,7 +91,6 @@ public class DistinctCountQueriesTest extends BaseQueriesTest { new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build(); private Set<Integer> _values; - private int[] _expectedResults; private IndexSegment _indexSegment; private List<IndexSegment> _indexSegments; @@ -119,33 +118,56 @@ public class DistinctCountQueriesTest extends BaseQueriesTest { List<GenericRow> records = new ArrayList<>(NUM_RECORDS); int hashMapCapacity = HashUtil.getHashMapCapacity(MAX_VALUE); _values = new HashSet<>(hashMapCapacity); - Set<Integer> longResultSet = new HashSet<>(hashMapCapacity); - Set<Integer> floatResultSet = new HashSet<>(hashMapCapacity); - Set<Integer> doubleResultSet = new HashSet<>(hashMapCapacity); - Set<Integer> stringResultSet = new HashSet<>(hashMapCapacity); - Set<Integer> bytesResultSet = new HashSet<>(hashMapCapacity); - for (int i = 0; i < NUM_RECORDS; i++) { + for (int i = 0; i < NUM_RECORDS - 2; i++) { int value = RANDOM.nextInt(MAX_VALUE); + _values.add(value); GenericRow record = new GenericRow(); record.putValue(INT_COLUMN, value); - _values.add(Integer.hashCode(value)); record.putValue(LONG_COLUMN, (long) value); - longResultSet.add(Long.hashCode(value)); record.putValue(FLOAT_COLUMN, (float) value); - floatResultSet.add(Float.hashCode(value)); record.putValue(DOUBLE_COLUMN, (double) value); - doubleResultSet.add(Double.hashCode(value)); String stringValue = Integer.toString(value); record.putValue(STRING_COLUMN, stringValue); - stringResultSet.add(stringValue.hashCode()); // NOTE: Create fixed-length bytes so that dictionary can be generated byte[] bytesValue = StringUtil.encodeUtf8(StringUtils.leftPad(stringValue, 3, '0')); record.putValue(BYTES_COLUMN, bytesValue); - bytesResultSet.add(Arrays.hashCode(bytesValue)); records.add(record); } - _expectedResults = - new int[]{_values.size(), longResultSet.size(), floatResultSet.size(), doubleResultSet.size(), stringResultSet.size(), bytesResultSet.size()}; + + // Intentionally put 2 extra records with hash collision values (except for INT_COLUMN and FLOAT_COLUMN which are + // impossible for hash collision) + long long1 = 0xFFFFFFFFL; + long long2 = 0xF00000000FFFFFFFL; + assertEquals(Long.hashCode(long1), Long.hashCode(long2)); + double double1 = Double.longBitsToDouble(long1); + double double2 = Double.longBitsToDouble(long2); + assertEquals(Double.hashCode(double1), Double.hashCode(double2)); + String string1 = new String(new char[]{32}); + String string2 = new String(new char[]{1, 1}); + assertEquals(string1.hashCode(), string2.hashCode()); + byte[] bytes1 = {0, 1, 1}; + byte[] bytes2 = {0, 0, 32}; + assertEquals(Arrays.hashCode(bytes1), Arrays.hashCode(bytes2)); + + _values.add(MAX_VALUE); + GenericRow record1 = new GenericRow(); + record1.putValue(INT_COLUMN, MAX_VALUE); + record1.putValue(LONG_COLUMN, long1); + record1.putValue(FLOAT_COLUMN, (float) MAX_VALUE); + record1.putValue(DOUBLE_COLUMN, double1); + record1.putValue(STRING_COLUMN, string1); + record1.putValue(BYTES_COLUMN, bytes1); + records.add(record1); + + _values.add(MAX_VALUE + 1); + GenericRow record2 = new GenericRow(); + record2.putValue(INT_COLUMN, MAX_VALUE + 1); + record2.putValue(LONG_COLUMN, 0xF00000000FFFFFFFL); + record2.putValue(FLOAT_COLUMN, (float) (MAX_VALUE + 1)); + record2.putValue(DOUBLE_COLUMN, double2); + record2.putValue(STRING_COLUMN, string2); + record2.putValue(BYTES_COLUMN, bytes2); + records.add(record2); SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA); segmentGeneratorConfig.setTableName(RAW_TABLE_NAME); @@ -186,13 +208,13 @@ public class DistinctCountQueriesTest extends BaseQueriesTest { assertNotNull(aggregationResultWithFilter); assertEquals(aggregationResult, aggregationResultWithFilter); for (int i = 0; i < 6; i++) { - assertEquals(((Set<Integer>) aggregationResult.get(i)).size(), _expectedResults[i]); + assertEquals(((Set) aggregationResult.get(i)).size(), _values.size()); } // Inter segments String[] expectedResults = new String[6]; for (int i = 0; i < 6; i++) { - expectedResults[i] = Integer.toString(_expectedResults[i]); + expectedResults[i] = Integer.toString(_values.size()); } BrokerResponseNative brokerResponse = getBrokerResponseForPqlQuery(query); QueriesTestUtils --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org