This is an automated email from the ASF dual-hosted git repository.

morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 18677371380 [fix](fe) Fix assert row join pushdown alias handling 
(#63892)
18677371380 is described below

commit 18677371380da5b4b4b290cd65e2a7cefdccb794
Author: morrySnow <[email protected]>
AuthorDate: Mon Jun 1 12:24:36 2026 +0800

    [fix](fe) Fix assert row join pushdown alias handling (#63892)
    
    ### What problem does this PR solve?
    
    Related PR: #57414
    
    Problem Summary: A scalar subquery comparison can reference a projected
    alias from the right side of an inner join. PushDownJoinOnAssertNumRows
    previously identified the pushed condition slots against the project
    output after rewriting the condition through the project, so aliases
    expanded to right-child slots could be treated as if no bottom-join
    slots were involved and the alias projection could be attached to the
    left child. The rewritten plan then referenced slots that were absent
    from that child. This change determines slot ownership from the bottom
    join output after project pushdown, keeps the original pushdown child
    order when assembling the new join, and adds a unit test for the
    right-child alias case.
    
    ### Release note
    
    Fix query planning failure for scalar subquery comparisons on projected
    join expressions.
---
 .../rules/rewrite/PushDownJoinOnAssertNumRows.java | 102 ++++++++++-----------
 .../rewrite/PushDownJoinOnAssertNumRowsTest.java   |  66 +++++++++++++
 2 files changed, 112 insertions(+), 56 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
index e52def0723c..d45cc5676fe 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRows.java
@@ -80,28 +80,28 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
     @Override
     public Rule build() {
         return logicalJoin()
-                .when(topJoin -> pattenCheck(topJoin))
-                .then(topJoin -> pushDownAssertNumRowsJoin(topJoin))
+                .when(this::pattenCheck)
+                .then(this::pushDownAssertNumRowsJoin)
                 .toRule(RuleType.PUSH_DOWN_JOIN_ON_ASSERT_NUM_ROWS);
     }
 
-    private boolean pattenCheck(LogicalJoin topJoin) {
+    private boolean pattenCheck(LogicalJoin<?, ?> topJoin) {
         // 1. right is LogicalAssertNumRows or 
LogicalProject->LogicalAssertNumRows
         // 2. left is join or project->join
         // 3. only one join condition.
         if (!topJoin.getJoinType().isInnerOrCrossJoin()) {
             return false;
         }
-        LogicalJoin bottomJoin;
+        LogicalJoin<?, ?> bottomJoin;
         Plan left = topJoin.left();
         Plan right = topJoin.right();
         if (!isAssertOneRowEqOrProjectAssertOneRowEq(right)) {
             return false;
         }
         if (left instanceof LogicalJoin) {
-            bottomJoin = (LogicalJoin) left;
+            bottomJoin = (LogicalJoin<?, ?>) left;
         } else if (left instanceof LogicalProject && left.child(0) instanceof 
LogicalJoin) {
-            bottomJoin = (LogicalJoin) left.child(0);
+            bottomJoin = (LogicalJoin<?, ?>) left.child(0);
         } else {
             return false;
         }
@@ -125,7 +125,7 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
             plan = plan.child(0);
         }
         if (plan instanceof LogicalAssertNumRows) {
-            AssertNumRowsElement assertNumRowsElement = 
((LogicalAssertNumRows) plan).getAssertNumRowsElement();
+            AssertNumRowsElement assertNumRowsElement = 
((LogicalAssertNumRows<?>) plan).getAssertNumRowsElement();
             if (assertNumRowsElement.getAssertion() == 
AssertNumRowsElement.Assertion.EQ
                     || assertNumRowsElement.getDesiredNumOfRows() == 1L) {
                 return true;
@@ -134,14 +134,14 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
         return false;
     }
 
-    private boolean joinOnAssertOneRowEq(LogicalJoin join) {
+    private boolean joinOnAssertOneRowEq(LogicalJoin<?, ?> join) {
         return isAssertOneRowEqOrProjectAssertOneRowEq(join.right())
                 || isAssertOneRowEqOrProjectAssertOneRowEq(join.left());
     }
 
-    private Plan pushDownAssertNumRowsJoin(LogicalJoin topJoin) {
+    private Plan pushDownAssertNumRowsJoin(LogicalJoin<?, ?> topJoin) {
         Plan assertBranch = topJoin.right();
-        Expression condition = (Expression) 
topJoin.getOtherJoinConjuncts().get(0);
+        Expression condition = topJoin.getOtherJoinConjuncts().get(0);
         List<Alias> aliasUsedInConditionFromLeftProject = new ArrayList<>();
         LogicalJoin<? extends Plan, ? extends Plan> bottomJoin;
         if (topJoin.left() instanceof LogicalProject) {
@@ -160,59 +160,49 @@ public class PushDownJoinOnAssertNumRows extends 
OneRewriteRuleFactory {
         Plan bottomRight = bottomJoin.right();
 
         List<Slot> conditionSlotsFromTopLeft = 
condition.getInputSlots().stream()
-                .filter(slot -> topJoin.left().getOutputSet().contains(slot))
+                .filter(slot -> bottomJoin.getOutputSet().contains(slot))
                 .collect(Collectors.toList());
+        // Nothing from the bottom join participates in this scalar-subquery 
condition.
+        if (conditionSlotsFromTopLeft.isEmpty()) {
+            return null;
+        }
         if (bottomLeft.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
-            // push to bottomLeft
-            Plan newBottomLeft;
-            if (aliasUsedInConditionFromLeftProject.isEmpty()) {
-                newBottomLeft = bottomLeft;
-            } else {
-                newBottomLeft = 
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomLeft);
-            }
-            LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new 
LogicalJoin<>(
-                    topJoin.getJoinType(),
-                    topJoin.getHashJoinConjuncts(),
-                    topJoin.getOtherJoinConjuncts(),
-                    newBottomLeft,
-                    assertBranch,
-                    topJoin.getJoinReorderContext());
-            LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = 
(LogicalJoin<? extends Plan, ? extends Plan>)
-                    bottomJoin.withChildren(newBottomJoin, bottomRight);
-            if (topJoin.left() instanceof LogicalProject) {
-                LogicalProject<? extends Plan> upperProject = 
projectAliasOnPlan(
-                        aliasUsedInConditionFromLeftProject, topJoin.left());
-                return upperProject.withChildren(newTopJoin);
-            } else {
-                return newTopJoin;
-            }
+            return assembleNewJoin(bottomLeft, topJoin, bottomJoin, 
bottomRight,
+                    assertBranch, aliasUsedInConditionFromLeftProject, true);
         } else if 
(bottomRight.getOutputSet().containsAll(conditionSlotsFromTopLeft)) {
-            Plan newBottomRight;
-            if (aliasUsedInConditionFromLeftProject.isEmpty()) {
-                newBottomRight = bottomRight;
-            } else {
-                newBottomRight = 
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottomRight);
-            }
-            LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new 
LogicalJoin<>(
-                    topJoin.getJoinType(),
-                    topJoin.getHashJoinConjuncts(),
-                    topJoin.getOtherJoinConjuncts(),
-                    newBottomRight,
-                    assertBranch,
-                    topJoin.getJoinReorderContext());
-            LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = 
(LogicalJoin<? extends Plan, ? extends Plan>)
-                    bottomJoin.withChildren(bottomLeft, newBottomJoin);
-            if (topJoin.left() instanceof LogicalProject) {
-                LogicalProject<? extends Plan> upperProject = 
projectAliasOnPlan(
-                        aliasUsedInConditionFromLeftProject, topJoin.left());
-                return upperProject.withChildren(newTopJoin);
-            } else {
-                return newTopJoin;
-            }
+            return assembleNewJoin(bottomRight, topJoin, bottomJoin, 
bottomLeft,
+                    assertBranch, aliasUsedInConditionFromLeftProject, false);
         }
         return null;
     }
 
+    private Plan assembleNewJoin(Plan bottom, LogicalJoin<?, ?> topJoin, 
LogicalJoin<?, ?> bottomJoin, Plan newTopChild,
+            Plan assertBranch, List<Alias> 
aliasUsedInConditionFromLeftProject, boolean pushLeft) {
+        Plan newBottomChild;
+        if (aliasUsedInConditionFromLeftProject.isEmpty()) {
+            newBottomChild = bottom;
+        } else {
+            newBottomChild = 
projectAliasOnPlan(aliasUsedInConditionFromLeftProject, bottom);
+        }
+        LogicalJoin<? extends Plan, ? extends Plan> newBottomJoin = new 
LogicalJoin<>(
+                topJoin.getJoinType(),
+                topJoin.getHashJoinConjuncts(),
+                topJoin.getOtherJoinConjuncts(),
+                newBottomChild,
+                assertBranch,
+                topJoin.getJoinReorderContext());
+        LogicalJoin<? extends Plan, ? extends Plan> newTopJoin = 
(LogicalJoin<? extends Plan, ? extends Plan>)
+                (pushLeft ? bottomJoin.withChildren(newBottomJoin, newTopChild)
+                        : bottomJoin.withChildren(newTopChild, newBottomJoin));
+        if (topJoin.left() instanceof LogicalProject) {
+            LogicalProject<? extends Plan> upperProject = projectAliasOnPlan(
+                    aliasUsedInConditionFromLeftProject, topJoin.left());
+            return upperProject.withChildren(newTopJoin);
+        } else {
+            return newTopJoin;
+        }
+    }
+
     @VisibleForTesting
     LogicalProject<? extends Plan> projectAliasOnPlan(List<Alias> projections, 
Plan child) {
         if (child instanceof LogicalProject) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
index aded31bd18f..d241433a219 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownJoinOnAssertNumRowsTest.java
@@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
+import org.apache.doris.nereids.trees.expressions.LessThan;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
@@ -253,6 +254,71 @@ class PushDownJoinOnAssertNumRowsTest implements 
MemoPatternMatchSupported {
                                                         logicalOlapScan())));
     }
 
+    /**
+     * Test push down when the top join condition uses an alias from the right 
child
+     * of the bottom join. This covers the following shape:
+     *
+     * Before:
+     * topJoin(rhs_score < x)
+     * |-- Project(T1.id, T2.cid + 1 as rhs_score, ...)
+     * |   `-- bottomJoin(T1.id = T2.sid)
+     * |       |-- Scan(T1)
+     * |       `-- Scan(T2)
+     * `-- LogicalAssertNumRows(output=(x, ...))
+     *
+     * After:
+     * Project(...)
+     * `-- bottomJoin(T1.id = T2.sid)
+     *     |-- Scan(T1)
+     *     `-- topJoin(rhs_score < x)
+     *         |-- Project(T2.cid + 1 as rhs_score, ...)
+     *         |   `-- Scan(T2)
+     *         `-- LogicalAssertNumRows(output=(x, ...))
+     */
+    @Test
+    void testPushDownWithProjectAliasFromRightChild() {
+        Plan oneRowRelation = new LogicalPlanBuilder(t3)
+                .limit(1)
+                .build();
+
+        AssertNumRowsElement assertElement = new AssertNumRowsElement(1, "", 
Assertion.EQ);
+        LogicalAssertNumRows<Plan> assertNumRows = new 
LogicalAssertNumRows<>(assertElement, oneRowRelation);
+
+        Expression bottomJoinCondition = new EqualTo(t1Slots.get(0), 
t2Slots.get(0));
+
+        LogicalPlan bottomJoin = new LogicalPlanBuilder(t1)
+                .join(t2, JoinType.INNER_JOIN, 
ImmutableList.of(bottomJoinCondition),
+                                ImmutableList.of())
+                .build();
+
+        Expression addExpr = new Add(t2Slots.get(1), Literal.of(1));
+        Alias rhsScore = new Alias(addExpr, "rhs_score");
+
+        ImmutableList.Builder<NamedExpression> projectListBuilder = 
ImmutableList.builder();
+        projectListBuilder.add(t1Slots.get(0));
+        projectListBuilder.add(t1Slots.get(1));
+        projectListBuilder.add(t2Slots.get(0));
+        projectListBuilder.add(rhsScore);
+
+        LogicalProject<Plan> project = new 
LogicalProject<>(projectListBuilder.build(), bottomJoin);
+
+        Expression topJoinCondition = new LessThan(rhsScore.toSlot(), 
t3Slots.get(0));
+
+        LogicalPlan root = new LogicalPlanBuilder(project)
+                .join(assertNumRows, JoinType.INNER_JOIN, ImmutableList.of(),
+                                ImmutableList.of(topJoinCondition))
+                .build();
+
+        PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+                .applyTopDown(new PushDownJoinOnAssertNumRows())
+                .matches(logicalProject(
+                                logicalJoin(
+                                                logicalOlapScan(),
+                                                logicalJoin(
+                                                                
logicalProject(logicalOlapScan()),
+                                                                
logicalAssertNumRows()))));
+    }
+
     /**
      * Test with CROSS JOIN type.
      */


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to