This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch pick_3.1_58964
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 0cc40aab25ec8009cee68177ae84a4e09857b4ce
Author: morrySnow <[email protected]>
AuthorDate: Mon Dec 15 18:12:07 2025 +0800

    branch-3.1: [fix](enforcer) shuffle if has continuous project or filter on 
cte consumer #58964
    
    picked from #58964
---
 .../java/org/apache/doris/nereids/cost/CostV1.java |   5 +-
 .../properties/ChildrenPropertiesRegulator.java    |  40 +++++
 .../ChildrenPropertiesRegulatorTest.java           | 170 +++++++++++++++++++++
 .../suites/nereids_syntax_p0/cte.groovy            |   5 +
 4 files changed, 219 insertions(+), 1 deletion(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java
index 75624068bc7..96fb96adf9b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/cost/CostV1.java
@@ -19,7 +19,10 @@ package org.apache.doris.nereids.cost;
 
 import org.apache.doris.qe.SessionVariable;
 
-class CostV1 implements Cost {
+/**
+ * Cost V1.
+ */
+public class CostV1 implements Cost {
     private static final CostV1 INFINITE = new 
CostV1(Double.POSITIVE_INFINITY, Double.POSITIVE_INFINITY,
             Double.POSITIVE_INFINITY,
             Double.POSITIVE_INFINITY);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
index f6868a819cc..050addafe0a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
@@ -39,6 +39,7 @@ import 
org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalPartitionTopN;
@@ -48,6 +49,7 @@ import 
org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion;
 import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
 import org.apache.doris.nereids.util.JoinUtils;
+import org.apache.doris.nereids.util.PlanUtils;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.SessionVariable;
 
@@ -202,6 +204,17 @@ public class ChildrenPropertiesRegulator extends 
PlanVisitor<Boolean, Void> {
         if (children.get(0).getPlan() instanceof PhysicalDistribute) {
             return false;
         }
+        DistributionSpec distributionSpec = 
childrenProperties.get(0).getDistributionSpec();
+        // process must shuffle
+        if (distributionSpec instanceof DistributionSpecMustShuffle) {
+            Plan child = filter.child();
+            Plan realChild = getChildPhysicalPlan(child);
+            if (realChild instanceof PhysicalProject
+                    || realChild instanceof PhysicalFilter
+                    || realChild instanceof PhysicalLimit) {
+                visit(filter, context);
+            }
+        }
         return true;
     }
 
@@ -234,6 +247,19 @@ public class ChildrenPropertiesRegulator extends 
PlanVisitor<Boolean, Void> {
         }
     }
 
+    private Plan getChildPhysicalPlan(Plan plan) {
+        if (!(plan instanceof GroupPlan)) {
+            return null;
+        }
+        GroupPlan groupPlan = (GroupPlan) plan;
+        if (groupPlan == null || groupPlan.getGroup() == null
+                || groupPlan.getGroup().getPhysicalExpressions().isEmpty()) {
+            return null;
+        } else {
+            return 
groupPlan.getGroup().getPhysicalExpressions().get(0).getPlan();
+        }
+    }
+
     private PhysicalOlapScan findDownGradeBucketShuffleCandidate(GroupPlan 
groupPlan) {
         if (groupPlan == null || groupPlan.getGroup() == null
                 || groupPlan.getGroup().getPhysicalExpressions().isEmpty()) {
@@ -467,6 +493,20 @@ public class ChildrenPropertiesRegulator extends 
PlanVisitor<Boolean, Void> {
         if (children.get(0).getPlan() instanceof PhysicalDistribute) {
             return false;
         }
+        DistributionSpec distributionSpec = 
childrenProperties.get(0).getDistributionSpec();
+        // process must shuffle
+        if (distributionSpec instanceof DistributionSpecMustShuffle) {
+            Plan child = project.child();
+            Plan realChild = getChildPhysicalPlan(child);
+            if (realChild instanceof PhysicalLimit) {
+                visit(project, context);
+            } else if (realChild instanceof PhysicalProject) {
+                PhysicalProject physicalProject = (PhysicalProject) realChild;
+                if (!PlanUtils.tryMergeProjections(project.getProjects(), 
physicalProject.getProjects()).isPresent()) {
+                    visit(project, context);
+                }
+            }
+        }
         return true;
     }
 
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulatorTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulatorTest.java
new file mode 100644
index 00000000000..61517ac5785
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulatorTest.java
@@ -0,0 +1,170 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.properties;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.cost.Cost;
+import org.apache.doris.nereids.cost.CostCalculator;
+import org.apache.doris.nereids.cost.CostV1;
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.memo.Group;
+import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
+import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+public class ChildrenPropertiesRegulatorTest {
+
+    private List<GroupExpression> children;
+    private JobContext mockedJobContext;
+    private List<PhysicalProperties> childrenOutputProperties = 
Lists.newArrayList(PhysicalProperties.MUST_SHUFFLE);
+
+    @BeforeEach
+    public void setUp() {
+        Group childGroup = Mockito.mock(Group.class);
+        
Mockito.when(childGroup.getLogicalProperties()).thenReturn(Mockito.mock(LogicalProperties.class));
+        GroupExpression child = Mockito.mock(GroupExpression.class);
+        
Mockito.when(child.getOutputProperties(Mockito.any())).thenReturn(PhysicalProperties.MUST_SHUFFLE);
+        Mockito.when(child.getOwnerGroup()).thenReturn(childGroup);
+        Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> lct = 
Maps.newHashMap();
+        lct.put(PhysicalProperties.MUST_SHUFFLE, Pair.of(CostV1.zero(), 
Lists.newArrayList()));
+        Mockito.when(child.getLowestCostTable()).thenReturn(lct);
+        children = Lists.newArrayList(child);
+
+        mockedJobContext = Mockito.mock(JobContext.class);
+        
Mockito.when(mockedJobContext.getCascadesContext()).thenReturn(Mockito.mock(CascadesContext.class));
+
+    }
+
+    @Test
+    public void testMustShuffleProjectProjectCanNotMerge() {
+        testMustShuffleProject(PhysicalProject.class, 
DistributionSpecExecutionAny.class, false);
+
+    }
+
+    @Test
+    public void testMustShuffleProjectProjectCanMerge() {
+        testMustShuffleProject(PhysicalProject.class, 
DistributionSpecMustShuffle.class, true);
+
+    }
+
+    @Test
+    public void testMustShuffleProjectFilter() {
+        testMustShuffleProject(PhysicalFilter.class, 
DistributionSpecMustShuffle.class, true);
+
+    }
+
+    @Test
+    public void testMustShuffleProjectLimit() {
+        testMustShuffleProject(PhysicalLimit.class, 
DistributionSpecExecutionAny.class, true);
+    }
+
+    public void testMustShuffleProject(Class<? extends Plan> childClazz,
+            Class<? extends DistributionSpec> distributeClazz,
+            boolean canMergeChildProject) {
+        try (MockedStatic<CostCalculator> mockedCostCalculator = 
Mockito.mockStatic(CostCalculator.class);
+                MockedStatic<PlanUtils> mockedPlanUtils = 
Mockito.mockStatic(PlanUtils.class)) {
+            mockedCostCalculator.when(() -> 
CostCalculator.calculateCost(Mockito.any(), Mockito.any(),
+                    Mockito.anyList())).thenReturn(CostV1.zero());
+            mockedCostCalculator.when(() -> 
CostCalculator.addChildCost(Mockito.any(), Mockito.any(), Mockito.any(),
+                    Mockito.any(), 
Mockito.anyInt())).thenReturn(CostV1.zero());
+            if (canMergeChildProject) {
+                mockedPlanUtils.when(() -> 
PlanUtils.tryMergeProjections(Mockito.any(), Mockito.any()))
+                        .thenReturn(Optional.of(Lists.newArrayList()));
+            } else {
+                mockedPlanUtils.when(() -> 
PlanUtils.tryMergeProjections(Mockito.any(), Mockito.any()))
+                        .thenReturn(Optional.empty());
+            }
+
+            // project, cannot merge
+            Plan mockedChild = Mockito.mock(childClazz);
+            
Mockito.when(mockedChild.withGroupExpression(Mockito.any())).thenReturn(mockedChild);
+            Group mockedGroup = Mockito.mock(Group.class);
+            List<GroupExpression> physicalExpressions = Lists.newArrayList(new 
GroupExpression(mockedChild));
+            
Mockito.when(mockedGroup.getPhysicalExpressions()).thenReturn(physicalExpressions);
+            GroupPlan mockedGroupPlan = Mockito.mock(GroupPlan.class);
+            Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup);
+            PhysicalProject parentPlan = new 
PhysicalProject<>(Lists.newArrayList(), null, mockedGroupPlan);
+            GroupExpression parent = new GroupExpression(parentPlan);
+            List<PhysicalProperties> childrenProperties = new 
ArrayList<>(childrenOutputProperties);
+            ChildrenPropertiesRegulator regulator = new 
ChildrenPropertiesRegulator(parent, children,
+                    childrenProperties, null, mockedJobContext);
+            regulator.adjustChildrenProperties();
+            PhysicalProperties result = childrenProperties.get(0);
+            Assertions.assertInstanceOf(distributeClazz, 
result.getDistributionSpec());
+        }
+    }
+
+    @Test
+    public void testMustShuffleFilterProject() {
+        testMustShuffleFilter(PhysicalProject.class);
+    }
+
+    @Test
+    public void testMustShuffleFilterFilter() {
+        testMustShuffleFilter(PhysicalFilter.class);
+    }
+
+    @Test
+    public void testMustShuffleFilterLimit() {
+        testMustShuffleFilter(PhysicalLimit.class);
+    }
+
+    private void testMustShuffleFilter(Class<? extends Plan> childClazz) {
+        try (MockedStatic<CostCalculator> mockedCostCalculator = 
Mockito.mockStatic(CostCalculator.class)) {
+            mockedCostCalculator.when(() -> 
CostCalculator.calculateCost(Mockito.any(), Mockito.any(),
+                    Mockito.anyList())).thenReturn(CostV1.zero());
+            mockedCostCalculator.when(() -> 
CostCalculator.addChildCost(Mockito.any(), Mockito.any(), Mockito.any(),
+                    Mockito.any(), 
Mockito.anyInt())).thenReturn(CostV1.zero());
+
+            // project, cannot merge
+            Plan mockedChild = Mockito.mock(childClazz);
+            
Mockito.when(mockedChild.withGroupExpression(Mockito.any())).thenReturn(mockedChild);
+            Group mockedGroup = Mockito.mock(Group.class);
+            List<GroupExpression> physicalExpressions = Lists.newArrayList(new 
GroupExpression(mockedChild));
+            
Mockito.when(mockedGroup.getPhysicalExpressions()).thenReturn(physicalExpressions);
+            GroupPlan mockedGroupPlan = Mockito.mock(GroupPlan.class);
+            Mockito.when(mockedGroupPlan.getGroup()).thenReturn(mockedGroup);
+            GroupExpression parent = new GroupExpression(new 
PhysicalFilter<>(Sets.newHashSet(), null, mockedGroupPlan));
+            List<PhysicalProperties> childrenProperties = new 
ArrayList<>(childrenOutputProperties);
+            ChildrenPropertiesRegulator regulator = new 
ChildrenPropertiesRegulator(parent, children,
+                    childrenProperties, null, mockedJobContext);
+            regulator.adjustChildrenProperties();
+            PhysicalProperties result = childrenProperties.get(0);
+            Assertions.assertInstanceOf(DistributionSpecExecutionAny.class, 
result.getDistributionSpec());
+        }
+    }
+}
diff --git a/regression-test/suites/nereids_syntax_p0/cte.groovy 
b/regression-test/suites/nereids_syntax_p0/cte.groovy
index 5402ffb8e21..f6b990b4f4d 100644
--- a/regression-test/suites/nereids_syntax_p0/cte.groovy
+++ b/regression-test/suites/nereids_syntax_p0/cte.groovy
@@ -334,5 +334,10 @@ suite("cte") {
     sql """
         WITH cte_0 AS ( SELECT 1 AS a ), cte_1 AS ( SELECT 1 AS a ) select * 
from cte_0, cte_1 union select * from cte_0, cte_1
     """
+
+    // test more than one project on cte consumer
+    sql """
+        with a as (select 1 c1) select *, uuid() from a union all select c2, 
c2 from (select c1 + 1, uuid() c2 from a) x ;
+    """
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to