github-actions[bot] commented on code in PR #63690:
URL: https://github.com/apache/doris/pull/63690#discussion_r3425217253


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java:
##########
@@ -70,77 +84,181 @@
  *         ->T2(D)
  */
 public class EagerAggRewriter extends DefaultPlanRewriter<PushDownAggContext> {
+    public static final int BIG_JOIN_BUILD_SIZE = 400_000;
     private static final double LOWER_AGGREGATE_EFFECT_COEFFICIENT = 10000;
     private static final double LOW_AGGREGATE_EFFECT_COEFFICIENT = 1000;
     private static final double MEDIUM_AGGREGATE_EFFECT_COEFFICIENT = 100;
+    private static final String JOIN_CNT = "joinCnt";
     private final StatsDerive derive = new StatsDerive(false);
 
     @Override
     public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> 
join, PushDownAggContext context) {
-        boolean toLeft = false;
-        boolean toRight = false;
-        boolean pushHere = false;
-        if (join.getJoinType().isAsofJoin()) {
-            // do nothing for asof join
-            return join;
+        Pair<Boolean, Boolean> pushSide = decideJoinPushSide(join, context);
+        boolean toLeft = pushSide.first;
+        boolean toRight = pushSide.second;
+        if (!toLeft && !toRight) {
+            if (SessionVariable.isEagerAggregationOnJoin()) {
+                return genAggregate(join, context);
+            } else {
+                return join;
+            }
         }
-        if (context.getAggFunctions().isEmpty()) {
-            // select t1.v from t1 join t2 on t1.id = t2.id group by t1.v, t2.v
-            // if no agg function, try to push agg to the child which contains 
all group keys
-            // TODO: consider t1.rows/(t1.id, t1.v).ndv and t2.rows/(t2.id, 
t2.v).ndv to determine push target
-            if 
(join.left().getOutputSet().containsAll(context.getGroupKeys())) {
-                toLeft = true;
-            } else if 
(join.right().getOutputSet().containsAll(context.getGroupKeys())) {
-                toRight = true;
+
+        // construct left and right group by keys
+        List<SlotReference> leftChildGroupByKeys = new ArrayList<>();
+        List<SlotReference> rightChildGroupByKeys = new ArrayList<>();
+        if (toLeft) {
+            fillGroupByKeys(join, join.left(), context, leftChildGroupByKeys);
+        }
+        if (toRight) {
+            fillGroupByKeys(join, join.right(), context, 
rightChildGroupByKeys);
+        }
+        // construct left and right aggFuncs and aliasMap
+        List<AggregateFunction> leftFuncs = new ArrayList<>();
+        List<AggregateFunction> rightFuncs = new ArrayList<>();
+        Map<AggregateFunction, Alias> leftAliasMap = new IdentityHashMap<>();
+        Map<AggregateFunction, Alias> rightAliasMap = new IdentityHashMap<>();
+        for (AggregateFunction f : context.getAggFunctions()) {
+            Set<Slot> inputs = f.getInputSlots();
+            Alias a = context.getAliasMap().get(f);
+            if (join.left().getOutputSet().containsAll(inputs)) {
+                leftFuncs.add(f);
+                leftAliasMap.put(f, a);
+            } else if (join.right().getOutputSet().containsAll(inputs)) {
+                rightFuncs.add(f);
+                rightAliasMap.put(f, a);
             } else {
-                pushHere = true;
+                return join;
             }
+        }
+
+        boolean passThroughBigJoin = isPassThroughBigJoin(join, context);
+        boolean leftNeedOutputCount = needOutputCountForJoinChild(join, 
toLeft, toRight,
+                context.needOutputCount(), rightFuncs);
+        boolean rightNeedOutputCount = needOutputCountForJoinChild(join, 
toRight, toLeft,
+                context.needOutputCount(), leftFuncs);
+        Optional<PushDownAggContext> leftChildContext = toLeft ? 
Optional.of(context.forOneBranch(leftFuncs,
+                leftAliasMap, leftChildGroupByKeys, passThroughBigJoin, 
leftNeedOutputCount)) : Optional.empty();
+        Optional<PushDownAggContext> rightChildContext = toRight ? 
Optional.of(context.forOneBranch(rightFuncs,
+                rightAliasMap, rightChildGroupByKeys, passThroughBigJoin, 
rightNeedOutputCount)) : Optional.empty();
+
+        Plan newLeft = join.left();
+        Plan newRight = join.right();
+        if (leftChildContext.isPresent()) {
+            newLeft = join.left().accept(this, leftChildContext.get());
+        }
+        if (rightChildContext.isPresent()) {
+            newRight = join.right().accept(this, rightChildContext.get());
+        }
+
+        if (newLeft == join.left() && newRight == join.right()) {
+            context.getBilateralState().registerNoCountSlot(join);
+            return join;
+        }
+        Optional<Slot> leftChildCountSlot = 
context.getBilateralState().getCountSlot(newLeft);
+        Optional<Slot> rightChildCountSlot = 
context.getBilateralState().getCountSlot(newRight);
+        LogicalJoin<? extends Plan, ? extends Plan> newJoin = (LogicalJoin<? 
extends Plan, ? extends Plan>)
+                join.withChildren(newLeft, newRight);
+
+        if (leftChildCountSlot.isPresent() || rightChildCountSlot.isPresent()) 
{
+            return buildCanonicalJoinProject(newJoin, context, 
leftChildContext, rightChildContext,
+                    leftChildCountSlot, rightChildCountSlot);
+        }
+        context.getBilateralState().registerNoCountSlot(newJoin);
+        return newJoin;
+    }
+
+    private Pair<Boolean, Boolean> decideJoinPushSide(
+            LogicalJoin<? extends Plan, ? extends Plan> join, 
PushDownAggContext context) {
+        if (join.getJoinType().isAsofJoin() || join.isMarkJoin()) {
+            // do nothing for asof join and mark join
+            return Pair.of(false, false);
+        }
+
+        boolean deduplicateOnly = context.getAggFunctions().isEmpty();
+        boolean toLeft = false;
+        boolean toRight = false;
+        if (deduplicateOnly) {
+            toLeft = true;
+            toRight = true;
         } else {

Review Comment:
   For empty-input aggregate functions, `getInputSlots()` is empty, so this 
`containsAll(...)` branch classifies them as left-side aggregates before the 
right side is considered. That breaks right semi/right anti joins, whose output 
is only the right side. For example, `SELECT r.k, count(*) FROM l RIGHT SEMI 
JOIN r ON l.k = r.k GROUP BY r.k` pushes the `count(*)` under the left child, 
but the rewritten parent later rolls up `sum0(<left count slot>)` above a 
right-semi join that does not output that slot. The same applies to literal 
aggregates such as `sum(1)`. Please either choose the preserved/right side for 
empty-input aggregates on right semi/anti joins, or decline the push for those 
join types, and add regression coverage for right semi/right anti 
count-star/literal aggregates.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to