jackwener commented on code in PR #13353:
URL: https://github.com/apache/doris/pull/13353#discussion_r994715649


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java:
##########
@@ -37,21 +56,296 @@
  * SELECT * FROM t1 JOIN t3 ON t1.id=t3.id JOIN t2 ON t2.id=t3.id
  * </pre>
  * </p>
- * TODO: This is tested by SSB queries currently, add more `unit` test for 
this rule
- * when we have a plan building and comparing framework.
+ * Using the {@link MultiJoin} to complete this task.
+ * {Join cluster}: contain multiple join with filter inside.
+ * <ul>
+ * <li> {Join cluster} to MultiJoin</li>
+ * <li> MultiJoin to {Join cluster}</li>
+ * </ul>
  */
 public class ReorderJoin extends OneRewriteRuleFactory {
     @Override
     public Rule build() {
         return logicalFilter(subTree(LogicalJoin.class, 
LogicalFilter.class)).thenApply(ctx -> {
             LogicalFilter<Plan> filter = ctx.root;
-            if (!ctx.cascadesContext.getConnectContext().getSessionVariable()
-                    .isEnableNereidsReorderToEliminateCrossJoin()) {
-                return filter;
-            }
-            MultiJoin multiJoin = new MultiJoin();
-            filter.accept(multiJoin, null);
-            return 
multiJoin.reorderJoinsAccordingToConditions().orElse(filter);
+
+            MultiJoin multiJoin = (MultiJoin) joinToMultiJoin(filter);
+            Plan plan = multiJoinToJoin(multiJoin);
+            return plan;
         }).toRule(RuleType.REORDER_JOIN);
     }
+
+    /**
+     * Recursively convert to
+     * {@link LogicalJoin} or
+     * {@link LogicalFilter}--{@link LogicalJoin}
+     * --> {@link MultiJoin}
+     */
+    public Plan joinToMultiJoin(Plan plan) {
+        // subtree can't specify the end of Pattern. so end can be GroupPlan 
or Filter
+        if (plan instanceof GroupPlan
+                || (plan instanceof LogicalFilter && plan.child(0) instanceof 
GroupPlan)) {
+            return plan;
+        }
+
+        List<Plan> inputs = Lists.newArrayList();
+        List<Expression> joinFilter = Lists.newArrayList();
+        List<Expression> notInnerJoinConditions = Lists.newArrayList();
+
+        LogicalJoin<?, ?> join;
+        if (plan instanceof LogicalFilter) {
+            LogicalFilter<?> filter = (LogicalFilter<?>) plan;
+            
joinFilter.addAll(ExpressionUtils.extractConjunction(filter.getPredicates()));
+            join = (LogicalJoin<?, ?>) filter.child();
+        } else {
+            join = (LogicalJoin<?, ?>) plan;
+        }
+
+        if (join.getJoinType().isInnerOrCrossJoin()) {
+            joinFilter.addAll(join.getHashJoinConjuncts());
+            joinFilter.addAll(join.getOtherJoinConjuncts());
+        } else {
+            notInnerJoinConditions.addAll(join.getHashJoinConjuncts());
+            notInnerJoinConditions.addAll(join.getOtherJoinConjuncts());
+        }
+
+        // recursively convert children.
+        Plan left = joinToMultiJoin(join.left());
+        Plan right = joinToMultiJoin(join.right());
+
+        boolean changeLeft = join.getJoinType().isRightJoin()
+                || join.getJoinType().isFullOuterJoin();
+        if (canCombine(left, changeLeft)) {
+            MultiJoin leftMultiJoin = (MultiJoin) left;
+            inputs.addAll(leftMultiJoin.children());
+            joinFilter.addAll(leftMultiJoin.getJoinFilter());
+        } else {
+            inputs.add(left);
+        }
+
+        boolean changeRight = join.getJoinType().isLeftJoin()
+                || join.getJoinType().isFullOuterJoin();
+        if (canCombine(right, changeRight)) {
+            MultiJoin rightMultiJoin = (MultiJoin) right;
+            inputs.addAll(rightMultiJoin.children());
+            joinFilter.addAll(rightMultiJoin.getJoinFilter());
+        } else {
+            inputs.add(right);
+        }
+
+        Optional<JoinType> joinType;
+        if (join.getJoinType().isInnerOrCrossJoin()) {
+            joinType = Optional.empty();
+        } else {
+            joinType = Optional.of(join.getJoinType());
+        }
+        return new MultiJoin(
+                inputs,
+                joinFilter,
+                joinType,
+                notInnerJoinConditions);
+    }
+
+    /**
+     * Recursively convert to
+     * {@link MultiJoin}
+     * -->
+     * {@link LogicalJoin} or
+     * {@link LogicalFilter}--{@link LogicalJoin}
+     * <p>
+     * When all input is CROSS/Inner Join, all join will be flattened.
+     * Otherwise, we will split {join cluster} into multiple {@link MultiJoin}.
+     * <p>
+     * Here are examples of the {@link MultiJoin}s constructed after this 
rules has been applied.
+     * <p>
+     * simple example:
+     * <ul>
+     * <li>A JOIN B --> MJ(A, B)
+     * <li>A JOIN B JOIN C JOIN D --> MJ(A, B, C, D)
+     * <li>A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([LOJ/LSJ/LAJ]A, B)
+     * <li>A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([ROJ/RSJ/RAJ]A, B)
+     * <li>A FULL JOIN B --> MJ[FOJ](A, B)
+     * </ul>
+     * </p>
+     * <p>
+     * complex example:
+     * <ul>
+     * <li>A LEFT OUTER JOIN (B JOIN C) --> MJ([LOJ]A, MJ(B, C)))
+     * <li>(A JOIN B) LEFT JOIN C --> MJ(A, B, C)
+     * <li>(A LEFT OUTER JOIN B) JOIN C --> MJ(MJ(A, B), C)
+     * <li>A LEFT JOIN (B FULL JOIN C) --> MJ(A, MJ[full](B, C))
+     * <li>(A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) --> MJ[full](MJ(A, B), 
MJ(C, D))
+     * </ul>
+     * </p>
+     * more complex example:
+     * <ul>
+     * <li> A JOIN B JOIN C LEFT JOIN D --> MJ([LOJ]A, B, C, D)
+     * <li> A JOIN B JOIN C LEFT JOIN (D JOIN F) --> MJ([LOJ]A, B, C, MJ(D, F))
+     * <li> A RIGHT JOIN (B JOIN C JOIN D)--> MJ([ROJ]A, B, C, D)
+     * <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
+     * </ul>
+     * </p>
+     * <p>
+     * Graphic presentation:
+     * A JOIN B JOIN C LEFT JOIN D JOIN F
+     *      left                  left│
+     * A  B  C  D  F   ──►   A  B  C  │ D  F   ──►  MJ(LOJ A,B,C,MJ(DF)
+     * <p>
+     * A JOIN B RIGHT JOIN C JOIN D JOIN F
+     *     right                  │right
+     * A  B  C  D  F   ──►   A  B │  C  D  F   ──►  MJ(A,B,MJ(ROJ C,D,F)
+     * <p>
+     * (A JOIN B JOIN C) FULL JOIN (D JOIN F)
+     *       full                    │
+     * A  B  C  D  F   ──►   A  B  C │ D  F    ──►  MJ(FOJ MJ(A,B,C) MJ(D,F))
+     * </p>
+     */
+    public Plan multiJoinToJoin(MultiJoin multiJoin) {
+        if (multiJoin.arity() == 1) {
+            return PlanUtils.filterOrSelf(multiJoin.getJoinFilter(), 
multiJoin.child(0));
+        }
+
+        Builder<Plan> builder = ImmutableList.builder();
+        // recursively hanlde multiJoin children.
+        for (Plan child : multiJoin.children()) {
+            if (child instanceof MultiJoin) {
+                MultiJoin childMultiJoin = (MultiJoin) child;
+                builder.add(multiJoinToJoin(childMultiJoin));
+            } else {
+                builder.add(child);
+            }
+        }
+        MultiJoin multiJoinHandleChildren = 
multiJoin.withChildren(builder.build());
+
+        if (multiJoinHandleChildren.getOnlyJoinType().isPresent()) {
+            List<Expression> leftFilter = Lists.newArrayList();
+            List<Expression> rightFilter = Lists.newArrayList();
+            List<Expression> remainingFilter = Lists.newArrayList();
+            Plan left = multiJoinToJoin(new MultiJoin(
+                    multiJoinHandleChildren.children().subList(0, 
multiJoinHandleChildren.arity() - 1),
+                    leftFilter,
+                    Optional.empty(),
+                    ExpressionUtils.EMPTY_CONDITION));
+            Plan right = multiJoinToJoin(new MultiJoin(
+                    multiJoinHandleChildren.children().subList(1, 
multiJoinHandleChildren.arity()),
+                    rightFilter,
+                    Optional.empty(),
+                    ExpressionUtils.EMPTY_CONDITION));
+            if (multiJoinHandleChildren.getOnlyJoinType().get().isLeftJoin()) {
+                right = 
multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
+            } else if 
(multiJoinHandleChildren.getOnlyJoinType().get().isRightJoin()) {
+                left = multiJoinHandleChildren.child(0);
+            }
+
+            // split filter
+            for (Expression expr : multiJoinHandleChildren.getJoinFilter()) {
+                Set<Slot> exprInputSlots = expr.getInputSlots();
+                Preconditions.checkState(!exprInputSlots.isEmpty());
+
+                if (left.getOutputSet().containsAll(exprInputSlots)) {
+                    leftFilter.add(expr);
+                } else if (right.getOutputSet().containsAll(exprInputSlots)) {
+                    rightFilter.add(expr);
+                } else if 
(multiJoin.getOutputSet().containsAll(exprInputSlots)) {
+                    remainingFilter.add(expr);
+                } else {
+                    throw new RuntimeException("invalid expression");
+                }
+            }
+
+            return PlanUtils.filterOrSelf(remainingFilter, new LogicalJoin<>(
+                    multiJoinHandleChildren.getOnlyJoinType().get(),
+                    ExpressionUtils.EMPTY_CONDITION,
+                    multiJoinHandleChildren.getNotInnerJoinConditions(),
+                    PlanUtils.filterOrSelf(leftFilter, left), 
PlanUtils.filterOrSelf(rightFilter, right)));
+        }
+
+        // following this multiJoin just contain INNER/CROSS.
+        List<Expression> joinFilter = multiJoinHandleChildren.getJoinFilter();
+
+        Plan left = multiJoinHandleChildren.child(0);
+        List<Plan> candidates = multiJoinHandleChildren.children().subList(1, 
multiJoinHandleChildren.arity());
+
+        LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, 
candidates, joinFilter);
+        List<Plan> newInputs = Lists.newArrayList();
+        newInputs.add(join);
+        newInputs.addAll(candidates.stream().filter(plan -> 
!join.right().equals(plan)).collect(Collectors.toList()));
+
+        joinFilter.removeAll(join.getHashJoinConjuncts());
+        joinFilter.removeAll(join.getOtherJoinConjuncts());
+        return multiJoinToJoin(new MultiJoin(
+                newInputs,
+                joinFilter,
+                Optional.empty(),
+                ExpressionUtils.EMPTY_CONDITION));
+    }
+
+    /**
+     * Returns whether an input can be merged without changing semantics.
+     *
+     * @param input input into a MultiJoin or (GroupPlan|LogicalFilter)
+     * @param changeLeft generate nullable or left not exist.
+     * @return true if the input can be combined into a parent MultiJoin
+     */
+    private static boolean canCombine(Plan input, boolean changeLeft) {

Review Comment:
   It can be SEMI/ANTI 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to