This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new dca6b8175d7 [fix](nereids) build agg for random distributed agg table 
in bindRelation phase (#40181)
dca6b8175d7 is described below

commit dca6b8175d7557b0b91c7a5c97d2656b02dced6f
Author: starocean999 <40539150+starocean...@users.noreply.github.com>
AuthorDate: Wed Sep 11 18:16:03 2024 +0800

    [fix](nereids) build agg for random distributed agg table in bindRelation 
phase (#40181)
    
    it's better to build agg for random distributed agg table in
    bindRelation phase instead of in BuildAggForRandomDistributedTable RULE
---
 .../doris/nereids/jobs/executor/Analyzer.java      |   3 -
 .../org/apache/doris/nereids/rules/RuleType.java   |   4 -
 .../doris/nereids/rules/analysis/BindRelation.java | 125 +++++++++-
 .../BuildAggForRandomDistributedTable.java         | 271 ---------------------
 .../doris/nereids/rules/analysis/CheckPolicy.java  |  21 +-
 .../nereids/rules/analysis/BindRelationTest.java   |  23 ++
 .../nereids/rules/analysis/CheckRowPolicyTest.java |  97 ++++++++
 .../aggregate/select_random_distributed_tbl.out    |  14 +-
 .../aggregate/select_random_distributed_tbl.groovy |  19 +-
 9 files changed, 285 insertions(+), 292 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index 605a848181c..1ffbac97d74 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -26,7 +26,6 @@ import org.apache.doris.nereids.rules.analysis.BindExpression;
 import org.apache.doris.nereids.rules.analysis.BindRelation;
 import 
org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
 import org.apache.doris.nereids.rules.analysis.BindSink;
-import 
org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable;
 import org.apache.doris.nereids.rules.analysis.CheckAfterBind;
 import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
 import org.apache.doris.nereids.rules.analysis.CheckPolicy;
@@ -163,8 +162,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
             topDown(new EliminateGroupByConstant()),
 
             topDown(new SimplifyAggGroupBy()),
-            // run BuildAggForRandomDistributedTable before NormalizeAggregate 
in order to optimize the agg plan
-            topDown(new BuildAggForRandomDistributedTable()),
             topDown(new NormalizeAggregate()),
             topDown(new HavingToFilter()),
             bottomUp(new SemiJoinCommute()),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b6ab6e2dac2..d345d9057e9 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -342,10 +342,6 @@ public enum RuleType {
 
     // topn opts
     DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE),
-    // pre agg for random distributed table
-    BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE),
-    BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE),
-    BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE),
     // short circuit rule
     SHOR_CIRCUIT_POINT_QUERY(RuleTypeClass.REWRITE),
     // exploration rules
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
index cedc4e92ff1..c81fcc25b83 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java
@@ -17,10 +17,17 @@
 
 package org.apache.doris.nereids.rules.analysis;
 
+import org.apache.doris.catalog.AggStateType;
+import org.apache.doris.catalog.AggregateType;
 import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.DistributionInfo;
+import org.apache.doris.catalog.Env;
+import org.apache.doris.catalog.FunctionRegistry;
+import org.apache.doris.catalog.KeysType;
 import org.apache.doris.catalog.OlapTable;
 import org.apache.doris.catalog.Partition;
 import org.apache.doris.catalog.TableIf;
+import org.apache.doris.catalog.Type;
 import org.apache.doris.catalog.View;
 import org.apache.doris.common.Config;
 import org.apache.doris.common.Pair;
@@ -43,13 +50,26 @@ import 
org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.properties.PhysicalProperties;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
+import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
+import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
 import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.PreAggStatus;
 import org.apache.doris.nereids.trees.plans.algebra.Relation;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
 import org.apache.doris.nereids.trees.plans.logical.LogicalEsScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan;
@@ -73,6 +93,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 import org.apache.commons.collections.CollectionUtils;
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Optional;
 import java.util.function.Function;
@@ -214,7 +235,109 @@ public class BindRelation extends OneAnalysisRuleFactory {
                     unboundRelation.getTableSample());
             }
         }
