jasperjiaguo commented on code in PR #10636: URL: https://github.com/apache/pinot/pull/10636#discussion_r1184412511
########## pinot-core/src/main/java/org/apache/pinot/core/query/utils/rewriter/ParentAggregationResultRewriter.java: ########## @@ -0,0 +1,206 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.core.query.utils.rewriter; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.pinot.common.utils.DataSchema; +import org.apache.pinot.core.query.aggregation.utils.ParentAggregationFunctionResultObject; +import org.apache.pinot.spi.utils.CommonConstants; + + +/** + * Use the result of parent aggregation functions to populate the result of child aggregation functions. + * This implementation is based on the column names of the result schema. + * The result column name of a parent aggregation function has the following format: + * CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX + aggregationFunctionType + FunctionID + * The result column name of corresponding child aggregation function has the following format: + * aggregationFunctionType + FunctionID + CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX + * + childFunctionKey + * This approach will not work with `AS` clauses as they alter the column names. + * TODO: Add support for `AS` clauses. + */ +public class ParentAggregationResultRewriter implements ResultRewriter { + public ParentAggregationResultRewriter() { + } + + public static Map<String, ChildFunctionMapping> createChildFunctionMapping(DataSchema schema, Object[] row) { + Map<String, ChildFunctionMapping> childFunctionMapping = new HashMap<>(); + for (int i = 0; i < schema.size(); i++) { + String columnName = schema.getColumnName(i); + if (columnName.startsWith(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX)) { + ParentAggregationFunctionResultObject parent = (ParentAggregationFunctionResultObject) row[i]; + + DataSchema nestedSchema = parent.getSchema(); + for (int j = 0; j < nestedSchema.size(); j++) { + String childColumnKey = nestedSchema.getColumnName(j); + String originalChildFunctionKey = + columnName.substring(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX.length()) + + CommonConstants.RewriterConstants.CHILD_KEY_SEPERATOR + childColumnKey; + // aggregationFunctionType + childFunctionID + CHILD_KEY_SEPERATOR + childFunctionKeyInParent + childFunctionMapping.put(originalChildFunctionKey, new ChildFunctionMapping(parent, j, i)); + } + } + } + return childFunctionMapping; + } + + public RewriterResult rewrite(DataSchema dataSchema, List<Object[]> rows) { + int numParentAggregationFunctions = 0; + // Count the number of parent aggregation functions + for (int i = 0; i < dataSchema.size(); i++) { + if (dataSchema.getColumnName(i).startsWith(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX)) { + numParentAggregationFunctions++; + } + } + + if (numParentAggregationFunctions == 0 || rows.isEmpty()) { + // no change to the result + return new RewriterResult(dataSchema, rows); + } + + // Create a mapping from the child aggregation function name to the child aggregation function + Map<String, ChildFunctionMapping> childFunctionMapping = createChildFunctionMapping(dataSchema, rows.get(0)); + + String[] newColumnNames = new String[dataSchema.size() - numParentAggregationFunctions]; + DataSchema.ColumnDataType[] newColumnDataTypes + = new DataSchema.ColumnDataType[dataSchema.size() - numParentAggregationFunctions]; + + // Create a mapping from the function offset in the final aggregation result + // to its own/parent function offset in the original aggregation result + Map<Integer, Integer> aggregationFunctionIndexMapping = new HashMap<>(); + // Create a set of the result indices of the child aggregation functions + Set<Integer> childAggregationFunctionIndices = new HashSet<>(); + // Create a mapping from the result aggregation function index to the nested index of the + // child aggregation function in the parent aggregation function + Map<Integer, Integer> childAggregationFunctionNestedIndexMapping = new HashMap<>(); + // Create a set of the result indices of the parent aggregation functions + Set<Integer> parentAggregationFunctionIndices = new HashSet<>(); + + for (int i = 0, j = 0; i < dataSchema.size(); i++) { + String columnName = dataSchema.getColumnName(i); + // Skip the parent aggregation functions + if (columnName.startsWith(CommonConstants.RewriterConstants.PARENT_AGGREGATION_NAME_PREFIX)) { + parentAggregationFunctionIndices.add(i); + continue; + } + + // for child aggregation functions and regular columns in the result + // create a new schema and populate the new column names and data types + // also populate the offset mappings used to rewrite the result + if (columnName.startsWith(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX)) { + // This is a child column of a parent aggregation function + String childAggregationFunctionNameWithKey = + columnName.substring(CommonConstants.RewriterConstants.CHILD_AGGREGATION_NAME_PREFIX.length()); + String[] s = childAggregationFunctionNameWithKey + .split(CommonConstants.RewriterConstants.CHILD_AGGREGATION_SEPERATOR); + newColumnNames[j] = s[0]; + ChildFunctionMapping childFunction = childFunctionMapping.get(s[1]); + newColumnDataTypes[j] = childFunction.getParent().getSchema() + .getColumnDataType(childFunction.getNestedOffset()); + + childAggregationFunctionNestedIndexMapping.put(j, childFunction.getNestedOffset()); + childAggregationFunctionIndices.add(j); + aggregationFunctionIndexMapping.put(j, childFunction.getOffset()); + } else { + // This is a regular column + newColumnNames[j] = columnName; + newColumnDataTypes[j] = dataSchema.getColumnDataType(i); + + aggregationFunctionIndexMapping.put(j, i); + } + j++; + } + + DataSchema newDataSchema = new DataSchema(newColumnNames, newColumnDataTypes); + List<Object[]> newRows = new ArrayList<>(); + + for (Object[] row : rows) { + int maxRows = parentAggregationFunctionIndices.stream().map(k -> { + ParentAggregationFunctionResultObject parentAggregationFunctionResultObject = + (ParentAggregationFunctionResultObject) row[k]; + return parentAggregationFunctionResultObject.getNumberOfRows(); + }).max(Integer::compareTo).orElse(0); + maxRows = maxRows == 0 ? 1 : maxRows; + + List<Object[]> newRowsBuffer = new ArrayList<>(); + for (int rowIter = 0; rowIter < maxRows; rowIter++) { + Object[] newRow = new Object[newDataSchema.size()]; + for (int fieldIter = 0; fieldIter < newDataSchema.size(); fieldIter++) { + // If the field is a child aggregation function, extract the value from the parent result + if (childAggregationFunctionIndices.contains(fieldIter)) { + int offset = aggregationFunctionIndexMapping.get(fieldIter); + int nestedOffset = childAggregationFunctionNestedIndexMapping.get(fieldIter); + ParentAggregationFunctionResultObject parentAggregationFunctionResultObject = + (ParentAggregationFunctionResultObject) row[offset]; + // If the parent result has more rows than the current row, extract the value from the row + if (rowIter < parentAggregationFunctionResultObject.getNumberOfRows()) { + newRow[fieldIter] = parentAggregationFunctionResultObject.getField(rowIter, nestedOffset); + } else { + newRow[fieldIter] = null; + } + } else { // If the field is a regular column, extract the value from the row, only the first row has value + if (rowIter == 0) { + newRow[fieldIter] = row[aggregationFunctionIndexMapping.get(fieldIter)]; + } else { + newRow[fieldIter] = null; + } + } + } + newRowsBuffer.add(newRow); + } + newRows.addAll(newRowsBuffer); + } + return new RewriterResult(newDataSchema, newRows); + } + + /** + * Mapping from child function key to the + * parent result object, + * offset of the parent result column in original result row, + * and the nested offset of the child function result in the parent data block + */ Review Comment: added one example -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org