This is an automated email from the ASF dual-hosted git repository. jackie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git
The following commit(s) were added to refs/heads/master by this push: new 89cd958 Support post-aggregation in SELECT (#5867) 89cd958 is described below commit 89cd958dbabce88befa7561d3d4a25a9f97afaaa Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com> AuthorDate: Tue Aug 18 14:41:59 2020 -0700 Support post-aggregation in SELECT (#5867) Add `PostAggregationHandler` to handle the post-aggregation calculation and column re-ordering for the aggregation result Enhance `AggregationDataTableReducer` and `GroupByDataTableReducer` to support post-aggregation in SELECT --- .../query/reduce/AggregationDataTableReducer.java | 79 +++---- .../core/query/reduce/GroupByDataTableReducer.java | 122 +++-------- .../core/query/reduce/PostAggregationHandler.java | 243 +++++++++++++++++++++ .../core/query/reduce/ResultReducerFactory.java | 4 +- .../query/reduce/PostAggregationHandlerTest.java | 120 ++++++++++ .../tests/BaseClusterIntegrationTestSet.java | 10 + 6 files changed, 450 insertions(+), 128 deletions(-) diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java index 855fbe0..3234a30 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/AggregationDataTableReducer.java @@ -28,6 +28,7 @@ import org.apache.pinot.common.response.broker.AggregationResult; import org.apache.pinot.common.response.broker.BrokerResponseNative; import org.apache.pinot.common.response.broker.ResultTable; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.DataTable; import org.apache.pinot.core.query.aggregation.function.AggregationFunction; import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils; @@ -41,12 +42,14 @@ import org.apache.pinot.core.util.QueryOptions; */ @SuppressWarnings({"rawtypes", "unchecked"}) public class AggregationDataTableReducer implements DataTableReducer { + private final QueryContext _queryContext; private final AggregationFunction[] _aggregationFunctions; private final boolean _preserveType; private final boolean _responseFormatSql; - AggregationDataTableReducer(QueryContext queryContext, AggregationFunction[] aggregationFunctions) { - _aggregationFunctions = aggregationFunctions; + AggregationDataTableReducer(QueryContext queryContext) { + _queryContext = queryContext; + _aggregationFunctions = queryContext.getAggregationFunctions(); QueryOptions queryOptions = new QueryOptions(queryContext.getQueryOptions()); _preserveType = queryOptions.isPreserveType(); _responseFormatSql = queryOptions.isResponseFormatSQL(); @@ -63,8 +66,9 @@ public class AggregationDataTableReducer implements DataTableReducer { BrokerMetrics brokerMetrics) { if (dataTableMap.isEmpty()) { if (_responseFormatSql) { - DataSchema finalDataSchema = getResultTableDataSchema(); - brokerResponseNative.setResultTable(new ResultTable(finalDataSchema, Collections.emptyList())); + DataSchema resultTableSchema = + new PostAggregationHandler(_queryContext, getPrePostAggregationDataSchema()).getResultDataSchema(); + brokerResponseNative.setResultTable(new ResultTable(resultTableSchema, Collections.emptyList())); } return; } @@ -75,7 +79,7 @@ public class AggregationDataTableReducer implements DataTableReducer { for (DataTable dataTable : dataTableMap.values()) { for (int i = 0; i < numAggregationFunctions; i++) { Object intermediateResultToMerge; - DataSchema.ColumnDataType columnDataType = dataSchema.getColumnDataType(i); + ColumnDataType columnDataType = dataSchema.getColumnDataType(i); switch (columnDataType) { case LONG: intermediateResultToMerge = dataTable.getLong(0, i); @@ -97,63 +101,62 @@ public class AggregationDataTableReducer implements DataTableReducer { } } } + Serializable[] finalResults = new Serializable[numAggregationFunctions]; + for (int i = 0; i < numAggregationFunctions; i++) { + finalResults[i] = AggregationFunctionUtils + .getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i])); + } if (_responseFormatSql) { - brokerResponseNative.setResultTable(reduceToResultTable(intermediateResults)); + brokerResponseNative.setResultTable(reduceToResultTable(finalResults)); } else { - brokerResponseNative.setAggregationResults(reduceToAggregationResult(intermediateResults, dataSchema)); + brokerResponseNative.setAggregationResults(reduceToAggregationResults(finalResults, dataSchema)); } } /** * Sets aggregation results into ResultsTable */ - private ResultTable reduceToResultTable(Object[] intermediateResults) { - List<Object[]> rows = new ArrayList<>(1); - int numAggregationFunctions = _aggregationFunctions.length; - Object[] row = new Object[numAggregationFunctions]; - for (int i = 0; i < numAggregationFunctions; i++) { - row[i] = AggregationFunctionUtils - .getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i])); - } - rows.add(row); - - DataSchema finalDataSchema = getResultTableDataSchema(); - return new ResultTable(finalDataSchema, rows); + private ResultTable reduceToResultTable(Object[] finalResults) { + PostAggregationHandler postAggregationHandler = + new PostAggregationHandler(_queryContext, getPrePostAggregationDataSchema()); + DataSchema resultTableSchema = postAggregationHandler.getResultDataSchema(); + Object[] resultRow = postAggregationHandler.getResult(finalResults); + return new ResultTable(resultTableSchema, Collections.singletonList(resultRow)); } /** * Sets aggregation results into AggregationResults */ - private List<AggregationResult> reduceToAggregationResult(Object[] intermediateResults, DataSchema dataSchema) { - // Extract final results and set them into the broker response + private List<AggregationResult> reduceToAggregationResults(Serializable[] finalResults, DataSchema dataSchema) { int numAggregationFunctions = _aggregationFunctions.length; - List<AggregationResult> reducedAggregationResults = new ArrayList<>(numAggregationFunctions); - for (int i = 0; i < numAggregationFunctions; i++) { - Serializable resultValue = AggregationFunctionUtils - .getSerializableValue(_aggregationFunctions[i].extractFinalResult(intermediateResults[i])); - - // Format the value into string if required - if (!_preserveType) { - resultValue = AggregationFunctionUtils.formatValue(resultValue); + List<AggregationResult> aggregationResults = new ArrayList<>(numAggregationFunctions); + if (_preserveType) { + for (int i = 0; i < numAggregationFunctions; i++) { + aggregationResults.add(new AggregationResult(dataSchema.getColumnName(i), finalResults[i])); + } + } else { + // Format the values into strings + for (int i = 0; i < numAggregationFunctions; i++) { + aggregationResults.add( + new AggregationResult(dataSchema.getColumnName(i), AggregationFunctionUtils.formatValue(finalResults[i]))); } - reducedAggregationResults.add(new AggregationResult(dataSchema.getColumnName(i), resultValue)); } - return reducedAggregationResults; + return aggregationResults; } /** - * Constructs the data schema for the final results table + * Constructs the DataSchema for the rows before the post-aggregation (SQL mode). */ - private DataSchema getResultTableDataSchema() { + private DataSchema getPrePostAggregationDataSchema() { int numAggregationFunctions = _aggregationFunctions.length; - String[] finalColumnNames = new String[numAggregationFunctions]; - DataSchema.ColumnDataType[] finalColumnDataTypes = new DataSchema.ColumnDataType[numAggregationFunctions]; + String[] columnNames = new String[numAggregationFunctions]; + ColumnDataType[] columnDataTypes = new ColumnDataType[numAggregationFunctions]; for (int i = 0; i < numAggregationFunctions; i++) { AggregationFunction aggregationFunction = _aggregationFunctions[i]; - finalColumnNames[i] = aggregationFunction.getResultColumnName(); - finalColumnDataTypes[i] = aggregationFunction.getFinalResultColumnType(); + columnNames[i] = aggregationFunction.getResultColumnName(); + columnDataTypes[i] = aggregationFunction.getFinalResultColumnType(); } - return new DataSchema(finalColumnNames, finalColumnDataTypes); + return new DataSchema(columnNames, columnDataTypes); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java index 066b755..ea8e9c7 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/GroupByDataTableReducer.java @@ -18,7 +18,6 @@ */ package org.apache.pinot.core.query.reduce; -import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.ArrayList; import java.util.Collection; @@ -34,6 +33,7 @@ import org.apache.pinot.common.response.broker.BrokerResponseNative; import org.apache.pinot.common.response.broker.GroupByResult; import org.apache.pinot.common.response.broker.ResultTable; import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.DataTable; import org.apache.pinot.common.utils.HashUtil; import org.apache.pinot.core.data.table.ConcurrentIndexedTable; @@ -44,7 +44,6 @@ import org.apache.pinot.core.query.aggregation.function.AggregationFunctionUtils import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByTrimmingService; import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator; import org.apache.pinot.core.query.request.context.ExpressionContext; -import org.apache.pinot.core.query.request.context.FunctionContext; import org.apache.pinot.core.query.request.context.QueryContext; import org.apache.pinot.core.transport.ServerRoutingInstance; import org.apache.pinot.core.util.GroupByUtils; @@ -67,10 +66,11 @@ public class GroupByDataTableReducer implements DataTableReducer { private final boolean _responseFormatSql; private final boolean _sqlQuery; - GroupByDataTableReducer(QueryContext queryContext, AggregationFunction[] aggregationFunctions) { + GroupByDataTableReducer(QueryContext queryContext) { _queryContext = queryContext; - _aggregationFunctions = aggregationFunctions; - _numAggregationFunctions = aggregationFunctions.length; + _aggregationFunctions = queryContext.getAggregationFunctions(); + assert _aggregationFunctions != null; + _numAggregationFunctions = _aggregationFunctions.length; _groupByExpressions = queryContext.getGroupByExpressions(); assert _groupByExpressions != null; _numGroupByExpressions = _groupByExpressions.size(); @@ -162,103 +162,49 @@ public class GroupByDataTableReducer implements DataTableReducer { */ private void setSQLGroupByInResultTable(BrokerResponseNative brokerResponseNative, DataSchema dataSchema, Collection<DataTable> dataTables) { - DataSchema resultTableSchema = getSQLResultTableSchema(dataSchema); IndexedTable indexedTable = getIndexedTable(dataSchema, dataTables); Iterator<Record> sortedIterator = indexedTable.iterator(); + DataSchema prePostAggregationDataSchema = getPrePostAggregationDataSchema(dataSchema); int limit = _queryContext.getLimit(); List<Object[]> rows = new ArrayList<>(limit); + for (int i = 0; i < limit && sortedIterator.hasNext(); i++) { + Object[] row = sortedIterator.next().getValues(); + for (int j = 0; j < _numAggregationFunctions; j++) { + int valueIndex = j + _numGroupByExpressions; + row[valueIndex] = + AggregationFunctionUtils.getSerializableValue(_aggregationFunctions[j].extractFinalResult(row[valueIndex])); + } + rows.add(row); + } if (_sqlQuery) { // SQL query with SQL group-by mode and response format - // NOTE: For SQL query, need to reorder the columns in the data table based on the select expressions. - - int[] selectExpressionIndexMap = getSelectExpressionIndexMap(); - int numSelectExpressions = selectExpressionIndexMap.length; - String[] columnNames = resultTableSchema.getColumnNames(); - DataSchema.ColumnDataType[] columnDataTypes = resultTableSchema.getColumnDataTypes(); - String[] reorderedColumnNames = new String[numSelectExpressions]; - DataSchema.ColumnDataType[] reorderedColumnDataTypes = new DataSchema.ColumnDataType[numSelectExpressions]; - resultTableSchema = new DataSchema(reorderedColumnNames, reorderedColumnDataTypes); - for (int i = 0; i < numSelectExpressions; i++) { - reorderedColumnNames[i] = columnNames[selectExpressionIndexMap[i]]; - reorderedColumnDataTypes[i] = columnDataTypes[selectExpressionIndexMap[i]]; - } - while (rows.size() < limit && sortedIterator.hasNext()) { - Record nextRecord = sortedIterator.next(); - Object[] values = nextRecord.getValues(); - for (int i = 0; i < _numAggregationFunctions; i++) { - int valueIndex = i + _numGroupByExpressions; - values[valueIndex] = AggregationFunctionUtils - .getSerializableValue(_aggregationFunctions[i].extractFinalResult(values[valueIndex])); - } - Object[] reorderedValues = new Object[numSelectExpressions]; - for (int i = 0; i < numSelectExpressions; i++) { - reorderedValues[i] = values[selectExpressionIndexMap[i]]; - } - rows.add(reorderedValues); - } + PostAggregationHandler postAggregationHandler = + new PostAggregationHandler(_queryContext, prePostAggregationDataSchema); + DataSchema resultTableSchema = postAggregationHandler.getResultDataSchema(); + rows.replaceAll(postAggregationHandler::getResult); + brokerResponseNative.setResultTable(new ResultTable(resultTableSchema, rows)); } else { // PQL query with SQL group-by mode and response format + // NOTE: For PQL query, keep the order of columns as is (group-by expressions followed by aggregations), no need + // to perform post-aggregation. - while (rows.size() < limit && sortedIterator.hasNext()) { - Record nextRecord = sortedIterator.next(); - Object[] values = nextRecord.getValues(); - for (int i = 0; i < _numAggregationFunctions; i++) { - int valueIndex = i + _numGroupByExpressions; - values[valueIndex] = AggregationFunctionUtils - .getSerializableValue(_aggregationFunctions[i].extractFinalResult(values[valueIndex])); - } - rows.add(values); - } - } - - brokerResponseNative.setResultTable(new ResultTable(resultTableSchema, rows)); - } - - /** - * Helper method to generate a map from the expression index in the select expressions to the column index in the data - * schema. This map is used to reorder the expressions according to the select expressions. - */ - private int[] getSelectExpressionIndexMap() { - List<ExpressionContext> selectExpressions = _queryContext.getSelectExpressions(); - List<ExpressionContext> groupByExpressions = _queryContext.getGroupByExpressions(); - assert groupByExpressions != null; - int numSelectExpressions = selectExpressions.size(); - int[] selectExpressionIndexMap = new int[numSelectExpressions]; - int aggregationExpressionIndex = _numGroupByExpressions; - for (int i = 0; i < numSelectExpressions; i++) { - ExpressionContext selectExpression = selectExpressions.get(i); - if (selectExpression.getType() == ExpressionContext.Type.FUNCTION - && selectExpression.getFunction().getType() == FunctionContext.Type.AGGREGATION) { - selectExpressionIndexMap[i] = aggregationExpressionIndex++; - } else { - int indexInGroupByExpressions = groupByExpressions.indexOf(selectExpression); - Preconditions.checkState(indexInGroupByExpressions >= 0, - "Select expression: %s is not an aggregation expression and not contained in the group-by expressions"); - selectExpressionIndexMap[i] = indexInGroupByExpressions; - } + brokerResponseNative.setResultTable(new ResultTable(prePostAggregationDataSchema, rows)); } - return selectExpressionIndexMap; } /** - * Constructs the final result table schema for sql mode execution - * The data type for the aggregations needs to be taken from the final result's data type + * Constructs the DataSchema for the rows before the post-aggregation (SQL mode). */ - private DataSchema getSQLResultTableSchema(DataSchema dataSchema) { - String[] columns = dataSchema.getColumnNames(); - DataSchema.ColumnDataType[] finalColumnDataTypes = new DataSchema.ColumnDataType[_numColumns]; - int aggIdx = 0; - for (int i = 0; i < _numColumns; i++) { - if (i < _numGroupByExpressions) { - finalColumnDataTypes[i] = dataSchema.getColumnDataType(i); - } else { - finalColumnDataTypes[i] = _aggregationFunctions[aggIdx].getFinalResultColumnType(); - aggIdx++; - } + private DataSchema getPrePostAggregationDataSchema(DataSchema dataSchema) { + String[] columnNames = dataSchema.getColumnNames(); + ColumnDataType[] columnDataTypes = new ColumnDataType[_numColumns]; + System.arraycopy(dataSchema.getColumnDataTypes(), 0, columnDataTypes, 0, _numGroupByExpressions); + for (int i = 0; i < _numAggregationFunctions; i++) { + columnDataTypes[i + _numGroupByExpressions] = _aggregationFunctions[i].getFinalResultColumnType(); } - return new DataSchema(columns, finalColumnDataTypes); + return new DataSchema(columnNames, columnDataTypes); } private IndexedTable getIndexedTable(DataSchema dataSchema, Collection<DataTable> dataTables) { @@ -268,7 +214,7 @@ public class GroupByDataTableReducer implements DataTableReducer { for (DataTable dataTable : dataTables) { BiFunction[] functions = new BiFunction[_numColumns]; for (int i = 0; i < _numColumns; i++) { - DataSchema.ColumnDataType columnDataType = dataSchema.getColumnDataType(i); + ColumnDataType columnDataType = dataSchema.getColumnDataType(i); BiFunction<Integer, Integer, Object> function; switch (columnDataType) { case INT: @@ -394,10 +340,10 @@ public class GroupByDataTableReducer implements DataTableReducer { */ private DataSchema getPQLResultTableSchema(AggregationFunction aggregationFunction) { String[] columnNames = new String[_numColumns]; - DataSchema.ColumnDataType[] columnDataTypes = new DataSchema.ColumnDataType[_numColumns]; + ColumnDataType[] columnDataTypes = new ColumnDataType[_numColumns]; for (int i = 0; i < _numGroupByExpressions; i++) { columnNames[i] = _groupByExpressions.get(i).toString(); - columnDataTypes[i] = DataSchema.ColumnDataType.STRING; + columnDataTypes[i] = ColumnDataType.STRING; } columnNames[_numGroupByExpressions] = aggregationFunction.getResultColumnName(); columnDataTypes[_numGroupByExpressions] = aggregationFunction.getFinalResultColumnType(); diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java new file mode 100644 index 0000000..fdf3cd9 --- /dev/null +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/PostAggregationHandler.java @@ -0,0 +1,243 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.reduce; + +import com.google.common.base.Preconditions; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.core.query.postaggregation.PostAggregationFunction; +import org.apache.pinot.core.query.request.context.ExpressionContext; +import org.apache.pinot.core.query.request.context.FunctionContext; +import org.apache.pinot.core.query.request.context.QueryContext; + + +/** + * The {@code PostAggregationHandler} handles the post-aggregation calculation as well as the column re-ordering for the + * aggregation result. + */ +public class PostAggregationHandler { + private final Map<FunctionContext, Integer> _aggregationFunctionIndexMap; + private final int _numGroupByExpressions; + private final Map<ExpressionContext, Integer> _groupByExpressionIndexMap; + private final DataSchema _dataSchema; + private final ValueExtractor[] _valueExtractors; + private final DataSchema _resultDataSchema; + + public PostAggregationHandler(QueryContext queryContext, DataSchema dataSchema) { + _aggregationFunctionIndexMap = queryContext.getAggregationFunctionIndexMap(); + assert _aggregationFunctionIndexMap != null; + List<ExpressionContext> groupByExpressions = queryContext.getGroupByExpressions(); + if (groupByExpressions != null) { + _numGroupByExpressions = groupByExpressions.size(); + _groupByExpressionIndexMap = new HashMap<>(); + for (int i = 0; i < _numGroupByExpressions; i++) { + _groupByExpressionIndexMap.put(groupByExpressions.get(i), i); + } + } else { + _numGroupByExpressions = 0; + _groupByExpressionIndexMap = null; + } + + // NOTE: The data schema will always have group-by expressions in the front, followed by aggregation functions of + // the same order as in the query context. This is handled in AggregationGroupByOrderByOperator. + _dataSchema = dataSchema; + + List<ExpressionContext> selectExpressions = queryContext.getSelectExpressions(); + int numSelectExpressions = selectExpressions.size(); + _valueExtractors = new ValueExtractor[numSelectExpressions]; + String[] columnNames = new String[numSelectExpressions]; + ColumnDataType[] columnDataTypes = new ColumnDataType[numSelectExpressions]; + for (int i = 0; i < numSelectExpressions; i++) { + ValueExtractor valueExtractor = getValueExtractor(selectExpressions.get(i)); + _valueExtractors[i] = valueExtractor; + columnNames[i] = valueExtractor.getColumnName(); + columnDataTypes[i] = valueExtractor.getColumnDataType(); + } + _resultDataSchema = new DataSchema(columnNames, columnDataTypes); + } + + /** + * Returns the DataSchema of the post-aggregation result. + */ + public DataSchema getResultDataSchema() { + return _resultDataSchema; + } + + /** + * Returns the post-aggregation result for the given row. + */ + public Object[] getResult(Object[] row) { + int numValues = _valueExtractors.length; + Object[] result = new Object[numValues]; + for (int i = 0; i < numValues; i++) { + result[i] = _valueExtractors[i].extract(row); + } + return result; + } + + /** + * Returns a ValueExtractor based on the given expression. + */ + public ValueExtractor getValueExtractor(ExpressionContext expression) { + if (expression.getType() == ExpressionContext.Type.LITERAL) { + // Literal + return new LiteralValueExtractor(expression.getLiteral()); + } + if (_numGroupByExpressions > 0) { + Integer groupByExpressionIndex = _groupByExpressionIndexMap.get(expression); + if (groupByExpressionIndex != null) { + // Group-by expression + return new ColumnValueExtractor(groupByExpressionIndex); + } + } + FunctionContext function = expression.getFunction(); + Preconditions + .checkState(function != null, "Failed to find SELECT expression: %s in the GROUP-BY clause", expression); + if (function.getType() == FunctionContext.Type.AGGREGATION) { + // Aggregation function + return new ColumnValueExtractor(_aggregationFunctionIndexMap.get(function) + _numGroupByExpressions); + } else { + // Post-aggregation function + return new PostAggregationValueExtractor(function); + } + } + + /** + * Value extractor for the post-aggregation function. + */ + public interface ValueExtractor { + + /** + * Returns the column name for the value extracted. + */ + String getColumnName(); + + /** + * Returns the ColumnDataType of the value extracted. + */ + ColumnDataType getColumnDataType(); + + /** + * Extracts the value from the given row. + */ + Object extract(Object[] row); + } + + /** + * Value extractor for a literal. + */ + private static class LiteralValueExtractor implements ValueExtractor { + final String _literal; + + LiteralValueExtractor(String literal) { + _literal = literal; + } + + @Override + public String getColumnName() { + return '\'' + _literal + '\''; + } + + @Override + public ColumnDataType getColumnDataType() { + return ColumnDataType.STRING; + } + + @Override + public Object extract(Object[] row) { + return _literal; + } + } + + /** + * Value extractor for a non-post-aggregation column (group-by expression or aggregation). + */ + private class ColumnValueExtractor implements ValueExtractor { + final int _index; + + ColumnValueExtractor(int index) { + _index = index; + } + + @Override + public String getColumnName() { + return _dataSchema.getColumnName(_index); + } + + @Override + public ColumnDataType getColumnDataType() { + return _dataSchema.getColumnDataType(_index); + } + + @Override + public Object extract(Object[] row) { + return row[_index]; + } + } + + /** + * Value extractor for a post-aggregation column. + */ + private class PostAggregationValueExtractor implements ValueExtractor { + final FunctionContext _function; + final Object[] _arguments; + final ValueExtractor[] _argumentExtractors; + final PostAggregationFunction _postAggregationFunction; + + PostAggregationValueExtractor(FunctionContext function) { + assert function.getType() == FunctionContext.Type.TRANSFORM; + + _function = function; + List<ExpressionContext> arguments = function.getArguments(); + int numArguments = arguments.size(); + _arguments = new Object[numArguments]; + _argumentExtractors = new ValueExtractor[numArguments]; + ColumnDataType[] argumentTypes = new ColumnDataType[numArguments]; + for (int i = 0; i < numArguments; i++) { + ExpressionContext argument = arguments.get(i); + ValueExtractor argumentExtractor = getValueExtractor(argument); + _argumentExtractors[i] = argumentExtractor; + argumentTypes[i] = argumentExtractor.getColumnDataType(); + } + _postAggregationFunction = new PostAggregationFunction(function.getFunctionName(), argumentTypes); + } + + @Override + public String getColumnName() { + return _function.toString(); + } + + @Override + public ColumnDataType getColumnDataType() { + return _postAggregationFunction.getResultType(); + } + + @Override + public Object extract(Object[] row) { + int numArguments = _arguments.length; + for (int i = 0; i < numArguments; i++) { + _arguments[i] = _argumentExtractors[i].extract(row); + } + return _postAggregationFunction.invoke(_arguments); + } + } +} diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java index 66c6313..04f97d6 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/reduce/ResultReducerFactory.java @@ -46,11 +46,11 @@ public final class ResultReducerFactory { // Distinct query return new DistinctDataTableReducer(queryContext, (DistinctAggregationFunction) aggregationFunctions[0]); } else { - return new AggregationDataTableReducer(queryContext, aggregationFunctions); + return new AggregationDataTableReducer(queryContext); } } else { // Aggregation group-by query - return new GroupByDataTableReducer(queryContext, aggregationFunctions); + return new GroupByDataTableReducer(queryContext); } } } diff --git a/pinot-core/src/test/java/org/apache/pinot/core/query/reduce/PostAggregationHandlerTest.java b/pinot-core/src/test/java/org/apache/pinot/core/query/reduce/PostAggregationHandlerTest.java new file mode 100644 index 0000000..38da116 --- /dev/null +++ b/pinot-core/src/test/java/org/apache/pinot/core/query/reduce/PostAggregationHandlerTest.java @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.reduce; + +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.core.query.request.context.QueryContext; +import org.apache.pinot.core.query.request.context.utils.QueryContextConverterUtils; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + + +public class PostAggregationHandlerTest { + + @Test + public void testPostAggregation() { + // Regular aggregation only + { + QueryContext queryContext = QueryContextConverterUtils.getQueryContextFromSQL("SELECT SUM(m1) FROM testTable"); + DataSchema dataSchema = new DataSchema(new String[]{"sum(m1)"}, new ColumnDataType[]{ColumnDataType.DOUBLE}); + PostAggregationHandler handler = new PostAggregationHandler(queryContext, dataSchema); + assertEquals(handler.getResultDataSchema(), dataSchema); + assertEquals(handler.getResult(new Object[]{1.0}), new Object[]{1.0}); + assertEquals(handler.getResult(new Object[]{2.0}), new Object[]{2.0}); + } + + // Regular aggregation group-by + { + QueryContext queryContext = + QueryContextConverterUtils.getQueryContextFromSQL("SELECT d1, SUM(m1) FROM testTable GROUP BY d1"); + DataSchema dataSchema = new DataSchema(new String[]{"d1", "sum(m1)"}, + new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.DOUBLE}); + PostAggregationHandler handler = new PostAggregationHandler(queryContext, dataSchema); + assertEquals(handler.getResultDataSchema(), dataSchema); + assertEquals(handler.getResult(new Object[]{1, 2.0}), new Object[]{1, 2.0}); + assertEquals(handler.getResult(new Object[]{3, 4.0}), new Object[]{3, 4.0}); + } + + // Aggregation group-by with partial columns selected + { + QueryContext queryContext = + QueryContextConverterUtils.getQueryContextFromSQL("SELECT SUM(m1), d2 FROM testTable GROUP BY d1, d2"); + DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)"}, + new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG, ColumnDataType.DOUBLE}); + PostAggregationHandler handler = new PostAggregationHandler(queryContext, dataSchema); + DataSchema resultDataSchema = handler.getResultDataSchema(); + assertEquals(resultDataSchema.size(), 2); + assertEquals(resultDataSchema.getColumnNames(), new String[]{"sum(m1)", "d2"}); + assertEquals(resultDataSchema.getColumnDataTypes(), + new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.LONG}); + assertEquals(handler.getResult(new Object[]{1, 2L, 3.0}), new Object[]{3.0, 2L}); + assertEquals(handler.getResult(new Object[]{4, 5L, 6.0}), new Object[]{6.0, 5L}); + } + + // Aggregation group-by with order-by + { + QueryContext queryContext = QueryContextConverterUtils + .getQueryContextFromSQL("SELECT SUM(m1), d2 FROM testTable GROUP BY d1, d2 ORDER BY MAX(m1)"); + DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m1)"}, + new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE}); + PostAggregationHandler handler = new PostAggregationHandler(queryContext, dataSchema); + DataSchema resultDataSchema = handler.getResultDataSchema(); + assertEquals(resultDataSchema.size(), 2); + assertEquals(resultDataSchema.getColumnNames(), new String[]{"sum(m1)", "d2"}); + assertEquals(resultDataSchema.getColumnDataTypes(), + new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.LONG}); + assertEquals(handler.getResult(new Object[]{1, 2L, 3.0, 4.0}), new Object[]{3.0, 2L}); + assertEquals(handler.getResult(new Object[]{5, 6L, 7.0, 8.0}), new Object[]{7.0, 6L}); + } + + // Post aggregation + { + QueryContext queryContext = + QueryContextConverterUtils.getQueryContextFromSQL("SELECT SUM(m1) + MAX(m2) FROM testTable"); + DataSchema dataSchema = new DataSchema(new String[]{"sum(m1)", "max(m2)"}, + new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.DOUBLE}); + PostAggregationHandler handler = new PostAggregationHandler(queryContext, dataSchema); + DataSchema resultDataSchema = handler.getResultDataSchema(); + assertEquals(resultDataSchema.size(), 1); + assertEquals(resultDataSchema.getColumnName(0), "plus(sum(m1),max(m2))"); + assertEquals(resultDataSchema.getColumnDataType(0), ColumnDataType.DOUBLE); + assertEquals(handler.getResult(new Object[]{1.0, 2.0}), new Object[]{3.0}); + assertEquals(handler.getResult(new Object[]{3.0, 4.0}), new Object[]{7.0}); + } + + // Post aggregation with group-by and order-by + { + QueryContext queryContext = QueryContextConverterUtils.getQueryContextFromSQL( + "SELECT (SUM(m1) + MAX(m2) - d1) / 2, d2 FROM testTable GROUP BY d1, d2 ORDER BY MAX(m1)"); + DataSchema dataSchema = new DataSchema(new String[]{"d1", "d2", "sum(m1)", "max(m2)", "max(m1)"}, + new ColumnDataType[]{ColumnDataType.INT, ColumnDataType.LONG, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE, ColumnDataType.DOUBLE}); + PostAggregationHandler handler = new PostAggregationHandler(queryContext, dataSchema); + DataSchema resultDataSchema = handler.getResultDataSchema(); + assertEquals(resultDataSchema.size(), 2); + assertEquals(resultDataSchema.getColumnNames(), + new String[]{"divide(minus(plus(sum(m1),max(m2)),d1),'2')", "d2"}); + assertEquals(resultDataSchema.getColumnDataTypes(), + new ColumnDataType[]{ColumnDataType.DOUBLE, ColumnDataType.LONG}); + assertEquals(handler.getResult(new Object[]{1, 2L, 3.0, 4.0, 5.0}), new Object[]{3.0, 2L}); + assertEquals(handler.getResult(new Object[]{6, 7L, 8.0, 9.0, 10.0}), new Object[]{5.5, 7L}); + } + } +} diff --git a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java index d02e19d..53e85ce 100644 --- a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java +++ b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/BaseClusterIntegrationTestSet.java @@ -231,6 +231,16 @@ public abstract class BaseClusterIntegrationTestSet extends BaseClusterIntegrati testSqlQuery(query, Collections.singletonList(query)); query = "SELECT MAX(ArrDelay), Month FROM mytable GROUP BY Month ORDER BY ABS(Month - 6) + MAX(ArrDelay)"; testSqlQuery(query, Collections.singletonList(query)); + + // Post-aggregation in SELECT + query = "SELECT MAX(ArrDelay) + MAX(AirTime) FROM mytable"; + testSqlQuery(query, Collections.singletonList(query)); + query = + "SELECT MAX(ArrDelay) - MAX(AirTime), DaysSinceEpoch FROM mytable GROUP BY DaysSinceEpoch ORDER BY MAX(ArrDelay) - MIN(AirTime) DESC"; + testSqlQuery(query, Collections.singletonList(query)); + query = + "SELECT DaysSinceEpoch, MAX(ArrDelay) * 2 - MAX(AirTime) - 3 FROM mytable GROUP BY DaysSinceEpoch ORDER BY MAX(ArrDelay) - MIN(AirTime) DESC"; + testSqlQuery(query, Collections.singletonList(query)); } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org