-        return checkAndAddDeleteSignFilter(scan, ConnectContext.get(), 
(OlapTable) table);
+        if (needGenerateLogicalAggForRandomDistAggTable(scan)) {
+            // it's a random distribution agg table
+            // add agg on olap scan
+            return preAggForRandomDistribution(scan);
+        } else {
+            // it's a duplicate, unique or hash distribution agg table
+            // add delete sign filter on olap scan if needed
+            return checkAndAddDeleteSignFilter(scan, ConnectContext.get(), 
(OlapTable) table);
+        }
+    }
+
+    private boolean 
needGenerateLogicalAggForRandomDistAggTable(LogicalOlapScan olapScan) {
+        if (ConnectContext.get() != null && ConnectContext.get().getState() != 
null
+                && ConnectContext.get().getState().isQuery()) {
+            // we only need to add an agg node for query, and should not do it 
for deleting
+            // from random distributed table. see 
https://github.com/apache/doris/pull/37985 for more info
+            OlapTable olapTable = olapScan.getTable();
+            KeysType keysType = olapTable.getKeysType();
+            DistributionInfo distributionInfo = 
olapTable.getDefaultDistributionInfo();
+            return keysType == KeysType.AGG_KEYS
+                    && distributionInfo.getType() == 
DistributionInfo.DistributionInfoType.RANDOM;
+        } else {
+            return false;
+        }
+    }
+
+    /**
+     * add LogicalAggregate above olapScan for preAgg
+     * @param olapScan olap scan plan
+     * @return rewritten plan
+     */
+    private LogicalPlan preAggForRandomDistribution(LogicalOlapScan olapScan) {
+        OlapTable olapTable = olapScan.getTable();
+        List<Slot> childOutputSlots = olapScan.computeOutput();
+        List<Expression> groupByExpressions = new ArrayList<>();
+        List<NamedExpression> outputExpressions = new ArrayList<>();
+        List<Column> columns = olapTable.getBaseSchema();
+
+        for (Column col : columns) {
+            // use exist slot in the plan
+            SlotReference slot = SlotReference.fromColumn(olapTable, col, 
col.getName(), olapScan.qualified());
+            ExprId exprId = slot.getExprId();
+            for (Slot childSlot : childOutputSlots) {
+                if (childSlot instanceof SlotReference && ((SlotReference) 
childSlot).getName() == col.getName()) {
+                    exprId = childSlot.getExprId();
+                    slot = slot.withExprId(exprId);
+                    break;
+                }
+            }
+            if (col.isKey()) {
+                groupByExpressions.add(slot);
+                outputExpressions.add(slot);
+            } else {
+                Expression function = generateAggFunction(slot, col);
+                // DO NOT rewrite
+                if (function == null) {
+                    return olapScan;
+                }
+                Alias alias = new Alias(exprId, ImmutableList.of(function), 
col.getName(),
+                        olapScan.qualified(), true);
+                outputExpressions.add(alias);
+            }
+        }
+        LogicalAggregate<LogicalOlapScan> aggregate = new 
LogicalAggregate<>(groupByExpressions, outputExpressions,
+                olapScan);
+        return aggregate;
+    }
+
+    /**
+     * generate aggregation function according to the aggType of column
+     *
+     * @param slot slot of column
+     * @return aggFunction generated
+     */
+    private Expression generateAggFunction(SlotReference slot, Column column) {
+        AggregateType aggregateType = column.getAggregationType();
+        switch (aggregateType) {
+            case SUM:
+                return new Sum(slot);
+            case MAX:
+                return new Max(slot);
+            case MIN:
+                return new Min(slot);
+            case HLL_UNION:
+                return new HllUnion(slot);
+            case BITMAP_UNION:
+                return new BitmapUnion(slot);
+            case QUANTILE_UNION:
+                return new QuantileUnion(slot);
+            case GENERIC:
+                Type type = column.getType();
+                if (!type.isAggStateType()) {
+                    return null;
+                }
+                AggStateType aggState = (AggStateType) type;
+                // use AGGREGATE_FUNCTION_UNION to aggregate multiple 
agg_state into one
+                String funcName = aggState.getFunctionName() + 
AggCombinerFunctionBuilder.UNION_SUFFIX;
+                FunctionRegistry functionRegistry = 
Env.getCurrentEnv().getFunctionRegistry();
+                FunctionBuilder builder = 
functionRegistry.findFunctionBuilder(funcName, slot);
+                return builder.build(funcName, ImmutableList.of(slot)).first;
+            default:
+                return null;
+        }
     }
 
     /**
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
deleted file mode 100644
index e547a55f9e3..00000000000
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java
+++ /dev/null
@@ -1,271 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.analysis;
-
-import org.apache.doris.catalog.AggStateType;
-import org.apache.doris.catalog.AggregateType;
-import org.apache.doris.catalog.Column;
-import org.apache.doris.catalog.DistributionInfo;
-import org.apache.doris.catalog.DistributionInfo.DistributionInfoType;
-import org.apache.doris.catalog.Env;
-import org.apache.doris.catalog.FunctionRegistry;
-import org.apache.doris.catalog.KeysType;
-import org.apache.doris.catalog.OlapTable;
-import org.apache.doris.catalog.Type;
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Alias;
-import org.apache.doris.nereids.trees.expressions.ExprId;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
-import 
org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
-import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
-import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
-import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction;
-import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
-import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion;
-import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.qe.ConnectContext;
-
-import com.google.common.collect.ImmutableList;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Set;
-
-/**
- * build agg plan for querying random distributed table
- */
-public class BuildAggForRandomDistributedTable implements AnalysisRuleFactory {
-
-    @Override
-    public List<Rule> buildRules() {
-        return ImmutableList.of(
-                // Project(Scan) -> project(agg(scan))
-                logicalProject(logicalOlapScan())
-                        .when(this::isQuery)
-                        .when(project -> 
isRandomDistributedTbl(project.child()))
-                        .then(project -> preAggForRandomDistribution(project, 
project.child()))
-                        
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN),
-                // agg(scan) -> agg(agg(scan)), agg(agg) may optimized by 
MergeAggregate
-                logicalAggregate(logicalOlapScan())
-                        .when(this::isQuery)
-                        .when(agg -> isRandomDistributedTbl(agg.child()))
-                        .whenNot(agg -> {
-                            Set<AggregateFunction> functions = 
agg.getAggregateFunctions();
-                            List<Expression> groupByExprs = 
agg.getGroupByExpressions();
-                            // check if need generate an inner agg plan or not
-                            // should not rewrite twice if we had rewritten 
olapScan to aggregate(olapScan)
-                            return 
functions.stream().allMatch(this::aggTypeMatch) && groupByExprs.stream()
-                                    .allMatch(this::isKeyOrConstantExpr);
-                        })
-                        .then(agg -> preAggForRandomDistribution(agg, 
agg.child()))
-                        
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN),
-                // filter(scan) -> filter(agg(scan))
-                logicalFilter(logicalOlapScan())
-                        .when(this::isQuery)
-                        .when(filter -> isRandomDistributedTbl(filter.child()))
-                        .then(filter -> preAggForRandomDistribution(filter, 
filter.child()))
-                        
.toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN));
-
-    }
-
-    /**
-     * check the olapTable of olapScan is randomDistributed table
-     *
-     * @param olapScan olap scan plan
-     * @return true if olapTable is randomDistributed table
-     */
-    private boolean isRandomDistributedTbl(LogicalOlapScan olapScan) {
-        OlapTable olapTable = olapScan.getTable();
-        KeysType keysType = olapTable.getKeysType();
-        DistributionInfo distributionInfo = 
olapTable.getDefaultDistributionInfo();
-        return keysType == KeysType.AGG_KEYS && distributionInfo.getType() == 
DistributionInfoType.RANDOM;
-    }
-
-    private boolean isQuery(LogicalPlan plan) {
-        return ConnectContext.get() != null
-                && ConnectContext.get().getState() != null
-                && ConnectContext.get().getState().isQuery();
-    }
-
-    /**
-     * add LogicalAggregate above olapScan for preAgg
-     *
-     * @param logicalPlan parent plan of olapScan
-     * @param olapScan olap scan plan, it may be LogicalProject, 
LogicalFilter, LogicalAggregate
-     * @return rewritten plan
-     */
-    private Plan preAggForRandomDistribution(LogicalPlan logicalPlan, 
LogicalOlapScan olapScan) {
-        OlapTable olapTable = olapScan.getTable();
-        List<Slot> childOutputSlots = olapScan.computeOutput();
-        List<Expression> groupByExpressions = new ArrayList<>();
-        List<NamedExpression> outputExpressions = new ArrayList<>();
-        List<Column> columns = olapTable.getBaseSchema();
-
-        for (Column col : columns) {
-            // use exist slot in the plan
-            SlotReference slot = SlotReference.fromColumn(olapTable, col, 
col.getName(), olapScan.getQualifier());
-            ExprId exprId = slot.getExprId();
-            for (Slot childSlot : childOutputSlots) {
-                if (childSlot instanceof SlotReference && ((SlotReference) 
childSlot).getName() == col.getName()) {
-                    exprId = childSlot.getExprId();
-                    slot = slot.withExprId(exprId);
-                    break;
-                }
-            }
-            if (col.isKey()) {
-                groupByExpressions.add(slot);
-                outputExpressions.add(slot);
-            } else {
-                Expression function = generateAggFunction(slot, col);
-                // DO NOT rewrite
-                if (function == null) {
-                    return logicalPlan;
-                }
-                Alias alias = new Alias(exprId, function, col.getName());
-                outputExpressions.add(alias);
-            }
-        }
-        LogicalAggregate<LogicalOlapScan> aggregate = new 
LogicalAggregate<>(groupByExpressions, outputExpressions,
-                olapScan);
-        return logicalPlan.withChildren(aggregate);
-    }
-
-    /**
-     * generate aggregation function according to the aggType of column
-     *
-     * @param slot slot of column
-     * @return aggFunction generated
-     */
-    private Expression generateAggFunction(SlotReference slot, Column column) {
-        AggregateType aggregateType = column.getAggregationType();
-        switch (aggregateType) {
-            case SUM:
-                return new Sum(slot);
-            case MAX:
-                return new Max(slot);
-            case MIN:
-                return new Min(slot);
-            case HLL_UNION:
-                return new HllUnion(slot);
-            case BITMAP_UNION:
-                return new BitmapUnion(slot);
-            case QUANTILE_UNION:
-                return new QuantileUnion(slot);
-            case GENERIC:
-                Type type = column.getType();
-                if (!type.isAggStateType()) {
-                    return null;
-                }
-                AggStateType aggState = (AggStateType) type;
-                // use AGGREGATE_FUNCTION_UNION to aggregate multiple 
agg_state into one
-                String funcName = aggState.getFunctionName() + 
AggCombinerFunctionBuilder.UNION_SUFFIX;
-                FunctionRegistry functionRegistry = 
Env.getCurrentEnv().getFunctionRegistry();
-                FunctionBuilder builder = 
functionRegistry.findFunctionBuilder(funcName, slot);
-                return builder.build(funcName, ImmutableList.of(slot)).first;
-            default:
-                return null;
-        }
-    }
-
-    /**
-     * if the agg type of AggregateFunction is as same as the agg type of 
column, DO NOT need to rewrite
-     *
-     * @param function agg function to check
-     * @return true if agg type match
-     */
-    private boolean aggTypeMatch(AggregateFunction function) {
-        List<Expression> children = function.children();
-        if (function.getName().equalsIgnoreCase("count")) {
-            Count count = (Count) function;
-            // do not rewrite for count distinct for key column
-            if (count.isDistinct()) {
-                return children.stream().allMatch(this::isKeyOrConstantExpr);
-            }
-            if (count.isStar()) {
-                return false;
-            }
-        }
-        return children.stream().allMatch(child -> aggTypeMatch(function, 
child));
-    }
-
-    /**
-     * check if the agg type of functionCall match the agg type of column
-     *
-     * @param function the functionCall
-     * @param expression expr to check
-     * @return true if agg type match
-     */
-    private boolean aggTypeMatch(AggregateFunction function, Expression 
expression) {
-        if (expression.children().isEmpty()) {
-            if (expression instanceof SlotReference && ((SlotReference) 
expression).getColumn().isPresent()) {
-                Column col = ((SlotReference) expression).getColumn().get();
-                String functionName = function.getName();
-                if (col.isKey()) {
-                    return functionName.equalsIgnoreCase("max") || 
functionName.equalsIgnoreCase("min");
-                }
-                if (col.isAggregated()) {
-                    AggregateType aggType = col.getAggregationType();
-                    // agg type not mach
-                    if (aggType == AggregateType.GENERIC) {
-                        return col.getType().isAggStateType();
-                    }
-                    if (aggType == AggregateType.HLL_UNION) {
-                        return function instanceof HllFunction;
-                    }
-                    if (aggType == AggregateType.BITMAP_UNION) {
-                        return function instanceof BitmapFunction;
-                    }
-                    return functionName.equalsIgnoreCase(aggType.name());
-                }
-            }
-            return false;
-        }
-        List<Expression> children = expression.children();
-        return children.stream().allMatch(child -> aggTypeMatch(function, 
child));
-    }
-
-    /**
-     * check if the columns in expr is key column or constant, if group by 
clause contains value column, need rewrite
-     *
-     * @param expr expr to check
-     * @return true if all columns is key column or constant
-     */
-    private boolean isKeyOrConstantExpr(Expression expr) {
-        if (expr instanceof SlotReference && ((SlotReference) 
expr).getColumn().isPresent()) {
-            Column col = ((SlotReference) expr).getColumn().get();
-            return col.isKey();
-        } else if (expr.isConstant()) {
-            return true;
-        }
-        List<Expression> children = expr.children();
-        return children.stream().allMatch(this::isKeyOrConstantExpr);
-    }
-}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
index 94f7c36b108..4beed413d09 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
 import 
