This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new db7f955a70 [improve](Nereids): split otherJoinCondition with List. (#13216) db7f955a70 is described below commit db7f955a70812c2b40dd7431200b01276f3f3197 Author: jakevin <jakevin...@gmail.com> AuthorDate: Thu Oct 13 13:49:46 2022 +0800 [improve](Nereids): split otherJoinCondition with List. (#13216) * split otherJoinCondition with List. --- .../glue/translator/PhysicalPlanTranslator.java | 27 +++---- .../doris/nereids/parser/LogicalPlanBuilder.java | 23 +++--- .../nereids/rules/analysis/BindSlotReference.java | 6 +- .../rules/exploration/join/InnerJoinLAsscom.java | 12 ++- .../exploration/join/InnerJoinLAsscomProject.java | 11 +-- .../rules/exploration/join/JoinCommute.java | 2 +- .../rules/exploration/join/JoinLAsscomHelper.java | 5 +- .../rules/exploration/join/OuterJoinLAsscom.java | 12 ++- .../exploration/join/OuterJoinLAsscomProject.java | 11 +-- .../join/SemiJoinLogicalJoinTranspose.java | 8 +- .../join/SemiJoinLogicalJoinTransposeProject.java | 8 +- .../join/SemiJoinSemiJoinTranspose.java | 4 +- .../expression/rewrite/ExpressionRewrite.java | 22 ++++-- .../implementation/LogicalJoinToHashJoin.java | 2 +- .../LogicalJoinToNestedLoopJoin.java | 2 +- .../rules/rewrite/logical/EliminateOuter.java | 2 +- .../rules/rewrite/logical/ExistsApplyToJoin.java | 18 ++++- .../rewrite/logical/FindHashConditionForJoin.java | 11 ++- .../rules/rewrite/logical/InApplyToJoin.java | 10 +-- .../nereids/rules/rewrite/logical/MultiJoin.java | 13 ++-- .../rewrite/logical/PushFilterInsideJoin.java | 8 +- .../rewrite/logical/PushdownFilterThroughJoin.java | 4 +- .../logical/PushdownJoinOtherCondition.java | 55 +++++++------- .../rules/rewrite/logical/ScalarApplyToJoin.java | 13 +++- .../doris/nereids/trees/plans/algebra/Join.java | 4 +- .../nereids/trees/plans/logical/LogicalApply.java | 2 +- .../nereids/trees/plans/logical/LogicalJoin.java | 85 ++++++++++------------ .../trees/plans/physical/AbstractPhysicalJoin.java | 49 +++++-------- .../trees/plans/physical/PhysicalHashJoin.java | 25 +++---- .../plans/physical/PhysicalNestedLoopJoin.java | 29 ++++---- .../apache/doris/nereids/util/ExpressionUtils.java | 17 +++++ .../org/apache/doris/nereids/util/JoinUtils.java | 19 ----- .../apache/doris/nereids/memo/MemoRewriteTest.java | 1 - .../properties/ChildOutputPropertyDeriverTest.java | 40 ++++++---- .../properties/RequestPropertyDeriverTest.java | 11 ++- .../exploration/join/InnerJoinLAsscomTest.java | 8 +- .../logical/FindHashConditionForJoinTest.java | 24 +++--- .../rules/rewrite/logical/LimitPushDownTest.java | 2 - .../rewrite/logical/PushFilterInsideJoinTest.java | 2 +- .../logical/PushdownFilterThroughJoinTest.java | 2 +- .../logical/PushdownJoinOtherConditionTest.java | 19 ++--- .../org/apache/doris/nereids/sqltest/SqlTest.java | 84 +++++++++++++++++++++ .../doris/nereids/trees/plans/PlanEqualsTest.java | 37 ++++++---- .../nereids/trees/plans/PlanToStringTest.java | 6 +- .../doris/nereids/util/AnalyzeSubQueryTest.java | 11 ++- .../nereids/util/AnalyzeWhereSubqueryTest.java | 82 ++++++++++++--------- .../doris/nereids/util/LogicalPlanBuilder.java | 11 ++- 47 files changed, 469 insertions(+), 390 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 63fb58beae..818fa16f62 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -136,7 +136,6 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla } - /** * Translate Agg. * todo: support DISTINCT @@ -567,9 +566,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla Map<ExprId, SlotReference> hashOutputSlotReferenceMap = Maps.newHashMap(outputSlotReferenceMap); - hashJoin.getOtherJoinCondition() - .map(ExpressionUtils::extractConjunction) - .orElseGet(Lists::newArrayList) + hashJoin.getOtherJoinConjuncts() .stream() .filter(e -> !(e.equals(BooleanLiteral.TRUE))) .flatMap(e -> e.getInputSlots().stream()) @@ -590,19 +587,19 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla .map(SlotReference.class::cast) .forEach(s -> rightChildOutputMap.put(s.getExprId(), s)); - //make intermediate tuple + // make intermediate tuple List<SlotDescriptor> leftIntermediateSlotDescriptor = Lists.newArrayList(); List<SlotDescriptor> rightIntermediateSlotDescriptor = Lists.newArrayList(); TupleDescriptor intermediateDescriptor = context.generateTupleDesc(); - if (!hashJoin.getOtherJoinCondition().isPresent() + if (hashJoin.getOtherJoinConjuncts().isEmpty() && (joinType == JoinType.LEFT_ANTI_JOIN || joinType == JoinType.LEFT_SEMI_JOIN)) { for (SlotDescriptor leftSlotDescriptor : leftSlotDescriptors) { SlotReference sf = leftChildOutputMap.get(context.findExprId(leftSlotDescriptor.getId())); SlotDescriptor sd = context.createSlotDesc(intermediateDescriptor, sf); leftIntermediateSlotDescriptor.add(sd); } - } else if (!hashJoin.getOtherJoinCondition().isPresent() + } else if (hashJoin.getOtherJoinConjuncts().isEmpty() && (joinType == JoinType.RIGHT_ANTI_JOIN || joinType == JoinType.RIGHT_SEMI_JOIN)) { for (SlotDescriptor rightSlotDescriptor : rightSlotDescriptors) { SlotReference sf = rightChildOutputMap.get(context.findExprId(rightSlotDescriptor.getId())); @@ -628,7 +625,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla } } - //set slots as nullable for outer join + // set slots as nullable for outer join if (joinType == JoinType.LEFT_OUTER_JOIN || joinType == JoinType.FULL_OUTER_JOIN) { rightIntermediateSlotDescriptor.forEach(sd -> sd.setIsNullable(true)); } @@ -636,9 +633,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla leftIntermediateSlotDescriptor.forEach(sd -> sd.setIsNullable(true)); } - List<Expr> otherJoinConjuncts = hashJoin.getOtherJoinCondition() - .map(ExpressionUtils::extractConjunction) - .orElseGet(Lists::newArrayList) + List<Expr> otherJoinConjuncts = hashJoin.getOtherJoinConjuncts() .stream() // TODO add constant expr will cause be crash, currently we only handle true literal. // remove it after Nereids could ensure no constant expr in other join condition @@ -660,7 +655,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla hashJoinNode.setvIntermediateTupleDescList(Lists.newArrayList(intermediateDescriptor)); if (hashJoin.isShouldTranslateOutput()) { - //translate output expr on intermediate tuple + // translate output expr on intermediate tuple List<Expr> srcToOutput = outputSlotReferences.stream() .map(e -> ExpressionTranslator.translate(e, context)) .collect(Collectors.toList()); @@ -699,10 +694,8 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla crossJoinNode.setChild(0, leftFragment.getPlanRoot()); connectChildFragment(crossJoinNode, 1, leftFragment, rightFragment, context); leftFragment.setPlanRoot(crossJoinNode); - if (nestedLoopJoin.getOtherJoinCondition().isPresent()) { - ExpressionUtils.extractConjunction(nestedLoopJoin.getOtherJoinCondition().get()).stream() - .map(e -> ExpressionTranslator.translate(e, context)).forEach(crossJoinNode::addConjunct); - } + nestedLoopJoin.getOtherJoinConjuncts().stream() + .map(e -> ExpressionTranslator.translate(e, context)).forEach(crossJoinNode::addConjunct); return leftFragment; } else { @@ -849,7 +842,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla public PlanFragment visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows, PlanTranslatorContext context) { PlanFragment currentFragment = assertNumRows.child(0).accept(this, context); - //create assertNode + // create assertNode AssertNumRowsNode assertNumRowsNode = new AssertNumRowsNode(context.nextPlanNodeId(), currentFragment.getPlanRoot(), ExpressionTranslator.translateAssert(assertNumRows.getAssertNumRowsElement())); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index b84e0328bc..0a9f7ebefc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -139,6 +139,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalSelectHint; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -648,8 +649,8 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> { left = (left == null) ? right : new LogicalJoin<>( JoinType.CROSS_JOIN, - ImmutableList.of(), - Optional.empty(), + ExpressionUtils.EMPTY_CONDITION, + ExpressionUtils.EMPTY_CONDITION, left, right); left = withJoinRelations(left, relation); @@ -707,9 +708,9 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> { private <T> List<T> visit(List<? extends ParserRuleContext> contexts, Class<T> clazz) { return contexts.stream() - .map(this::visit) - .map(clazz::cast) - .collect(ImmutableList.toImmutableList()); + .map(this::visit) + .map(clazz::cast) + .collect(ImmutableList.toImmutableList()); } private LogicalPlan plan(ParserRuleContext tree) { @@ -829,15 +830,17 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> { // TODO: natural join, lateral join, using join, union join JoinCriteriaContext joinCriteria = join.joinCriteria(); - Expression condition; + Optional<Expression> condition; if (joinCriteria == null) { - condition = null; + condition = Optional.empty(); } else { - condition = getExpression(joinCriteria.booleanExpression()); + condition = Optional.ofNullable(getExpression(joinCriteria.booleanExpression())); } - last = new LogicalJoin<>(joinType, new ArrayList<Expression>(), - Optional.ofNullable(condition), last, plan(join.relationPrimary())); + last = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, + condition.map(ExpressionUtils::extractConjunction) + .orElse(ExpressionUtils.EMPTY_CONDITION), + last, plan(join.relationPrimary())); } return last; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java index 21e40ee86a..7417d27d3a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java @@ -108,9 +108,9 @@ public class BindSlotReference implements AnalysisRuleFactory { RuleType.BINDING_JOIN_SLOT.build( logicalJoin().when(Plan::canBind).thenApply(ctx -> { LogicalJoin<GroupPlan, GroupPlan> join = ctx.root; - Optional<Expression> cond = join.getOtherJoinCondition() - .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)); - + List<Expression> cond = join.getOtherJoinConjuncts().stream() + .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)) + .collect(Collectors.toList()); List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts().stream() .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)) .collect(Collectors.toList()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java index 3bb81e6a66..c83d3e80b9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java @@ -32,7 +32,6 @@ import com.google.common.base.Preconditions; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -54,8 +53,8 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory { return innerLogicalJoin(innerLogicalJoin(), group()) .when(topJoin -> checkReorder(topJoin, topJoin.left())) // TODO: handle otherJoinCondition - .whenNot(topJoin -> topJoin.getOtherJoinCondition().isPresent()) - .whenNot(topJoin -> topJoin.left().getOtherJoinCondition().isPresent()) + .when(topJoin -> topJoin.getOtherJoinConjuncts().isEmpty()) + .when(topJoin -> topJoin.left().getOtherJoinConjuncts().isEmpty()) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); GroupPlan a = bottomJoin.left(); @@ -74,14 +73,13 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory { // TODO: split otherCondition. LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, - newBottomHashConjuncts, Optional.empty(), - a, c, bottomJoin.getJoinReorderContext()); + newBottomHashConjuncts, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> newTopJoin = new LogicalJoin<>( - JoinType.INNER_JOIN, newTopHashConjuncts, Optional.empty(), - newBottomJoin, b, topJoin.getJoinReorderContext()); + JoinType.INNER_JOIN, newTopHashConjuncts, newBottomJoin, b, + topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); return newTopJoin; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java index 2d600573af..92ef509813 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java @@ -39,7 +39,6 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -64,8 +63,8 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { return innerLogicalJoin(logicalProject(innerLogicalJoin()), group()) .when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) // TODO: handle otherJoinCondition - .whenNot(topJoin -> topJoin.getOtherJoinCondition().isPresent()) - .whenNot(topJoin -> topJoin.left().child().getOtherJoinCondition().isPresent()) + .when(topJoin -> topJoin.getOtherJoinConjuncts().isEmpty()) + .when(topJoin -> topJoin.left().child().getOtherJoinConjuncts().isEmpty()) .then(topJoin -> { /* ********** init ********** */ @@ -139,8 +138,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { /* ********** new Plan ********** */ LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - newBottomHashJoinConjuncts, Optional.empty(), - a, c, bottomJoin.getJoinReorderContext()); + newBottomHashJoinConjuncts, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); @@ -156,8 +154,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { } LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - newTopHashJoinConjuncts, Optional.empty(), - left, right, topJoin.getJoinReorderContext()); + newTopHashJoinConjuncts, left, right, topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); if (topJoin.getLogicalProperties().equals(newTopJoin.getLogicalProperties())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index d8f7b30087..6705c74fbc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -49,7 +49,7 @@ public class JoinCommute extends OneExplorationRuleFactory { LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>( join.getJoinType().swap(), join.getHashJoinConjuncts(), - join.getOtherJoinCondition(), + join.getOtherJoinConjuncts(), join.right(), join.left(), join.getJoinReorderContext()); newJoin.getJoinReorderContext().setHasCommute(true); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java index 3477c050a2..c030a03d87 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java @@ -89,7 +89,7 @@ class JoinLAsscomHelper extends ThreeJoinHelper { newLeftProjects.addAll(cOutputSet); } LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - newBottomHashJoinConjuncts, ExpressionUtils.optionalAnd(newBottomNonHashJoinConjuncts), a, c, + newBottomHashJoinConjuncts, newBottomNonHashJoinConjuncts, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); @@ -106,8 +106,7 @@ class JoinLAsscomHelper extends ThreeJoinHelper { } LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - newTopHashJoinConjuncts, - ExpressionUtils.optionalAnd(newTopNonHashJoinConjuncts), left, right, + newTopHashJoinConjuncts, newTopNonHashJoinConjuncts, left, right, topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java index 6e2bd098b9..47b7a3ffce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java @@ -32,7 +32,6 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableSet; -import java.util.Optional; import java.util.Set; /** @@ -64,8 +63,8 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory { .when(topJoin -> checkReorder(topJoin, topJoin.left())) .when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputSet())) // TODO: handle otherJoinCondition - .whenNot(topJoin -> topJoin.getOtherJoinCondition().isPresent()) - .whenNot(topJoin -> topJoin.left().getOtherJoinCondition().isPresent()) + .when(topJoin -> topJoin.getOtherJoinConjuncts().isEmpty()) + .when(topJoin -> topJoin.left().getOtherJoinConjuncts().isEmpty()) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); GroupPlan a = bottomJoin.left(); @@ -73,14 +72,13 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory { GroupPlan c = topJoin.right(); LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - topJoin.getHashJoinConjuncts(), Optional.empty(), - a, c, bottomJoin.getJoinReorderContext()); + topJoin.getHashJoinConjuncts(), a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> newTopJoin = new LogicalJoin<>( - bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), Optional.empty(), - newBottomJoin, b, topJoin.getJoinReorderContext()); + bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), newBottomJoin, b, + topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); return newTopJoin; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java index 0901290c3a..b761317a60 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java @@ -39,7 +39,6 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -65,8 +64,8 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { Pair.of(join.left().child().getJoinType(), join.getJoinType()))) .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) // TODO: handle otherJoinCondition - .whenNot(topJoin -> topJoin.getOtherJoinCondition().isPresent()) - .whenNot(topJoin -> topJoin.left().child().getOtherJoinCondition().isPresent()) + .when(topJoin -> topJoin.getOtherJoinConjuncts().isEmpty()) + .when(topJoin -> topJoin.left().child().getOtherJoinConjuncts().isEmpty()) .then(topJoin -> { /* ********** init ********** */ @@ -145,8 +144,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { /* ********** new Plan ********** */ LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - newBottomHashJoinConjuncts, Optional.empty(), - a, c, bottomJoin.getJoinReorderContext()); + newBottomHashJoinConjuncts, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); @@ -162,8 +160,7 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { } LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - newTopHashJoinConjuncts, Optional.empty(), - left, right, topJoin.getJoinReorderContext()); + newTopHashJoinConjuncts, left, right, topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); if (topJoin.getLogicalProperties().equals(newTopJoin.getLogicalProperties())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java index c524cae763..4a12d5d836 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java @@ -82,9 +82,9 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { */ LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( topSemiJoin.getJoinType(), - topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinCondition(), a, c); + topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinConjuncts(), a, c); return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), - bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b); + bottomJoin.getOtherJoinConjuncts(), newBottomSemiJoin, b); } else { /* * topSemiJoin newTopJoin @@ -95,9 +95,9 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { */ LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( topSemiJoin.getJoinType(), - topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinCondition(), b, c); + topSemiJoin.getHashJoinConjuncts(), topSemiJoin.getOtherJoinConjuncts(), b, c); return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), - bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin); + bottomJoin.getOtherJoinConjuncts(), a, newBottomSemiJoin); } }).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java index 6698c773b9..f16a1c2818 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -88,10 +88,10 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto */ LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), - topSemiJoin.getOtherJoinCondition(), a, c); + topSemiJoin.getOtherJoinConjuncts(), a, c); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(), + bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), newBottomSemiJoin, b); return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); @@ -107,10 +107,10 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto */ LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), - topSemiJoin.getOtherJoinCondition(), b, c); + topSemiJoin.getOtherJoinConjuncts(), b, c); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(), + bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), a, newBottomSemiJoin); return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java index a9b34bb2f1..61f86f6314 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTranspose.java @@ -63,10 +63,10 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory { GroupPlan c = topJoin.right(); LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - topJoin.getHashJoinConjuncts(), topJoin.getOtherJoinCondition(), a, c); + topJoin.getHashJoinConjuncts(), topJoin.getOtherJoinConjuncts(), a, c); LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> newTopJoin = new LogicalJoin<>( bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), - bottomJoin.getOtherJoinCondition(), + bottomJoin.getOtherJoinConjuncts(), newBottomJoin, b); return newTopJoin; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java index f660ee0ec0..8bf60e96c8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java @@ -34,7 +34,6 @@ import com.google.common.collect.Lists; import java.util.List; import java.util.Objects; -import java.util.Optional; import java.util.stream.Collectors; /** @@ -136,24 +135,31 @@ public class ExpressionRewrite implements RewriteRuleFactory { public Rule build() { return logicalJoin().then(join -> { List<Expression> hashJoinConjuncts = join.getHashJoinConjuncts(); - Optional<Expression> otherJoinCondition = join.getOtherJoinCondition(); - if (!otherJoinCondition.isPresent() && hashJoinConjuncts.isEmpty()) { + List<Expression> otherJoinConjuncts = join.getOtherJoinConjuncts(); + if (otherJoinConjuncts.isEmpty() && hashJoinConjuncts.isEmpty()) { return join; } List<Expression> rewriteHashJoinConjuncts = Lists.newArrayList(); - boolean joinConjunctsChanged = false; + boolean hashJoinConjunctsChanged = false; for (Expression expr : hashJoinConjuncts) { Expression newExpr = rewriter.rewrite(expr); - joinConjunctsChanged = joinConjunctsChanged || !newExpr.equals(expr); + hashJoinConjunctsChanged = hashJoinConjunctsChanged || !newExpr.equals(expr); rewriteHashJoinConjuncts.add(newExpr); } - Optional<Expression> newOtherJoinCondition = rewriter.rewrite(otherJoinCondition); - if (!joinConjunctsChanged && newOtherJoinCondition.equals(otherJoinCondition)) { + List<Expression> rewriteOtherJoinConjuncts = Lists.newArrayList(); + boolean otherJoinConjunctsChanged = false; + for (Expression expr : otherJoinConjuncts) { + Expression newExpr = rewriter.rewrite(expr); + otherJoinConjunctsChanged = otherJoinConjunctsChanged || !newExpr.equals(expr); + rewriteOtherJoinConjuncts.add(newExpr); + } + + if (!hashJoinConjunctsChanged && !otherJoinConjunctsChanged) { return join; } return new LogicalJoin<>(join.getJoinType(), rewriteHashJoinConjuncts, - newOtherJoinCondition, join.left(), join.right()); + rewriteOtherJoinConjuncts, join.left(), join.right()); }).toRule(RuleType.REWRITE_JOIN_EXPRESSION); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToHashJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToHashJoin.java index 047b68cbc7..161ebe52b7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToHashJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToHashJoin.java @@ -33,7 +33,7 @@ public class LogicalJoinToHashJoin extends OneImplementationRuleFactory { .then(join -> new PhysicalHashJoin<>( join.getJoinType(), join.getHashJoinConjuncts(), - join.getOtherJoinCondition(), + join.getOtherJoinConjuncts(), join.getLogicalProperties(), join.left(), join.right()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToNestedLoopJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToNestedLoopJoin.java index 759e4a9592..7de8da2ec4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToNestedLoopJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalJoinToNestedLoopJoin.java @@ -33,7 +33,7 @@ public class LogicalJoinToNestedLoopJoin extends OneImplementationRuleFactory { .then(join -> new PhysicalNestedLoopJoin<>( join.getJoinType(), join.getHashJoinConjuncts(), - join.getOtherJoinCondition(), + join.getOtherJoinConjuncts(), join.getLogicalProperties(), join.left(), join.right()) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuter.java index 0d1de04411..10a03df464 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuter.java @@ -84,7 +84,7 @@ public class EliminateOuter extends OneRewriteRuleFactory { return new LogicalFilter<>(filter.getPredicates(), new LogicalJoin<>(joinType, - join.getHashJoinConjuncts(), join.getOtherJoinCondition(), + join.getHashJoinConjuncts(), join.getOtherJoinConjuncts(), join.left(), join.right())); }).toRule(RuleType.ELIMINATE_OUTER); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java index f315cdafeb..2dd33b524b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExistsApplyToJoin.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Exists; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.plans.JoinType; @@ -33,15 +34,16 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; import java.util.ArrayList; +import java.util.Optional; /** * Convert Existsapply to LogicalJoin. - * + * <p> * Exists * Correlated -> LEFT_SEMI_JOIN * apply LEFT_SEMI_JOIN(Correlated Predicate) @@ -85,11 +87,19 @@ public class ExistsApplyToJoin extends OneRewriteRuleFactory { } private Plan correlatedToJoin(LogicalApply apply) { + Optional<Expression> correlationFilter = apply.getCorrelationFilter(); + if (((Exists) apply.getSubqueryExpr()).isNot()) { - return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(), apply.getCorrelationFilter(), + return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, ExpressionUtils.EMPTY_CONDITION, + correlationFilter + .map(ExpressionUtils::extractConjunction) + .orElse(ExpressionUtils.EMPTY_CONDITION), (LogicalPlan) apply.left(), (LogicalPlan) apply.right()); } else { - return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(), apply.getCorrelationFilter(), + return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, ExpressionUtils.EMPTY_CONDITION, + correlationFilter + .map(ExpressionUtils::extractConjunction) + .orElse(ExpressionUtils.EMPTY_CONDITION), (LogicalPlan) apply.left(), (LogicalPlan) apply.right()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java index 2256e6e727..412c3c2308 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java @@ -22,14 +22,13 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.Optional; /** * this rule aims to find a conjunct list from on clause expression, which could @@ -47,9 +46,13 @@ public class FindHashConditionForJoin extends OneRewriteRuleFactory { @Override public Rule build() { return logicalJoin().then(join -> { - Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(join); + List<Slot> leftSlots = join.left().getOutput(); + List<Slot> rightSlots = join.right().getOutput(); + Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(leftSlots, + rightSlots, join.getOtherJoinConjuncts()); + List<Expression> extractedHashJoinConjuncts = pair.first; - Optional<Expression> remainedNonHashJoinConjuncts = ExpressionUtils.optionalAnd(pair.second); + List<Expression> remainedNonHashJoinConjuncts = pair.second; if (extractedHashJoinConjuncts.isEmpty()) { return join; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java index 6a354d0b4b..8daab33f46 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/InApplyToJoin.java @@ -30,11 +30,9 @@ import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.Lists; -import java.util.Optional; - /** * Convert InApply to LogicalJoin. - * + * <p> * Not In -> LEFT_ANTI_JOIN * In -> LEFT_SEMI_JOIN */ @@ -54,10 +52,12 @@ public class InApplyToJoin extends OneRewriteRuleFactory { } if (((InSubquery) apply.getSubqueryExpr()).isNot()) { - return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(), Optional.of(predicate), + return new LogicalJoin<>(JoinType.LEFT_ANTI_JOIN, Lists.newArrayList(), + ExpressionUtils.extractConjunction(predicate), apply.left(), apply.right()); } else { - return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(), Optional.of(predicate), + return new LogicalJoin<>(JoinType.LEFT_SEMI_JOIN, Lists.newArrayList(), + ExpressionUtils.extractConjunction(predicate), apply.left(), apply.right()); } }).toRule(RuleType.IN_APPLY_TO_JOIN); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java index af20344205..20a9550af3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java @@ -93,8 +93,8 @@ public class MultiJoin extends PlanVisitor<Void, Void> { conjunctsKeepInFilter = pair.second; return new LogicalJoin<>(JoinType.INNER_JOIN, - new ArrayList<>(), - ExpressionUtils.optionalAnd(joinConditions), + ExpressionUtils.EMPTY_CONDITION, + joinConditions, joinInputs.get(0), joinInputs.get(1)); } // input size >= 3; @@ -133,9 +133,8 @@ public class MultiJoin extends PlanVisitor<Void, Void> { conjuncts); List<Expression> joinConditions = pair.first; List<Expression> nonJoinConditions = pair.second; - LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(), - ExpressionUtils.optionalAnd(joinConditions), - left, right); + LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION, + joinConditions, left, right); List<Plan> newInputs = new ArrayList<>(); newInputs.add(join); @@ -185,9 +184,7 @@ public class MultiJoin extends PlanVisitor<Void, Void> { join.right().accept(this, context); conjunctsForAllHashJoins.addAll(join.getHashJoinConjuncts()); - if (join.getOtherJoinCondition().isPresent()) { - conjunctsForAllHashJoins.addAll(ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get())); - } + conjunctsForAllHashJoins.addAll(join.getOtherJoinConjuncts()); Plan leftChild = join.left(); if (join.left() instanceof LogicalFilter) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoin.java index dc84e663c7..eeb1699f3f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoin.java @@ -25,8 +25,6 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.Lists; - import java.util.List; /** @@ -42,11 +40,11 @@ public class PushFilterInsideJoin extends OneRewriteRuleFactory { .when(filter -> filter.child().getJoinType().isCrossJoin() || filter.child().getJoinType().isInnerJoin()) .then(filter -> { - List<Expression> otherConditions = Lists.newArrayList(filter.getPredicates()); + List<Expression> otherConditions = ExpressionUtils.extractConjunction(filter.getPredicates()); LogicalJoin<GroupPlan, GroupPlan> join = filter.child(); - join.getOtherJoinCondition().map(otherConditions::add); + otherConditions.addAll(join.getOtherJoinConjuncts()); return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), - ExpressionUtils.optionalAnd(otherConditions), join.left(), join.right()); + otherConditions, join.left(), join.right()); }).toRule(RuleType.PUSH_FILTER_INSIDE_JOIN); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoin.java index 2b5cf9065b..4d2fa019e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoin.java @@ -131,12 +131,12 @@ public class PushdownFilterThroughJoin extends OneRewriteRuleFactory { } } - join.getOtherJoinCondition().map(joinConditions::add); + joinConditions.addAll(join.getOtherJoinConjuncts()); return PlanUtils.filterOrSelf(remainingPredicates, new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), - ExpressionUtils.optionalAnd(joinConditions), + joinConditions, PlanUtils.filterOrSelf(leftPredicates, join.left()), PlanUtils.filterOrSelf(rightPredicates, join.right()))); }).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_JOIN); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherCondition.java index 52e60aa102..3d6c3b3143 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherCondition.java @@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableList; @@ -58,39 +57,37 @@ public class PushdownJoinOtherCondition extends OneRewriteRuleFactory { @Override public Rule build() { - return logicalJoin().then(join -> { - if (!join.getOtherJoinCondition().isPresent()) { - return null; - } - List<Expression> otherConjuncts = ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get()); - List<Expression> leftConjuncts = Lists.newArrayList(); - List<Expression> rightConjuncts = Lists.newArrayList(); + return logicalJoin() + .whenNot(join -> join.getOtherJoinConjuncts().isEmpty()) + .then(join -> { + List<Expression> otherJoinConjuncts = join.getOtherJoinConjuncts(); + List<Expression> remainingOther = Lists.newArrayList(); + List<Expression> leftConjuncts = Lists.newArrayList(); + List<Expression> rightConjuncts = Lists.newArrayList(); - for (Expression otherConjunct : otherConjuncts) { - if (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType()) - && allCoveredBy(otherConjunct, join.left().getOutputSet())) { - leftConjuncts.add(otherConjunct); - } - if (PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType()) - && allCoveredBy(otherConjunct, join.right().getOutputSet())) { - rightConjuncts.add(otherConjunct); - } - } + for (Expression otherConjunct : otherJoinConjuncts) { + if (PUSH_DOWN_LEFT_VALID_TYPE.contains(join.getJoinType()) + && allCoveredBy(otherConjunct, join.left().getOutputSet())) { + leftConjuncts.add(otherConjunct); + } else if (PUSH_DOWN_RIGHT_VALID_TYPE.contains(join.getJoinType()) + && allCoveredBy(otherConjunct, join.right().getOutputSet())) { + rightConjuncts.add(otherConjunct); + } else { + remainingOther.add(otherConjunct); + } + } - if (leftConjuncts.isEmpty() && rightConjuncts.isEmpty()) { - return null; - } + if (leftConjuncts.isEmpty() && rightConjuncts.isEmpty()) { + return null; + } - otherConjuncts.removeAll(leftConjuncts); - otherConjuncts.removeAll(rightConjuncts); + Plan left = PlanUtils.filterOrSelf(leftConjuncts, join.left()); + Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right()); - Plan left = PlanUtils.filterOrSelf(leftConjuncts, join.left()); - Plan right = PlanUtils.filterOrSelf(rightConjuncts, join.right()); + return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), + remainingOther, left, right); - return new LogicalJoin<>(join.getJoinType(), join.getHashJoinConjuncts(), - ExpressionUtils.optionalAnd(otherConjuncts), left, right); - - }).toRule(RuleType.PUSHDOWN_JOIN_OTHER_CONDITION); + }).toRule(RuleType.PUSHDOWN_JOIN_OTHER_CONDITION); } private boolean allCoveredBy(Expression predicate, Set<Slot> inputSlotSet) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java index 775628cc7b..4110c15e3c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ScalarApplyToJoin.java @@ -21,18 +21,20 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalApply; import org.apache.doris.nereids.trees.plans.logical.LogicalAssertNumRows; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.Lists; +import java.util.Optional; /** * Convert scalarApply to LogicalJoin. - * + * <p> * UnCorrelated -> CROSS_JOIN * Correlated -> LEFT_OUTER_JOIN */ @@ -59,9 +61,12 @@ public class ScalarApplyToJoin extends OneRewriteRuleFactory { } private Plan correlatedToJoin(LogicalApply apply) { + Optional<Expression> correlationFilter = apply.getCorrelationFilter(); return new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN, - Lists.newArrayList(), - apply.getCorrelationFilter(), + ExpressionUtils.EMPTY_CONDITION, + correlationFilter + .map(ExpressionUtils::extractConjunction) + .orElse(ExpressionUtils.EMPTY_CONDITION), (LogicalPlan) apply.left(), (LogicalPlan) apply.right()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java index 48fb08250c..6bf8810344 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java @@ -31,7 +31,7 @@ public interface Join { List<Expression> getHashJoinConjuncts(); - Optional<Expression> getOtherJoinCondition(); + List<Expression> getOtherJoinConjuncts(); Optional<Expression> getOnClauseCondition(); @@ -39,6 +39,6 @@ public interface Join { * The join plan has join condition or not. */ default boolean hasJoinCondition() { - return !getHashJoinConjuncts().isEmpty() || getOtherJoinCondition().isPresent(); + return !getHashJoinConjuncts().isEmpty() || !getOtherJoinConjuncts().isEmpty(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java index db7afb481d..4bfd3f00de 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java @@ -136,7 +136,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends @Override public <R, C> R accept(PlanVisitor<R, C> visitor, C context) { - return visitor.visitLogicalApply((LogicalApply<Plan, Plan>) this, context); + return visitor.visitLogicalApply(this, context); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 87e6d4e03e..2d844689f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -44,9 +44,8 @@ import java.util.stream.Collectors; */ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends Plan> extends LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> implements Join { - private final JoinType joinType; - private final Optional<Expression> otherJoinCondition; + private final List<Expression> otherJoinConjuncts; private final List<Expression> hashJoinConjuncts; // Use for top-to-down join reorder @@ -58,28 +57,40 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends * @param joinType logical type for join */ public LogicalJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - this(joinType, ImmutableList.of(), - Optional.empty(), Optional.empty(), + this(joinType, ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, Optional.empty(), Optional.empty(), leftChild, rightChild); } - public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, Optional<Expression> otherJoinCondition, + public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, LEFT_CHILD_TYPE leftChild, + RIGHT_CHILD_TYPE rightChild) { + this(joinType, hashJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, Optional.empty(), Optional.empty(), + leftChild, rightChild); + } + + public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { this(joinType, hashJoinConjuncts, - otherJoinCondition, Optional.empty(), Optional.empty(), leftChild, rightChild); + otherJoinConjuncts, Optional.empty(), Optional.empty(), leftChild, rightChild); + } + + public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, LEFT_CHILD_TYPE leftChild, + RIGHT_CHILD_TYPE rightChild, JoinReorderContext joinReorderContext) { + this(joinType, hashJoinConjuncts, ExpressionUtils.EMPTY_CONDITION, + Optional.empty(), Optional.empty(), leftChild, rightChild); + this.joinReorderContext.copyFrom(joinReorderContext); } - public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, Optional<Expression> otherJoinCondition, + public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild, JoinReorderContext joinReorderContext) { - this(joinType, hashJoinConjuncts, otherJoinCondition, + this(joinType, hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), Optional.empty(), leftChild, rightChild); this.joinReorderContext.copyFrom(joinReorderContext); } - public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, Optional<Expression> otherJoinCondition, + public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild, JoinReorderContext joinReorderContext) { - this(joinType, hashJoinConjuncts, otherJoinCondition, groupExpression, logicalProperties, leftChild, + this(joinType, hashJoinConjuncts, otherJoinConjuncts, groupExpression, logicalProperties, leftChild, rightChild); this.joinReorderContext.copyFrom(joinReorderContext); } @@ -88,24 +99,18 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends * Constructor for LogicalJoinPlan. * * @param joinType logical type for join - * @param condition on clause for join node */ - public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, Optional<Expression> condition, + public LogicalJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { super(PlanType.LOGICAL_JOIN, groupExpression, logicalProperties, leftChild, rightChild); this.joinType = Objects.requireNonNull(joinType, "joinType can not be null"); this.hashJoinConjuncts = hashJoinConjuncts; - this.otherJoinCondition = Objects.requireNonNull(condition, "condition can not be null"); + this.otherJoinConjuncts = Objects.requireNonNull(otherJoinConjuncts, "condition can not be null"); } - /** - * get combination of hashJoinConjuncts and condition - * - * @return combine hashJoinConjuncts and condition by AND - */ - public Optional<Expression> getOtherJoinCondition() { - return otherJoinCondition; + public List<Expression> getOtherJoinConjuncts() { + return otherJoinConjuncts; } @Override @@ -113,20 +118,8 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends return hashJoinConjuncts; } - /** - * hashJoinConjuncts and otherJoinCondition - * - * @return the combination of hashJoinConjuncts and otherJoinCondition - */ public Optional<Expression> getOnClauseCondition() { - Optional<Expression> hashJoinCondition = ExpressionUtils.optionalAnd(hashJoinConjuncts); - - if (hashJoinCondition.isPresent() && otherJoinCondition.isPresent()) { - return ExpressionUtils.optionalAnd(hashJoinCondition.get(), otherJoinCondition.get()); - } - - return hashJoinCondition.map(Optional::of) - .orElse(otherJoinCondition); + return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts); } public JoinType getJoinType() { @@ -177,12 +170,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends return Utils.toSqlString("LogicalJoin", "type", joinType, "hashJoinCondition", hashJoinConjuncts, - "otherJoinCondition", otherJoinCondition + "otherJoinCondition", otherJoinConjuncts ); } // TODO: - // 1. consider the order of conjucts in otherJoinCondition and hashJoinConditions + // 1. consider the order of conjucts in otherJoinConjuncts and hashJoinConditions @Override public boolean equals(Object o) { if (this == o) { @@ -197,12 +190,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends // TODO: why use containsAll? && that.getHashJoinConjuncts().containsAll(hashJoinConjuncts) && hashJoinConjuncts.containsAll(that.getHashJoinConjuncts()) - && Objects.equals(otherJoinCondition, that.otherJoinCondition); + && Objects.equals(otherJoinConjuncts, that.otherJoinConjuncts); } @Override public int hashCode() { - return Objects.hash(joinType, otherJoinCondition); + return Objects.hash(joinType, hashJoinConjuncts, otherJoinConjuncts); } @Override @@ -212,10 +205,10 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends @Override public List<? extends Expression> getExpressions() { - Builder<Expression> builder = new Builder<Expression>() - .addAll(hashJoinConjuncts); - otherJoinCondition.ifPresent(builder::add); - return builder.build(); + return new Builder<Expression>() + .addAll(hashJoinConjuncts) + .addAll(otherJoinConjuncts) + .build(); } public JoinReorderContext getJoinReorderContext() { @@ -225,19 +218,19 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends @Override public LogicalBinary<Plan, Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 2); - return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, children.get(0), children.get(1), + return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, children.get(0), children.get(1), joinReorderContext); } @Override public Plan withGroupExpression(Optional<GroupExpression> groupExpression) { - return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, groupExpression, + return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, groupExpression, Optional.of(getLogicalProperties()), left(), right(), joinReorderContext); } @Override public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) { - return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, + return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), logicalProperties, left(), right(), joinReorderContext); } @@ -253,12 +246,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends public LogicalJoin withHashJoinConjuncts(List<Expression> hashJoinConjuncts) { return new LogicalJoin<>( - joinType, hashJoinConjuncts, this.otherJoinCondition, left(), right(), joinReorderContext); + joinType, hashJoinConjuncts, this.otherJoinConjuncts, left(), right(), joinReorderContext); } public LogicalJoin withhashJoinConjunctsAndChildren(List<Expression> hashJoinConjuncts, List<Plan> children) { Preconditions.checkArgument(children.size() == 2); - return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, children.get(0), children.get(1), + return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, children.get(0), children.get(1), joinReorderContext); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java index 9d6d567946..1aa4834d28 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java @@ -45,38 +45,31 @@ public abstract class AbstractPhysicalJoin< protected final List<Expression> hashJoinConjuncts; - protected final Optional<Expression> otherJoinCondition; + protected final List<Expression> otherJoinConjuncts; /** * Constructor of PhysicalJoin. - * - * @param joinType Which join type, left semi join, inner join... - * @param condition join condition. */ public AbstractPhysicalJoin(PlanType type, JoinType joinType, List<Expression> hashJoinConjuncts, - Optional<Expression> condition, - Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, - LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { + List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, + LogicalProperties logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { super(type, groupExpression, logicalProperties, leftChild, rightChild); this.joinType = Objects.requireNonNull(joinType, "joinType can not be null"); this.hashJoinConjuncts = hashJoinConjuncts; - this.otherJoinCondition = Objects.requireNonNull(condition, "condition can not be null"); + this.otherJoinConjuncts = Objects.requireNonNull(otherJoinConjuncts, "condition can not be null"); } /** * Constructor of PhysicalJoin. - * - * @param joinType Which join type, left semi join, inner join... - * @param condition join condition. */ public AbstractPhysicalJoin(PlanType type, JoinType joinType, List<Expression> hashJoinConjuncts, - Optional<Expression> condition, Optional<GroupExpression> groupExpression, + List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { super(type, groupExpression, logicalProperties, physicalProperties, statsDeriveResult, leftChild, rightChild); this.joinType = Objects.requireNonNull(joinType, "joinType can not be null"); this.hashJoinConjuncts = hashJoinConjuncts; - this.otherJoinCondition = Objects.requireNonNull(condition, "condition can not be null"); + this.otherJoinConjuncts = Objects.requireNonNull(otherJoinConjuncts, "condition can not be null"); } public List<Expression> getHashJoinConjuncts() { @@ -87,20 +80,19 @@ public abstract class AbstractPhysicalJoin< return joinType; } - public Optional<Expression> getOtherJoinCondition() { - return otherJoinCondition; + public List<Expression> getOtherJoinConjuncts() { + return otherJoinConjuncts; } @Override public List<? extends Expression> getExpressions() { - Builder<Expression> builder = new Builder<Expression>() - .addAll(hashJoinConjuncts); - otherJoinCondition.ifPresent(builder::add); - return builder.build(); + return new Builder<Expression>() + .addAll(hashJoinConjuncts) + .addAll(otherJoinConjuncts).build(); } // TODO: - // 1. consider the order of conjucts in otherJoinCondition and hashJoinConditions + // 1. consider the order of conjucts in otherJoinConjuncts and hashJoinConditions @Override public boolean equals(Object o) { if (this == o) { @@ -115,27 +107,20 @@ public abstract class AbstractPhysicalJoin< AbstractPhysicalJoin<?, ?> that = (AbstractPhysicalJoin<?, ?>) o; return joinType == that.joinType && hashJoinConjuncts.equals(that.hashJoinConjuncts) - && otherJoinCondition.equals(that.otherJoinCondition); + && otherJoinConjuncts.equals(that.otherJoinConjuncts); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), joinType, hashJoinConjuncts, otherJoinCondition); + return Objects.hash(super.hashCode(), joinType, hashJoinConjuncts, otherJoinConjuncts); } /** - * hashJoinConjuncts and otherJoinCondition + * hashJoinConjuncts and otherJoinConjuncts * - * @return the combination of hashJoinConjuncts and otherJoinCondition + * @return the combination of hashJoinConjuncts and otherJoinConjuncts */ public Optional<Expression> getOnClauseCondition() { - Optional<Expression> hashJoinCondition = ExpressionUtils.optionalAnd(hashJoinConjuncts); - - if (hashJoinCondition.isPresent() && otherJoinCondition.isPresent()) { - return ExpressionUtils.optionalAnd(hashJoinCondition.get(), otherJoinCondition.get()); - } - - return hashJoinCondition.map(Optional::of) - .orElse(otherJoinCondition); + return ExpressionUtils.optionalAnd(hashJoinConjuncts, otherJoinConjuncts); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java index 949d33a11d..11a27566e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java @@ -47,9 +47,10 @@ public class PhysicalHashJoin< private final List<Expression> filterConjuncts = Lists.newArrayList(); public PhysicalHashJoin(JoinType joinType, List<Expression> hashJoinConjuncts, - Optional<Expression> condition, LogicalProperties logicalProperties, + List<Expression> otherJoinConjuncts, LogicalProperties logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - this(joinType, hashJoinConjuncts, condition, Optional.empty(), logicalProperties, leftChild, rightChild); + this(joinType, hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), logicalProperties, leftChild, + rightChild); } /** @@ -57,12 +58,11 @@ public class PhysicalHashJoin< * * @param joinType Which join type, left semi join, inner join... * @param hashJoinConjuncts conjunct list could use for build hash table in hash join - * @param condition join condition except hash join conjuncts */ - public PhysicalHashJoin(JoinType joinType, List<Expression> hashJoinConjuncts, Optional<Expression> condition, + public PhysicalHashJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, condition, + super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, groupExpression, logicalProperties, leftChild, rightChild); } @@ -71,13 +71,12 @@ public class PhysicalHashJoin< * * @param joinType Which join type, left semi join, inner join... * @param hashJoinConjuncts conjunct list could use for build hash table in hash join - * @param condition join condition except hash join conjuncts */ - public PhysicalHashJoin(JoinType joinType, List<Expression> hashJoinConjuncts, Optional<Expression> condition, + public PhysicalHashJoin(JoinType joinType, List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, condition, + super(PlanType.PHYSICAL_HASH_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, groupExpression, logicalProperties, physicalProperties, statsDeriveResult, leftChild, rightChild); } @@ -91,35 +90,35 @@ public class PhysicalHashJoin< return Utils.toSqlString("PhysicalHashJoin", "type", joinType, "hashJoinCondition", hashJoinConjuncts, - "otherJoinCondition", otherJoinCondition, + "otherJoinCondition", otherJoinConjuncts, "stats", statsDeriveResult); } @Override public PhysicalHashJoin<Plan, Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 2); - return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, + return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, getLogicalProperties(), children.get(0), children.get(1)); } @Override public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withGroupExpression( Optional<GroupExpression> groupExpression) { - return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, + return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, groupExpression, getLogicalProperties(), left(), right()); } @Override public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withLogicalProperties( Optional<LogicalProperties> logicalProperties) { - return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, + return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), logicalProperties.get(), left(), right()); } @Override public PhysicalHashJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withPhysicalPropertiesAndStats( PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult) { - return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinCondition, + return new PhysicalHashJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), getLogicalProperties(), physicalProperties, statsDeriveResult, left(), right()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalNestedLoopJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalNestedLoopJoin.java index 44fa04b859..e9b54e65b8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalNestedLoopJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalNestedLoopJoin.java @@ -42,9 +42,9 @@ public class PhysicalNestedLoopJoin< extends AbstractPhysicalJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> { public PhysicalNestedLoopJoin(JoinType joinType, - List<Expression> hashJoinConjuncts, Optional<Expression> condition, + List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, LogicalProperties logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - this(joinType, hashJoinConjuncts, condition, + this(joinType, hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), logicalProperties, leftChild, rightChild); } @@ -53,13 +53,12 @@ public class PhysicalNestedLoopJoin< * * @param joinType Which join type, left semi join, inner join... * @param hashJoinConjuncts conjunct list could use for build hash table in hash join - * @param condition join condition except hash join conjuncts */ public PhysicalNestedLoopJoin(JoinType joinType, - List<Expression> hashJoinConjuncts, Optional<Expression> condition, + List<Expression> hashJoinConjuncts, List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, condition, + super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, groupExpression, logicalProperties, leftChild, rightChild); } @@ -68,13 +67,13 @@ public class PhysicalNestedLoopJoin< * * @param joinType Which join type, left semi join, inner join... * @param hashJoinConjuncts conjunct list could use for build hash table in hash join - * @param condition join condition except hash join conjuncts */ - public PhysicalNestedLoopJoin(JoinType joinType, List<Expression> hashJoinConjuncts, Optional<Expression> condition, - Optional<GroupExpression> groupExpression, LogicalProperties logicalProperties, - PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult, LEFT_CHILD_TYPE leftChild, + public PhysicalNestedLoopJoin(JoinType joinType, List<Expression> hashJoinConjuncts, + List<Expression> otherJoinConjuncts, Optional<GroupExpression> groupExpression, + LogicalProperties logicalProperties, PhysicalProperties physicalProperties, + StatsDeriveResult statsDeriveResult, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, condition, + super(PlanType.PHYSICAL_NESTED_LOOP_JOIN, joinType, hashJoinConjuncts, otherJoinConjuncts, groupExpression, logicalProperties, physicalProperties, statsDeriveResult, leftChild, rightChild); } @@ -88,7 +87,7 @@ public class PhysicalNestedLoopJoin< // TODO: Maybe we could pull up this to the abstract class in the future. return Utils.toSqlString("PhysicalNestedLoopJoin", "type", joinType, - "otherJoinCondition", otherJoinCondition + "otherJoinCondition", otherJoinConjuncts ); } @@ -96,21 +95,21 @@ public class PhysicalNestedLoopJoin< public PhysicalNestedLoopJoin<Plan, Plan> withChildren(List<Plan> children) { Preconditions.checkArgument(children.size() == 2); return new PhysicalNestedLoopJoin<>(joinType, - hashJoinConjuncts, otherJoinCondition, getLogicalProperties(), children.get(0), children.get(1)); + hashJoinConjuncts, otherJoinConjuncts, getLogicalProperties(), children.get(0), children.get(1)); } @Override public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withGroupExpression( Optional<GroupExpression> groupExpression) { return new PhysicalNestedLoopJoin<>(joinType, - hashJoinConjuncts, otherJoinCondition, groupExpression, getLogicalProperties(), left(), right()); + hashJoinConjuncts, otherJoinConjuncts, groupExpression, getLogicalProperties(), left(), right()); } @Override public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withLogicalProperties( Optional<LogicalProperties> logicalProperties) { return new PhysicalNestedLoopJoin<>(joinType, - hashJoinConjuncts, otherJoinCondition, Optional.empty(), + hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), logicalProperties.get(), left(), right()); } @@ -118,7 +117,7 @@ public class PhysicalNestedLoopJoin< public PhysicalNestedLoopJoin<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> withPhysicalPropertiesAndStats( PhysicalProperties physicalProperties, StatsDeriveResult statsDeriveResult) { return new PhysicalNestedLoopJoin<>(joinType, - hashJoinConjuncts, otherJoinCondition, Optional.empty(), + hashJoinConjuncts, otherJoinConjuncts, Optional.empty(), getLogicalProperties(), physicalProperties, statsDeriveResult, left(), right()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index aa0f83c32c..42b73e573e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -48,6 +48,8 @@ import java.util.Set; */ public class ExpressionUtils { + public static final List<Expression> EMPTY_CONDITION = ImmutableList.of(); + public static List<Expression> extractConjunction(Expression expr) { return extract(And.class, expr); } @@ -96,6 +98,21 @@ public class ExpressionUtils { } } + /** + * And two list. + */ + public static Optional<Expression> optionalAnd(List<Expression> left, List<Expression> right) { + if (left.isEmpty() && right.isEmpty()) { + return Optional.empty(); + } else if (left.isEmpty()) { + return optionalAnd(right); + } else if (right.isEmpty()) { + return optionalAnd(left); + } else { + return Optional.of(new And(optionalAnd(left).get(), optionalAnd(right).get())); + } + } + public static Optional<Expression> optionalAnd(Expression... expressions) { return optionalAnd(Lists.newArrayList(expressions)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index b65843d457..8e6eb1ae38 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -29,11 +29,9 @@ import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Join; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; @@ -117,23 +115,6 @@ public class JoinUtils { } } - /** - * collect expressions from on clause, which could be used to build hash table - * @param join join node - * @return pair of expressions, for hash table or not. - */ - public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable( - LogicalJoin<GroupPlan, GroupPlan> join) { - if (join.getOtherJoinCondition().isPresent()) { - List<Expression> onExprs = ExpressionUtils.extractConjunction( - join.getOtherJoinCondition().get()); - List<Slot> leftSlots = join.left().getOutput(); - List<Slot> rightSlots = join.right().getOutput(); - return extractExpressionForHashTable(leftSlots, rightSlots, onExprs); - } - return Pair.of(Lists.newArrayList(), Lists.newArrayList()); - } - /** * extract expression * @param leftSlots left child output slots diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java index 1b992ffaec..7a82c5dcd2 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java @@ -796,7 +796,6 @@ public class MemoRewriteTest implements PatternMatchSupported { .analyze(new LogicalLimit<>(10, 0, new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN, ImmutableList.of(new EqualTo(new UnboundSlot("sid"), new UnboundSlot("id"))), - Optional.empty(), new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.score), new LogicalOlapScan(RelationId.createGenerator().getNextId(), PlanConstructor.student) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java index 48e7d043dd..ec7347ebbd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/ChildOutputPropertyDeriverTest.java @@ -42,6 +42,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.qe.ConnectContext; @@ -61,7 +62,6 @@ import org.junit.jupiter.api.Test; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.stream.Collectors; @SuppressWarnings("unused") @@ -110,7 +110,7 @@ public class ChildOutputPropertyDeriverTest { }; PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.RIGHT_OUTER_JOIN, - Collections.emptyList(), Optional.empty(), logicalProperties, groupPlan, groupPlan); + ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan, groupPlan); GroupExpression groupExpression = new GroupExpression(join); PhysicalProperties left = new PhysicalProperties( @@ -157,8 +157,9 @@ public class ChildOutputPropertyDeriverTest { PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(0), "left", IntegerType.INSTANCE, false, Collections.emptyList()), - new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false, Collections.emptyList()))), - Optional.empty(), logicalProperties, groupPlan, groupPlan); + new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false, + Collections.emptyList()))), + ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan, groupPlan); GroupExpression groupExpression = new GroupExpression(join); Map<ExprId, Integer> leftMap = Maps.newHashMap(); @@ -174,7 +175,8 @@ public class ChildOutputPropertyDeriverTest { )); PhysicalProperties right = new PhysicalProperties(DistributionSpecReplicated.INSTANCE, - new OrderSpec(Lists.newArrayList(new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); + new OrderSpec(Lists.newArrayList( + new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); List<PhysicalProperties> childrenOutputProperties = Lists.newArrayList(left, right); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(childrenOutputProperties); @@ -201,8 +203,9 @@ public class ChildOutputPropertyDeriverTest { PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(0), "left", IntegerType.INSTANCE, false, Collections.emptyList()), - new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false, Collections.emptyList()))), - Optional.empty(), logicalProperties, groupPlan, groupPlan); + new SlotReference(new ExprId(2), "right", IntegerType.INSTANCE, false, + Collections.emptyList()))), + ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan, groupPlan); GroupExpression groupExpression = new GroupExpression(join); Map<ExprId, Integer> leftMap = Maps.newHashMap(); @@ -239,7 +242,7 @@ public class ChildOutputPropertyDeriverTest { @Test public void testNestedLoopJoin() { PhysicalNestedLoopJoin<GroupPlan, GroupPlan> join = new PhysicalNestedLoopJoin<>(JoinType.CROSS_JOIN, - Collections.emptyList(), Optional.empty(), logicalProperties, groupPlan, groupPlan); + ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan, groupPlan); GroupExpression groupExpression = new GroupExpression(join); Map<ExprId, Integer> leftMap = Maps.newHashMap(); @@ -280,7 +283,8 @@ public class ChildOutputPropertyDeriverTest { ); GroupExpression groupExpression = new GroupExpression(aggregate); PhysicalProperties child = new PhysicalProperties(DistributionSpecReplicated.INSTANCE, - new OrderSpec(Lists.newArrayList(new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); + new OrderSpec(Lists.newArrayList( + new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); PhysicalProperties result = deriver.getOutputProperties(groupExpression); @@ -303,9 +307,11 @@ public class ChildOutputPropertyDeriverTest { groupPlan ); GroupExpression groupExpression = new GroupExpression(aggregate); - DistributionSpecHash childHash = new DistributionSpecHash(Lists.newArrayList(partition.getExprId()), ShuffleType.BUCKETED); + DistributionSpecHash childHash = new DistributionSpecHash(Lists.newArrayList(partition.getExprId()), + ShuffleType.BUCKETED); PhysicalProperties child = new PhysicalProperties(childHash, - new OrderSpec(Lists.newArrayList(new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); + new OrderSpec(Lists.newArrayList( + new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); PhysicalProperties result = deriver.getOutputProperties(groupExpression); @@ -333,7 +339,8 @@ public class ChildOutputPropertyDeriverTest { GroupExpression groupExpression = new GroupExpression(aggregate); PhysicalProperties child = new PhysicalProperties(DistributionSpecGather.INSTANCE, - new OrderSpec(Lists.newArrayList(new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); + new OrderSpec(Lists.newArrayList( + new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); PhysicalProperties result = deriver.getOutputProperties(groupExpression); @@ -347,7 +354,8 @@ public class ChildOutputPropertyDeriverTest { PhysicalLocalQuickSort<GroupPlan> sort = new PhysicalLocalQuickSort(orderKeys, logicalProperties, groupPlan); GroupExpression groupExpression = new GroupExpression(sort); PhysicalProperties child = new PhysicalProperties(DistributionSpecReplicated.INSTANCE, - new OrderSpec(Lists.newArrayList(new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); + new OrderSpec(Lists.newArrayList( + new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); PhysicalProperties result = deriver.getOutputProperties(groupExpression); @@ -362,7 +370,8 @@ public class ChildOutputPropertyDeriverTest { PhysicalQuickSort<GroupPlan> sort = new PhysicalQuickSort<>(orderKeys, logicalProperties, groupPlan); GroupExpression groupExpression = new GroupExpression(sort); PhysicalProperties child = new PhysicalProperties(DistributionSpecReplicated.INSTANCE, - new OrderSpec(Lists.newArrayList(new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); + new OrderSpec(Lists.newArrayList( + new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); PhysicalProperties result = deriver.getOutputProperties(groupExpression); @@ -377,7 +386,8 @@ public class ChildOutputPropertyDeriverTest { PhysicalTopN<GroupPlan> sort = new PhysicalTopN<>(orderKeys, 10, 10, logicalProperties, groupPlan); GroupExpression groupExpression = new GroupExpression(sort); PhysicalProperties child = new PhysicalProperties(DistributionSpecReplicated.INSTANCE, - new OrderSpec(Lists.newArrayList(new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); + new OrderSpec(Lists.newArrayList( + new OrderKey(new SlotReference("ignored", IntegerType.INSTANCE), true, true)))); ChildOutputPropertyDeriver deriver = new ChildOutputPropertyDeriver(Lists.newArrayList(child)); PhysicalProperties result = deriver.getOutputProperties(groupExpression); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java index c133bf8ed3..90d3010446 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java @@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin; import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin; import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; import com.google.common.collect.Lists; @@ -46,9 +47,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import java.util.Collections; import java.util.List; -import java.util.Optional; @SuppressWarnings("unused") public class RequestPropertyDeriverTest { @@ -75,8 +74,8 @@ public class RequestPropertyDeriverTest { @Test public void testNestedLoopJoin() { - PhysicalNestedLoopJoin<GroupPlan, GroupPlan> join = new PhysicalNestedLoopJoin<>(JoinType.CROSS_JOIN, Collections.emptyList(), - Optional.empty(), logicalProperties, groupPlan, groupPlan); + PhysicalNestedLoopJoin<GroupPlan, GroupPlan> join = new PhysicalNestedLoopJoin<>(JoinType.CROSS_JOIN, + ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan, groupPlan); GroupExpression groupExpression = new GroupExpression(join); RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); @@ -99,7 +98,7 @@ public class RequestPropertyDeriverTest { }; PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.RIGHT_OUTER_JOIN, - Collections.emptyList(), Optional.empty(), logicalProperties, groupPlan, groupPlan); + ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan, groupPlan); GroupExpression groupExpression = new GroupExpression(join); RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); @@ -125,7 +124,7 @@ public class RequestPropertyDeriverTest { }; PhysicalHashJoin<GroupPlan, GroupPlan> join = new PhysicalHashJoin<>(JoinType.INNER_JOIN, - Collections.emptyList(), Optional.empty(), logicalProperties, groupPlan, groupPlan); + ExpressionUtils.EMPTY_CONDITION, ExpressionUtils.EMPTY_CONDITION, logicalProperties, groupPlan, groupPlan); GroupExpression groupExpression = new GroupExpression(join); RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(jobContext); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomTest.java index 04cc86e9bb..8b8e7b155b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomTest.java @@ -35,8 +35,6 @@ import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.Optional; - public class InnerJoinLAsscomTest implements PatternMatchSupported { private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); @@ -96,11 +94,9 @@ public class InnerJoinLAsscomTest implements PatternMatchSupported { Expression bottomJoinOnCondition = new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0)); Expression topJoinOnCondition = new EqualTo(scan2.getOutput().get(0), scan3.getOutput().get(0)); LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, - Lists.newArrayList(bottomJoinOnCondition), - Optional.empty(), scan1, scan2); + Lists.newArrayList(bottomJoinOnCondition), scan1, scan2); LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>( - JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), - Optional.empty(), bottomJoin, scan3); + JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), bottomJoin, scan3); PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .applyExploration(InnerJoinLAsscom.INSTANCE.build()) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java index 025f2a39eb..88f5e0570f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoinTest.java @@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanConstructor; @@ -41,7 +40,6 @@ import org.junit.jupiter.api.Test; import java.util.ArrayList; import java.util.List; -import java.util.Optional; /** * initial plan: @@ -57,26 +55,26 @@ import java.util.Optional; class FindHashConditionForJoinTest { @Test public void testFindHashCondition() { - Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); - Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); + Plan student = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, + ImmutableList.of("")); + Plan score = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, + ImmutableList.of("")); Slot studentId = student.getOutput().get(0); Slot gender = student.getOutput().get(1); Slot scoreId = score.getOutput().get(0); Slot cid = score.getOutput().get(1); - Expression eq1 = new EqualTo(studentId, scoreId); //a=b - Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); //a=1 - Expression eq3 = new EqualTo( - new Add(studentId, new IntegerLiteral(1)), - cid); + Expression eq1 = new EqualTo(studentId, scoreId); // a=b + Expression eq2 = new EqualTo(studentId, new IntegerLiteral(1)); // a=1 + Expression eq3 = new EqualTo(new Add(studentId, new IntegerLiteral(1)), cid); Expression or = new Or( new EqualTo(scoreId, studentId), new EqualTo(gender, cid)); Expression less = new LessThan(scoreId, studentId); - Expression expr = ExpressionUtils.and(eq1, eq2, eq3, or, less); - LogicalJoin join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), - Optional.of(expr), student, score); + List<Expression> expr = ImmutableList.of(eq1, eq2, eq3, or, less); + LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, new ArrayList<>(), + expr, student, score); CascadesContext context = MemoTestUtils.createCascadesContext(join); List<Rule> rules = Lists.newArrayList(new FindHashConditionForJoin().build()); @@ -86,7 +84,7 @@ class FindHashConditionForJoinTest { Assertions.assertEquals(after.getHashJoinConjuncts().size(), 2); Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq1)); Assertions.assertTrue(after.getHashJoinConjuncts().contains(eq3)); - List<Expression> others = ExpressionUtils.extractConjunction((Expression) after.getOtherJoinCondition().get()); + List<Expression> others = after.getOtherJoinConjuncts(); Assertions.assertEquals(others.size(), 3); Assertions.assertTrue(others.contains(less)); Assertions.assertTrue(others.contains(eq2)); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java index f71417e469..1d953700a1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/LimitPushDownTest.java @@ -43,7 +43,6 @@ import org.junit.jupiter.api.Test; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -213,7 +212,6 @@ class LimitPushDownTest extends TestWithFeService implements PatternMatchSupport LogicalJoin<? extends Plan, ? extends Plan> join = new LogicalJoin<>( joinType, joinConditions, - Optional.empty(), new LogicalOlapScan(new RelationId(0), PlanConstructor.score), new LogicalOlapScan(new RelationId(1), PlanConstructor.student) ); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoinTest.java index 842523ef4e..9b23a7312b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushFilterInsideJoinTest.java @@ -48,7 +48,7 @@ class PushFilterInsideJoinTest implements PatternMatchSupported { .applyTopDown(PushFilterInsideJoin.INSTANCE) .printlnTree() .matchesFromRoot( - logicalJoin().when(join -> join.getOtherJoinCondition().get().equals(predicates)) + logicalJoin().when(join -> join.getOtherJoinConjuncts().get(0).equals(predicates)) ); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoinTest.java index 58e9760e9c..c10b0b05b1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownFilterThroughJoinTest.java @@ -140,7 +140,7 @@ public class PushdownFilterThroughJoinTest implements PatternMatchSupported { .when(filter -> filter.getPredicates().equals(leftSide)), logicalFilter(logicalOlapScan()) .when(filter -> filter.getPredicates().equals(rightSide)) - ).when(join -> join.getOtherJoinCondition().get().equals(bothSideEqualTo)) + ).when(join -> join.getOtherJoinConjuncts().get(0).equals(bothSideEqualTo)) ); } if (joinType.isCrossJoin()) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java index 0d7d711c42..387edc747c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushdownJoinOtherConditionTest.java @@ -40,7 +40,7 @@ import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; -import java.util.Optional; +import java.util.List; @TestInstance(TestInstance.Lifecycle.PER_CLASS) public class PushdownJoinOtherConditionTest { @@ -53,7 +53,8 @@ public class PushdownJoinOtherConditionTest { */ @BeforeAll public final void beforeAll() { - rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("")); + rStudent = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, + ImmutableList.of("")); rScore = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.score, ImmutableList.of("")); } @@ -73,7 +74,7 @@ public class PushdownJoinOtherConditionTest { Expression pushSide1 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); Expression pushSide2 = new GreaterThan(rStudent.getOutput().get(1), Literal.of(50)); - Expression condition = ExpressionUtils.and(pushSide1, pushSide2); + List<Expression> condition = ImmutableList.of(pushSide1, pushSide2); Plan left = rStudent; Plan right = rScore; @@ -82,7 +83,7 @@ public class PushdownJoinOtherConditionTest { right = rStudent; } - Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right); + Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, left, right); Plan root = new LogicalProject<>(Lists.newArrayList(), join); Memo memo = rewrite(root); @@ -105,7 +106,7 @@ public class PushdownJoinOtherConditionTest { Assertions.assertTrue(shouldScan instanceof LogicalOlapScan); LogicalFilter<Plan> actualFilter = (LogicalFilter<Plan>) shouldFilter; - Assertions.assertEquals(condition, actualFilter.getPredicates()); + Assertions.assertEquals(ExpressionUtils.and(condition), actualFilter.getPredicates()); } @Test @@ -120,9 +121,9 @@ public class PushdownJoinOtherConditionTest { Expression leftSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); Expression rightSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); - Expression condition = ExpressionUtils.and(leftSide, rightSide); + List<Expression> condition = ImmutableList.of(leftSide, rightSide); - Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), rStudent, rScore); + Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, rStudent, rScore); Plan root = new LogicalProject<>(Lists.newArrayList(), join); Memo memo = rewrite(root); @@ -155,7 +156,7 @@ public class PushdownJoinOtherConditionTest { Expression pushSide = new GreaterThan(rStudent.getOutput().get(1), Literal.of(18)); Expression reserveSide = new GreaterThan(rScore.getOutput().get(2), Literal.of(60)); - Expression condition = ExpressionUtils.and(pushSide, reserveSide); + List<Expression> condition = ImmutableList.of(pushSide, reserveSide); Plan left = rStudent; Plan right = rScore; @@ -164,7 +165,7 @@ public class PushdownJoinOtherConditionTest { right = rStudent; } - Plan join = new LogicalJoin<>(joinType, Lists.newArrayList(), Optional.of(condition), left, right); + Plan join = new LogicalJoin<>(joinType, ExpressionUtils.EMPTY_CONDITION, condition, left, right); Plan root = new LogicalProject<>(Lists.newArrayList(), join); Memo memo = rewrite(root); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java new file mode 100644 index 0000000000..2d958a613b --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.sqltest; + +import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin; +import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +public class SqlTest extends TestWithFeService implements PatternMatchSupported { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + connectContext.setDatabase("default_cluster:test"); + + createTables( + "CREATE TABLE IF NOT EXISTS T1 (\n" + + " id bigint,\n" + + " score bigint\n" + + ")\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES (\n" + + " \"replication_num\" = \"1\"\n" + + ")\n", + "CREATE TABLE IF NOT EXISTS T2 (\n" + + " id bigint,\n" + + " score bigint\n" + + ")\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES (\n" + + " \"replication_num\" = \"1\"\n" + + ")\n", + "CREATE TABLE IF NOT EXISTS T3 (\n" + + " id bigint,\n" + + " score bigint\n" + + ")\n" + + "DUPLICATE KEY(id)\n" + + "DISTRIBUTED BY HASH(id) BUCKETS 1\n" + + "PROPERTIES (\n" + + " \"replication_num\" = \"1\"\n" + + ")\n" + ); + } + + @Override + protected void runBeforeEach() throws Exception { + NamedExpressionUtil.clear(); + } + + @Test + void testSql() { + // String sql = "SELECT *" + // + " FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id" + // + " WHERE T1.id = T2.id"; + String sql = "SELECT *" + + " FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1" + + " WHERE T1.id = T2.id"; + PlanChecker.from(connectContext) + .analyze(sql) + .applyTopDown(new ReorderJoin()) + .implement() + .printlnTree(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java index 2c56e6900b..03e166b0f1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java @@ -40,6 +40,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalQuickSort; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; @@ -101,18 +102,18 @@ public class PlanEqualsTest { LogicalJoin<Plan, Plan> actual = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), - Optional.empty(), left, right); + left, right); LogicalJoin<Plan, Plan> expected = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), - Optional.empty(), left, right); + left, right); Assertions.assertEquals(expected, actual); LogicalJoin<Plan, Plan> unexpected = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, false, Lists.newArrayList()), new SlotReference(new ExprId(3), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), - Optional.empty(), left, right); + left, right); Assertions.assertNotEquals(unexpected, actual); } @@ -130,21 +131,25 @@ public class PlanEqualsTest { @Test public void testLogicalProject(@Mocked Plan child) { LogicalProject<Plan> actual = new LogicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child); LogicalProject<Plan> expected = new LogicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child); Assertions.assertEquals(expected, actual); LogicalProject<Plan> unexpected1 = new LogicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), child); Assertions.assertNotEquals(unexpected1, actual); LogicalProject<Plan> unexpected2 = new LogicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), child); Assertions.assertNotEquals(unexpected2, actual); } @@ -215,20 +220,20 @@ public class PlanEqualsTest { Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), - Optional.empty(), logicalProperties, left, right); + ExpressionUtils.EMPTY_CONDITION, logicalProperties, left, right); PhysicalHashJoin<Plan, Plan> expected = new PhysicalHashJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), - Optional.empty(), logicalProperties, left, right); + ExpressionUtils.EMPTY_CONDITION, logicalProperties, left, right); Assertions.assertEquals(expected, actual); PhysicalHashJoin<Plan, Plan> unexpected = new PhysicalHashJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(new EqualTo( new SlotReference(new ExprId(2), "a", BigIntType.INSTANCE, false, Lists.newArrayList()), new SlotReference(new ExprId(3), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), - Optional.empty(), logicalProperties, left, right); + ExpressionUtils.EMPTY_CONDITION, logicalProperties, left, right); Assertions.assertNotEquals(unexpected, actual); } @@ -262,24 +267,28 @@ public class PlanEqualsTest { @Test public void testPhysicalProject(@Mocked Plan child, @Mocked LogicalProperties logicalProperties) { PhysicalProject<Plan> actual = new PhysicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), logicalProperties, child); PhysicalProject<Plan> expected = new PhysicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), logicalProperties, child); Assertions.assertEquals(expected, actual); PhysicalProject<Plan> unexpected1 = new PhysicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(1), "a", BigIntType.INSTANCE, true, Lists.newArrayList())), logicalProperties, child); Assertions.assertNotEquals(unexpected1, actual); PhysicalProject<Plan> unexpected2 = new PhysicalProject<>( - ImmutableList.of(new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), + ImmutableList.of( + new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList())), logicalProperties, child); Assertions.assertNotEquals(unexpected2, actual); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java index 803523eb0c..f39bba3331 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanToStringTest.java @@ -40,7 +40,6 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.List; -import java.util.Optional; public class PlanToStringTest { @@ -72,10 +71,9 @@ public class PlanToStringTest { LogicalJoin<Plan, Plan> plan = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList( new EqualTo(new SlotReference(new ExprId(0), "a", BigIntType.INSTANCE, true, Lists.newArrayList()), new SlotReference(new ExprId(1), "b", BigIntType.INSTANCE, true, Lists.newArrayList()))), - Optional.empty(), left, right); - + left, right); Assertions.assertTrue(plan.toString().matches( - "LogicalJoin \\( type=INNER_JOIN, hashJoinCondition=\\[\\(a#\\d+ = b#\\d+\\)], otherJoinCondition=Optional.empty \\)")); + "LogicalJoin \\( type=INNER_JOIN, hashJoinCondition=\\[\\(a#\\d+ = b#\\d+\\)], otherJoinCondition=\\[] \\)")); } @Test diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java index 4b8960d75c..43ca065228 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeSubQueryTest.java @@ -37,7 +37,6 @@ import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; import java.util.List; -import java.util.Optional; public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMatchSupported { @@ -136,10 +135,10 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat ) ) .when(FieldChecker.check("joinType", JoinType.INNER_JOIN)) - .when(FieldChecker.check("otherJoinCondition", - Optional.of(new EqualTo( - new SlotReference(new ExprId(2), "id", BigIntType.INSTANCE, true, ImmutableList.of("TT1")), - new SlotReference(new ExprId(0), "id", BigIntType.INSTANCE, true, ImmutableList.of("T"))))) + .when(FieldChecker.check("otherJoinConjuncts", + ImmutableList.of(new EqualTo( + new SlotReference(new ExprId(2), "id", BigIntType.INSTANCE, true, ImmutableList.of("TT1")), + new SlotReference(new ExprId(0), "id", BigIntType.INSTANCE, true, ImmutableList.of("T"))))) ) ).when(FieldChecker.check("projects", ImmutableList.of( new SlotReference(new ExprId(2), "id", BigIntType.INSTANCE, true, ImmutableList.of("TT1")), @@ -162,7 +161,7 @@ public class AnalyzeSubQueryTest extends TestWithFeService implements PatternMat logicalOlapScan() ) .when(FieldChecker.check("joinType", JoinType.INNER_JOIN)) - .when(FieldChecker.check("otherJoinCondition", Optional.of(new EqualTo( + .when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of(new EqualTo( new SlotReference(new ExprId(0), "id", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "T1")), new SlotReference(new ExprId(2), "id", BigIntType.INSTANCE, true, ImmutableList.of("T2"))))) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java index 60bba6ba38..94f5149bd3 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java @@ -34,7 +34,6 @@ import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderFilter; import org.apache.doris.nereids.rules.rewrite.logical.PushApplyUnderProject; import org.apache.doris.nereids.rules.rewrite.logical.ScalarApplyToJoin; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil; @@ -75,9 +74,11 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte + "t7.k3 in (select t8.k1 from t8 where t8.k1 = 3) " + "and t7.v2 > (select sum(t9.k2) from t9 where t9.k2 = t7.v1))"; // exists and not exists - private final String sql9 = "select * from t6 where exists (select t7.k3 from t7 where t6.k2 = t7.v2) and not exists (select t8.k2 from t8 where t6.k2 = t8.k2)"; + private final String sql9 + = "select * from t6 where exists (select t7.k3 from t7 where t6.k2 = t7.v2) and not exists (select t8.k2 from t8 where t6.k2 = t8.k2)"; // with subquery alias - private final String sql10 = "select * from t6 where t6.k1 < (select max(aa) from (select v1 as aa from t7 where t6.k2=t7.v2) t2 )"; + private final String sql10 + = "select * from t6 where t6.k1 < (select max(aa) from (select v1 as aa from t7 where t6.k2=t7.v2) t2 )"; private final List<String> testSql = ImmutableList.of( sql1, sql2, sql3, sql4, sql5, sql6, sql7, sql8, sql9, sql10 @@ -139,7 +140,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte @Test public void testWhereSql2AfterAnalyzed() { - //after analyze + // after analyze PlanChecker.from(connectContext) .analyze(sql2) .matches( @@ -163,7 +164,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte @Test public void testWhereSql2AfterAggFilterRule() { - //after aggFilter rule + // after aggFilter rule PlanChecker.from(connectContext) .analyze(sql2) .applyBottomUp(new ApplyPullFilterOnAgg()) @@ -197,7 +198,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte @Test public void testWhereSql2AfterScalarToJoin() { - //after Scalar CorrelatedJoin to join + // after Scalar CorrelatedJoin to join PlanChecker.from(connectContext) .analyze(sql2) .applyBottomUp(new ApplyPullFilterOnAgg()) @@ -207,9 +208,10 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte any(), logicalAggregate() ).when(FieldChecker.check("joinType", JoinType.LEFT_OUTER_JOIN)) - .when(FieldChecker.check("otherJoinCondition", - Optional.of(new EqualTo(new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t7")), + .when(FieldChecker.check("otherJoinConjuncts", + ImmutableList.of(new EqualTo( + new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true, + ImmutableList.of("default_cluster:test", "t7")), new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")))))) ); @@ -276,15 +278,15 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte .applyBottomUp(new InApplyToJoin()) .matches( logicalJoin().when(FieldChecker.check("joinType", JoinType.LEFT_SEMI_JOIN)) - .when(FieldChecker.check("otherJoinCondition", Optional.of( - new And(new EqualTo(new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true, + .when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of( + new EqualTo(new SlotReference(new ExprId(0), "k1", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")), new SlotReference(new ExprId(2), "k1", BigIntType.INSTANCE, false, ImmutableList.of("default_cluster:test", "t7"))), - new EqualTo(new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t7")), - new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t6")))) + new EqualTo(new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true, + ImmutableList.of("default_cluster:test", "t7")), + new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, + ImmutableList.of("default_cluster:test", "t6"))) ))) ); } @@ -350,7 +352,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte .applyBottomUp(new ExistsApplyToJoin()) .matches( logicalJoin().when(FieldChecker.check("joinType", JoinType.LEFT_SEMI_JOIN)) - .when(FieldChecker.check("otherJoinCondition", Optional.of( + .when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of( new EqualTo(new SlotReference(new ExprId(1), "k2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")), new SlotReference(new ExprId(6), "v2", BigIntType.INSTANCE, true, @@ -361,7 +363,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte @Test public void testSql10AfterAnalyze() { - //select * from t6 where t6.k1 < (select max(aa) from (select v1 as aa from t7 where t6.k2=t7.v2) t2 ) + // select * from t6 where t6.k1 < (select max(aa) from (select v1 as aa from t7 where t6.k2=t7.v2) t2 ) PlanChecker.from(connectContext) .analyze(sql10) .matches( @@ -371,13 +373,17 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte logicalProject( logicalFilter() ).when(FieldChecker.check("projects", ImmutableList.of( - new Alias(new ExprId(0), new SlotReference(new ExprId(6), "v1", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t7")), "aa") + new Alias(new ExprId(0), + new SlotReference(new ExprId(6), "v1", BigIntType.INSTANCE, + true, + ImmutableList.of("default_cluster:test", "t7")), "aa") ))) ).when(FieldChecker.check("outputExpressions", ImmutableList.of( - new Alias(new ExprId(8), new Max(new SlotReference(new ExprId(0), "aa", BigIntType.INSTANCE, true, - ImmutableList.of("t2"))), "max(aa)") - ))) + new Alias(new ExprId(8), + new Max(new SlotReference(new ExprId(0), "aa", BigIntType.INSTANCE, + true, + ImmutableList.of("t2"))), "max(aa)") + ))) .when(FieldChecker.check("groupByExpressions", ImmutableList.of())) ).when(FieldChecker.check("correlationSlot", ImmutableList.of( new SlotReference(new ExprId(2), "k2", BigIntType.INSTANCE, true, @@ -396,17 +402,23 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte logicalAggregate( logicalFilter( logicalProject().when(FieldChecker.check("projects", ImmutableList.of( - new Alias(new ExprId(0), new SlotReference(new ExprId(6), "v1", BigIntType.INSTANCE, true, + new Alias(new ExprId(0), new SlotReference(new ExprId(6), "v1", + BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t7")), "aa"), - new SlotReference(new ExprId(3), "k1", BigIntType.INSTANCE, false, + new SlotReference(new ExprId(3), "k1", BigIntType.INSTANCE, + false, + ImmutableList.of("default_cluster:test", "t7")), + new SlotReference(new ExprId(4), "k2", new VarcharType(128), + true, ImmutableList.of("default_cluster:test", "t7")), - new SlotReference(new ExprId(4), "k2", new VarcharType(128), true, + new SlotReference(new ExprId(5), "k3", BigIntType.INSTANCE, + true, ImmutableList.of("default_cluster:test", "t7")), - new SlotReference(new ExprId(5), "k3", BigIntType.INSTANCE, true, + new SlotReference(new ExprId(6), "v1", BigIntType.INSTANCE, + true, ImmutableList.of("default_cluster:test", "t7")), - new SlotReference(new ExprId(6), "v1", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t7")), - new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true, + new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, + true, ImmutableList.of("default_cluster:test", "t7")) ))) ) @@ -427,10 +439,12 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte logicalAggregate( logicalProject() ).when(FieldChecker.check("outputExpressions", ImmutableList.of( - new Alias(new ExprId(8), new Max(new SlotReference(new ExprId(0), "aa", BigIntType.INSTANCE, true, - ImmutableList.of("t2"))), "max(aa)"), - new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true, - ImmutableList.of("default_cluster:test", "t7"))))) + new Alias(new ExprId(8), + new Max(new SlotReference(new ExprId(0), "aa", BigIntType.INSTANCE, + true, + ImmutableList.of("t2"))), "max(aa)"), + new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true, + ImmutableList.of("default_cluster:test", "t7"))))) .when(FieldChecker.check("groupByExpressions", ImmutableList.of( new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t7")) @@ -453,7 +467,7 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte logicalProject() ) ).when(FieldChecker.check("joinType", JoinType.LEFT_OUTER_JOIN)) - .when(FieldChecker.check("otherJoinCondition", Optional.of( + .when(FieldChecker.check("otherJoinConjuncts", ImmutableList.of( new EqualTo(new SlotReference(new ExprId(2), "k2", BigIntType.INSTANCE, true, ImmutableList.of("default_cluster:test", "t6")), new SlotReference(new ExprId(7), "v2", BigIntType.INSTANCE, true, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java index 13573fc0d2..f0bfa95d73 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/LogicalPlanBuilder.java @@ -36,7 +36,6 @@ import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.stream.Collectors; public class LogicalPlanBuilder { @@ -89,23 +88,23 @@ public class LogicalPlanBuilder { new EqualTo(this.plan.getOutput().get(hashOnSlots.first), right.getOutput().get(hashOnSlots.second))); LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(hashConjunts), - Optional.empty(), this.plan, right); + this.plan, right); return from(join); } - public LogicalPlanBuilder hashJoinUsing(LogicalPlan right, JoinType joinType, List<Pair<Integer, Integer>> hashOnSlots) { + public LogicalPlanBuilder hashJoinUsing(LogicalPlan right, JoinType joinType, + List<Pair<Integer, Integer>> hashOnSlots) { List<EqualTo> hashConjunts = hashOnSlots.stream() .map(pair -> new EqualTo(this.plan.getOutput().get(pair.first), right.getOutput().get(pair.second))) .collect(Collectors.toList()); LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(hashConjunts), - Optional.empty(), this.plan, right); + this.plan, right); return from(join); } public LogicalPlanBuilder hashJoinEmptyOn(LogicalPlan right, JoinType joinType) { - LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(), - Optional.empty(), this.plan, right); + LogicalJoin<LogicalPlan, LogicalPlan> join = new LogicalJoin<>(joinType, new ArrayList<>(), this.plan, right); return from(join); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org