This is an automated email from the ASF dual-hosted git repository. huajianlan pushed a commit to branch nested_column_prune in repository https://gitbox.apache.org/repos/asf/doris.git
commit e2f2c3e091e48cf8752a161c8568d4a03f1383be Author: 924060929 <[email protected]> AuthorDate: Tue Oct 14 17:21:24 2025 +0800 support prune nested column in fe --- .../org/apache/doris/datasource/FileScanNode.java | 2 + .../org/apache/doris/nereids/StatementContext.java | 11 - .../glue/translator/PhysicalPlanTranslator.java | 19 - .../glue/translator/PlanTranslatorContext.java | 5 + .../doris/nereids/jobs/executor/Rewriter.java | 4 +- .../doris/nereids/rules/analysis/BindRelation.java | 9 +- .../nereids/rules/analysis/ExpressionAnalyzer.java | 7 + .../rewrite/AccessPathExpressionCollector.java | 444 +++++++++++++++++ .../rules/rewrite/AccessPathPlanCollector.java | 161 ++++++ ...lumnCollector.java => NestedColumnPruning.java} | 285 ++--------- .../nereids/rules/rewrite/SlotTypeReplacer.java | 541 +++++++++++++++++++++ .../rules/rewrite/VariantSubPathPruning.java | 3 +- .../nereids/trees/expressions/SlotReference.java | 30 +- .../expressions/functions/scalar/ArrayFirst.java | 6 + .../expressions/functions/scalar/ArrayLast.java | 6 + .../expressions/visitor/ScalarFunctionVisitor.java | 10 + .../trees/plans/logical/LogicalFileScan.java | 38 +- .../trees/plans/logical/LogicalHudiScan.java | 22 +- .../trees/plans/logical/LogicalOlapScan.java | 16 + .../apache/doris/nereids/types/VariantType.java | 2 +- .../org/apache/doris/planner/OlapScanNode.java | 3 + .../java/org/apache/doris/planner/ScanNode.java | 50 ++ .../java/org/apache/doris/qe/SessionVariable.java | 9 + ...estedColumn.java => PruneNestedColumnTest.java} | 142 +++++- gensrc/thrift/Descriptors.thrift | 92 ++-- 25 files changed, 1568 insertions(+), 349 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/datasource/FileScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/datasource/FileScanNode.java index 8cfe183ebf0..8694ebcbc6c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/datasource/FileScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/datasource/FileScanNode.java @@ -171,6 +171,8 @@ public abstract class FileScanNode extends ExternalScanNode { } output.append(String.format("numNodes=%s", numNodes)).append("\n"); + printNestedColumns(output, prefix); + // pushdown agg output.append(prefix).append(String.format("pushdown agg=%s", pushDownAggNoGroupingOp)); if (pushDownAggNoGroupingOp.equals(TPushAggOp.COUNT)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index e8d4b28ea83..2539206925c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -17,7 +17,6 @@ package org.apache.doris.nereids; -import org.apache.doris.analysis.AccessPathInfo; import org.apache.doris.analysis.StatementBase; import org.apache.doris.analysis.TableScanParams; import org.apache.doris.analysis.TableSnapshot; @@ -278,8 +277,6 @@ public class StatementContext implements Closeable { private boolean hasNestedColumns; - private Map<Integer, AccessPathInfo> slotIdToAcessPathInfo = new HashMap<>(); - public StatementContext() { this(ConnectContext.get(), null, 0); } @@ -1007,12 +1004,4 @@ public class StatementContext implements Closeable { public void setHasNestedColumns(boolean hasNestedColumns) { this.hasNestedColumns = hasNestedColumns; } - - public void setSlotIdToAccessPathInfo(int slotId, AccessPathInfo accessPathInfo) { - this.slotIdToAcessPathInfo.put(slotId, accessPathInfo); - } - - public Optional<AccessPathInfo> getAccessPathInfo(Slot slot) { - return Optional.ofNullable(this.slotIdToAcessPathInfo.get(slot.getExprId().asInt())); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 2e8b26b41a5..de8448f4880 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -17,7 +17,6 @@ package org.apache.doris.nereids.glue.translator; -import org.apache.doris.analysis.AccessPathInfo; import org.apache.doris.analysis.AggregateInfo; import org.apache.doris.analysis.AnalyticWindow; import org.apache.doris.analysis.BinaryPredicate; @@ -75,7 +74,6 @@ import org.apache.doris.fs.TransactionScopeCachingDirectoryListerFactory; import org.apache.doris.info.BaseTableRefInfo; import org.apache.doris.info.TableNameInfo; import org.apache.doris.info.TableRefInfo; -import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.processor.post.runtimefilterv2.RuntimeFilterV2; import org.apache.doris.nereids.properties.DistributionSpec; @@ -174,7 +172,6 @@ import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.JsonType; import org.apache.doris.nereids.types.MapType; -import org.apache.doris.nereids.types.NestedColumnPrunable; import org.apache.doris.nereids.types.StructType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; @@ -249,7 +246,6 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Optional; import java.util.Set; import java.util.TreeMap; import java.util.concurrent.atomic.AtomicBoolean; @@ -830,22 +826,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla OlapTable olapTable = olapScan.getTable(); // generate real output tuple TupleDescriptor tupleDescriptor = generateTupleDesc(slots, olapTable, context); - StatementContext statementContext = context.getStatementContext(); - List<SlotDescriptor> slotDescriptors = tupleDescriptor.getSlots(); - for (int i = 0; i < slots.size(); i++) { - Slot slot = slots.get(i); - if (slot.getDataType() instanceof NestedColumnPrunable) { - Optional<AccessPathInfo> accessPathInfo = statementContext.getAccessPathInfo(slot); - if (accessPathInfo.isPresent()) { - SlotDescriptor slotDescriptor = slotDescriptors.get(i); - AccessPathInfo accessPath = accessPathInfo.get(); - slotDescriptor.setType(accessPath.getPrunedType().toCatalogDataType()); - slotDescriptor.setAllAccessPaths(accessPath.getAllAccessPaths()); - slotDescriptor.setPredicateAccessPaths(accessPath.getPredicateAccessPaths()); - } - } - } // put virtual column expr into slot desc Map<ExprId, Expression> slotToVirtualColumnMap = olapScan.getSlotToVirtualColumnMap(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java index 1dd79a033cd..0f303b88b55 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PlanTranslatorContext.java @@ -336,6 +336,11 @@ public class PlanTranslatorContext { } this.addExprIdSlotRefPair(slotReference.getExprId(), slotRef); slotDescriptor.setIsNullable(slotReference.nullable()); + + if (slotReference.getAllAccessPaths().isPresent()) { + slotDescriptor.setAllAccessPaths(slotReference.getAllAccessPaths().get()); + slotDescriptor.setPredicateAccessPaths(slotReference.getPredicateAccessPaths().get()); + } return slotDescriptor; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 820b6ec0453..eac528bc152 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -109,7 +109,7 @@ import org.apache.doris.nereids.rules.rewrite.MergeProjectable; import org.apache.doris.nereids.rules.rewrite.MergeSetOperations; import org.apache.doris.nereids.rules.rewrite.MergeSetOperationsExcept; import org.apache.doris.nereids.rules.rewrite.MergeTopNs; -import org.apache.doris.nereids.rules.rewrite.NestedColumnCollector; +import org.apache.doris.nereids.rules.rewrite.NestedColumnPruning; import org.apache.doris.nereids.rules.rewrite.NormalizeSort; import org.apache.doris.nereids.rules.rewrite.OperativeColumnDerive; import org.apache.doris.nereids.rules.rewrite.OrExpansion; @@ -909,7 +909,7 @@ public class Rewriter extends AbstractBatchJobExecutor { } rewriteJobs.add( topic("nested column prune", - custom(RuleType.NESTED_COLUMN_PRUNING, NestedColumnCollector::new) + custom(RuleType.NESTED_COLUMN_PRUNING, NestedColumnPruning::new) ) ); rewriteJobs.addAll(jobs( diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java index 84652b9eed0..36b7f0064ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java @@ -424,7 +424,8 @@ public class BindRelation extends OneAnalysisRuleFactory { if (hmsTable.getDlaType() == DLAType.HUDI) { LogicalHudiScan hudiScan = new LogicalHudiScan(unboundRelation.getRelationId(), hmsTable, qualifierWithoutTableName, ImmutableList.of(), Optional.empty(), - unboundRelation.getTableSample(), unboundRelation.getTableSnapshot()); + unboundRelation.getTableSample(), unboundRelation.getTableSnapshot(), + Optional.empty()); hudiScan = hudiScan.withScanParams( hmsTable, Optional.ofNullable(unboundRelation.getScanParams())); return hudiScan; @@ -434,7 +435,7 @@ public class BindRelation extends OneAnalysisRuleFactory { ImmutableList.of(), unboundRelation.getTableSample(), unboundRelation.getTableSnapshot(), - Optional.ofNullable(unboundRelation.getScanParams())); + Optional.ofNullable(unboundRelation.getScanParams()), Optional.empty()); } case ICEBERG_EXTERNAL_TABLE: IcebergExternalTable icebergExternalTable = (IcebergExternalTable) table; @@ -464,7 +465,7 @@ public class BindRelation extends OneAnalysisRuleFactory { qualifierWithoutTableName, ImmutableList.of(), unboundRelation.getTableSample(), unboundRelation.getTableSnapshot(), - Optional.ofNullable(unboundRelation.getScanParams())); + Optional.ofNullable(unboundRelation.getScanParams()), Optional.empty()); case PAIMON_EXTERNAL_TABLE: case MAX_COMPUTE_EXTERNAL_TABLE: case TRINO_CONNECTOR_EXTERNAL_TABLE: @@ -473,7 +474,7 @@ public class BindRelation extends OneAnalysisRuleFactory { qualifierWithoutTableName, ImmutableList.of(), unboundRelation.getTableSample(), unboundRelation.getTableSnapshot(), - Optional.ofNullable(unboundRelation.getScanParams())); + Optional.ofNullable(unboundRelation.getScanParams()), Optional.empty()); case SCHEMA: LogicalSchemaScan schemaScan = new LogicalSchemaScan(unboundRelation.getRelationId(), table, qualifierWithoutTableName); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 4b6e9b6e09b..43a1e76751e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -93,6 +93,7 @@ import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.NestedColumnPrunable; import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -275,6 +276,9 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext } outerScope.get().getCorrelatedSlots().add((Slot) firstBound); } + if (firstBound.getDataType() instanceof NestedColumnPrunable) { + context.cascadesContext.getStatementContext().setHasNestedColumns(true); + } return firstBound; default: if (enableExactMatch) { @@ -294,6 +298,9 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext .filter(bound -> unboundSlot.getNameParts().size() == bound.getQualifier().size() + 1) .collect(Collectors.toList()); if (exactMatch.size() == 1) { + if (exactMatch.get(0).getDataType() instanceof NestedColumnPrunable) { + context.cascadesContext.getStatementContext().setHasNestedColumns(true); + } return exactMatch.get(0); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java new file mode 100644 index 00000000000..7caa3e59c45 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathExpressionCollector.java @@ -0,0 +1,444 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectorContext; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference.ArrayItemSlot; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCount; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirst; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirstIndex; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLast; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLastIndex; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAll; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAny; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSplit; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySplit; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MapKeys; +import org.apache.doris.nereids.trees.expressions.functions.scalar.MapValues; +import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.NestedColumnPrunable; +import org.apache.doris.nereids.util.Utils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; +import com.google.common.collect.Multimap; + +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Stack; + +/** + * collect the access path, for example: `select struct_element(s, 'data')` has access path: ['s', 'data'] + */ +public class AccessPathExpressionCollector extends DefaultExpressionVisitor<Void, CollectorContext> { + private StatementContext statementContext; + private boolean bottomPredicate; + private Multimap<Integer, CollectAccessPathResult> slotToAccessPaths; + private Stack<Map<String, Expression>> nameToLambdaArguments = new Stack<>(); + + public AccessPathExpressionCollector( + StatementContext statementContext, Multimap<Integer, CollectAccessPathResult> slotToAccessPaths, + boolean bottomPredicate) { + this.statementContext = statementContext; + this.slotToAccessPaths = slotToAccessPaths; + this.bottomPredicate = bottomPredicate; + } + + public void collect(Expression expression) { + expression.accept(this, new CollectorContext(statementContext, bottomPredicate)); + } + + private Void continueCollectAccessPath(Expression expr, CollectorContext context) { + return expr.accept(this, context); + } + + @Override + public Void visit(Expression expr, CollectorContext context) { + for (Expression child : expr.children()) { + child.accept(this, new CollectorContext(context.statementContext, context.bottomFilter)); + } + return null; + } + + @Override + public Void visitSlotReference(SlotReference slotReference, CollectorContext context) { + DataType dataType = slotReference.getDataType(); + if (dataType instanceof NestedColumnPrunable) { + context.accessPathBuilder.addPrefix(slotReference.getName()); + ImmutableList<String> path = Utils.fastToImmutableList(context.accessPathBuilder.accessPath); + int slotId = slotReference.getExprId().asInt(); + slotToAccessPaths.put(slotId, new CollectAccessPathResult(path, context.bottomFilter)); + } + return null; + } + + @Override + public Void visitArrayItemSlot(ArrayItemSlot arrayItemSlot, CollectorContext context) { + if (nameToLambdaArguments.isEmpty()) { + return null; + } + context.accessPathBuilder.addPrefix("*"); + Expression argument = nameToLambdaArguments.peek().get(arrayItemSlot.getName()); + if (argument == null) { + return null; + } + return continueCollectAccessPath(argument, context); + } + + @Override + public Void visitAlias(Alias alias, CollectorContext context) { + return alias.child(0).accept(this, context); + } + + @Override + public Void visitCast(Cast cast, CollectorContext context) { + return cast.child(0).accept(this, context); + } + + // array element at + @Override + public Void visitElementAt(ElementAt elementAt, CollectorContext context) { + List<Expression> arguments = elementAt.getArguments(); + Expression first = arguments.get(0); + if (first.getDataType().isArrayType() || first.getDataType().isMapType()) { + context.accessPathBuilder.addPrefix("*"); + continueCollectAccessPath(first, context); + + for (int i = 1; i < arguments.size(); i++) { + visit(arguments.get(i), context); + } + return null; + } else { + return visit(elementAt, context); + } + } + + // struct element_at + @Override + public Void visitStructElement(StructElement structElement, CollectorContext context) { + List<Expression> arguments = structElement.getArguments(); + Expression struct = arguments.get(0); + Expression fieldName = arguments.get(1); + DataType fieldType = fieldName.getDataType(); + + if (fieldName.isLiteral() && (fieldType.isIntegerLikeType() || fieldType.isStringLikeType())) { + context.accessPathBuilder.addPrefix(((Literal) fieldName).getStringValue()); + return continueCollectAccessPath(struct, context); + } + + for (Expression argument : arguments) { + visit(argument, context); + } + return null; + } + + @Override + public Void visitMapKeys(MapKeys mapKeys, CollectorContext context) { + context.accessPathBuilder.addPrefix("KEYS"); + return continueCollectAccessPath(mapKeys.getArgument(0), context); + } + + @Override + public Void visitMapValues(MapValues mapValues, CollectorContext context) { + LinkedList<String> suffixPath = context.accessPathBuilder.accessPath; + if (!suffixPath.isEmpty() && suffixPath.get(0).equals("*")) { + CollectorContext removeStarContext + = new CollectorContext(context.statementContext, context.bottomFilter); + removeStarContext.accessPathBuilder.accessPath.addAll(suffixPath.subList(1, suffixPath.size())); + removeStarContext.accessPathBuilder.addPrefix("VALUES"); + return continueCollectAccessPath(mapValues.getArgument(0), removeStarContext); + } + context.accessPathBuilder.addPrefix("VALUES"); + return continueCollectAccessPath(mapValues.getArgument(0), context); + } + + @Override + public Void visitMapContainsKey(MapContainsKey mapContainsKey, CollectorContext context) { + context.accessPathBuilder.addPrefix("KEYS"); + return continueCollectAccessPath(mapContainsKey.getArgument(0), context); + } + + @Override + public Void visitMapContainsValue(MapContainsValue mapContainsValue, CollectorContext context) { + context.accessPathBuilder.addPrefix("VALUES"); + return continueCollectAccessPath(mapContainsValue.getArgument(0), context); + } + + @Override + public Void visitArrayMap(ArrayMap arrayMap, CollectorContext context) { + // ARRAY_MAP(lambda, <arr> [ , <arr> ... ] ) + + Expression argument = arrayMap.getArgument(0); + if ((argument instanceof Lambda)) { + return collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayMap, context); + } + + @Override + public Void visitArrayCount(ArrayCount arrayCount, CollectorContext context) { + // ARRAY_COUNT(<lambda>, <arr>[, ... ]) + + Expression argument = arrayCount.getArgument(0); + if ((argument instanceof Lambda)) { + return collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayCount, context); + } + + @Override + public Void visitArrayExists(ArrayExists arrayExists, CollectorContext context) { + // ARRAY_EXISTS([ <lambda>, ] <arr1> [, <arr2> , ...] ) + + Expression argument = arrayExists.getArgument(0); + if ((argument instanceof Lambda)) { + return collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayExists, context); + } + + @Override + public Void visitArrayFilter(ArrayFilter arrayFilter, CollectorContext context) { + // ARRAY_FILTER(<lambda>, <arr>) + + Expression argument = arrayFilter.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayFilter, context); + } + + @Override + public Void visitArrayFirst(ArrayFirst arrayFirst, CollectorContext context) { + // ARRAY_FIRST(<lambda>, <arr>) + + Expression argument = arrayFirst.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayFirst, context); + } + + @Override + public Void visitArrayFirstIndex(ArrayFirstIndex arrayFirstIndex, CollectorContext context) { + // ARRAY_FIRST_INDEX(<lambda>, <arr> [, ...]) + + Expression argument = arrayFirstIndex.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayFirstIndex, context); + } + + @Override + public Void visitArrayLast(ArrayLast arrayLast, CollectorContext context) { + // ARRAY_LAST(<lambda>, <arr>) + + Expression argument = arrayLast.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayLast, context); + } + + @Override + public Void visitArrayLastIndex(ArrayLastIndex arrayLastIndex, CollectorContext context) { + // ARRAY_LAST_INDEX(<lambda>, <arr> [, ...]) + + Expression argument = arrayLastIndex.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayLastIndex, context); + } + + @Override + public Void visitArrayMatchAny(ArrayMatchAny arrayMatchAny, CollectorContext context) { + // array_match_any(lambda, <arr> [, <arr> ...]) + + Expression argument = arrayMatchAny.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayMatchAny, context); + } + + @Override + public Void visitArrayMatchAll(ArrayMatchAll arrayMatchAll, CollectorContext context) { + // array_match_all(lambda, <arr> [, <arr> ...]) + + Expression argument = arrayMatchAll.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayMatchAll, context); + } + + @Override + public Void visitArrayReverseSplit(ArrayReverseSplit arrayReverseSplit, CollectorContext context) { + // ARRAY_REVERSE_SPLIT(<lambda>, <arr> [, ...]) + + Expression argument = arrayReverseSplit.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arrayReverseSplit, context); + } + + @Override + public Void visitArraySplit(ArraySplit arraySplit, CollectorContext context) { + // ARRAY_SPLIT(<lambda>, arr [, ...]) + + Expression argument = arraySplit.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arraySplit, context); + } + + @Override + public Void visitArraySortBy(ArraySortBy arraySortBy, CollectorContext context) { + // ARRAY_SORTBY(<lambda>, <arr> [, ...]) + + Expression argument = arraySortBy.getArgument(0); + if ((argument instanceof Lambda)) { + collectArrayPathInLambda((Lambda) argument, context); + } + return visit(arraySortBy, context); + } + + private Void collectArrayPathInLambda(Lambda lambda, CollectorContext context) { + List<Expression> arguments = lambda.getArguments(); + Map<String, Expression> nameToArray = Maps.newLinkedHashMap(); + for (Expression argument : arguments) { + if (argument instanceof ArrayItemReference) { + nameToArray.put(((ArrayItemReference) argument).getName(), argument.child(0)); + } + } + + List<String> path = context.accessPathBuilder.getPathList(); + if (!path.isEmpty() && path.get(0).equals("*")) { + context.accessPathBuilder.removePrefix(); + } + + nameToLambdaArguments.push(nameToArray); + try { + continueCollectAccessPath(arguments.get(0), context); + } finally { + nameToLambdaArguments.pop(); + } + return null; + } + + /** CollectorContext */ + public static class CollectorContext { + private StatementContext statementContext; + private AccessPathBuilder accessPathBuilder; + private boolean bottomFilter; + + public CollectorContext(StatementContext statementContext, boolean bottomFilter) { + this.statementContext = statementContext; + this.accessPathBuilder = new AccessPathBuilder(); + this.bottomFilter = bottomFilter; + } + } + + private static class AccessPathBuilder { + private LinkedList<String> accessPath; + + public AccessPathBuilder() { + accessPath = new LinkedList<>(); + } + + public AccessPathBuilder addPrefix(String prefix) { + accessPath.addFirst(prefix); + return this; + } + + public AccessPathBuilder removePrefix() { + accessPath.removeFirst(); + return this; + } + + public List<String> getPathList() { + return accessPath; + } + + @Override + public String toString() { + return String.join(".", accessPath); + } + } + + /** AccessPathIsPredicate */ + public static class CollectAccessPathResult { + private final List<String> path; + private final boolean isPredicate; + + public CollectAccessPathResult(List<String> path, boolean isPredicate) { + this.path = path; + this.isPredicate = isPredicate; + } + + public List<String> getPath() { + return path; + } + + public boolean isPredicate() { + return isPredicate; + } + + @Override + public String toString() { + return String.join(".", path) + ", " + isPredicate; + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) { + return false; + } + CollectAccessPathResult that = (CollectAccessPathResult) o; + return isPredicate == that.isPredicate && Objects.equals(path, that.path); + } + + @Override + public int hashCode() { + return path.hashCode(); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java new file mode 100644 index 00000000000..21406c2ea13 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AccessPathPlanCollector.java @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectAccessPathResult; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor; +import org.apache.doris.nereids.types.NestedColumnPrunable; + +import com.google.common.collect.LinkedHashMultimap; +import com.google.common.collect.Multimap; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** AccessPathPlanCollector */ +public class AccessPathPlanCollector extends DefaultPlanVisitor<Void, StatementContext> { + private Multimap<Integer, CollectAccessPathResult> allSlotToAccessPaths = LinkedHashMultimap.create(); + private Map<Slot, List<CollectAccessPathResult>> scanSlotToAccessPaths = new LinkedHashMap<>(); + + public Map<Slot, List<CollectAccessPathResult>> collect(Plan root, StatementContext context) { + root.accept(this, context); + return scanSlotToAccessPaths; + } + + @Override + public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter, StatementContext context) { + boolean bottomFilter = filter.child().arity() == 0; + collectByExpressions(filter, context, bottomFilter); + return filter.child().accept(this, context); + } + + @Override + public Void visitLogicalCTEAnchor( + LogicalCTEAnchor<? extends Plan, ? extends Plan> cteAnchor, StatementContext context) { + // first, collect access paths in the outer slots, and propagate outer slot's access path to inner slots + cteAnchor.right().accept(this, context); + + // second, push down access path in the inner slots + cteAnchor.left().accept(this, context); + return null; + } + + @Override + public Void visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, StatementContext context) { + // propagate outer slot's access path to inner slots + for (Entry<Slot, Slot> slots : cteConsumer.getConsumerToProducerOutputMap().entrySet()) { + Slot outerSlot = slots.getKey(); + + if (outerSlot.getDataType() instanceof NestedColumnPrunable) { + int outerSlotId = outerSlot.getExprId().asInt(); + int innerSlotId = slots.getValue().getExprId().asInt(); + allSlotToAccessPaths.putAll(innerSlotId, allSlotToAccessPaths.get(outerSlotId)); + } + } + return null; + } + + @Override + public Void visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, StatementContext context) { + return cteProducer.child().accept(this, context); + } + + @Override + public Void visitLogicalUnion(LogicalUnion union, StatementContext context) { + // now we will not prune complex type through union, because we can not prune the complex type's literal, + // for example, we can not prune the literal now: array(map(1, named_struct('a', 100, 'b', 100))), + // so we can not prune this sql: + // select struct_element(map_values(s[0]), 'a') + // from ( + // select s from tbl + // union all + // select array(map(1, named_struct('a', 100, 'b', 100))) + // ) tbl; + // + // so we will not propagate access paths through the union + for (Plan child : union.children()) { + child.accept(this, context); + } + return null; + } + + @Override + public Void visitLogicalOlapScan(LogicalOlapScan olapScan, StatementContext context) { + for (Slot slot : olapScan.getOutput()) { + if (!(slot.getDataType() instanceof NestedColumnPrunable)) { + continue; + } + Collection<CollectAccessPathResult> accessPaths = allSlotToAccessPaths.get(slot.getExprId().asInt()); + if (!accessPaths.isEmpty()) { + scanSlotToAccessPaths.put(slot, new ArrayList<>(accessPaths)); + } + } + return null; + } + + @Override + public Void visitLogicalFileScan(LogicalFileScan fileScan, StatementContext context) { + for (Slot slot : fileScan.getOutput()) { + if (!(slot.getDataType() instanceof NestedColumnPrunable)) { + continue; + } + Collection<CollectAccessPathResult> accessPaths = allSlotToAccessPaths.get(slot.getExprId().asInt()); + if (!accessPaths.isEmpty()) { + scanSlotToAccessPaths.put(slot, new ArrayList<>(accessPaths)); + } + } + return null; + } + + @Override + public Void visit(Plan plan, StatementContext context) { + collectByExpressions(plan, context); + + for (Plan child : plan.children()) { + child.accept(this, context); + } + return null; + } + + private void collectByExpressions(Plan plan, StatementContext context) { + collectByExpressions(plan, context, false); + } + + private void collectByExpressions(Plan plan, StatementContext context, boolean bottomPredicate) { + AccessPathExpressionCollector exprCollector + = new AccessPathExpressionCollector(context, allSlotToAccessPaths, bottomPredicate); + for (Expression expression : plan.getExpressions()) { + exprCollector.collect(expression); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnCollector.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java similarity index 56% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnCollector.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java index d95ace059b4..609b577c6a2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnCollector.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NestedColumnPruning.java @@ -22,36 +22,23 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.jobs.JobContext; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.rules.rewrite.AccessPathExpressionCollector.CollectAccessPathResult; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap; -import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; -import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey; -import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue; -import org.apache.doris.nereids.trees.expressions.functions.scalar.MapKeys; -import org.apache.doris.nereids.trees.expressions.functions.scalar.MapValues; -import org.apache.doris.nereids.trees.expressions.functions.scalar.StructElement; -import org.apache.doris.nereids.trees.expressions.literal.Literal; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.MapType; -import org.apache.doris.nereids.types.NestedColumnPrunable; import org.apache.doris.nereids.types.NullType; import org.apache.doris.nereids.types.StructField; import org.apache.doris.nereids.types.StructType; -import org.apache.doris.nereids.util.Utils; +import org.apache.doris.qe.SessionVariable; import org.apache.doris.thrift.TAccessPathType; import org.apache.doris.thrift.TColumnAccessPaths; import org.apache.doris.thrift.TColumnNameAccessPath; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Maps; import com.google.common.collect.Multimap; import com.google.common.collect.TreeMultimap; import org.apache.commons.lang3.StringUtils; @@ -59,32 +46,53 @@ import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; import java.util.Comparator; import java.util.LinkedHashMap; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; -/** NestedColumnCollector */ -public class NestedColumnCollector implements CustomRewriter { +/** + * <li> 1. prune the data type of struct/map + * + * <p> for example, column s is a struct<id: int, value: double>, + * and `select struct(s, 'id') from tbl` will return the data type: `struct<id: int>` + * </p> + * </li> + * + * <li> 2. collect the access paths + * <p> for example, select struct(s, 'id'), struct(s, 'data') from tbl` will collect the access path: + * [s.id, s.data] + * </p> + * </li> + */ +public class NestedColumnPruning implements CustomRewriter { @Override public Plan rewriteRoot(Plan plan, JobContext jobContext) { StatementContext statementContext = jobContext.getCascadesContext().getStatementContext(); - // if (!statementContext.hasNestedColumns()) { - // return plan; - // } + SessionVariable sessionVariable = statementContext.getConnectContext().getSessionVariable(); + if (!sessionVariable.enablePruneNestedColumns || !statementContext.hasNestedColumns()) { + return plan; + } - AccessPathCollector collector = new AccessPathCollector(); - List<AccessPathIsPredicate> slotToAccessPaths = collector.collectInPlan(plan, statementContext); + AccessPathPlanCollector collector = new AccessPathPlanCollector(); + Map<Slot, List<CollectAccessPathResult>> slotToAccessPaths = collector.collect(plan, statementContext); Map<Integer, AccessPathInfo> slotToResult = pruneDataType(slotToAccessPaths); - for (Entry<Integer, AccessPathInfo> kv : slotToResult.entrySet()) { - Integer slotId = kv.getKey(); - statementContext.setSlotIdToAccessPathInfo(slotId, kv.getValue()); + + if (!slotToResult.isEmpty()) { + Map<Integer, AccessPathInfo> slotIdToPruneType = Maps.newLinkedHashMap(); + for (Entry<Integer, AccessPathInfo> kv : slotToResult.entrySet()) { + Integer slotId = kv.getKey(); + AccessPathInfo accessPathInfo = kv.getValue(); + slotIdToPruneType.put(slotId, accessPathInfo); + } + SlotTypeReplacer typeReplacer = new SlotTypeReplacer(slotIdToPruneType); + return plan.accept(typeReplacer, null); } return plan; } - private static Map<Integer, AccessPathInfo> pruneDataType(List<AccessPathIsPredicate> slotToAccessPaths) { + private static Map<Integer, AccessPathInfo> pruneDataType( + Map<Slot, List<CollectAccessPathResult>> slotToAccessPaths) { Map<Integer, AccessPathInfo> result = new LinkedHashMap<>(); Map<Slot, DataTypeAccessTree> slotIdToAllAccessTree = new LinkedHashMap<>(); Map<Slot, DataTypeAccessTree> slotIdToPredicateAccessTree = new LinkedHashMap<>(); @@ -98,22 +106,24 @@ public class NestedColumnCollector implements CustomRewriter { Comparator.naturalOrder(), pathComparator); // first: build access data type tree - for (AccessPathIsPredicate accessPathIsPredicate : slotToAccessPaths) { - Slot slot = accessPathIsPredicate.slot; - List<String> path = accessPathIsPredicate.path; - - DataTypeAccessTree allAccessTree = slotIdToAllAccessTree.computeIfAbsent( - slot, i -> DataTypeAccessTree.ofRoot(slot) - ); - allAccessTree.setAccessByPath(path, 0); - allAccessPaths.put(slot.getExprId().asInt(), path); - - if (accessPathIsPredicate.isPredicate()) { - DataTypeAccessTree predicateAccessTree = slotIdToPredicateAccessTree.computeIfAbsent( + for (Entry<Slot, List<CollectAccessPathResult>> kv : slotToAccessPaths.entrySet()) { + Slot slot = kv.getKey(); + List<CollectAccessPathResult> collectAccessPathResults = kv.getValue(); + for (CollectAccessPathResult collectAccessPathResult : collectAccessPathResults) { + List<String> path = collectAccessPathResult.getPath(); + DataTypeAccessTree allAccessTree = slotIdToAllAccessTree.computeIfAbsent( slot, i -> DataTypeAccessTree.ofRoot(slot) ); - predicateAccessTree.setAccessByPath(path, 0); - predicateAccessPaths.put(slot.getExprId().asInt(), path); + allAccessTree.setAccessByPath(path, 0); + allAccessPaths.put(slot.getExprId().asInt(), path); + + if (collectAccessPathResult.isPredicate()) { + DataTypeAccessTree predicateAccessTree = slotIdToPredicateAccessTree.computeIfAbsent( + slot, i -> DataTypeAccessTree.ofRoot(slot) + ); + predicateAccessTree.setAccessByPath(path, 0); + predicateAccessPaths.put(slot.getExprId().asInt(), path); + } } } @@ -173,168 +183,6 @@ public class NestedColumnCollector implements CustomRewriter { return result; } - private class AccessPathCollector extends DefaultExpressionVisitor<Void, CollectorContext> { - private List<AccessPathIsPredicate> slotToAccessPaths = new ArrayList<>(); - - public List<AccessPathIsPredicate> collectInPlan( - Plan plan, StatementContext statementContext) { - boolean bottomFilter = plan instanceof LogicalFilter && plan.child(0).arity() == 0; - for (Expression expression : plan.getExpressions()) { - expression.accept(this, new CollectorContext(statementContext, bottomFilter)); - } - for (Plan child : plan.children()) { - collectInPlan(child, statementContext); - } - return slotToAccessPaths; - } - - private Void continueCollectAccessPath(Expression expr, CollectorContext context) { - return expr.accept(this, context); - } - - @Override - public Void visit(Expression expr, CollectorContext context) { - for (Expression child : expr.children()) { - child.accept(this, new CollectorContext(context.statementContext, context.bottomFilter)); - } - return null; - } - - @Override - public Void visitSlotReference(SlotReference slotReference, CollectorContext context) { - DataType dataType = slotReference.getDataType(); - if (dataType instanceof NestedColumnPrunable) { - context.accessPathBuilder.addPrefix(slotReference.getName()); - ImmutableList<String> path = Utils.fastToImmutableList(context.accessPathBuilder.accessPath); - slotToAccessPaths.add(new AccessPathIsPredicate(slotReference, path, context.bottomFilter)); - } - return null; - } - - @Override - public Void visitAlias(Alias alias, CollectorContext context) { - return alias.child(0).accept(this, context); - } - - @Override - public Void visitCast(Cast cast, CollectorContext context) { - return cast.child(0).accept(this, context); - } - - // array element at - @Override - public Void visitElementAt(ElementAt elementAt, CollectorContext context) { - List<Expression> arguments = elementAt.getArguments(); - Expression first = arguments.get(0); - if (first.getDataType().isArrayType() || first.getDataType().isMapType() - || first.getDataType().isVariantType()) { - context.accessPathBuilder.addPrefix("*"); - continueCollectAccessPath(first, context); - - for (int i = 1; i < arguments.size(); i++) { - visit(arguments.get(i), context); - } - return null; - } else { - return visit(elementAt, context); - } - } - - // struct element_at - @Override - public Void visitStructElement(StructElement structElement, CollectorContext context) { - List<Expression> arguments = structElement.getArguments(); - Expression struct = arguments.get(0); - Expression fieldName = arguments.get(1); - DataType fieldType = fieldName.getDataType(); - - if (fieldName.isLiteral() && (fieldType.isIntegerLikeType() || fieldType.isStringLikeType())) { - context.accessPathBuilder.addPrefix(((Literal) fieldName).getStringValue()); - return continueCollectAccessPath(struct, context); - } - - for (Expression argument : arguments) { - visit(argument, context); - } - return null; - } - - @Override - public Void visitMapKeys(MapKeys mapKeys, CollectorContext context) { - context.accessPathBuilder.addPrefix("KEYS"); - return continueCollectAccessPath(mapKeys.getArgument(0), context); - } - - @Override - public Void visitMapValues(MapValues mapValues, CollectorContext context) { - LinkedList<String> suffixPath = context.accessPathBuilder.accessPath; - if (!suffixPath.isEmpty() && suffixPath.get(0).equals("*")) { - CollectorContext removeStarContext - = new CollectorContext(context.statementContext, context.bottomFilter); - removeStarContext.accessPathBuilder.accessPath.addAll(suffixPath.subList(1, suffixPath.size())); - removeStarContext.accessPathBuilder.addPrefix("VALUES"); - return continueCollectAccessPath(mapValues.getArgument(0), removeStarContext); - } - context.accessPathBuilder.addPrefix("VALUES"); - return continueCollectAccessPath(mapValues.getArgument(0), context); - } - - @Override - public Void visitMapContainsKey(MapContainsKey mapContainsKey, CollectorContext context) { - context.accessPathBuilder.addPrefix("KEYS"); - return continueCollectAccessPath(mapContainsKey.getArgument(0), context); - } - - @Override - public Void visitMapContainsValue(MapContainsValue mapContainsValue, CollectorContext context) { - context.accessPathBuilder.addPrefix("VALUES"); - return continueCollectAccessPath(mapContainsValue.getArgument(0), context); - } - - @Override - public Void visitArrayMap(ArrayMap arrayMap, CollectorContext context) { - // Lambda lambda = (Lambda) arrayMap.getArgument(0); - // Expression array = arrayMap.getArgument(1); - - // String arrayName = lambda.getLambdaArgumentName(0); - return super.visitArrayMap(arrayMap, context); - } - } - - private static class CollectorContext { - private StatementContext statementContext; - private AccessPathBuilder accessPathBuilder; - private boolean bottomFilter; - - public CollectorContext(StatementContext statementContext, boolean bottomFilter) { - this.statementContext = statementContext; - this.accessPathBuilder = new AccessPathBuilder(); - this.bottomFilter = bottomFilter; - } - } - - private static class AccessPathBuilder { - private LinkedList<String> accessPath; - - public AccessPathBuilder() { - accessPath = new LinkedList<>(); - } - - public AccessPathBuilder addPrefix(String prefix) { - accessPath.addFirst(prefix); - return this; - } - - public List<String> toStringList() { - return new ArrayList<>(accessPath); - } - - @Override - public String toString() { - return String.join(".", accessPath); - } - } - private static class DataTypeAccessTree { private DataType type; private boolean isRoot; @@ -488,33 +336,4 @@ public class NestedColumnCollector implements CustomRewriter { } } } - - private static class AccessPathIsPredicate { - private final Slot slot; - private final List<String> path; - private final boolean isPredicate; - - public AccessPathIsPredicate(Slot slot, List<String> path, boolean isPredicate) { - this.slot = slot; - this.path = path; - this.isPredicate = isPredicate; - } - - public Slot getSlot() { - return slot; - } - - public List<String> getPath() { - return path; - } - - public boolean isPredicate() { - return isPredicate; - } - - @Override - public String toString() { - return slot.getName() + ": " + String.join(".", path) + ", " + isPredicate; - } - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java new file mode 100644 index 00000000000..ae61878f036 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SlotTypeReplacer.java @@ -0,0 +1,541 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.analysis.AccessPathInfo; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.Function; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalDeferMaterializeOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalDeferMaterializeTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalExcept; +import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate; +import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; +import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.LogicalResultSink; +import org.apache.doris.nereids.trees.plans.logical.LogicalSink; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; +import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.MoreFieldsThread; + +import com.google.common.collect.ImmutableCollection; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableMultimap.Builder; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Maps; +import com.google.common.collect.Multimap; + +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +/** SlotTypeReplacer */ +public class SlotTypeReplacer extends DefaultPlanRewriter<Void> { + private Map<Integer, AccessPathInfo> replacedDataTypes; + + public SlotTypeReplacer(Map<Integer, AccessPathInfo> bottomReplacedDataTypes) { + this.replacedDataTypes = Maps.newLinkedHashMap(bottomReplacedDataTypes); + } + + @Override + public Plan visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer, Void context) { + return super.visitLogicalCTEProducer(cteProducer, context); + } + + @Override + public Plan visitLogicalWindow(LogicalWindow<? extends Plan> window, Void context) { + window = visitChildren(this, window, context); + Pair<Boolean, ? extends List<? extends Expression>> replaced = replaceExpressions( + window.getExpressions(), false, false); + if (replaced.first) { + return window.withExpressionsAndChild((List) replaced.second, window.child()); + } + return window; + } + + @Override + public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Void context) { + Map<Slot, Slot> consumerToProducerOutputMap = cteConsumer.getConsumerToProducerOutputMap(); + Multimap<Slot, Slot> producerToConsumerOutputMap = cteConsumer.getProducerToConsumerOutputMap(); + + Map<Slot, Slot> replacedConsumerToProducerOutputMap = new LinkedHashMap<>(); + Builder<Slot, Slot> replacedProducerToConsumerOutputMap = ImmutableMultimap.builder(); + + boolean changed = false; + for (Entry<Slot, Slot> kv : consumerToProducerOutputMap.entrySet()) { + Slot consumerSlot = kv.getKey(); + Slot producerSlot = kv.getValue(); + AccessPathInfo accessPathInfo = replacedDataTypes.get(producerSlot.getExprId().asInt()); + if (accessPathInfo != null) { + DataType prunedType = accessPathInfo.getPrunedType(); + if (!prunedType.equals(producerSlot.getDataType())) { + replacedDataTypes.put(consumerSlot.getExprId().asInt(), accessPathInfo); + changed = true; + producerSlot = producerSlot.withNullableAndDataType(producerSlot.nullable(), prunedType); + consumerSlot = consumerSlot.withNullableAndDataType(consumerSlot.nullable(), prunedType); + } + } + replacedConsumerToProducerOutputMap.put(consumerSlot, producerSlot); + } + + for (Entry<Slot, Collection<Slot>> kv : producerToConsumerOutputMap.asMap().entrySet()) { + Slot producerSlot = kv.getKey(); + Collection<Slot> consumerSlots = kv.getValue(); + AccessPathInfo accessPathInfo = replacedDataTypes.get(producerSlot.getExprId().asInt()); + if (accessPathInfo != null && !accessPathInfo.getPrunedType().equals(producerSlot.getDataType())) { + DataType replacedDataType = accessPathInfo.getPrunedType(); + changed = true; + producerSlot = producerSlot.withNullableAndDataType(producerSlot.nullable(), replacedDataType); + for (Slot consumerSlot : consumerSlots) { + consumerSlot = consumerSlot.withNullableAndDataType(consumerSlot.nullable(), replacedDataType); + replacedProducerToConsumerOutputMap.put(producerSlot, consumerSlot); + } + } else { + replacedProducerToConsumerOutputMap.putAll(producerSlot, consumerSlots); + } + } + + if (changed) { + return new LogicalCTEConsumer( + cteConsumer.getRelationId(), cteConsumer.getCteId(), cteConsumer.getName(), + replacedConsumerToProducerOutputMap, replacedProducerToConsumerOutputMap.build() + ); + } + return cteConsumer; + } + + @Override + public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) { + join = visitChildren(this, join, context); + Pair<Boolean, List<Expression>> replacedHashJoinConjuncts + = replaceExpressions(join.getHashJoinConjuncts(), false, false); + Pair<Boolean, List<Expression>> replacedOtherJoinConjuncts + = replaceExpressions(join.getOtherJoinConjuncts(), false, false); + + if (replacedHashJoinConjuncts.first || replacedOtherJoinConjuncts.first) { + return join.withJoinConjuncts( + replacedHashJoinConjuncts.second, + replacedOtherJoinConjuncts.second, + join.getJoinReorderContext()); + } + return join; + } + + @Override + public Plan visitLogicalProject(LogicalProject<? extends Plan> project, Void context) { + project = visitChildren(this, project, context); + + Pair<Boolean, List<NamedExpression>> projects = replaceExpressions(project.getProjects(), true, false); + if (projects.first) { + return project.withProjects(projects.second); + } + return project; + } + + @Override + public Plan visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan> partitionTopN, Void context) { + partitionTopN = visitChildren(this, partitionTopN, context); + + Pair<Boolean, List<Expression>> replacedPartitionKeys = replaceExpressions( + partitionTopN.getPartitionKeys(), false, false); + Pair<Boolean, List<OrderExpression>> replacedOrderExpressions + = replaceOrderExpressions(partitionTopN.getOrderKeys()); + if (replacedPartitionKeys.first || replacedOrderExpressions.first) { + return partitionTopN.withPartitionKeysAndOrderKeys( + replacedPartitionKeys.second, replacedOrderExpressions.second); + } + return partitionTopN; + } + + @Override + public Plan visitLogicalDeferMaterializeTopN(LogicalDeferMaterializeTopN<? extends Plan> topN, Void context) { + topN = visitChildren(this, topN, context); + + LogicalTopN logicalTopN = (LogicalTopN) topN.getLogicalTopN().accept(this, context); + if (logicalTopN != topN.getLogicalTopN()) { + SlotReference replacedColumnIdSlot = replaceExpressions( + ImmutableList.of(topN.getColumnIdSlot()), false, false).second.get(0); + return new LogicalDeferMaterializeTopN( + logicalTopN, topN.getDeferMaterializeSlotIds(), replacedColumnIdSlot); + } + + return topN; + } + + @Override + public Plan visitLogicalExcept(LogicalExcept except, Void context) { + except = visitChildren(this, except, context); + + Pair<Boolean, List<List<SlotReference>>> replacedRegularChildrenOutputs = replaceMultiExpressions( + except.getRegularChildrenOutputs()); + + Pair<Boolean, List<NamedExpression>> replacedOutputs + = replaceExpressions(except.getOutputs(), true, false); + + if (replacedRegularChildrenOutputs.first || replacedOutputs.first) { + return new LogicalExcept(except.getQualifier(), except.getOutputs(), + except.getRegularChildrenOutputs(), except.children()); + } + + return except; + } + + @Override + public Plan visitLogicalIntersect(LogicalIntersect intersect, Void context) { + intersect = visitChildren(this, intersect, context); + + Pair<Boolean, List<List<SlotReference>>> replacedRegularChildrenOutputs = replaceMultiExpressions( + intersect.getRegularChildrenOutputs()); + + Pair<Boolean, List<NamedExpression>> replacedOutputs + = replaceExpressions(intersect.getOutputs(), true, false); + + if (replacedRegularChildrenOutputs.first || replacedOutputs.first) { + return new LogicalIntersect(intersect.getQualifier(), intersect.getOutputs(), + intersect.getRegularChildrenOutputs(), intersect.children()); + } + return intersect; + } + + @Override + public Plan visitLogicalUnion(LogicalUnion union, Void context) { + union = visitChildren(this, union, context); + + Pair<Boolean, List<List<SlotReference>>> replacedRegularChildrenOutputs = replaceMultiExpressions( + union.getRegularChildrenOutputs()); + + Pair<Boolean, List<NamedExpression>> replacedOutputs + = replaceExpressions(union.getOutputs(), true, false); + + if (replacedRegularChildrenOutputs.first || replacedOutputs.first) { + return new LogicalUnion( + union.getQualifier(), + replacedOutputs.second, + replacedRegularChildrenOutputs.second, + union.getConstantExprsList(), + union.hasPushedFilter(), + union.children() + ); + } + + return union; + } + + @Override + public Plan visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Void context) { + repeat = visitChildren(this, repeat, context); + + Pair<Boolean, List<List<Expression>>> replacedGroupingSets + = replaceMultiExpressions(repeat.getGroupingSets()); + Pair<Boolean, List<NamedExpression>> replacedOutputs + = replaceExpressions(repeat.getOutputExpressions(), true, false); + + if (replacedGroupingSets.first || replacedOutputs.first) { + return repeat.withGroupSetsAndOutput(replacedGroupingSets.second, replacedOutputs.second); + } + return repeat; + } + + @Override + public Plan visitLogicalGenerate(LogicalGenerate<? extends Plan> generate, Void context) { + generate = visitChildren(this, generate, context); + + Pair<Boolean, List<Function>> replacedGenerators + = replaceExpressions(generate.getGenerators(), false, false); + Pair<Boolean, List<Slot>> replacedGeneratorOutput + = replaceExpressions(generate.getGeneratorOutput(), false, false); + if (replacedGenerators.first || replacedGeneratorOutput.first) { + return new LogicalGenerate<>(replacedGenerators.second, replacedGeneratorOutput.second, + generate.getExpandColumnAlias(), generate.child()); + } + return generate; + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) { + aggregate = visitChildren(this, aggregate, context); + + Pair<Boolean, List<Expression>> replacedGroupBy = replaceExpressions( + aggregate.getGroupByExpressions(), false, false); + Pair<Boolean, List<NamedExpression>> replacedOutput = replaceExpressions( + aggregate.getOutputExpressions(), true, false); + + if (replacedGroupBy.first || replacedOutput.first) { + return aggregate.withGroupByAndOutput(replacedGroupBy.second, replacedOutput.second); + } + return aggregate; + } + + @Override + public Plan visitLogicalSort(LogicalSort<? extends Plan> sort, Void context) { + sort = visitChildren(this, sort, context); + + Pair<Boolean, List<OrderKey>> replaced = replaceOrderKeys(sort.getOrderKeys()); + if (replaced.first) { + return sort.withOrderKeys(replaced.second); + } + return sort; + } + + @Override + public Plan visitLogicalTopN(LogicalTopN<? extends Plan> topN, Void context) { + topN = visitChildren(this, topN, context); + + Pair<Boolean, List<OrderKey>> replaced = replaceOrderKeys(topN.getOrderKeys()); + if (replaced.first) { + return topN.withOrderKeys(replaced.second); + } + return topN; + } + + @Override + public Plan visitLogicalDeferMaterializeOlapScan( + LogicalDeferMaterializeOlapScan deferMaterializeOlapScan, Void context) { + + LogicalOlapScan logicalOlapScan + = (LogicalOlapScan) deferMaterializeOlapScan.getLogicalOlapScan().accept(this, context); + + if (logicalOlapScan != deferMaterializeOlapScan.getLogicalOlapScan()) { + SlotReference replacedColumnIdSlot = replaceExpressions( + ImmutableList.of(deferMaterializeOlapScan.getColumnIdSlot()), false, false).second.get(0); + return new LogicalDeferMaterializeOlapScan( + logicalOlapScan, deferMaterializeOlapScan.getDeferMaterializeSlotIds(), replacedColumnIdSlot + ); + } + return deferMaterializeOlapScan; + } + + @Override + public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) { + filter = visitChildren(this, filter, context); + + Pair<Boolean, Set<Expression>> replaced = replaceExpressions(filter.getConjuncts(), false, false); + if (replaced.first) { + return filter.withConjuncts(replaced.second); + } + return filter; + } + + @Override + public Plan visitLogicalFileScan(LogicalFileScan fileScan, Void context) { + Pair<Boolean, List<Slot>> replaced = replaceExpressions(fileScan.getOutput(), false, true); + if (replaced.first) { + return fileScan.withCachedOutput(replaced.second); + } + return fileScan; + } + + @Override + public Plan visitLogicalOlapScan(LogicalOlapScan olapScan, Void context) { + Pair<Boolean, List<Slot>> replaced = replaceExpressions(olapScan.getOutput(), false, true); + if (replaced.first) { + return olapScan.withPrunedTypeSlots(replaced.second); + } + return olapScan; + } + + @Override + public Plan visitLogicalEmptyRelation(LogicalEmptyRelation emptyRelation, Void context) { + Pair<Boolean, List<NamedExpression>> replacedProjects + = replaceExpressions(emptyRelation.getProjects(), true, false); + + if (replacedProjects.first) { + return emptyRelation.withProjects(replacedProjects.second); + } + return emptyRelation; + } + + @Override + public Plan visitLogicalOneRowRelation(LogicalOneRowRelation oneRowRelation, Void context) { + Pair<Boolean, List<NamedExpression>> replacedProjects + = replaceExpressions(oneRowRelation.getProjects(), true, false); + + if (replacedProjects.first) { + return oneRowRelation.withProjects(replacedProjects.second); + } + return oneRowRelation; + } + + @Override + public Plan visitLogicalResultSink(LogicalResultSink<? extends Plan> logicalResultSink, Void context) { + logicalResultSink = visitChildren(this, logicalResultSink, context); + + Pair<Boolean, List<NamedExpression>> replacedOutput = replaceExpressions(logicalResultSink.getOutputExprs(), + false, false); + if (replacedOutput.first) { + return logicalResultSink.withOutputExprs(replacedOutput.second); + } + return logicalResultSink; + } + + @Override + public Plan visitLogicalSink(LogicalSink<? extends Plan> logicalSink, Void context) { + // do nothing + return logicalSink; + } + + private Pair<Boolean, List<OrderExpression>> replaceOrderExpressions(List<OrderExpression> orderExpressions) { + ImmutableList.Builder<OrderExpression> newOrderKeys + = ImmutableList.builderWithExpectedSize(orderExpressions.size()); + boolean changed = false; + for (OrderExpression orderExpression : orderExpressions) { + Expression newOrderKeyExpr = replaceSlot(orderExpression.getOrderKey().getExpr(), false); + if (newOrderKeyExpr != orderExpression.getOrderKey().getExpr()) { + newOrderKeys.add(new OrderExpression(orderExpression.getOrderKey().withExpression(newOrderKeyExpr))); + changed = true; + } else { + newOrderKeys.add(orderExpression); + } + } + return Pair.of(changed, newOrderKeys.build()); + } + + private Pair<Boolean, List<OrderKey>> replaceOrderKeys(List<OrderKey> orderKeys) { + ImmutableList.Builder<OrderKey> newOrderKeys = ImmutableList.builderWithExpectedSize(orderKeys.size()); + boolean changed = false; + for (OrderKey orderKey : orderKeys) { + Expression newOrderKeyExpr = replaceSlot(orderKey.getExpr(), false); + if (newOrderKeyExpr != orderKey.getExpr()) { + newOrderKeys.add(orderKey.withExpression(newOrderKeyExpr)); + changed = true; + } else { + newOrderKeys.add(orderKey); + } + } + return Pair.of(changed, newOrderKeys.build()); + } + + private <C extends Collection<E>, E extends Expression> + Pair<Boolean, List<C>> replaceMultiExpressions(List<C> expressionsList) { + ImmutableList.Builder<C> result = ImmutableList.builderWithExpectedSize(expressionsList.size()); + boolean changed = false; + for (C expressions : expressionsList) { + Pair<Boolean, C> replaced = replaceExpressions(expressions, false, false); + changed |= replaced.first; + result.add(replaced.second); + } + return Pair.of(changed, result.build()); + } + + private <C extends Collection<E>, E extends Expression> Pair<Boolean, C> replaceExpressions( + C expressions, boolean propagateType, boolean fillAccessPaths) { + ImmutableCollection.Builder<E> newExprs; + if (expressions instanceof List) { + newExprs = ImmutableList.builder(); + } else { + newExprs = ImmutableSet.builder(); + } + + boolean changed = false; + for (Expression oldExpr : expressions) { + Expression newExpr = replaceSlot(oldExpr, fillAccessPaths); + if (newExpr != oldExpr) { + newExprs.add((E) newExpr); + changed = true; + + if (propagateType && oldExpr instanceof NamedExpression + && !oldExpr.getDataType().equals(newExpr.getDataType())) { + replacedDataTypes.put( + ((NamedExpression) oldExpr).getExprId().asInt(), + // not need access path in the upper slots + new AccessPathInfo(newExpr.getDataType(), null, null) + ); + } + } else { + newExprs.add((E) oldExpr); + } + } + return Pair.of(changed, (C) newExprs.build()); + } + + private Expression replaceSlot(Expression expr, boolean fillAccessPath) { + return MoreFieldsThread.keepFunctionSignature(false, () -> { + return expr.rewriteUp(e -> { + if (e instanceof Lambda) { + return rewriteLambda((Lambda) e, fillAccessPath); + } else if (e instanceof SlotReference) { + AccessPathInfo accessPathInfo = replacedDataTypes.get(((SlotReference) e).getExprId().asInt()); + if (accessPathInfo != null) { + SlotReference newSlot + = (SlotReference) ((SlotReference) e).withNullableAndDataType( + e.nullable(), accessPathInfo.getPrunedType()); + if (fillAccessPath) { + newSlot = newSlot.withAccessPaths( + accessPathInfo.getAllAccessPaths(), accessPathInfo.getPredicateAccessPaths() + ); + } + return newSlot; + } + } + return e; + }); + }); + } + + private Expression rewriteLambda(Lambda e, boolean fillAccessPath) { + // we should rewrite ArrayItemReference first, then we can replace the ArrayItemSlot int the lambda + Expression[] newChildren = new Expression[e.arity()]; + for (int i = 0; i < e.arity(); i++) { + Expression child = e.child(i); + if (child instanceof ArrayItemReference) { + Expression newRef = child.withChildren(replaceSlot(child.child(0), fillAccessPath)); + replacedDataTypes.put(((ArrayItemReference) child).getExprId().asInt(), + new AccessPathInfo(newRef.getDataType(), null, null)); + newChildren[i] = newRef; + } else { + newChildren[i] = child; + } + } + + for (int i = 0; i < newChildren.length; i++) { + Expression child = newChildren[i]; + if (!(child instanceof ArrayItemReference)) { + newChildren[i] = replaceSlot(child, fillAccessPath); + } + } + + return e.withChildren(newChildren); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSubPathPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSubPathPruning.java index 1ed8d151270..abb65a283ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSubPathPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSubPathPruning.java @@ -22,7 +22,6 @@ import org.apache.doris.common.util.DebugUtil; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.rules.rewrite.ColumnPruning.PruneContext; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -87,7 +86,7 @@ import java.util.Set; * generating the slots for the required sub path on scan, union, and cte consumer. * Then, it replaces the element_at with the corresponding slot. */ -public class VariantSubPathPruning extends DefaultPlanRewriter<PruneContext> implements CustomRewriter { +public class VariantSubPathPruning implements CustomRewriter { public static final Logger LOG = LogManager.getLogger(VariantSubPathPruning.class); @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index ce2d3dbb3e5..ad09b4f870a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.exceptions.UnboundException; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.util.Utils; +import org.apache.doris.thrift.TColumnAccessPaths; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -56,6 +57,8 @@ public class SlotReference extends Slot { // that need return original table and name for view not its original table if u query a view private final TableIf oneLevelTable; private final Column oneLevelColumn; + private final Optional<TColumnAccessPaths> allAccessPaths; + private final Optional<TColumnAccessPaths> predicateAccessPaths; public SlotReference(String name, DataType dataType) { this(StatementScopeIdGenerator.newExprId(), name, dataType, true, ImmutableList.of(), @@ -92,6 +95,14 @@ public class SlotReference extends Slot { subPath, Optional.empty()); } + public SlotReference(ExprId exprId, Supplier<String> name, DataType dataType, boolean nullable, + List<String> qualifier, @Nullable TableIf originalTable, @Nullable Column originalColumn, + @Nullable TableIf oneLevelTable, Column oneLevelColumn, + List<String> subPath, Optional<Pair<Integer, Integer>> indexInSql) { + this(exprId, name, dataType, nullable, qualifier, originalTable, originalColumn, oneLevelTable, + oneLevelColumn, subPath, indexInSql, Optional.empty(), Optional.empty()); + } + /** * Constructor for SlotReference. * @@ -106,7 +117,8 @@ public class SlotReference extends Slot { public SlotReference(ExprId exprId, Supplier<String> name, DataType dataType, boolean nullable, List<String> qualifier, @Nullable TableIf originalTable, @Nullable Column originalColumn, @Nullable TableIf oneLevelTable, Column oneLevelColumn, - List<String> subPath, Optional<Pair<Integer, Integer>> indexInSql) { + List<String> subPath, Optional<Pair<Integer, Integer>> indexInSql, + Optional<TColumnAccessPaths> allAccessPaths, Optional<TColumnAccessPaths> predicateAccessPaths) { super(indexInSql); this.exprId = exprId; this.name = name; @@ -119,6 +131,8 @@ public class SlotReference extends Slot { this.oneLevelTable = oneLevelTable; this.oneLevelColumn = oneLevelColumn; this.subPath = Objects.requireNonNull(subPath, "subPath can not be null"); + this.allAccessPaths = allAccessPaths; + this.predicateAccessPaths = predicateAccessPaths; } public static SlotReference of(String name, DataType type) { @@ -342,4 +356,18 @@ public class SlotReference extends Slot { public boolean hasAutoInc() { return originalColumn != null ? originalColumn.isAutoInc() : false; } + + public SlotReference withAccessPaths(TColumnAccessPaths allAccessPaths, TColumnAccessPaths predicateAccessPaths) { + return new SlotReference(exprId, name, dataType, nullable, qualifier, + originalTable, originalColumn, oneLevelTable, oneLevelColumn, + subPath, indexInSqlString, Optional.of(allAccessPaths), Optional.of(predicateAccessPaths)); + } + + public Optional<TColumnAccessPaths> getAllAccessPaths() { + return allAccessPaths; + } + + public Optional<TColumnAccessPaths> getPredicateAccessPaths() { + return predicateAccessPaths; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java index 8c0babc39ca..5410de371a7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import java.util.List; @@ -51,4 +52,9 @@ public class ArrayFirst extends ElementAt public List<FunctionSignature> getImplSignature() { return SIGNATURES; } + + @Override + public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { + return visitor.visitArrayFirst(this, context); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java index e1ed4f4d27b..b9f5650156f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import java.util.List; @@ -51,4 +52,9 @@ public class ArrayLast extends ElementAt public ElementAt withChildren(List<Expression> children) { return new ArrayLast(getFunctionParams(children)); } + + @Override + public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { + return visitor.visitArrayLast(this, context); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java index 37b6b6233bd..f32e91edd23 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java @@ -54,10 +54,12 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerat import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirst; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirstIndex; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFlatten; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLast; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLastIndex; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap; import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMatchAll; @@ -624,6 +626,10 @@ public interface ScalarFunctionVisitor<R, C> { return visitScalarFunction(arrayFilter, context); } + default R visitArrayFirst(ArrayFirst arrayFirst, C context) { + return visitElementAt(arrayFirst, context); + } + default R visitArrayFirstIndex(ArrayFirstIndex arrayFirstIndex, C context) { return visitScalarFunction(arrayFirstIndex, context); } @@ -636,6 +642,10 @@ public interface ScalarFunctionVisitor<R, C> { return visitScalarFunction(arrayJoin, context); } + default R visitArrayLast(ArrayLast arrayLast, C context) { + return visitElementAt(arrayLast, context); + } + default R visitArrayLastIndex(ArrayLastIndex arrayLastIndex, C context) { return visitScalarFunction(arrayLastIndex, context); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java index 4167e68856d..5d0f41f3818 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java @@ -51,16 +51,17 @@ public class LogicalFileScan extends LogicalCatalogRelation { protected final Optional<TableSample> tableSample; protected final Optional<TableSnapshot> tableSnapshot; protected final Optional<TableScanParams> scanParams; + protected final Optional<List<Slot>> cachedOutputs; public LogicalFileScan(RelationId id, ExternalTable table, List<String> qualifier, Collection<Slot> operativeSlots, Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot, - Optional<TableScanParams> scanParams) { + Optional<TableScanParams> scanParams, Optional<List<Slot>> cachedOutputs) { this(id, table, qualifier, table.initSelectedPartitions(MvccUtil.getSnapshotFromContext(table)), operativeSlots, ImmutableList.of(), tableSample, tableSnapshot, - scanParams, Optional.empty(), Optional.empty()); + scanParams, Optional.empty(), Optional.empty(), cachedOutputs); } /** @@ -70,13 +71,15 @@ public class LogicalFileScan extends LogicalCatalogRelation { SelectedPartitions selectedPartitions, Collection<Slot> operativeSlots, List<NamedExpression> virtualColumns, Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot, Optional<TableScanParams> scanParams, - Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties) { + Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, + Optional<List<Slot>> cachedSlots) { super(id, PlanType.LOGICAL_FILE_SCAN, table, qualifier, operativeSlots, virtualColumns, groupExpression, logicalProperties); this.selectedPartitions = selectedPartitions; this.tableSample = tableSample; this.tableSnapshot = tableSnapshot; this.scanParams = scanParams; + this.cachedOutputs = cachedSlots; } public SelectedPartitions getSelectedPartitions() { @@ -116,7 +119,7 @@ public class LogicalFileScan extends LogicalCatalogRelation { public LogicalFileScan withGroupExpression(Optional<GroupExpression> groupExpression) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, - scanParams, groupExpression, Optional.of(getLogicalProperties())); + scanParams, groupExpression, Optional.of(getLogicalProperties()), cachedOutputs); } @Override @@ -124,20 +127,20 @@ public class LogicalFileScan extends LogicalCatalogRelation { Optional<LogicalProperties> logicalProperties, List<Plan> children) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, - scanParams, groupExpression, logicalProperties); + scanParams, groupExpression, logicalProperties, cachedOutputs); } public LogicalFileScan withSelectedPartitions(SelectedPartitions selectedPartitions) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, - scanParams, Optional.empty(), Optional.of(getLogicalProperties())); + scanParams, Optional.empty(), Optional.of(getLogicalProperties()), cachedOutputs); } @Override public LogicalFileScan withRelationId(RelationId relationId) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, - scanParams, Optional.empty(), Optional.empty()); + scanParams, Optional.empty(), Optional.empty(), cachedOutputs); } @Override @@ -150,6 +153,19 @@ public class LogicalFileScan extends LogicalCatalogRelation { return super.equals(o) && Objects.equals(selectedPartitions, ((LogicalFileScan) o).selectedPartitions); } + @Override + public List<Slot> computeOutput() { + if (cachedOutputs.isPresent()) { + return cachedOutputs.get(); + } + return super.computeOutput(); + } + + @Override + public List<Slot> computeAsteriskOutput() { + return super.computeAsteriskOutput(); + } + /** * SelectedPartitions contains the selected partitions and the total partition number. * Mainly for hive table partition pruning. @@ -207,7 +223,13 @@ public class LogicalFileScan extends LogicalCatalogRelation { public LogicalFileScan withOperativeSlots(Collection<Slot> operativeSlots) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, - scanParams, groupExpression, Optional.of(getLogicalProperties())); + scanParams, groupExpression, Optional.of(getLogicalProperties()), cachedOutputs); + } + + public LogicalFileScan withCachedOutput(List<Slot> cachedOutputs) { + return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, + selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, + scanParams, groupExpression, Optional.empty(), Optional.of(cachedOutputs)); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java index 9a123a04d2b..134f48309d7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java @@ -77,9 +77,10 @@ public class LogicalHudiScan extends LogicalFileScan { Collection<Slot> operativeSlots, List<NamedExpression> virtualColumns, Optional<GroupExpression> groupExpression, - Optional<LogicalProperties> logicalProperties) { + Optional<LogicalProperties> logicalProperties, + Optional<List<Slot>> cachedOutputs) { super(id, table, qualifier, selectedPartitions, operativeSlots, virtualColumns, - tableSample, tableSnapshot, scanParams, groupExpression, logicalProperties); + tableSample, tableSnapshot, scanParams, groupExpression, logicalProperties, cachedOutputs); Objects.requireNonNull(scanParams, "scanParams should not null"); Objects.requireNonNull(incrementalRelation, "incrementalRelation should not null"); this.incrementalRelation = incrementalRelation; @@ -87,10 +88,11 @@ public class LogicalHudiScan extends LogicalFileScan { public LogicalHudiScan(RelationId id, ExternalTable table, List<String> qualifier, Collection<Slot> operativeSlots, Optional<TableScanParams> scanParams, - Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot) { + Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot, + Optional<List<Slot>> cachedOutputs) { this(id, table, qualifier, ((HMSExternalTable) table).initHudiSelectedPartitions(tableSnapshot), tableSample, tableSnapshot, scanParams, Optional.empty(), operativeSlots, ImmutableList.of(), - Optional.empty(), Optional.empty()); + Optional.empty(), Optional.empty(), cachedOutputs); } public Optional<TableScanParams> getScanParams() { @@ -140,7 +142,7 @@ public class LogicalHudiScan extends LogicalFileScan { public LogicalHudiScan withGroupExpression(Optional<GroupExpression> groupExpression) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, - operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties()), cachedOutputs); } @Override @@ -148,20 +150,20 @@ public class LogicalHudiScan extends LogicalFileScan { Optional<LogicalProperties> logicalProperties, List<Plan> children) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, - operativeSlots, virtualColumns, groupExpression, logicalProperties); + operativeSlots, virtualColumns, groupExpression, logicalProperties, cachedOutputs); } public LogicalHudiScan withSelectedPartitions(SelectedPartitions selectedPartitions) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, - operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties()), cachedOutputs); } @Override public LogicalHudiScan withRelationId(RelationId relationId) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, - operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties()), cachedOutputs); } @Override @@ -173,7 +175,7 @@ public class LogicalHudiScan extends LogicalFileScan { public LogicalFileScan withOperativeSlots(Collection<Slot> operativeSlots) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, - operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties()), cachedOutputs); } /** @@ -226,6 +228,6 @@ public class LogicalHudiScan extends LogicalFileScan { } return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, selectedPartitions, tableSample, tableSnapshot, scanParams, newIncrementalRelation, - operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties()), cachedOutputs); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java index 8df1cd2f6a4..33d8571026e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalOlapScan.java @@ -840,4 +840,20 @@ public class LogicalOlapScan extends LogicalCatalogRelation implements OlapScan } return replaceMap; } + + /** withPrunedTypeSlots */ + public LogicalOlapScan withPrunedTypeSlots(List<Slot> outputSlots) { + Map<Pair<Long, String>, Slot> replaceSlotMap = new HashMap<>(); + for (Slot outputSlot : outputSlots) { + Pair<Long, String> key = Pair.of(selectedIndexId, outputSlot.getName()); + replaceSlotMap.put(key, outputSlot); + } + + return new LogicalOlapScan(relationId, (Table) table, qualifier, + Optional.empty(), Optional.empty(), + selectedPartitionIds, false, selectedTabletIds, + selectedIndexId, indexSelected, preAggStatus, manuallySpecifiedPartitions, + hints, replaceSlotMap, tableSample, directMvScan, colToSubPathsMap, selectedTabletIds, + operativeSlots, virtualColumns, scoreOrderKeys, scoreLimit, annOrderKeys, annLimit); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java index 4cc3bfccf3e..7a5ff6d72d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java @@ -36,7 +36,7 @@ import java.util.stream.Collectors; * Example: VARIANT <`a.b`:INT, `a.c`:DATETIMEV2> * */ -public class VariantType extends PrimitiveType implements NestedColumnPrunable { +public class VariantType extends PrimitiveType { public static final VariantType INSTANCE = new VariantType(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java index ca3edc13c50..bd7b006f3ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/OlapScanNode.java @@ -1088,6 +1088,9 @@ public class OlapScanNode extends ScanNode { output.append(prefix).append("rewrittenProjectList: ").append( getExplainString(rewrittenProjectList)).append("\n"); } + + printNestedColumns(output, prefix); + return output.toString(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java index 4a9a1945d9b..e438b0984d1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/ScanNode.java @@ -66,6 +66,7 @@ import com.google.common.collect.RangeSet; import com.google.common.collect.Sets; import com.google.common.collect.TreeRangeSet; import org.apache.commons.collections.map.CaseInsensitiveMap; +import org.apache.commons.lang3.StringUtils; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; @@ -557,6 +558,55 @@ public abstract class ScanNode extends PlanNode implements SplitGenerator { .addValue(super.debugString()).toString(); } + protected void printNestedColumns(StringBuilder output, String prefix) { + boolean printNestedColumnsHeader = true; + for (SlotDescriptor slot : getTupleDesc().getSlots()) { + String prunedType = null; + if (!slot.getType().equals(slot.getColumn().getType())) { + prunedType = slot.getType().toString(); + } + + String allAccessPathsString = null; + if (slot.getAllAccessPaths() != null + && slot.getAllAccessPaths().name_access_paths != null + && !slot.getAllAccessPaths().name_access_paths.isEmpty()) { + allAccessPathsString = slot.getAllAccessPaths().name_access_paths + .stream() + .map(a -> StringUtils.join(a.path, ".")) + .collect(Collectors.joining(", ")); + } + String predicateAccessPathsString = null; + if (slot.getPredicateAccessPaths() != null + && slot.getPredicateAccessPaths().name_access_paths != null + && !slot.getPredicateAccessPaths().name_access_paths.isEmpty()) { + predicateAccessPathsString = slot.getPredicateAccessPaths().name_access_paths + .stream() + .map(a -> StringUtils.join(a.path, ".")) + .collect(Collectors.joining(", ")); + } + if (prunedType == null && allAccessPathsString == null && predicateAccessPathsString == null) { + continue; + } + + if (printNestedColumnsHeader) { + output.append(prefix).append("nested columns:\n"); + printNestedColumnsHeader = false; + } + output.append(prefix).append(" ").append(slot.getColumn().getName()).append(":\n"); + output.append(prefix).append(" origin type: ").append(slot.getColumn().getType()).append("\n"); + if (prunedType != null) { + output.append(prefix).append(" pruned type: ").append(prunedType).append("\n"); + } + if (allAccessPathsString != null) { + output.append(prefix).append(" all access paths: [").append(allAccessPathsString).append("]\n"); + } + if (predicateAccessPathsString != null) { + output.append(prefix).append(" predicate access paths: [") + .append(predicateAccessPathsString).append("]\n"); + } + } + } + public List<TupleId> getOutputTupleIds() { if (outputTupleDesc != null) { return Lists.newArrayList(outputTupleDesc.getId()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 6381773d09f..26085fa5361 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -384,6 +384,8 @@ public class SessionVariable implements Serializable, Writable { public static final String ENABLE_RUNTIME_FILTER_PARTITION_PRUNE = "enable_runtime_filter_partition_prune"; + public static final String ENABLE_PRUNE_NESTED_COLUMN = "enable_prune_nested_column"; + static final String SESSION_CONTEXT = "session_context"; public static final String DEFAULT_ORDER_BY_LIMIT = "default_order_by_limit"; @@ -1511,6 +1513,13 @@ public class SessionVariable implements Serializable, Writable { varType = VariableAnnotation.EXPERIMENTAL) public int topNLazyMaterializationThreshold = 1024; + @VariableMgr.VarAttr(name = ENABLE_PRUNE_NESTED_COLUMN, needForward = true, + fuzzy = false, + varType = VariableAnnotation.EXPERIMENTAL, + description = {"是否裁剪map/struct类型", "Whether to prune the type of map/struct"} + ) + public boolean enablePruneNestedColumns = true; + public boolean enableTopnLazyMaterialization() { return ConnectContext.get() != null && ConnectContext.get().getSessionVariable().topNLazyMaterializationThreshold > 0; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumn.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java similarity index 64% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumn.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java index 7cccbf18496..feec310b4f7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumn.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PruneNestedColumnTest.java @@ -19,8 +19,19 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.analysis.SlotDescriptor; import org.apache.doris.catalog.Type; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ArrayItemReference; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; +import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.planner.OlapScanNode; import org.apache.doris.planner.PlanFragment; import org.apache.doris.thrift.TAccessPathType; @@ -33,10 +44,15 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import java.util.TreeSet; +import java.util.function.Consumer; -public class PruneNestedColumn extends TestWithFeService { +public class PruneNestedColumnTest extends TestWithFeService { @BeforeAll public void createTable() throws Exception { createDatabase("test"); @@ -54,6 +70,23 @@ public class PruneNestedColumn extends TestWithFeService { + "properties ('replication_num'='1')"); connectContext.getSessionVariable().setDisableNereidsRules(RuleType.PRUNE_EMPTY_PARTITION.name()); + connectContext.getSessionVariable().enableNereidsTimeout = false; + } + + @Test + public void testPruneArrayLambda() throws Exception { + // map_values(struct_element(s, 'data').*)[0].a + assertColumn("select struct_element(array_map(x -> map_values(x)[0], struct_element(s, 'data'))[0], 'a') from tbl", + "struct<data:array<map<int,struct<a:int>>>>", + ImmutableList.of(path("s", "data", "*", "VALUES", "a")), + ImmutableList.of() + ); + + assertColumn("select array_map((x, y) -> struct_element(map_values(x)[0], 'a') + struct_element(map_values(y)[0], 'b'), struct_element(s, 'data'), struct_element(s, 'data')) from tbl", + "struct<data:array<map<int,struct<a:int,b:double>>>>", + ImmutableList.of(path("s", "data", "*", "VALUES", "a"), path("s", "data", "*", "VALUES", "b")), + ImmutableList.of() + ); } @Test @@ -114,14 +147,6 @@ public class PruneNestedColumn extends TestWithFeService { ImmutableList.of(path("s", "data", "*", "*", "b")), ImmutableList.of() ); - // assertColumn("select struct_element(struct_element(s, 'data')[1][1], 'b') from tbl where struct_element(s, 'city')='beijing", - // "struct<data:array<map<int,struct<b:double>>>>", - // predicatePath("city"), - // path("data", "*", "*", "b") - // ); - - // assertColumn("select array_map(x -> x[2], struct_element(s, 'data')) from tbl", "struct<data:array<map<int,struct<a:int,b:double>>>>", path("data")); - // assertColumn("select array_map(x -> struct_element(x[2], 'b'), struct_element(s, 'data')) from tbl", "struct<data:array<map<int,struct<b:double>>>>", path("data", "*", "*", "b")); } @Test @@ -207,10 +232,57 @@ public class PruneNestedColumn extends TestWithFeService { ); } + @Test + public void testCte() throws Throwable { + assertColumn("with t as (select id, s from tbl) select struct_element(t1.s, 'city') from t t1 join t t2 on t1.id = t2.id", + "struct<city:text>", + ImmutableList.of(path("s", "city")), + ImmutableList.of() + ); + + assertColumn("with t as (select id, struct_element(s, 'city') as c from tbl) select t1.c from t t1 join t t2 on t1.id = t2.id", + "struct<city:text>", + ImmutableList.of(path("s", "city")), + ImmutableList.of() + ); + } + + @Test + public void testUnion() throws Throwable { + assertColumn("select struct_element(s, 'city') from (select s from tbl union all select null)a", + "struct<city:text,data:array<map<int,struct<a:int,b:double>>>>", + ImmutableList.of(path("s")), + ImmutableList.of() + ); + + assertColumn("select * from (select struct_element(s, 'city') from tbl union all select null)a", + "struct<city:text>", + ImmutableList.of(path("s", "city")), + ImmutableList.of() + ); + } + + @Test + public void testCteAndUnion() throws Throwable { + assertColumn("with t as (select id, s from tbl) select struct_element(s, 'city') from (select * from t union all select 1, null) tmp", + "struct<city:text,data:array<map<int,struct<a:int,b:double>>>>", + ImmutableList.of(path("s")), + ImmutableList.of() + ); + + assertColumn("with t as (select id, s from tbl) select * from (select struct_element(s, 'city') from t union all select null) tmp", + "struct<city:text>", + ImmutableList.of(path("s", "city")), + ImmutableList.of() + ); + } + private void assertColumn(String sql, String expectType, List<TColumnNameAccessPath> expectAllAccessPaths, List<TColumnNameAccessPath> expectPredicateAccessPaths) throws Exception { - List<SlotDescriptor> slotDescriptors = collectComplexSlots(sql); + Pair<PhysicalPlan, List<SlotDescriptor>> result = collectComplexSlots(sql); + PhysicalPlan physicalPlan = result.first; + List<SlotDescriptor> slotDescriptors = result.second; if (expectType == null) { Assertions.assertEquals(0, slotDescriptors.size()); return; @@ -230,11 +302,57 @@ public class PruneNestedColumn extends TestWithFeService { TreeSet<TColumnNameAccessPath> actualPredicateAccessPaths = new TreeSet<>(slotDescriptors.get(0).getPredicateAccessPaths().name_access_paths); Assertions.assertEquals(expectPredicateAccessPathSet, actualPredicateAccessPaths); + + Map<Integer, DataType> slotIdToDataTypes = new LinkedHashMap<>(); + Consumer<Expression> assertHasSameType = e -> { + if (e instanceof NamedExpression) { + DataType dataType = slotIdToDataTypes.get(((NamedExpression) e).getExprId().asInt()); + if (dataType != null) { + Assertions.assertEquals(dataType, e.getDataType()); + } else { + slotIdToDataTypes.put(((NamedExpression) e).getExprId().asInt(), e.getDataType()); + } + } + }; + + // assert same slot id has same type + physicalPlan.foreachUp(plan -> { + List<? extends Expression> expressions = ((PhysicalPlan) plan).getExpressions(); + for (Expression expression : expressions) { + expression.foreach(e -> { + assertHasSameType.accept((Expression) e); + if (e instanceof Alias && e.child(0) instanceof Slot) { + assertHasSameType.accept((Alias) e); + } else if (e instanceof ArrayItemReference) { + assertHasSameType.accept((ArrayItemReference) e); + } + }); + } + + if (plan instanceof PhysicalCTEConsumer) { + for (Entry<Slot, Collection<Slot>> kv : ((PhysicalCTEConsumer) plan).getProducerToConsumerSlotMap() + .asMap().entrySet()) { + Slot producerSlot = kv.getKey(); + for (Slot consumerSlot : kv.getValue()) { + Assertions.assertEquals(producerSlot.getDataType(), consumerSlot.getDataType()); + } + } + } else if (plan instanceof PhysicalUnion) { + List<Slot> output = ((PhysicalUnion) plan).getOutput(); + for (List<SlotReference> regularChildrenOutput : ((PhysicalUnion) plan).getRegularChildrenOutputs()) { + Assertions.assertEquals(output.size(), regularChildrenOutput.size()); + for (int i = 0; i < output.size(); i++) { + Assertions.assertEquals(output.get(i).getDataType(), regularChildrenOutput.get(i).getDataType()); + } + } + } + }); } - private List<SlotDescriptor> collectComplexSlots(String sql) throws Exception { + private Pair<PhysicalPlan, List<SlotDescriptor>> collectComplexSlots(String sql) throws Exception { NereidsPlanner planner = (NereidsPlanner) getSqlStmtExecutor(sql).planner(); List<SlotDescriptor> complexSlots = new ArrayList<>(); + PhysicalPlan physicalPlan = planner.getPhysicalPlan(); for (PlanFragment fragment : planner.getFragments()) { List<OlapScanNode> olapScanNodes = fragment.getPlanRoot().collectInCurrentFragment(OlapScanNode.class::isInstance); for (OlapScanNode olapScanNode : olapScanNodes) { @@ -247,7 +365,7 @@ public class PruneNestedColumn extends TestWithFeService { } } } - return complexSlots; + return Pair.of(physicalPlan, complexSlots); } private TColumnNameAccessPath path(String... path) { diff --git a/gensrc/thrift/Descriptors.thrift b/gensrc/thrift/Descriptors.thrift index 5276f9df4cb..4eee5d9de50 100644 --- a/gensrc/thrift/Descriptors.thrift +++ b/gensrc/thrift/Descriptors.thrift @@ -27,6 +27,52 @@ enum TPatternType { MATCH_NAME_GLOB = 2 } +enum TAccessPathType { + NAME = 1, + // ICEBERG = 2 // implement in the future +} + +struct TColumnNameAccessPath { + // the specification of special path: + // <empty>: access the whole complex column + // *: + // 1. access every items when the type is array + // 2. access key and value when the type is map + // KEYS: only access the keys of map + // VALUES: only access the keys of map + // + // example: + // s: struct< + // data: array< + // map< + // int, + // struct< + // a: id + // b: double + // > + // > + // > + // > + // if we want to access `map_keys(s.data[0])`, the path will be: ['s', 'data', '*', 'KEYS'], + // if we want to access `map_values(s.data[0])[0].b`, the path will be: ['s', 'data', '*', 'VALUES', 'b'], + // if we want to access `s.data[0]['k'].b`, the path will be ['s', 'data', '*', '*', 'b'] + // if we want to access the whole struct of s, the path will be: ['s'], + 1: required list<string> path +} + +/* +// implement in the future +struct TIcebergColumnAccessPath { + 1: required list<i64> path +} +*/ + +struct TColumnAccessPaths { + 1: required TAccessPathType type + 2: optional list<TColumnNameAccessPath> name_access_paths + // 3: optional list<TIcebergColumnAccessPath> iceberg_column_access_paths // implement in the future +} + struct TColumn { 1: required string column_name 2: required Types.TColumnType column_type @@ -81,52 +127,6 @@ struct TSlotDescriptor { 20: optional TColumnAccessPaths predicate_access_paths } -enum TAccessPathType { - NAME = 1, - // ICEBERG = 2 // implement in the future -} - -struct TColumnNameAccessPath { - // the specification of special path: - // <empty>: access the whole complex column - // *: - // 1. access every items when the type is array - // 2. access key and value when the type is map - // KEYS: only access the keys of map - // VALUES: only access the keys of map - // - // example: - // s: struct< - // data: array< - // map< - // int, - // struct< - // a: id - // b: double - // > - // > - // > - // > - // if we want to access `map_keys(s.data[0])`, the path will be: ['s', 'data', '*', 'KEYS'], - // if we want to access `map_values(s.data[0])[0].b`, the path will be: ['s', 'data', '*', 'VALUES', 'b'], - // if we want to access `s.data[0]['k'].b`, the path will be ['s', 'data', '*', '*', 'b'] - // if we want to access the whole struct of s, the path will be: ['s'], - 1: required list<string> path -} - -/* -// implement in the future -struct TIcebergColumnAccessPath { - 1: required list<i64> path -} -*/ - -struct TColumnAccessPaths { - 1: required TAccessPathType type - 2: optional list<TColumnNameAccessPath> name_access_paths - // 3: optional list<TIcebergColumnAccessPath> iceberg_column_access_paths // implement in the future -} - struct TTupleDescriptor { 1: required Types.TTupleId id 2: required i32 byteSize // deprecated --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