org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
@@ -49,12 +50,23 @@ public class CheckPolicy implements AnalysisRuleFactory {
                         logicalCheckPolicy(any().when(child -> !(child 
instanceof UnboundRelation))).thenApply(ctx -> {
                             LogicalCheckPolicy<Plan> checkPolicy = ctx.root;
                             LogicalFilter<Plan> upperFilter = null;
+                            Plan upAgg = null;
 
                             Plan child = checkPolicy.child();
                             // Because the unique table will automatically 
include a filter condition
-                            if (child instanceof LogicalFilter && 
child.bound() && child
-                                    .child(0) instanceof LogicalRelation) {
+                            if ((child instanceof LogicalFilter) && 
child.bound()) {
                                 upperFilter = (LogicalFilter) child;
+                                if (child.child(0) instanceof LogicalRelation) 
{
+                                    child = child.child(0);
+                                } else if (child.child(0) instanceof 
LogicalAggregate
+                                        && child.child(0).child(0) instanceof 
LogicalRelation) {
+                                    upAgg = child.child(0);
+                                    child = child.child(0).child(0);
+                                }
+                            }
+                            if ((child instanceof LogicalAggregate)
+                                    && child.bound() && child.child(0) 
instanceof LogicalRelation) {
+                                upAgg = child;
                                 child = child.child(0);
                             }
                             if (!(child instanceof LogicalRelation)
@@ -76,16 +88,17 @@ public class CheckPolicy implements AnalysisRuleFactory {
                             RelatedPolicy relatedPolicy = 
checkPolicy.findPolicy(relation, ctx.cascadesContext);
                             relatedPolicy.rowPolicyFilter.ifPresent(expression 
-> combineFilter.addAll(
                                             
ExpressionUtils.extractConjunctionToSet(expression)));
-                            Plan result = relation;
+                            Plan result = upAgg != null ? 
upAgg.withChildren(relation) : relation;
                             if (upperFilter != null) {
                                 
combineFilter.addAll(upperFilter.getConjuncts());
                             }
                             if (!combineFilter.isEmpty()) {
-                                result = new LogicalFilter<>(combineFilter, 
relation);
+                                result = new LogicalFilter<>(combineFilter, 
result);
                             }
                             if (relatedPolicy.dataMaskProjects.isPresent()) {
                                 result = new 
LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result);
                             }
+
                             return result;
                         })
                 )
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
index b14834fd321..67115e67687 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java
@@ -29,6 +29,7 @@ import org.apache.doris.nereids.rules.RulePromise;
 import 
org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.nereids.util.PlanRewriter;
@@ -54,6 +55,12 @@ class BindRelationTest extends TestWithFeService implements 
GeneratedPlanPattern
                 + ")ENGINE=OLAP\n"
                 + "DISTRIBUTED BY HASH(`a`) BUCKETS 3\n"
                 + "PROPERTIES (\"replication_num\"= \"1\");");
+        createTable("CREATE TABLE db1.tagg ( \n"
+                + " \ta INT,\n"
+                + " \tb INT SUM\n"
+                + ")ENGINE=OLAP AGGREGATE KEY(a)\n "
+                + "DISTRIBUTED BY random BUCKETS 3\n"
+                + "PROPERTIES (\"replication_num\"= \"1\");");
         
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
     }
 
