This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push: new fcaebab69a [Multi-stage] Support null in aggregate and filter (#10799) fcaebab69a is described below commit fcaebab69a8a9d129cfae0130ab67c88815630bf Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Thu May 25 10:17:48 2023 -0700 [Multi-stage] Support null in aggregate and filter (#10799) --- .../pinot/query/planner/logical/RexExpression.java | 18 +++-- .../partitioning/FieldSelectionKeySelector.java | 5 +- .../query/runtime/operator/AggregateOperator.java | 91 +++++++++++++++------- .../runtime/operator/WindowAggregateOperator.java | 10 +-- .../runtime/operator/operands/FilterOperand.java | 45 ++++++++--- .../operator/operands/TransformOperand.java | 77 ++---------------- .../runtime/operator/utils/AggregationUtils.java | 84 ++++++++++++++------ .../operator/utils/FunctionInvokeUtils.java | 6 +- .../runtime/operator/AggregateOperatorTest.java | 9 ++- .../operator/WindowAggregateOperatorTest.java | 4 +- .../src/test/resources/queries/NullHandling.json | 51 ++++++++++++ .../test/resources/queries/WindowFunctions.json | 16 ++-- 12 files changed, 255 insertions(+), 161 deletions(-) diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java index 9b879ab779..ab78924548 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java @@ -82,19 +82,23 @@ public interface RexExpression { operands); } - static Object toRexValue(FieldSpec.DataType dataType, Comparable value) { + @Nullable + static Object toRexValue(FieldSpec.DataType dataType, @Nullable Comparable<?> value) { + if (value == null) { + return null; + } switch (dataType) { case INT: - return value == null ? 0 : ((BigDecimal) value).intValue(); + return ((BigDecimal) value).intValue(); case LONG: - return value == null ? 0L : ((BigDecimal) value).longValue(); + return ((BigDecimal) value).longValue(); case FLOAT: - return value == null ? 0f : ((BigDecimal) value).floatValue(); - case BIG_DECIMAL: + return ((BigDecimal) value).floatValue(); case DOUBLE: - return value == null ? 0d : ((BigDecimal) value).doubleValue(); + case BIG_DECIMAL: + return ((BigDecimal) value).doubleValue(); case STRING: - return value == null ? "" : ((NlsString) value).getValue(); + return ((NlsString) value).getValue(); default: return value; } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java index 235c5bd491..b23b34433b 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java @@ -85,7 +85,10 @@ public class FieldSelectionKeySelector implements KeySelector<Object[], Object[] // TODO: consider better hashing algorithms than hashCode sum, such as XOR'ing int hashCode = 0; for (int columnIndex : _columnIndices) { - hashCode += input[columnIndex].hashCode(); + Object value = input[columnIndex]; + if (value != null) { + hashCode += value.hashCode(); + } } // return a positive number because this is used directly to modulo-index diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java index 8835980ace..d445a6a2ec 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java @@ -40,12 +40,9 @@ import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils; import org.apache.pinot.query.runtime.operator.utils.AggregationUtils; import org.apache.pinot.query.runtime.plan.OpChainExecutionContext; import org.apache.pinot.segment.local.customobject.PinotFourthMoment; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** - * * AggregateOperator is used to aggregate values over a set of group by keys. * Output data will be in the format of [group by key, aggregate result1, ... aggregate resultN] * Currently, we only support SUM/COUNT/MIN/MAX aggregation. @@ -60,7 +57,6 @@ import org.slf4j.LoggerFactory; */ public class AggregateOperator extends MultiStageOperator { private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR"; - private static final Logger LOGGER = LoggerFactory.getLogger(AggregateOperator.class); private final MultiStageOperator _inputOperator; @@ -177,7 +173,7 @@ public class AggregateOperator extends MultiStageOperator { private TransferableBlock constructEmptyAggResultBlock() { Object[] row = new Object[_aggCalls.size()]; for (int i = 0; i < _accumulators.length; i++) { - row[i] = _accumulators[i].getMerger().initialize(null, _accumulators[i].getDataType()); + row[i] = _accumulators[i].getMerger().init(null, _accumulators[i].getDataType()); } return new TransferableBlock(Collections.singletonList(row), _resultSchema, DataBlock.Type.ROW); } @@ -220,43 +216,76 @@ public class AggregateOperator extends MultiStageOperator { private static class MergeFourthMomentNumeric implements AggregationUtils.Merger { + @Nullable @Override - public Object merge(Object left, Object right) { - ((PinotFourthMoment) left).increment(((Number) right).doubleValue()); - return left; + public PinotFourthMoment init(@Nullable Object value, DataSchema.ColumnDataType dataType) { + if (value == null) { + return null; + } + PinotFourthMoment moment = new PinotFourthMoment(); + moment.increment(((Number) value).doubleValue()); + return moment; } + @Nullable @Override - public Object initialize(Object other, DataSchema.ColumnDataType dataType) { - PinotFourthMoment moment = new PinotFourthMoment(); - moment.increment(((Number) other).doubleValue()); + public PinotFourthMoment merge(@Nullable Object agg, @Nullable Object value) { + PinotFourthMoment moment = (PinotFourthMoment) agg; + if (value == null) { + return moment; + } + if (moment == null) { + moment = new PinotFourthMoment(); + } + moment.increment(((Number) value).doubleValue()); return moment; } } private static class MergeFourthMomentObject implements AggregationUtils.Merger { + @Nullable @Override - public Object merge(Object left, Object right) { - PinotFourthMoment agg = (PinotFourthMoment) left; - agg.combine((PinotFourthMoment) right); - return agg; + public PinotFourthMoment merge(@Nullable Object agg, @Nullable Object value) { + PinotFourthMoment moment1 = (PinotFourthMoment) agg; + PinotFourthMoment moment2 = (PinotFourthMoment) value; + if (moment1 == null) { + return moment2; + } + if (moment2 == null) { + return moment1; + } + moment1.combine(moment2); + return moment1; } } + // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet) private static class MergeCountDistinctScalars implements AggregationUtils.Merger { - @SuppressWarnings("unchecked") + + @Nullable @Override - public Object merge(Object agg, Object value) { - // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet) - ((Set<Object>) agg).add(value); - return agg; + public Set<Object> init(@Nullable Object value, DataSchema.ColumnDataType dataType) { + if (value == null) { + return null; + } + Set<Object> set = new ObjectOpenHashSet<>(); + set.add(value); + return set; } + @SuppressWarnings("unchecked") + @Nullable @Override - public Object initialize(Object other, DataSchema.ColumnDataType dataType) { - ObjectOpenHashSet<Object> set = new ObjectOpenHashSet<>(); - set.add(other); + public Set<Object> merge(@Nullable Object agg, @Nullable Object value) { + Set<Object> set = (Set<Object>) agg; + if (value == null) { + return set; + } + if (set == null) { + set = new ObjectOpenHashSet<>(); + } + set.add(value); return set; } } @@ -264,11 +293,19 @@ public class AggregateOperator extends MultiStageOperator { private static class MergeCountDistinctSets implements AggregationUtils.Merger { @SuppressWarnings("unchecked") + @Nullable @Override - public Object merge(Object agg, Object value) { - // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet) - ((Set<Object>) agg).addAll((Set<Object>) value); - return agg; + public Set<Object> merge(@Nullable Object agg, @Nullable Object value) { + Set<Object> set1 = (Set<Object>) agg; + Set<Object> set2 = (Set<Object>) value; + if (set1 == null) { + return set2; + } + if (set2 == null) { + return set1; + } + set1.addAll(set2); + return set1; } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java index 6b29cdc82a..9791a72d42 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java @@ -402,13 +402,13 @@ public class WindowAggregateOperator extends MultiStageOperator { private static class MergeRowNumber implements AggregationUtils.Merger { @Override - public Object initialize(Object other, DataSchema.ColumnDataType dataType) { + public Long init(@Nullable Object value, DataSchema.ColumnDataType dataType) { return 1L; } @Override - public Object merge(Object left, Object right) { - return ((Number) left).longValue() + 1L; + public Long merge(Object agg, @Nullable Object value) { + return (long) agg + 1; } } @@ -440,7 +440,7 @@ public class WindowAggregateOperator extends MultiStageOperator { Object previousRowOutputValue) { Object value = _inputRef == -1 ? _literal : row[_inputRef]; if (previousPartitionKey == null || !currentPartitionKey.equals(previousPartitionKey)) { - return _merger.initialize(currentPartitionKey, _dataType); + return _merger.init(currentPartitionKey, _dataType); } else { return _merger.merge(previousRowOutputValue, value); } @@ -466,7 +466,7 @@ public class WindowAggregateOperator extends MultiStageOperator { _orderByResults.putIfAbsent(key, new OrderKeyResult()); if (currentRes == null) { - _orderByResults.get(key).addOrderByResult(orderKey, _merger.initialize(value, _dataType)); + _orderByResults.get(key).addOrderByResult(orderKey, _merger.init(value, _dataType)); } else { Object mergedResult; if (orderKey.equals(previousOrderKeyIfPresent)) { diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java index fbd8f95bd2..ab69ce67b2 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java @@ -18,12 +18,13 @@ */ package org.apache.pinot.query.runtime.operator.operands; - import com.google.common.base.Preconditions; import java.util.ArrayList; import java.util.List; +import java.util.function.IntPredicate; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils; import org.apache.pinot.spi.utils.BooleanUtils; @@ -100,15 +101,16 @@ public abstract class FilterOperand extends TransformOperand { } } - public static abstract class Predicate extends FilterOperand { - protected final TransformOperand _lhs; - protected final TransformOperand _rhs; - protected final boolean _requireCasting; - protected final DataSchema.ColumnDataType _commonCastType; + public static class Predicate extends FilterOperand { + private final TransformOperand _lhs; + private final TransformOperand _rhs; + private final IntPredicate _comparisonResultPredicate; + private final boolean _requireCasting; + private final DataSchema.ColumnDataType _commonCastType; /** * Predicate constructor also resolve data type, - * since we don't have a exhausted list of filter function signatures. we rely on type casting. + * since we don't have an exhausted list of filter function signatures. we rely on type casting. * * <ul> * <li>if both RHS and LHS has null data type, exception occurs.</li> @@ -116,22 +118,22 @@ public abstract class FilterOperand extends TransformOperand { * <li>if either side supertype of the other, we use the super type.</li> * <li>if we can't resolve a common data type, exception occurs.</li> * </ul> - * - * */ - public Predicate(List<RexExpression> functionOperands, DataSchema inputDataSchema) { + public Predicate(List<RexExpression> functionOperands, DataSchema inputDataSchema, + IntPredicate comparisonResultPredicate) { Preconditions.checkState(functionOperands.size() == 2, "Expected 2 function ops for Predicate but got:" + functionOperands.size()); _lhs = TransformOperand.toTransformOperand(functionOperands.get(0), inputDataSchema); _rhs = TransformOperand.toTransformOperand(functionOperands.get(1), inputDataSchema); + _comparisonResultPredicate = comparisonResultPredicate; // TODO: Correctly throw exception instead of returning null. // Currently exception thrown during constructor is not piped back to query dispatcher, thus in order to // avoid silent failure, we deliberately set to null here, make the exception thrown during data processing. // TODO: right now all the numeric columns are still doing conversion b/c even if the inputDataSchema asked for // one of the number type, it might not contain the exact type in the payload. - if (_lhs._resultType == null || _lhs._resultType == DataSchema.ColumnDataType.OBJECT - || _rhs._resultType == null || _rhs._resultType == DataSchema.ColumnDataType.OBJECT) { + if (_lhs._resultType == null || _lhs._resultType == DataSchema.ColumnDataType.OBJECT || _rhs._resultType == null + || _rhs._resultType == DataSchema.ColumnDataType.OBJECT) { _requireCasting = false; _commonCastType = null; } else if (_lhs._resultType.isSuperTypeOf(_rhs._resultType)) { @@ -145,5 +147,24 @@ public abstract class FilterOperand extends TransformOperand { _commonCastType = null; } } + + @SuppressWarnings({"rawtypes", "unchecked"}) + @Override + public Boolean apply(Object[] row) { + Comparable v1 = (Comparable) _lhs.apply(row); + if (v1 == null) { + return false; + } + Comparable v2 = (Comparable) _rhs.apply(row); + if (v2 == null) { + return false; + } + if (_requireCasting) { + v1 = (Comparable) FunctionInvokeUtils.convert(v1, _commonCastType); + v2 = (Comparable) FunctionInvokeUtils.convert(v2, _commonCastType); + assert v1 != null && v2 != null; + } + return _comparisonResultPredicate.test(v1.compareTo(v2)); + } } } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java index 7c34c8e46c..88eb4f37de 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java @@ -18,12 +18,11 @@ */ package org.apache.pinot.query.runtime.operator.operands; - import com.google.common.base.Preconditions; import java.util.List; +import javax.annotation.Nullable; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.query.planner.logical.RexExpression; -import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils; import org.apache.pinot.query.runtime.operator.utils.OperatorUtils; @@ -43,7 +42,6 @@ public abstract class TransformOperand { } } - @SuppressWarnings({"ConstantConditions", "rawtypes", "unchecked"}) private static TransformOperand toTransformOperand(RexExpression.FunctionCall functionCall, DataSchema inputDataSchema) { final List<RexExpression> functionOperands = functionCall.getFunctionOperands(); @@ -65,77 +63,17 @@ public abstract class TransformOperand { "BOOL / IS_TRUE takes one argument, passed in argument size:" + operandSize); return new FilterOperand.True(functionOperands.get(0), inputDataSchema); case "equals": - return new FilterOperand.Predicate(functionOperands, inputDataSchema) { - @Override - public Boolean apply(Object[] row) { - if (_requireCasting) { - return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo( - FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) == 0; - } else { - return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) == 0; - } - } - }; + return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v == 0); case "notEquals": - return new FilterOperand.Predicate(functionOperands, inputDataSchema) { - @Override - public Boolean apply(Object[] row) { - if (_requireCasting) { - return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo( - FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) != 0; - } else { - return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) != 0; - } - } - }; + return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v != 0); case "greaterThan": - return new FilterOperand.Predicate(functionOperands, inputDataSchema) { - @Override - public Boolean apply(Object[] row) { - if (_requireCasting) { - return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo( - FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) > 0; - } else { - return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) > 0; - } - } - }; + return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v > 0); case "greaterThanOrEqual": - return new FilterOperand.Predicate(functionOperands, inputDataSchema) { - @Override - public Boolean apply(Object[] row) { - if (_requireCasting) { - return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo( - FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) >= 0; - } else { - return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) >= 0; - } - } - }; + return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v >= 0); case "lessThan": - return new FilterOperand.Predicate(functionOperands, inputDataSchema) { - @Override - public Boolean apply(Object[] row) { - if (_requireCasting) { - return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo( - FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) < 0; - } else { - return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) < 0; - } - } - }; + return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v < 0); case "lessThanOrEqual": - return new FilterOperand.Predicate(functionOperands, inputDataSchema) { - @Override - public Boolean apply(Object[] row) { - if (_requireCasting) { - return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo( - FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) <= 0; - } else { - return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) <= 0; - } - } - }; + return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v <= 0); default: return new FunctionOperand(functionCall, inputDataSchema); } @@ -149,5 +87,6 @@ public abstract class TransformOperand { return _resultType; } + @Nullable public abstract Object apply(Object[] row); } diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java index 81bd7dea0c..e3e466d7d2 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Function; +import javax.annotation.Nullable; import org.apache.pinot.common.utils.DataSchema; import org.apache.pinot.core.data.table.Key; import org.apache.pinot.query.planner.logical.RexExpression; @@ -38,7 +39,6 @@ import org.apache.pinot.spi.data.FieldSpec; * <p>Accumulation is used by {@code WindowAggregateOperator} and {@code AggregateOperator}. */ public class AggregationUtils { - private AggregationUtils() { } @@ -54,54 +54,92 @@ public class AggregationUtils { return new Key(new Object[0]); } - private static Object mergeSum(Object left, Object right) { - return ((Number) left).doubleValue() + ((Number) right).doubleValue(); + // TODO: Use the correct type for SUM/MIN/MAX instead of always using double + + @Nullable + private static Object mergeSum(@Nullable Object agg, @Nullable Object value) { + if (agg == null) { + return value; + } + if (value == null) { + return agg; + } + return ((Number) agg).doubleValue() + ((Number) value).doubleValue(); } - private static Object mergeMin(Object left, Object right) { - return Math.min(((Number) left).doubleValue(), ((Number) right).doubleValue()); + @Nullable + private static Object mergeMin(@Nullable Object agg, @Nullable Object value) { + if (agg == null) { + return value; + } + if (value == null) { + return agg; + } + return Math.min(((Number) agg).doubleValue(), ((Number) value).doubleValue()); } - private static Object mergeMax(Object left, Object right) { - return Math.max(((Number) left).doubleValue(), ((Number) right).doubleValue()); + @Nullable + private static Object mergeMax(@Nullable Object agg, @Nullable Object value) { + if (agg == null) { + return value; + } + if (value == null) { + return agg; + } + return Math.max(((Number) agg).doubleValue(), ((Number) value).doubleValue()); } - private static Boolean mergeBoolAnd(Object left, Object right) { - return ((Boolean) left) && ((Boolean) right); + @Nullable + private static Boolean mergeBoolAnd(@Nullable Object agg, @Nullable Object value) { + if (agg == null) { + return (Boolean) value; + } + if (value == null) { + return (Boolean) agg; + } + return ((Boolean) agg) & ((Boolean) value); } - private static Boolean mergeBoolOr(Object left, Object right) { - return ((Boolean) left) || ((Boolean) right); + @Nullable + private static Boolean mergeBoolOr(@Nullable Object agg, @Nullable Object value) { + if (agg == null) { + return (Boolean) value; + } + if (value == null) { + return (Boolean) agg; + } + return ((Boolean) agg) | ((Boolean) value); } private static class MergeCounts implements AggregationUtils.Merger { @Override - public Object initialize(Object other, DataSchema.ColumnDataType dataType) { - return other == null ? 0 : 1; + public Long init(@Nullable Object value, DataSchema.ColumnDataType dataType) { + return value == null ? 0L : 1L; } @Override - public Object merge(Object left, Object right) { - return ((Number) left).doubleValue() + (right == null ? 0 : 1); + public Long merge(Object agg, @Nullable Object value) { + return value == null ? (long) agg : (long) agg + 1; } } public interface Merger { + /** - * Initializes the merger based on the first input + * Initializes the merger based on the column data type and first value. */ - default Object initialize(Object other, DataSchema.ColumnDataType dataType) { - // TODO: Initialize as a double so that if only one row is returned it matches the type when many rows are - // returned - return other == null ? dataType.getNullPlaceholder() : other; + @Nullable + default Object init(@Nullable Object value, DataSchema.ColumnDataType dataType) { + return value; } /** - * Merges the existing aggregate (the result of {@link #initialize(Object, DataSchema.ColumnDataType)}) with + * Merges the existing aggregate (the result of {@link #init(Object, DataSchema.ColumnDataType)}) with * the new value coming in (which may be an aggregate in and of itself). */ - Object merge(Object agg, Object value); + @Nullable + Object merge(@Nullable Object agg, @Nullable Object value); } /** @@ -169,7 +207,7 @@ public class AggregationUtils { Object value = _inputRef == -1 ? _literal : row[_inputRef]; if (currentRes == null) { - _results.put(key, _merger.initialize(value, _dataType)); + _results.put(key, _merger.init(value, _dataType)); } else { Object mergedResult = _merger.merge(currentRes, value); _results.put(key, mergedResult); diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java index de26cdfab4..851748d2b2 100644 --- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java +++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java @@ -18,13 +18,12 @@ */ package org.apache.pinot.query.runtime.operator.utils; +import javax.annotation.Nullable; import org.apache.pinot.common.utils.DataSchema; public class FunctionInvokeUtils { - private FunctionInvokeUtils() { - // do not instantiate. } /** @@ -35,7 +34,8 @@ public class FunctionInvokeUtils { * @param columnDataType desired column data type * @return converted entry */ - public static Object convert(Object inputObj, DataSchema.ColumnDataType columnDataType) { + @Nullable + public static Object convert(@Nullable Object inputObj, DataSchema.ColumnDataType columnDataType) { if (columnDataType.isNumber() && columnDataType != DataSchema.ColumnDataType.BIG_DECIMAL) { return inputObj == null ? null : columnDataType.convert(inputObj); } else { diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java index fce25e7e3f..bda5086a40 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java @@ -21,6 +21,7 @@ package org.apache.pinot.query.runtime.operator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.apache.calcite.sql.SqlKind; import org.apache.pinot.common.utils.DataSchema; @@ -201,7 +202,7 @@ public class AggregateOperatorTest { AggregationUtils.Merger merger = Mockito.mock(AggregationUtils.Merger.class); Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d); - Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d); + Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d); DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE}); AggregateOperator operator = new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, calls, group, inSchema, @@ -213,7 +214,7 @@ public class AggregateOperatorTest { // Then: // should call merger twice, one from second row in first block and two from the first row // in second block - Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any()); + Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), Mockito.any()); Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any()); Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 12d}, "Expected two columns (group by key, agg value)"); @@ -226,8 +227,8 @@ public class AggregateOperatorTest { RexExpression.FunctionCall agg = getSum(new RexExpression.InputRef(0)); DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT}); AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator, - OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), Arrays.asList(agg), - Arrays.asList(new RexExpression.InputRef(1)), inSchema); + OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), Collections.singletonList(agg), + Collections.singletonList(new RexExpression.InputRef(1)), inSchema); TransferableBlock result = sum0GroupBy1.getNextBlock(); while (result.isNoOpBlock()) { result = sum0GroupBy1.getNextBlock(); diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java index 817c7239d6..fc54b8941e 100644 --- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java +++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java @@ -319,7 +319,7 @@ public class WindowAggregateOperatorTest { AggregationUtils.Merger merger = Mockito.mock(AggregationUtils.Merger.class); Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d); - Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d); + Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d); DataSchema outSchema = new DataSchema(new String[]{"group", "arg", "sum"}, new DataSchema.ColumnDataType[]{INT, INT, DOUBLE}); WindowAggregateOperator operator = @@ -334,7 +334,7 @@ public class WindowAggregateOperatorTest { // Then: // should call merger twice, one from second row in first block and two from the first row // in second block - Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any()); + Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), Mockito.any()); Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any()); Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 1, 12d}, "Expected three columns (original two columns, agg literal value)"); diff --git a/pinot-query-runtime/src/test/resources/queries/NullHandling.json b/pinot-query-runtime/src/test/resources/queries/NullHandling.json new file mode 100644 index 0000000000..f51701317c --- /dev/null +++ b/pinot-query-runtime/src/test/resources/queries/NullHandling.json @@ -0,0 +1,51 @@ +{ + "null_on_intermediate": { + "tables": { + "tbl1" : { + "schema": [ + {"name": "strCol1", "type": "STRING"}, + {"name": "intCol1", "type": "INT"}, + {"name": "strCol2", "type": "STRING"} + ], + "inputs": [ + ["foo", 1, "foo"], + ["bar", 2, "alice"] + ] + }, + "tbl2" : { + "schema": [ + {"name": "strCol1", "type": "STRING"}, + {"name": "strCol2", "type": "STRING"}, + {"name": "intCol1", "type": "INT"}, + {"name": "doubleCol1", "type": "DOUBLE"} + ], + "inputs": [ + ["foo", "bob", 3, 3.1416], + ["alice", "alice", 4, 2.7183] + ] + } + }, + "queries": [ + { + "description": "LEFT JOIN and FILTER", + "sql": "SELECT {tbl1}.strCol2, {tbl2}.doubleCol1 IS NULL OR {tbl1}.intCol1 > 3 AS boolFlag FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1" + }, + { + "description": "LEFT JOIN and TRANSFORM", + "sql": "SELECT {tbl1}.strCol2, {tbl1}.intCol1 * {tbl2}.doubleCol1 + {tbl2}.intCol1 FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1" + }, + { + "description": "LEFT JOIN and AGGREGATE", + "sql": "SELECT COUNT({tbl2}.intCol1), MIN({tbl2}.intCol1), MAX({tbl2}.doubleCol1), SUM({tbl2}.doubleCol1) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1" + }, + { + "description": "LEFT JOIN and GROUP BY", + "sql": "SELECT {tbl1}.strCol2, {tbl2}.intCol1, COUNT(*) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1 GROUP BY {tbl1}.strCol2, {tbl2}.intCol1" + }, + { + "description": "LEFT JOIN and GROUP BY with AGGREGATE", + "sql": "SELECT {tbl1}.strCol2, COUNT({tbl2}.intCol1), MIN({tbl2}.intCol1), MAX({tbl2}.doubleCol1), SUM({tbl2}.doubleCol1) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1 GROUP BY {tbl1}.strCol2" + } + ] + } +} \ No newline at end of file diff --git a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json index 5caa982bec..afc989fded 100644 --- a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json +++ b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json @@ -572,7 +572,7 @@ "description": "Single empty OVER() with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column", "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER() as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)", "outputs": [ - [0] + [null] ] }, { @@ -580,7 +580,7 @@ "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(ORDER BY string_col) as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)", "keepOutputRowOrder": true, "outputs": [ - [0] + [null] ] }, { @@ -1335,7 +1335,7 @@ "description": "Multiple empty OVER()s with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column", "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER() as count, MIN(double_col) OVER() as min FROM {tbl} WHERE string_col = 'a' AND bool_col != false AND int_col > 200)", "outputs": [ - [0] + [null] ] }, { @@ -1343,7 +1343,7 @@ "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(ORDER BY string_col) as count, MIN(double_col) OVER(ORDER BY string_col) as min FROM {tbl} WHERE string_col = 'a' AND bool_col != false AND int_col > 200)", "keepOutputRowOrder": true, "outputs": [ - [0] + [null] ] }, { @@ -2477,7 +2477,7 @@ "description": "Single OVER(PARTITION BY) with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column", "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(PARTITION BY string_col) as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)", "outputs": [ - [0] + [null] ] }, { @@ -2486,7 +2486,7 @@ "comments": "Cannot enforce a global ordering as partitions aren't ordered, just keys within a partition are", "keepOutputRowOrder": false, "outputs": [ - [0] + [null] ] }, { @@ -3445,7 +3445,7 @@ "description": "Multiple OVER(PARTITION BY)s with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column", "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(PARTITION BY string_col) as count, AVG(int_col) OVER(PARTITION BY string_col) as avg FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)", "outputs": [ - [0] + [null] ] }, { @@ -3454,7 +3454,7 @@ "comments": "Cannot enforce a global ordering as partitions aren't ordered, just keys within a partition are", "keepOutputRowOrder": false, "outputs": [ - [0] + [null] ] }, { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org