This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 045dd05f2aecf11dc9651edca5e0b3e1c212abc5 Author: 谢健 <jianx...@gmail.com> AuthorDate: Mon Apr 8 11:08:05 2024 +0800 [fix](Nereids): don't transpose agg and join if join is mark join (#33312) --- .../rules/exploration/TransposeAggSemiJoin.java | 4 ++-- .../rules/exploration/TransposeAggSemiJoinProject.java | 4 ++-- .../rules/exploration/TransposeAggSemiJoinTest.java | 18 ++++++++++++++++++ .../rules/rewrite/TransposeSemiJoinAggProjectTest.java | 15 +++++++++++++++ 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java index e25e1c816a2..564fc07513e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoin.java @@ -31,8 +31,8 @@ public class TransposeAggSemiJoin extends OneExplorationRuleFactory { @Override public Rule build() { - return logicalAggregate(logicalJoin()) - .when(agg -> agg.child().getJoinType().isLeftSemiOrAntiJoin()) + return logicalAggregate( + logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin())) .then(agg -> { LogicalJoin<GroupPlan, GroupPlan> join = agg.child(); if (!TransposeSemiJoinAgg.canTranspose(agg, join)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java index f1a7355a195..9beb93b9654 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinProject.java @@ -32,8 +32,8 @@ public class TransposeAggSemiJoinProject extends OneExplorationRuleFactory { @Override public Rule build() { - return logicalAggregate(logicalProject(logicalJoin())) - .when(agg -> agg.child().child().getJoinType().isLeftSemiOrAntiJoin()) + return logicalAggregate(logicalProject( + logicalJoin().when(join -> join.getJoinType().isLeftSemiOrAntiJoin() && !join.isMarkJoin()))) .then(agg -> { LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = agg.child(); LogicalJoin<GroupPlan, GroupPlan> join = project.child(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java index 9c1e19282ad..68cac382bc1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/TransposeAggSemiJoinTest.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; class TransposeAggSemiJoinTest implements MemoPatternMatchSupported { @@ -57,4 +58,21 @@ class TransposeAggSemiJoinTest implements MemoPatternMatchSupported { ) ); } + + @Test + void markJoin() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .aggGroupUsingIndex(ImmutableList.of(0), + ImmutableList.of( + scan1.getOutput().get(0), + new Alias(new Sum(scan1.getOutput().get(1)), "sum") + ) + ) + .build(); + int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(TransposeAggSemiJoin.INSTANCE.build()) + .getAllPlan().size(); + Assertions.assertEquals(1, size); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java index ae91e5074e6..810ab1e6295 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/TransposeSemiJoinAggProjectTest.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.exploration.TransposeAggSemiJoin; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -28,6 +29,7 @@ import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported { @@ -53,4 +55,17 @@ class TransposeSemiJoinAggProjectTest implements MemoPatternMatchSupported { ); } + @Test + void markJoin() { + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .aggAllUsingIndex(ImmutableList.of(0, 1), ImmutableList.of(0, 1)) + .project(ImmutableList.of(0)) + .markJoin(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .build(); + int size = PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyExploration(TransposeAggSemiJoin.INSTANCE.build()) + .getAllPlan().size(); + Assertions.assertEquals(1, size); + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org