@@ -125,6 +132,22 @@ class BindRelationTest extends TestWithFeService 
implements GeneratedPlanPattern
                 );
     }
 
+    @Test
+    void bindRandomAggTable() {
+        connectContext.setDatabase(DEFAULT_CLUSTER_PREFIX + DB1);
+        connectContext.getState().setIsQuery(true);
+        Plan plan = PlanRewriter.bottomUpRewrite(new 
UnboundRelation(StatementScopeIdGenerator.newRelationId(), 
ImmutableList.of("tagg")),
+                connectContext, new BindRelation());
+
+        Assertions.assertTrue(plan instanceof LogicalAggregate);
+        Assertions.assertEquals(
+                ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1, 
"tagg"),
+                plan.getOutput().get(0).getQualifier());
+        Assertions.assertEquals(
+                ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1, 
"tagg"),
+                plan.getOutput().get(1).getQualifier());
+    }
+
     @Override
     public RulePromise defaultPromise() {
         return RulePromise.REWRITE;
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
index 196d99037e2..b807bbbbc7a 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java
@@ -34,6 +34,9 @@ import org.apache.doris.catalog.OlapTable;
 import org.apache.doris.catalog.PartitionInfo;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.FeConstants;
+import org.apache.doris.mysql.privilege.AccessControllerManager;
+import org.apache.doris.mysql.privilege.DataMaskPolicy;
+import org.apache.doris.nereids.analyzer.UnboundRelation;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
@@ -41,6 +44,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy;
 import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
 import org.apache.doris.nereids.util.PlanRewriter;
 import org.apache.doris.thrift.TStorageType;
@@ -48,17 +52,22 @@ import org.apache.doris.utframe.TestWithFeService;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
+import mockit.Mock;
+import mockit.MockUp;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.Arrays;
 import java.util.List;
+import java.util.Optional;
 
 public class CheckRowPolicyTest extends TestWithFeService {
 
     private static String dbName = "check_row_policy";
     private static String fullDbName = "" + dbName;
     private static String tableName = "table1";
+
+    private static String tableNameRanddomDist = "tableRandomDist";
     private static String userName = "user1";
     private static String policyName = "policy1";
 
@@ -76,6 +85,10 @@ public class CheckRowPolicyTest extends TestWithFeService {
                 + tableName
                 + " (k1 int, k2 int) distributed by hash(k1) buckets 1"
                 + " properties(\"replication_num\" = \"1\");");
+        createTable("create table "
+                + tableNameRanddomDist
+                + " (k1 int, k2 int) AGGREGATE KEY(k1, k2) distributed by 
random buckets 1"
+                + " properties(\"replication_num\" = \"1\");");
         Database db = 
Env.getCurrentInternalCatalog().getDbOrMetaException(fullDbName);
         long tableId = db.getTableOrMetaException("table1").getId();
         olapTable.setId(tableId);
@@ -85,6 +98,7 @@ public class CheckRowPolicyTest extends TestWithFeService {
                 0, 0, (short) 0,
                 TStorageType.COLUMN,
                 KeysType.PRIMARY_KEYS);
+
         // create user
         UserIdentity user = new UserIdentity(userName, "%");
         user.analyze();
@@ -98,6 +112,27 @@ public class CheckRowPolicyTest extends TestWithFeService {
         Analyzer analyzer = new Analyzer(connectContext.getEnv(), 
connectContext);
         grantStmt.analyze(analyzer);
         Env.getCurrentEnv().getAuth().grant(grantStmt);
+
+        new MockUp<AccessControllerManager>() {
+            @Mock
+            public Optional<DataMaskPolicy> evalDataMaskPolicy(UserIdentity 
currentUser, String ctl,
+                    String db, String tbl, String col) {
+                return tbl.equalsIgnoreCase(tableNameRanddomDist)
+                        ? Optional.of(new DataMaskPolicy() {
+                            @Override
+                            public String getMaskTypeDef() {
+                                return String.format("concat(%s, '_****_', 
%s)", col, col);
+                            }
+
+                            @Override
+                            public String getPolicyIdent() {
+                                return String.format("custom policy: 
concat(%s, '_****_', %s)", col,
+                                        col);
+                            }
+                        })
+                        : Optional.empty();
+            }
+        };
     }
 
     @Test
@@ -115,6 +150,24 @@ public class CheckRowPolicyTest extends TestWithFeService {
         Assertions.assertEquals(plan, relation);
     }
 
+    @Test
+    public void checkUserRandomDist() throws AnalysisException, 
org.apache.doris.common.AnalysisException {
+        connectContext.getState().setIsQuery(true);
+        Plan plan = PlanRewriter.bottomUpRewrite(new 
UnboundRelation(StatementScopeIdGenerator.newRelationId(),
+                        ImmutableList.of(tableNameRanddomDist)), 
connectContext, new BindRelation());
+        LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan);
+
+        useUser("root");
+        Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, 
connectContext, new CheckPolicy(),
+                new BindExpression());
+        Assertions.assertEquals(plan, rewrittenPlan);
+
+        useUser("notFound");
+        rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, 
connectContext, new CheckPolicy(),
+                new BindExpression());
+        Assertions.assertEquals(plan, rewrittenPlan.child(0));
+    }
+
     @Test
     public void checkNoPolicy() throws 
