This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d23646793c [fix](nereids) binding group by key on agg.output if output 
is slot (#15623)
d23646793c is described below

commit d23646793c71ea44e8e9f6b5f28acd9b5b51e237
Author: minghong <engle...@gmail.com>
AuthorDate: Thu Jan 12 16:34:56 2023 +0800

    [fix](nereids) binding group by key on agg.output if output is slot (#15623)
    
    case 1
    `select count(1) from t1 join t2 on t1.a = t2.a group by a`
    `group by a` is ambiguous
    
    case 2
    `select t1.a from t1 join t2 on t1.a = t2.a group by a`
    `group by a` is bound on t1.a
---
 .../nereids/rules/analysis/BindSlotReference.java  | 40 +++++++++++++++-
 .../rules/analysis/BindSlotReferenceTest.java      | 55 ++++++++++++++++++++++
 2 files changed, 93 insertions(+), 2 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
index bc0c9325ae..7fbcb9fde0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java
@@ -74,6 +74,7 @@ import com.google.common.base.Preconditions;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 import org.apache.commons.lang.StringUtils;
 
 import java.util.ArrayList;
@@ -227,8 +228,20 @@ public class BindSlotReference implements 
AnalysisRuleFactory {
                      group by key cannot bind with agg func
                      plan:
                         agg(group_by v, output sum(k) as v)
-
                      throw AnalysisException
+
+                    CASE 4
+                     sql:
+                        `select count(1) from t1 join t2 group by a`
+                     we cannot bind `group by a`, because it is ambiguous 
(t1.a and t2.a)
+
+                    CASE 5
+                     following case 4, if t1.a is in agg.output, we can bind 
`group by a` to t1.a
+                     sql
+                        select t1.a
+                        from t1 join t2 on t1.a = t2.a
+                        group by a
+                     group_by_key is bound on t1.a
                     */
                     duplicatedSlotNames.stream().forEach(dup -> 
childOutputsToExpr.remove(dup));
                     Map<String, Expression> aliasNameToExpr = output.stream()
@@ -261,8 +274,31 @@ public class BindSlotReference implements 
AnalysisRuleFactory {
                                 }
                                 return groupBy;
                             }).collect(Collectors.toList());
+                    /*
+                    according to case 4 and case 5, we construct boundSlots
+                    */
+                    Set<String> outputSlotNames = Sets.newHashSet();
+                    Set<Slot> outputSlots = output.stream()
+                            .filter(SlotReference.class::isInstance)
+                            .peek(slot -> outputSlotNames.add(slot.getName()))
+                            .map(NamedExpression::toSlot).collect(
+                                    Collectors.toSet());
+                    //suppose group by key is a.
+                    // if both t1.a and t2.a are in agg.child.output, and t1.a 
in agg.output,
+                    // bind group_by_key a with t1.a
+                    // ` .filter(slot -> 
!outputSlotNames.contains(slot.getName()))`
+                    // is used to avoid add t2.a into boundSlots
+                    Set<Slot> boundSlots = agg.child().getOutputSet().stream()
+                            .filter(slot -> 
!outputSlotNames.contains(slot.getName()))
+                            .collect(Collectors.toSet());
+
+                    boundSlots.addAll(outputSlots);
+                    SlotBinder binder = new 
SlotBinder(toScope(Lists.newArrayList(boundSlots)), ctx.cascadesContext);
+
+                    List<Expression> groupBy = replacedGroupBy.stream()
+                            .map(expression -> binder.bind(expression))
+                            .collect(Collectors.toList());
 
-                    List<Expression> groupBy = bind(replacedGroupBy, 
agg.children(), agg, ctx.cascadesContext);
                     List<Expression> unboundGroupBys = Lists.newArrayList();
                     boolean hasUnbound = groupBy.stream().anyMatch(
                             expression -> {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
index e63618cdd7..abc18a70c7 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindSlotReferenceTest.java
@@ -19,17 +19,24 @@ package org.apache.doris.nereids.rules.analysis;
 
 import org.apache.doris.nereids.analyzer.UnboundSlot;
 import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.NamedExpressionUtil;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
 import org.apache.doris.nereids.trees.plans.logical.RelationUtil;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanConstructor;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -65,4 +72,52 @@ class BindSlotReferenceTest {
         Assertions.assertTrue(exception.getMessage().contains("id#4"));
         Assertions.assertTrue(exception.getMessage().contains("id#0"));
     }
+
+    /*
+    select t1.id from student t1 join on student t2 on t1.di=t2.id group by id;
+    group_by_key bind on t1.id, not t2.id
+     */
+    @Test
+    public void testGroupByOnJoin() {
+        LogicalOlapScan scan1 = new 
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+        LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
+        LogicalOlapScan scan2 = new 
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+        LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
+        LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>, 
LogicalSubQueryAlias<LogicalOlapScan>> join =
+                new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
+        LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
+                Lists.newArrayList(new UnboundSlot("id")), //group by
+                Lists.newArrayList(new UnboundSlot("t1", "id")), //output
+                join
+        );
+        PlanChecker checker = 
PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate);
+        LogicalAggregate plan = (LogicalAggregate) 
checker.getCascadesContext().getMemo().copyOut();
+        SlotReference groupByKey = (SlotReference) 
plan.getGroupByExpressions().get(0);
+        SlotReference t1id = (SlotReference) ((LogicalJoin) 
plan.child()).left().getOutput().get(0);
+        SlotReference t2id = (SlotReference) ((LogicalJoin) 
plan.child()).right().getOutput().get(0);
+        Assertions.assertEquals(groupByKey.getExprId(), t1id.getExprId());
+        Assertions.assertNotEquals(t1id.getExprId(), t2id.getExprId());
+    }
+
+    /*
+    select count(1) from student t1 join on student t2 on t1.di=t2.id group by 
id;
+    group by key is ambiguous
+     */
+    @Test
+    public void testGroupByOnJoinAmbiguous() {
+        LogicalOlapScan scan1 = new 
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+        LogicalSubQueryAlias sub1 = new LogicalSubQueryAlias("t1", scan1);
+        LogicalOlapScan scan2 = new 
LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
+        LogicalSubQueryAlias sub2 = new LogicalSubQueryAlias("t2", scan2);
+        LogicalJoin<LogicalSubQueryAlias<LogicalOlapScan>, 
LogicalSubQueryAlias<LogicalOlapScan>> join =
+                new LogicalJoin<>(JoinType.CROSS_JOIN, sub1, sub2);
+        LogicalAggregate<LogicalJoin> aggregate = new LogicalAggregate<>(
+                Lists.newArrayList(new UnboundSlot("id")), //group by
+                Lists.newArrayList(new Alias(new Count(new IntegerLiteral(1)), 
"count(1)")), //output
+                join
+        );
+        AnalysisException exception = 
Assertions.assertThrows(AnalysisException.class,
+                () -> 
PlanChecker.from(MemoTestUtils.createConnectContext()).analyze(aggregate));
+        Assertions.assertTrue(exception.getMessage().contains("id is 
ambiguous: "));
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to