This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new 16fcdcd4b7d [fix](Nereids) not do distinct when aggregate with 
distinct project (#36057)
16fcdcd4b7d is described below

commit 16fcdcd4b7db81a2290d53bfc21fcf09d2a11356
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Sat Jun 8 09:04:56 2024 +0800

    [fix](Nereids) not do distinct when aggregate with distinct project (#36057)
    
    pick from master #35899
---
 .../glue/translator/PhysicalPlanTranslator.java    |  2 +-
 .../doris/nereids/parser/LogicalPlanBuilder.java   |  9 ++-
 .../org/apache/doris/nereids/rules/RuleType.java   |  2 +
 .../nereids/rules/analysis/BindExpression.java     | 13 +++-
 .../nereids/rules/analysis/FillUpMissingSlots.java | 64 ++++++++++++++++-
 .../nereids/rules/analysis/NormalizeRepeat.java    |  4 +-
 .../rules/exploration/join/OuterJoinLAsscom.java   |  2 +-
 .../nereids/rules/rewrite/AdjustPreAggStatus.java  |  2 +-
 .../nereids/rules/rewrite/MergeAggregate.java      |  2 +-
 .../mv/SelectMaterializedIndexWithAggregate.java   |  2 +-
 .../org/apache/doris/nereids/trees/TreeNode.java   |  4 +-
 .../functions/ComputeSignatureHelper.java          |  1 -
 .../expressions/functions/udf/AliasUdfBuilder.java |  5 +-
 .../doris/nereids/trees/plans/algebra/Repeat.java  | 32 +++++++++
 .../trees/plans/commands/DeleteFromCommand.java    |  8 +--
 .../insert/BatchInsertIntoTableCommand.java        |  7 +-
 .../commands/insert/InsertIntoTableCommand.java    |  3 +-
 .../insert/InsertOverwriteTableCommand.java        |  3 +-
 .../trees/plans/physical/PhysicalRepeat.java       |  7 ++
 .../apache/doris/nereids/util/ExpressionUtils.java |  5 +-
 .../org/apache/doris/nereids/util/PlanUtils.java   |  6 +-
 .../rules/rewrite/mv/SelectMvIndexTest.java        |  4 +-
 .../aggregate/agg_with_distinct_project.out        | 30 ++++++++
 .../aggregate/agg_with_distinct_project.groovy     | 82 ++++++++++++++++++++++
 24 files changed, 259 insertions(+), 40 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 9902921f39a..983d00a27be 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -951,7 +951,7 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
         List<AggregateExpression> aggregateExpressionList = 
outputExpressions.stream()
                 .filter(o -> o.anyMatch(AggregateExpression.class::isInstance))
                 .peek(o -> aggFunctionOutput.add(o.toSlot()))
-                .map(o -> 
o.<Set<AggregateExpression>>collect(AggregateExpression.class::isInstance))
+                .map(o -> 
o.<AggregateExpression>collect(AggregateExpression.class::isInstance))
                 .flatMap(Set::stream)
                 .collect(Collectors.toList());
         ArrayList<FunctionCallExpr> execAggregateFunctions = 
aggregateExpressionList.stream()
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
index ea62df6dc18..a499e0e00ce 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
@@ -3110,10 +3110,15 @@ public class LogicalPlanBuilder extends 
DorisParserBaseVisitor<Object> {
     }
 
     private LogicalPlan withProjection(LogicalPlan input, 
SelectColumnClauseContext selectCtx,
-                                       Optional<AggClauseContext> aggCtx, 
boolean isDistinct) {
+            Optional<AggClauseContext> aggCtx, boolean isDistinct) {
         return ParserUtils.withOrigin(selectCtx, () -> {
             if (aggCtx.isPresent()) {
-                return input;
+                if (isDistinct) {
+                    return new LogicalProject<>(ImmutableList.of(new 
UnboundStar(ImmutableList.of())),
+                            Collections.emptyList(), isDistinct, input);
+                } else {
+                    return input;
+                }
             } else {
                 if (selectCtx.EXCEPT() != null) {
                     List<NamedExpression> expressions = 
getNamedExpressions(selectCtx.namedExpressionSeq());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 24ba21c06a3..864d8fd6bd3 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -60,6 +60,8 @@ public enum RuleType {
     FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE),
     FILL_UP_HAVING_PROJECT(RuleTypeClass.REWRITE),
     FILL_UP_SORT_AGGREGATE(RuleTypeClass.REWRITE),
+    FILL_UP_SORT_AGGREGATE_AGGREGATE(RuleTypeClass.REWRITE),
+    FILL_UP_SORT_AGGREGATE_HAVING_AGGREGATE(RuleTypeClass.REWRITE),
     FILL_UP_SORT_HAVING_PROJECT(RuleTypeClass.REWRITE),
     FILL_UP_SORT_HAVING_AGGREGATE(RuleTypeClass.REWRITE),
     FILL_UP_SORT_PROJECT(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
index 34af516c45b..c1080adf3b7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
@@ -740,16 +740,25 @@ public class BindExpression implements 
AnalysisRuleFactory {
     }
 
     private Plan 
bindSortWithoutSetOperation(MatchingContext<LogicalSort<Plan>> ctx) {
+        CascadesContext cascadesContext = ctx.cascadesContext;
         LogicalSort<Plan> sort = ctx.root;
         Plan input = sort.child();
-
         List<Slot> childOutput = input.getOutput();
 
+        // we should skip distinct project to bind slot in LogicalSort;
+        // check input.child(0) to avoid process SELECT DISTINCT a FROM t 
ORDER BY b by mistake
+        // NOTICE: SELECT a FROM (SELECT sum(a) AS a FROM t GROUP BY b) v 
ORDER BY b will not raise error result
+        //   because input.child(0) is LogicalSubqueryAlias
+        if (input instanceof LogicalProject && ((LogicalProject<?>) 
input).isDistinct()
+                && (input.child(0) instanceof LogicalHaving
+                || input.child(0) instanceof LogicalAggregate
+                || input.child(0) instanceof LogicalRepeat)) {
+            input = input.child(0);
+        }
         // we should skip LogicalHaving to bind slot in LogicalSort;
         if (input instanceof LogicalHaving) {
             input = input.child(0);
         }
-        CascadesContext cascadesContext = ctx.cascadesContext;
 
         // 1. We should deduplicate the slots, otherwise the binding process 
will fail due to the
         //    ambiguous slots exist.
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
index 1cab3614302..f78beb130e5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java
@@ -39,6 +39,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 import com.google.common.collect.Streams;
 
 import java.util.List;
@@ -76,6 +77,22 @@ public class FillUpMissingSlots implements 
AnalysisRuleFactory {
                                 sort.withChildren(new 
LogicalProject<>(projects, project.child())));
                     })
             ),
+            RuleType.FILL_UP_SORT_AGGREGATE_HAVING_AGGREGATE.build(
+                logicalSort(
+                    aggregate(logicalHaving(aggregate()))
+                        .when(a -> 
a.getOutputExpressions().stream().allMatch(SlotReference.class::isInstance))
+                ).when(this::checkSort)
+                    .then(sort -> processDistinctProjectWithAggregate(sort, 
sort.child(), sort.child().child().child()))
+            ),
+            // ATTN: process aggregate with distinct project, must run this 
rule before FILL_UP_SORT_AGGREGATE
+            //   because this pattern will always fail in 
FILL_UP_SORT_AGGREGATE
+            RuleType.FILL_UP_SORT_AGGREGATE_AGGREGATE.build(
+                logicalSort(
+                    aggregate(aggregate())
+                        .when(a -> 
a.getOutputExpressions().stream().allMatch(SlotReference.class::isInstance))
+                ).when(this::checkSort)
+                    .then(sort -> processDistinctProjectWithAggregate(sort, 
sort.child(), sort.child().child()))
+            ),
             RuleType.FILL_UP_SORT_AGGREGATE.build(
                 logicalSort(aggregate())
                     .when(this::checkSort)
@@ -334,7 +351,7 @@ public class FillUpMissingSlots implements 
AnalysisRuleFactory {
     }
 
     interface PlanGenerator {
-        Plan apply(Resolver resolver, Aggregate aggregate);
+        Plan apply(Resolver resolver, Aggregate<?> aggregate);
     }
 
     private Plan createPlan(Resolver resolver, Aggregate<? extends Plan> 
aggregate, PlanGenerator planGenerator) {
@@ -371,4 +388,49 @@ public class FillUpMissingSlots implements 
AnalysisRuleFactory {
         }
         return false;
     }
+
+    /**
+     * for sql like SELECT DISTINCT a FROM t GROUP BY a HAVING b > 0 ORDER BY 
a.
+     * there order by need to bind with bottom aggregate's output and bottom 
aggregate's child's output.
+     * this function used to fill up missing slot for these situations 
correctly.
+     *
+     * @param sort top sort
+     * @param upperAggregate upper aggregate used to check slot in order by 
should be in select list
+     * @param bottomAggregate bottom aggregate used to bind with its and its 
child's output
+     *
+     * @return filled up plan
+     */
+    private Plan processDistinctProjectWithAggregate(LogicalSort<?> sort,
+            Aggregate<?> upperAggregate, Aggregate<Plan> bottomAggregate) {
+        Resolver resolver = new Resolver(bottomAggregate);
+        sort.getExpressions().forEach(resolver::resolve);
+        return createPlan(resolver, bottomAggregate, (r, a) -> {
+            List<OrderKey> newOrderKeys = sort.getOrderKeys().stream()
+                    .map(ok -> new OrderKey(
+                            ExpressionUtils.replace(ok.getExpr(), 
r.getSubstitution()),
+                            ok.isAsc(),
+                            ok.isNullFirst()))
+                    .collect(ImmutableList.toImmutableList());
+            boolean sortNotChanged = newOrderKeys.equals(sort.getOrderKeys());
+            boolean aggNotChanged = a.equals(bottomAggregate);
+            if (sortNotChanged && aggNotChanged) {
+                return null;
+            }
+            if (aggNotChanged) {
+                // since sort expr must in select list, we should not change 
agg at all.
+                return new LogicalSort<>(newOrderKeys, sort.child());
+            } else {
+                Set<NamedExpression> upperAggOutputs = 
Sets.newHashSet(upperAggregate.getOutputExpressions());
+                for (int i = 0; i < newOrderKeys.size(); i++) {
+                    OrderKey orderKey = newOrderKeys.get(i);
+                    Expression expression = orderKey.getExpr();
+                    if 
(!upperAggOutputs.containsAll(expression.getInputSlots())) {
+                        throw new 
AnalysisException(sort.getOrderKeys().get(i).getExpr().toSql()
+                                + " of ORDER BY clause is not in SELECT list");
+                    }
+                }
+                throw new AnalysisException("Expression of ORDER BY clause is 
not in SELECT list");
+            }
+        });
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
index 2d39852dd18..c355effcffc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java
@@ -372,14 +372,14 @@ public class NormalizeRepeat extends 
OneAnalysisRuleFactory {
                 
CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions());
         ImmutableSet.Builder<Slot> aggUsedSlotBuilder = ImmutableSet.builder();
         for (AggregateFunction function : aggregateFunctions) {
-            
aggUsedSlotBuilder.addAll(function.<Set<SlotReference>>collect(SlotReference.class::isInstance));
+            
aggUsedSlotBuilder.addAll(function.<SlotReference>collect(SlotReference.class::isInstance));
         }
         ImmutableSet<Slot> aggUsedSlots = aggUsedSlotBuilder.build();
 
         ImmutableSet.Builder<Slot> groupingSetsUsedSlotBuilder = 
ImmutableSet.builder();
         for (List<Expression> groupingSet : repeat.getGroupingSets()) {
             for (Expression expr : groupingSet) {
-                
groupingSetsUsedSlotBuilder.addAll(expr.<Set<SlotReference>>collect(SlotReference.class::isInstance));
+                
groupingSetsUsedSlotBuilder.addAll(expr.<SlotReference>collect(SlotReference.class::isInstance));
             }
         }
         ImmutableSet<Slot> groupingSetsUsedSlot = 
groupingSetsUsedSlotBuilder.build();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
index 40ef9b96229..f10daeadcd9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
@@ -96,7 +96,7 @@ public class OuterJoinLAsscom extends 
OneExplorationRuleFactory {
                         topJoin.getHashJoinConjuncts().stream(),
                         topJoin.getOtherJoinConjuncts().stream())
                 .allMatch(expr -> {
-                    Set<ExprId> usedExprIdSet = 
expr.<Set<SlotReference>>collect(SlotReference.class::isInstance)
+                    Set<ExprId> usedExprIdSet = 
expr.<SlotReference>collect(SlotReference.class::isInstance)
                             .stream()
                             .map(SlotReference::getExprId)
                             .collect(Collectors.toSet());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java
index 8b90e4cdedc..495a06870f5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustPreAggStatus.java
@@ -383,7 +383,7 @@ public class AdjustPreAggStatus implements 
RewriteRuleFactory {
                 project.map(Project::getAliasToProducer);
         return agg.getOutputExpressions().stream()
                 // extract aggregate functions.
-                .flatMap(e -> 
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance)
+                .flatMap(e -> 
e.<AggregateFunction>collect(AggregateFunction.class::isInstance)
                         .stream())
                 // replace aggregate function's input slot by its producing 
expression.
                 .map(expr -> slotToProducerOpt
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
index 889adfb69f5..8b4a724d073 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/MergeAggregate.java
@@ -150,7 +150,7 @@ public class MergeAggregate implements RewriteRuleFactory {
         });
     }
 
-    boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, 
LogicalAggregate<Plan> innerAgg,
+    private boolean commonCheck(LogicalAggregate<? extends Plan> outerAgg, 
LogicalAggregate<Plan> innerAgg,
             boolean sameGroupBy, Optional<LogicalProject> projectOptional) {
         innerAggExprIdToAggFunc = innerAgg.getOutputExpressions().stream()
                 .filter(expr -> (expr instanceof Alias) && (expr.child(0) 
instanceof AggregateFunction))
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
index cb03a0c5840..f710372d6eb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
@@ -674,7 +674,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
         Optional<Map<Slot, Expression>> slotToProducerOpt = 
project.map(Project::getAliasToProducer);
         return agg.getOutputExpressions().stream()
                 // extract aggregate functions.
-                .flatMap(e -> 
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance).stream())
+                .flatMap(e -> 
e.<AggregateFunction>collect(AggregateFunction.class::isInstance).stream())
                 // replace aggregate function's input slot by its producing 
expression.
                 .map(expr -> slotToProducerOpt.map(slotToExpressions
                                 -> (AggregateFunction) 
ExpressionUtils.replace(expr, slotToExpressions))
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index 6d1a298eb79..a4bfab08890 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -232,14 +232,14 @@ public interface TreeNode<NODE_TYPE extends 
TreeNode<NODE_TYPE>> {
     /**
      * Collect the nodes that satisfied the predicate.
      */
-    default <T> T collect(Predicate<TreeNode<NODE_TYPE>> predicate) {
+    default <T> Set<T> collect(Predicate<TreeNode<NODE_TYPE>> predicate) {
         ImmutableSet.Builder<TreeNode<NODE_TYPE>> result = 
ImmutableSet.builder();
         foreach(node -> {
             if (predicate.test(node)) {
                 result.add(node);
             }
         });
-        return (T) result.build();
+        return (Set<T>) result.build();
     }
 
     /**
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java
index 2cdbe43c12e..166f1c9db7f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java
@@ -39,7 +39,6 @@ import org.apache.doris.nereids.util.TypeCoercionUtils;
 
 import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java
index 1f15b7e6049..9ddb8ea25e5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java
@@ -33,7 +33,6 @@ import com.google.common.collect.Maps;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
-import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
@@ -85,8 +84,8 @@ public class AliasUdfBuilder extends UdfBuilder {
 
         // replace the placeholder slot to the input expressions.
         // adjust input, parameter and replaceMap to be corresponding.
-        Map<String, SlotReference> slots = ((Set<SlotReference>) boundFunction
-                .collect(SlotReference.class::isInstance))
+        Map<String, SlotReference> slots = (boundFunction
+                .<SlotReference>collect(SlotReference.class::isInstance))
                 .stream().collect(Collectors.toMap(SlotReference::getName, k 
-> k, (v1, v2) -> v2));
 
         Map<SlotReference, Expression> replaceMap = Maps.newHashMap();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
index e729f2a7cb3..8925e597850 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java
@@ -58,6 +58,38 @@ public interface Repeat<CHILD_PLAN extends Plan> extends 
Aggregate<CHILD_PLAN> {
         return ExpressionUtils.flatExpressions(getGroupingSets());
     }
 
+    @Override
+    default Aggregate<CHILD_PLAN> pruneOutputs(List<NamedExpression> 
prunedOutputs) {
+        // just output reserved outputs and COL_GROUPING_ID for repeat 
correctly.
+        ImmutableList.Builder<NamedExpression> outputBuilder
+                = ImmutableList.builderWithExpectedSize(prunedOutputs.size() + 
1);
+        outputBuilder.addAll(prunedOutputs);
+        for (NamedExpression output : getOutputExpressions()) {
+            Set<VirtualSlotReference> v = 
output.collect(VirtualSlotReference.class::isInstance);
+            if (v.stream().anyMatch(slot -> 
slot.getName().equals(COL_GROUPING_ID))) {
+                outputBuilder.add(output);
+            }
+        }
+        // prune groupingSets, if parent operator do not need some exprs in 
grouping sets, we removed it.
+        // this could not lead to wrong result because be repeat other columns 
by normal.
+        ImmutableList.Builder<List<Expression>> groupingSetsBuilder
+                = 
ImmutableList.builderWithExpectedSize(getGroupingSets().size());
+        for (List<Expression> groupingSet : getGroupingSets()) {
+            ImmutableList.Builder<Expression> groupingSetBuilder
+                    = 
ImmutableList.builderWithExpectedSize(groupingSet.size());
+            for (Expression expr : groupingSet) {
+                if (prunedOutputs.contains(expr)) {
+                    groupingSetBuilder.add(expr);
+                }
+            }
+            groupingSetsBuilder.add(groupingSetBuilder.build());
+        }
+        return withGroupSetsAndOutput(groupingSetsBuilder.build(), 
outputBuilder.build());
+    }
+
+    Repeat<CHILD_PLAN> withGroupSetsAndOutput(List<List<Expression>> 
groupingSets,
+            List<NamedExpression> outputExpressions);
+
     static VirtualSlotReference generateVirtualGroupingIdSlot() {
         return new VirtualSlotReference(COL_GROUPING_ID, BigIntType.INSTANCE, 
Optional.empty(),
                 GroupingSetShapes::computeVirtualGroupingIdValue);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DeleteFromCommand.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DeleteFromCommand.java
index 62feee1c43f..6563d815382 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DeleteFromCommand.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/DeleteFromCommand.java
@@ -109,13 +109,13 @@ public class DeleteFromCommand extends Command implements 
ForwardWithSync {
             return;
         }
         Optional<PhysicalFilter<?>> optFilter = (planner.getPhysicalPlan()
-                
.<Set<PhysicalFilter<?>>>collect(PhysicalFilter.class::isInstance)).stream()
+                
.<PhysicalFilter<?>>collect(PhysicalFilter.class::isInstance)).stream()
                 .findAny();
         Optional<PhysicalOlapScan> optScan = (planner.getPhysicalPlan()
-                
.<Set<PhysicalOlapScan>>collect(PhysicalOlapScan.class::isInstance)).stream()
+                
.<PhysicalOlapScan>collect(PhysicalOlapScan.class::isInstance)).stream()
                 .findAny();
         Optional<UnboundRelation> optRelation = (logicalQuery
-                
.<Set<UnboundRelation>>collect(UnboundRelation.class::isInstance)).stream()
+                
.<UnboundRelation>collect(UnboundRelation.class::isInstance)).stream()
                 .findAny();
         Preconditions.checkArgument(optFilter.isPresent(), "delete command 
must contain filter");
         Preconditions.checkArgument(optScan.isPresent(), "delete command could 
be only used on olap table");
@@ -141,7 +141,7 @@ public class DeleteFromCommand extends Command implements 
ForwardWithSync {
             Plan plan = planner.getPhysicalPlan();
             checkSubQuery(plan);
             for (Expression conjunct : filter.getConjuncts()) {
-                
conjunct.<Set<SlotReference>>collect(SlotReference.class::isInstance)
+                
conjunct.<SlotReference>collect(SlotReference.class::isInstance)
                         .forEach(s -> checkColumn(columns, s, olapTable));
                 checkPredicate(conjunct);
             }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/BatchInsertIntoTableCommand.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/BatchInsertIntoTableCommand.java
index 4399cd57db4..4b7afb1f6a8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/BatchInsertIntoTableCommand.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/BatchInsertIntoTableCommand.java
@@ -53,7 +53,6 @@ import org.apache.logging.log4j.Logger;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
-import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
@@ -111,7 +110,7 @@ public class BatchInsertIntoTableCommand extends Command 
implements NoForward, E
             }
 
             Optional<TreeNode<?>> plan = planner.getPhysicalPlan()
-                    
.<Set<TreeNode<?>>>collect(PhysicalOlapTableSink.class::isInstance).stream().findAny();
+                    
.<TreeNode<?>>collect(PhysicalOlapTableSink.class::isInstance).stream().findAny();
             Preconditions.checkArgument(plan.isPresent(), "insert into command 
must contain OlapTableSinkNode");
             sink = ((PhysicalOlapTableSink<?>) plan.get());
             Table targetTable = sink.getTargetTable();
@@ -141,14 +140,14 @@ public class BatchInsertIntoTableCommand extends Command 
implements NoForward, E
             }
 
             Optional<PhysicalUnion> union = planner.getPhysicalPlan()
-                    
.<Set<PhysicalUnion>>collect(PhysicalUnion.class::isInstance).stream().findAny();
+                    
.<PhysicalUnion>collect(PhysicalUnion.class::isInstance).stream().findAny();
             if (union.isPresent()) {
                 InsertUtils.executeBatchInsertTransaction(ctx, 
targetTable.getQualifiedDbName(),
                         targetTable.getName(), targetSchema, 
union.get().getConstantExprsList());
                 return;
             }
             Optional<PhysicalOneRowRelation> oneRowRelation = 
planner.getPhysicalPlan()
-                    
.<Set<PhysicalOneRowRelation>>collect(PhysicalOneRowRelation.class::isInstance).stream().findAny();
+                    
.<PhysicalOneRowRelation>collect(PhysicalOneRowRelation.class::isInstance).stream().findAny();
             if (oneRowRelation.isPresent()) {
                 InsertUtils.executeBatchInsertTransaction(ctx, 
targetTable.getQualifiedDbName(),
                         targetTable.getName(), targetSchema, 
ImmutableList.of(oneRowRelation.get().getProjects()));
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertIntoTableCommand.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertIntoTableCommand.java
index 4118cd1ddaa..a88b8cc2e05 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertIntoTableCommand.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertIntoTableCommand.java
@@ -54,7 +54,6 @@ import org.apache.logging.log4j.Logger;
 
 import java.util.Objects;
 import java.util.Optional;
-import java.util.Set;
 
 /**
  * insert into select command implementation
@@ -152,7 +151,7 @@ public class InsertIntoTableCommand extends Command 
implements ForwardWithSync,
                 ctx.getMysqlChannel().reset();
             }
             Optional<PhysicalSink<?>> plan = (planner.getPhysicalPlan()
-                    
.<Set<PhysicalSink<?>>>collect(PhysicalSink.class::isInstance)).stream()
+                    
.<PhysicalSink<?>>collect(PhysicalSink.class::isInstance)).stream()
                     .findAny();
             Preconditions.checkArgument(plan.isPresent(), "insert into command 
must contain target table");
             PhysicalSink physicalSink = plan.get();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertOverwriteTableCommand.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertOverwriteTableCommand.java
index d9047dcfc7d..34d9c093718 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertOverwriteTableCommand.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/insert/InsertOverwriteTableCommand.java
@@ -60,7 +60,6 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
-import java.util.Set;
 
 /**
  * insert into select command implementation
@@ -122,7 +121,7 @@ public class InsertOverwriteTableCommand extends Command 
implements ForwardWithS
         }
 
         Optional<TreeNode<?>> plan = (planner.getPhysicalPlan()
-                .<Set<TreeNode<?>>>collect(node -> node instanceof 
PhysicalTableSink)).stream().findAny();
+                .<TreeNode<?>>collect(node -> node instanceof 
PhysicalTableSink)).stream().findAny();
         Preconditions.checkArgument(plan.isPresent(), "insert into command 
must contain OlapTableSinkNode");
         PhysicalTableSink<?> physicalTableSink = ((PhysicalTableSink<?>) 
plan.get());
         TableIf targetTable = physicalTableSink.getTargetTable();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java
index 3cb6f730069..1ef2ff77d5a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalRepeat.java
@@ -178,6 +178,13 @@ public class PhysicalRepeat<CHILD_TYPE extends Plan> 
extends PhysicalUnary<CHILD
                 getLogicalProperties(), physicalProperties, statistics, 
child());
     }
 
+    @Override
+    public PhysicalRepeat<CHILD_TYPE> 
withGroupSetsAndOutput(List<List<Expression>> groupingSets,
+            List<NamedExpression> outputExpressionList) {
+        return new PhysicalRepeat<>(groupingSets, outputExpressionList, 
Optional.empty(),
+                getLogicalProperties(), physicalProperties, statistics, 
child());
+    }
+
     @Override
     public PhysicalRepeat<CHILD_TYPE> resetLogicalProperties() {
         return new PhysicalRepeat<>(groupingSets, outputExpressions, 
groupExpression,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index a6a4d999a92..b19d4b096e2 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -65,6 +65,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.BitSet;
 import java.util.Collection;
@@ -522,9 +523,7 @@ public class ExpressionUtils {
         ImmutableList<Literal> literals =
                 ImmutableList.of(new NullLiteral(BooleanType.INSTANCE), 
BooleanLiteral.FALSE);
         List<MarkJoinSlotReference> markJoinSlotReferenceList =
-                ((Set<MarkJoinSlotReference>) predicate
-                        
.collect(MarkJoinSlotReference.class::isInstance)).stream()
-                                .collect(Collectors.toList());
+                new 
ArrayList<>((predicate.collect(MarkJoinSlotReference.class::isInstance)));
         int markSlotSize = markJoinSlotReferenceList.size();
         int maxMarkSlotCount = 4;
         // if the conjunct has mark slot, and maximum 4 mark slots(for 
performance)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
index 9c5e6b318e8..7cfa7b7709e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
@@ -40,7 +40,6 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 
 import java.util.Collection;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
@@ -140,10 +139,7 @@ public class PlanUtils {
     }
 
     public static Set<LogicalCatalogRelation> 
getLogicalScanFromRootPlan(LogicalPlan rootPlan) {
-        Set<LogicalCatalogRelation> tableSet = new HashSet<>();
-        tableSet.addAll((Collection<? extends LogicalCatalogRelation>) rootPlan
-                .collect(LogicalCatalogRelation.class::isInstance));
-        return tableSet;
+        return rootPlan.collect(LogicalCatalogRelation.class::isInstance);
     }
 
     /**
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
index 83b969d9f12..bc8df2ab0f7 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMvIndexTest.java
@@ -1224,7 +1224,7 @@ class SelectMvIndexTest extends 
BaseMaterializedIndexSelectTest implements MemoP
     private void assertOneAggFuncType(LogicalAggregate<? extends Plan> agg, 
Class<?> aggFuncType) {
         Set<AggregateFunction> aggFuncs = agg.getOutputExpressions()
                 .stream()
-                .flatMap(e -> 
e.<Set<AggregateFunction>>collect(AggregateFunction.class::isInstance)
+                .flatMap(e -> 
e.<AggregateFunction>collect(AggregateFunction.class::isInstance)
                         .stream())
                 .collect(Collectors.toSet());
         Assertions.assertEquals(1, aggFuncs.size());
@@ -1239,7 +1239,7 @@ class SelectMvIndexTest extends 
BaseMaterializedIndexSelectTest implements MemoP
             Assertions.assertEquals(2, scans.size());
 
             ScanNode scanNode0 = scans.get(0);
-            Assertions.assertTrue(scanNode0 instanceof OlapScanNode);
+            Assertions.assertInstanceOf(OlapScanNode.class, scanNode0);
             OlapScanNode scan0 = (OlapScanNode) scanNode0;
             Assertions.assertTrue(scan0.isPreAggregation());
             Assertions.assertEquals(firstTableIndexName, 
scan0.getSelectedIndexName());
diff --git 
a/regression-test/data/nereids_p0/aggregate/agg_with_distinct_project.out 
b/regression-test/data/nereids_p0/aggregate/agg_with_distinct_project.out
new file mode 100644
index 00000000000..ac5a1851bcb
--- /dev/null
+++ b/regression-test/data/nereids_p0/aggregate/agg_with_distinct_project.out
@@ -0,0 +1,30 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !base_case --
+0
+81
+
+-- !with_order --
+1
+82
+
+-- !with_having --
+1
+82
+
+-- !with_having_with_order --
+1
+82
+
+-- !with_order_with_grouping_sets --
+\N
+1
+82
+
+-- !with_having_with_grouping_sets --
+1
+82
+
+-- !with_having_with_order_with_grouping_sets --
+1
+82
+
diff --git 
a/regression-test/suites/nereids_p0/aggregate/agg_with_distinct_project.groovy 
b/regression-test/suites/nereids_p0/aggregate/agg_with_distinct_project.groovy
new file mode 100644
index 00000000000..56cce71c8fc
--- /dev/null
+++ 
b/regression-test/suites/nereids_p0/aggregate/agg_with_distinct_project.groovy
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("agg_with_distinct_project") {
+
+    sql "set enable_fallback_to_original_planner=false"
+    sql "DROP TABLE IF EXISTS agg_with_distinct_project;"
+    sql """
+        CREATE TABLE agg_with_distinct_project (
+          id int NOT NULL,
+          a int DEFAULT NULL,
+          b int DEFAULT NULL
+        )
+        PROPERTIES (
+          "replication_allocation" = "tag.location.default: 1"
+        );
+    """
+
+    sql """INSERT INTO agg_with_distinct_project 
VALUES(83,0,38),(26,0,79),(43,81,24)"""
+
+    order_qt_base_case """
+        SELECT DISTINCT a as c1 FROM agg_with_distinct_project GROUP BY b, a;
+    """
+
+    qt_with_order """
+        select distinct  a + 1 from agg_with_distinct_project group by a + 1, 
b order by a + 1;
+    """
+
+    order_qt_with_having """
+        select distinct a + 1 from agg_with_distinct_project group by a + 1, b 
having b > 1;
+    """
+
+    qt_with_having_with_order """
+        select distinct a + 1 from agg_with_distinct_project group by a + 1, b 
having b > 1 order by a + 1;
+    """
+
+    qt_with_order_with_grouping_sets """
+         select distinct  a + 1 from agg_with_distinct_project group by 
grouping sets(( a + 1, b ), (b + 1)) order by a + 1;
+    """
+
+    order_qt_with_having_with_grouping_sets """
+         select distinct a + 1 from agg_with_distinct_project group by 
grouping sets(( a + 1, b ), (b + 1)) having b > 1;
+    """
+
+    qt_with_having_with_order_with_grouping_sets """
+         select distinct a + 1 from agg_with_distinct_project group by 
grouping sets(( a + 1, b ), (b + 1)) having b > 1 order by a + 1;
+    """
+
+    // order by column not in select list
+    test {
+        sql """
+             select distinct  a + 1 from agg_with_distinct_project group by a 
+ 1, b order by b;
+        """
+        exception "b of ORDER BY clause is not in SELECT list"
+    }
+
+    // order by column not in select list
+    test {
+        sql """
+             select distinct  a + 1 from agg_with_distinct_project group by 
grouping sets(( a + 1, b ), (b + 1)) order by b;
+        """
+        exception "b of ORDER BY clause is not in SELECT list"
+    }
+
+    sql "DROP TABLE IF EXISTS agg_with_distinct_project;"
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to