org.apache.doris.common.AnalysisException {
         useUser(userName);
@@ -125,6 +178,18 @@ public class CheckRowPolicyTest extends TestWithFeService {
         Assertions.assertEquals(plan, relation);
     }
 
+    @Test
+    public void checkNoPolicyRandomDist() throws 
org.apache.doris.common.AnalysisException {
+        useUser(userName);
+        connectContext.getState().setIsQuery(true);
+        Plan plan = PlanRewriter.bottomUpRewrite(new 
UnboundRelation(StatementScopeIdGenerator.newRelationId(),
+                ImmutableList.of(tableNameRanddomDist)), connectContext, new 
BindRelation());
+        LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan);
+        Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, 
connectContext, new CheckPolicy(),
+                new BindExpression());
+        Assertions.assertEquals(plan, rewrittenPlan.child(0));
+    }
+
     @Test
     public void checkOnePolicy() throws Exception {
         useUser(userName);
@@ -152,4 +217,36 @@ public class CheckRowPolicyTest extends TestWithFeService {
                 + " ON "
                 + tableName);
     }
+
+    @Test
+    public void checkOnePolicyRandomDist() throws Exception {
+        useUser(userName);
+        connectContext.getState().setIsQuery(true);
+        Plan plan = PlanRewriter.bottomUpRewrite(new 
UnboundRelation(StatementScopeIdGenerator.newRelationId(),
+                ImmutableList.of(tableNameRanddomDist)), connectContext, new 
BindRelation());
+
+        LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan);
+        connectContext.getSessionVariable().setEnableNereidsPlanner(true);
+        createPolicy("CREATE ROW POLICY "
+                + policyName
+                + " ON "
+                + tableNameRanddomDist
+                + " AS PERMISSIVE TO "
+                + userName
+                + " USING (k1 = 1)");
+        Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, 
connectContext, new CheckPolicy(),
+                new BindExpression());
+
+        Assertions.assertTrue(rewrittenPlan instanceof LogicalProject
+                && rewrittenPlan.child(0) instanceof LogicalFilter);
+        LogicalFilter filter = (LogicalFilter) rewrittenPlan.child(0);
+        Assertions.assertEquals(filter.child(), plan);
+        
Assertions.assertTrue(ImmutableList.copyOf(filter.getConjuncts()).get(0) 
instanceof EqualTo);
+        Assertions.assertTrue(filter.getConjuncts().toString().contains("k1#0 
= 1"));
+
+        dropPolicy("DROP ROW POLICY "
+                + policyName
+                + " ON "
+                + tableNameRanddomDist);
+    }
 }
