This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 045dd05f2aecf11dc9651edca5e0b3e1c212abc5
Author: 谢健 <jianx...@gmail.com>
AuthorDate: Mon Apr 8 11:08:05 2024 +0800

    [fix](Nereids): don't transpose agg and join if join is mark join (#33312)
---
 .../rules/exploration/TransposeAggSemiJoin.java        |  4 ++--
 .../rules/exploration/TransposeAggSemiJoinProject.java |  4 ++--
 .../rules/exploration/TransposeAggSemiJoinTest.java    | 18 ++++++++++++++++++
 .../rules/rewrite/TransposeSemiJoinAggProjectTest.java | 15 +++++++++++++++
 4 files changed, 37 insertions(+), 4 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java
index e25e1c816a2..564fc07513e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java
@@ -31,8 +31,8 @@ public class TransposeAggSemiJoin extends 
OneExplorationRuleFactory {
 
     @Override
     public Rule build() {
-        return logicalAggregate(logicalJoin())
-                .when(agg -> agg.child().getJoinType().isLeftSemiOrAntiJoin())
+        return logicalAggregate(
+                logicalJoin().when(join -> 
join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin()))
                 .then(agg -> {
                     LogicalJoin<GroupPlan, GroupPlan> join = agg.child();
                     if (!TransposeSemiJoinAgg.canTranspose(agg, join)) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java
index f1a7355a195..9beb93b9654 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java
@@ -32,8 +32,8 @@ public class TransposeAggSemiJoinProject extends 
OneExplorationRuleFactory {
 
     @Override
     public Rule build() {
-        return logicalAggregate(logicalProject(logicalJoin()))
-                .when(agg -> 
agg.child().child().getJoinType().isLeftSemiOrAntiJoin())
+        return logicalAggregate(logicalProject(
+                logicalJoin().when(join -> 
join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin())))
                 .then(agg -> {
                     LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project 
= agg.child();
                     LogicalJoin<GroupPlan, GroupPlan> join = project.child();
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java
index 9c1e19282ad..68cac382bc1 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java
@@ -30,6 +30,7 @@ import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
 import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 class TransposeAggSemiJoinTest implements MemoPatternMatchSupported {
@@ -57,4 +58,21 @@ class TransposeAggSemiJoinTest implements 
MemoPatternMatchSupported {
                         )
                 );
     }
+
+    @Test
+    void markJoin() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
+                .aggGroupUsingIndex(ImmutableList.of(0),
+                        ImmutableList.of(
+                                scan1.getOutput().get(0),
+                                new Alias(new Sum(scan1.getOutput().get(1)), 
"sum")
+                        )
+                )
+                .build();
+        int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyExploration(TransposeAggSemiJoin.INSTANCE.build())
+                .getAllPlan().size();
+        Assertions.assertEquals(1, size);
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java
index ae91e5074e6..810ab1e6295 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
@@ -28,6 +29,7 @@ import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
 import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported {
@@ -53,4 +55,17 @@ class TransposeSemiJoinAggProjectTest implements 
MemoPatternMatchSupported {
                 );
     }
 
+    @Test
+    void markJoin() {
+        LogicalPlan plan = new LogicalPlanBuilder(scan1)
+                .aggAllUsingIndex(ImmutableList.of(0, 1), ImmutableList.of(0, 
1))
+                .project(ImmutableList.of(0))
+                .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
+                .build();
+        int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+                .applyExploration(TransposeAggSemiJoin.INSTANCE.build())
+                .getAllPlan().size();
+        Assertions.assertEquals(1, size);
+    }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to