diff --git 
a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out 
b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
index c03e72c8f9e..eb099225960 100644
--- a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
+++ b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out
@@ -217,13 +217,25 @@
 
 -- !sql_17 --
 1
+3
 
 -- !sql_18 --
 1
+3
 
 -- !sql_19 --
-1
+999999999999999.99
+1999999999999999.98
 
 -- !sql_20 --
 1
+3
+
+-- !sql_21 --
+1
+3
+
+-- !sql_22 --
+999999999999999.99
+1999999999999999.98
 
diff --git 
a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
 
b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
index c818454c261..5c99a0a4aa0 100644
--- 
a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
+++ 
b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy
@@ -123,7 +123,8 @@ suite("select_random_distributed_tbl") {
     // test all keys are NOT NULL for AGG table
     sql "drop table if exists random_distributed_tbl_test_2;"
     sql """ CREATE TABLE random_distributed_tbl_test_2 (
-        `k1` LARGEINT NOT NULL
+        `k1` LARGEINT NOT NULL,
+        `k2` DECIMAL(18, 2) SUM NOT NULL
     ) ENGINE=OLAP
     AGGREGATE KEY(`k1`)
     COMMENT 'OLAP'
@@ -133,17 +134,19 @@ suite("select_random_distributed_tbl") {
     );
     """
 
-    sql """ insert into random_distributed_tbl_test_2 values(1); """
-    sql """ insert into random_distributed_tbl_test_2 values(1); """
-    sql """ insert into random_distributed_tbl_test_2 values(1); """
+    sql """ insert into random_distributed_tbl_test_2 values(1, 
999999999999999.99); """
+    sql """ insert into random_distributed_tbl_test_2 values(1, 
999999999999999.99); """
+    sql """ insert into random_distributed_tbl_test_2 values(3, 
999999999999999.99); """
 
     sql "set enable_nereids_planner = false;"
-    qt_sql_17 "select k1 from random_distributed_tbl_test_2;"
-    qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2;"
+    qt_sql_17 "select k1 from random_distributed_tbl_test_2 order by k1;"
+    qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2 order by 
k1;"
+    qt_sql_19 "select k2 from random_distributed_tbl_test_2 order by k2;"
 
     sql "set enable_nereids_planner = true;"
-    qt_sql_19 "select k1 from random_distributed_tbl_test_2;"
-    qt_sql_20 "select distinct k1 from random_distributed_tbl_test_2;"
+    qt_sql_20 "select k1 from random_distributed_tbl_test_2 order by k1;"
+    qt_sql_21 "select distinct k1 from random_distributed_tbl_test_2 order by 
k1;"
+    qt_sql_22 "select k2 from random_distributed_tbl_test_2 order by k2;"
 
     sql "drop table random_distributed_tbl_test_2;"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to