This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new dca6b8175d7 [fix](nereids) build agg for random distributed agg table in bindRelation phase (#40181) dca6b8175d7 is described below commit dca6b8175d7557b0b91c7a5c97d2656b02dced6f Author: starocean999 <40539150+starocean...@users.noreply.github.com> AuthorDate: Wed Sep 11 18:16:03 2024 +0800 [fix](nereids) build agg for random distributed agg table in bindRelation phase (#40181) it's better to build agg for random distributed agg table in bindRelation phase instead of in BuildAggForRandomDistributedTable RULE --- .../doris/nereids/jobs/executor/Analyzer.java | 3 - .../org/apache/doris/nereids/rules/RuleType.java | 4 - .../doris/nereids/rules/analysis/BindRelation.java | 125 +++++++++- .../BuildAggForRandomDistributedTable.java | 271 --------------------- .../doris/nereids/rules/analysis/CheckPolicy.java | 21 +- .../nereids/rules/analysis/BindRelationTest.java | 23 ++ .../nereids/rules/analysis/CheckRowPolicyTest.java | 97 ++++++++ .../aggregate/select_random_distributed_tbl.out | 14 +- .../aggregate/select_random_distributed_tbl.groovy | 19 +- 9 files changed, 285 insertions(+), 292 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index 605a848181c..1ffbac97d74 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -26,7 +26,6 @@ import org.apache.doris.nereids.rules.analysis.BindExpression; import org.apache.doris.nereids.rules.analysis.BindRelation; import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver; import org.apache.doris.nereids.rules.analysis.BindSink; -import org.apache.doris.nereids.rules.analysis.BuildAggForRandomDistributedTable; import org.apache.doris.nereids.rules.analysis.CheckAfterBind; import org.apache.doris.nereids.rules.analysis.CheckAnalysis; import org.apache.doris.nereids.rules.analysis.CheckPolicy; @@ -163,8 +162,6 @@ public class Analyzer extends AbstractBatchJobExecutor { topDown(new EliminateGroupByConstant()), topDown(new SimplifyAggGroupBy()), - // run BuildAggForRandomDistributedTable before NormalizeAggregate in order to optimize the agg plan - topDown(new BuildAggForRandomDistributedTable()), topDown(new NormalizeAggregate()), topDown(new HavingToFilter()), bottomUp(new SemiJoinCommute()), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index b6ab6e2dac2..d345d9057e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -342,10 +342,6 @@ public enum RuleType { // topn opts DEFER_MATERIALIZE_TOP_N_RESULT(RuleTypeClass.REWRITE), - // pre agg for random distributed table - BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN(RuleTypeClass.REWRITE), - BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN(RuleTypeClass.REWRITE), - BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN(RuleTypeClass.REWRITE), // short circuit rule SHOR_CIRCUIT_POINT_QUERY(RuleTypeClass.REWRITE), // exploration rules diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java index cedc4e92ff1..c81fcc25b83 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java @@ -17,10 +17,17 @@ package org.apache.doris.nereids.rules.analysis; +import org.apache.doris.catalog.AggStateType; +import org.apache.doris.catalog.AggregateType; import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.DistributionInfo; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.catalog.KeysType; import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.Partition; import org.apache.doris.catalog.TableIf; +import org.apache.doris.catalog.Type; import org.apache.doris.catalog.View; import org.apache.doris.common.Config; import org.apache.doris.common.Pair; @@ -43,13 +50,26 @@ import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PreAggStatus; import org.apache.doris.nereids.trees.plans.algebra.Relation; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.trees.plans.logical.LogicalEsScan; import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan; @@ -73,6 +93,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.commons.collections.CollectionUtils; +import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Function; @@ -214,7 +235,109 @@ public class BindRelation extends OneAnalysisRuleFactory { unboundRelation.getTableSample()); } } - return checkAndAddDeleteSignFilter(scan, ConnectContext.get(), (OlapTable) table); + if (needGenerateLogicalAggForRandomDistAggTable(scan)) { + // it's a random distribution agg table + // add agg on olap scan + return preAggForRandomDistribution(scan); + } else { + // it's a duplicate, unique or hash distribution agg table + // add delete sign filter on olap scan if needed + return checkAndAddDeleteSignFilter(scan, ConnectContext.get(), (OlapTable) table); + } + } + + private boolean needGenerateLogicalAggForRandomDistAggTable(LogicalOlapScan olapScan) { + if (ConnectContext.get() != null && ConnectContext.get().getState() != null + && ConnectContext.get().getState().isQuery()) { + // we only need to add an agg node for query, and should not do it for deleting + // from random distributed table. see https://github.com/apache/doris/pull/37985 for more info + OlapTable olapTable = olapScan.getTable(); + KeysType keysType = olapTable.getKeysType(); + DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo(); + return keysType == KeysType.AGG_KEYS + && distributionInfo.getType() == DistributionInfo.DistributionInfoType.RANDOM; + } else { + return false; + } + } + + /** + * add LogicalAggregate above olapScan for preAgg + * @param olapScan olap scan plan + * @return rewritten plan + */ + private LogicalPlan preAggForRandomDistribution(LogicalOlapScan olapScan) { + OlapTable olapTable = olapScan.getTable(); + List<Slot> childOutputSlots = olapScan.computeOutput(); + List<Expression> groupByExpressions = new ArrayList<>(); + List<NamedExpression> outputExpressions = new ArrayList<>(); + List<Column> columns = olapTable.getBaseSchema(); + + for (Column col : columns) { + // use exist slot in the plan + SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.qualified()); + ExprId exprId = slot.getExprId(); + for (Slot childSlot : childOutputSlots) { + if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) { + exprId = childSlot.getExprId(); + slot = slot.withExprId(exprId); + break; + } + } + if (col.isKey()) { + groupByExpressions.add(slot); + outputExpressions.add(slot); + } else { + Expression function = generateAggFunction(slot, col); + // DO NOT rewrite + if (function == null) { + return olapScan; + } + Alias alias = new Alias(exprId, ImmutableList.of(function), col.getName(), + olapScan.qualified(), true); + outputExpressions.add(alias); + } + } + LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions, + olapScan); + return aggregate; + } + + /** + * generate aggregation function according to the aggType of column + * + * @param slot slot of column + * @return aggFunction generated + */ + private Expression generateAggFunction(SlotReference slot, Column column) { + AggregateType aggregateType = column.getAggregationType(); + switch (aggregateType) { + case SUM: + return new Sum(slot); + case MAX: + return new Max(slot); + case MIN: + return new Min(slot); + case HLL_UNION: + return new HllUnion(slot); + case BITMAP_UNION: + return new BitmapUnion(slot); + case QUANTILE_UNION: + return new QuantileUnion(slot); + case GENERIC: + Type type = column.getType(); + if (!type.isAggStateType()) { + return null; + } + AggStateType aggState = (AggStateType) type; + // use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one + String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX; + FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry(); + FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot); + return builder.build(funcName, ImmutableList.of(slot)).first; + default: + return null; + } } /** diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java deleted file mode 100644 index e547a55f9e3..00000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BuildAggForRandomDistributedTable.java +++ /dev/null @@ -1,271 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.analysis; - -import org.apache.doris.catalog.AggStateType; -import org.apache.doris.catalog.AggregateType; -import org.apache.doris.catalog.Column; -import org.apache.doris.catalog.DistributionInfo; -import org.apache.doris.catalog.DistributionInfo.DistributionInfoType; -import org.apache.doris.catalog.Env; -import org.apache.doris.catalog.FunctionRegistry; -import org.apache.doris.catalog.KeysType; -import org.apache.doris.catalog.OlapTable; -import org.apache.doris.catalog.Type; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder; -import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; -import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.BitmapUnion; -import org.apache.doris.nereids.trees.expressions.functions.agg.Count; -import org.apache.doris.nereids.trees.expressions.functions.agg.HllFunction; -import org.apache.doris.nereids.trees.expressions.functions.agg.HllUnion; -import org.apache.doris.nereids.trees.expressions.functions.agg.Max; -import org.apache.doris.nereids.trees.expressions.functions.agg.Min; -import org.apache.doris.nereids.trees.expressions.functions.agg.QuantileUnion; -import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.qe.ConnectContext; - -import com.google.common.collect.ImmutableList; - -import java.util.ArrayList; -import java.util.List; -import java.util.Set; - -/** - * build agg plan for querying random distributed table - */ -public class BuildAggForRandomDistributedTable implements AnalysisRuleFactory { - - @Override - public List<Rule> buildRules() { - return ImmutableList.of( - // Project(Scan) -> project(agg(scan)) - logicalProject(logicalOlapScan()) - .when(this::isQuery) - .when(project -> isRandomDistributedTbl(project.child())) - .then(project -> preAggForRandomDistribution(project, project.child())) - .toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_PROJECT_SCAN), - // agg(scan) -> agg(agg(scan)), agg(agg) may optimized by MergeAggregate - logicalAggregate(logicalOlapScan()) - .when(this::isQuery) - .when(agg -> isRandomDistributedTbl(agg.child())) - .whenNot(agg -> { - Set<AggregateFunction> functions = agg.getAggregateFunctions(); - List<Expression> groupByExprs = agg.getGroupByExpressions(); - // check if need generate an inner agg plan or not - // should not rewrite twice if we had rewritten olapScan to aggregate(olapScan) - return functions.stream().allMatch(this::aggTypeMatch) && groupByExprs.stream() - .allMatch(this::isKeyOrConstantExpr); - }) - .then(agg -> preAggForRandomDistribution(agg, agg.child())) - .toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_AGG_SCAN), - // filter(scan) -> filter(agg(scan)) - logicalFilter(logicalOlapScan()) - .when(this::isQuery) - .when(filter -> isRandomDistributedTbl(filter.child())) - .then(filter -> preAggForRandomDistribution(filter, filter.child())) - .toRule(RuleType.BUILD_AGG_FOR_RANDOM_DISTRIBUTED_TABLE_FILTER_SCAN)); - - } - - /** - * check the olapTable of olapScan is randomDistributed table - * - * @param olapScan olap scan plan - * @return true if olapTable is randomDistributed table - */ - private boolean isRandomDistributedTbl(LogicalOlapScan olapScan) { - OlapTable olapTable = olapScan.getTable(); - KeysType keysType = olapTable.getKeysType(); - DistributionInfo distributionInfo = olapTable.getDefaultDistributionInfo(); - return keysType == KeysType.AGG_KEYS && distributionInfo.getType() == DistributionInfoType.RANDOM; - } - - private boolean isQuery(LogicalPlan plan) { - return ConnectContext.get() != null - && ConnectContext.get().getState() != null - && ConnectContext.get().getState().isQuery(); - } - - /** - * add LogicalAggregate above olapScan for preAgg - * - * @param logicalPlan parent plan of olapScan - * @param olapScan olap scan plan, it may be LogicalProject, LogicalFilter, LogicalAggregate - * @return rewritten plan - */ - private Plan preAggForRandomDistribution(LogicalPlan logicalPlan, LogicalOlapScan olapScan) { - OlapTable olapTable = olapScan.getTable(); - List<Slot> childOutputSlots = olapScan.computeOutput(); - List<Expression> groupByExpressions = new ArrayList<>(); - List<NamedExpression> outputExpressions = new ArrayList<>(); - List<Column> columns = olapTable.getBaseSchema(); - - for (Column col : columns) { - // use exist slot in the plan - SlotReference slot = SlotReference.fromColumn(olapTable, col, col.getName(), olapScan.getQualifier()); - ExprId exprId = slot.getExprId(); - for (Slot childSlot : childOutputSlots) { - if (childSlot instanceof SlotReference && ((SlotReference) childSlot).getName() == col.getName()) { - exprId = childSlot.getExprId(); - slot = slot.withExprId(exprId); - break; - } - } - if (col.isKey()) { - groupByExpressions.add(slot); - outputExpressions.add(slot); - } else { - Expression function = generateAggFunction(slot, col); - // DO NOT rewrite - if (function == null) { - return logicalPlan; - } - Alias alias = new Alias(exprId, function, col.getName()); - outputExpressions.add(alias); - } - } - LogicalAggregate<LogicalOlapScan> aggregate = new LogicalAggregate<>(groupByExpressions, outputExpressions, - olapScan); - return logicalPlan.withChildren(aggregate); - } - - /** - * generate aggregation function according to the aggType of column - * - * @param slot slot of column - * @return aggFunction generated - */ - private Expression generateAggFunction(SlotReference slot, Column column) { - AggregateType aggregateType = column.getAggregationType(); - switch (aggregateType) { - case SUM: - return new Sum(slot); - case MAX: - return new Max(slot); - case MIN: - return new Min(slot); - case HLL_UNION: - return new HllUnion(slot); - case BITMAP_UNION: - return new BitmapUnion(slot); - case QUANTILE_UNION: - return new QuantileUnion(slot); - case GENERIC: - Type type = column.getType(); - if (!type.isAggStateType()) { - return null; - } - AggStateType aggState = (AggStateType) type; - // use AGGREGATE_FUNCTION_UNION to aggregate multiple agg_state into one - String funcName = aggState.getFunctionName() + AggCombinerFunctionBuilder.UNION_SUFFIX; - FunctionRegistry functionRegistry = Env.getCurrentEnv().getFunctionRegistry(); - FunctionBuilder builder = functionRegistry.findFunctionBuilder(funcName, slot); - return builder.build(funcName, ImmutableList.of(slot)).first; - default: - return null; - } - } - - /** - * if the agg type of AggregateFunction is as same as the agg type of column, DO NOT need to rewrite - * - * @param function agg function to check - * @return true if agg type match - */ - private boolean aggTypeMatch(AggregateFunction function) { - List<Expression> children = function.children(); - if (function.getName().equalsIgnoreCase("count")) { - Count count = (Count) function; - // do not rewrite for count distinct for key column - if (count.isDistinct()) { - return children.stream().allMatch(this::isKeyOrConstantExpr); - } - if (count.isStar()) { - return false; - } - } - return children.stream().allMatch(child -> aggTypeMatch(function, child)); - } - - /** - * check if the agg type of functionCall match the agg type of column - * - * @param function the functionCall - * @param expression expr to check - * @return true if agg type match - */ - private boolean aggTypeMatch(AggregateFunction function, Expression expression) { - if (expression.children().isEmpty()) { - if (expression instanceof SlotReference && ((SlotReference) expression).getColumn().isPresent()) { - Column col = ((SlotReference) expression).getColumn().get(); - String functionName = function.getName(); - if (col.isKey()) { - return functionName.equalsIgnoreCase("max") || functionName.equalsIgnoreCase("min"); - } - if (col.isAggregated()) { - AggregateType aggType = col.getAggregationType(); - // agg type not mach - if (aggType == AggregateType.GENERIC) { - return col.getType().isAggStateType(); - } - if (aggType == AggregateType.HLL_UNION) { - return function instanceof HllFunction; - } - if (aggType == AggregateType.BITMAP_UNION) { - return function instanceof BitmapFunction; - } - return functionName.equalsIgnoreCase(aggType.name()); - } - } - return false; - } - List<Expression> children = expression.children(); - return children.stream().allMatch(child -> aggTypeMatch(function, child)); - } - - /** - * check if the columns in expr is key column or constant, if group by clause contains value column, need rewrite - * - * @param expr expr to check - * @return true if all columns is key column or constant - */ - private boolean isKeyOrConstantExpr(Expression expr) { - if (expr instanceof SlotReference && ((SlotReference) expr).getColumn().isPresent()) { - Column col = ((SlotReference) expr).getColumn().get(); - return col.isKey(); - } else if (expr.isConstant()) { - return true; - } - List<Expression> children = expr.children(); - return children.stream().allMatch(this::isKeyOrConstantExpr); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java index 94f7c36b108..4beed413d09 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; @@ -49,12 +50,23 @@ public class CheckPolicy implements AnalysisRuleFactory { logicalCheckPolicy(any().when(child -> !(child instanceof UnboundRelation))).thenApply(ctx -> { LogicalCheckPolicy<Plan> checkPolicy = ctx.root; LogicalFilter<Plan> upperFilter = null; + Plan upAgg = null; Plan child = checkPolicy.child(); // Because the unique table will automatically include a filter condition - if (child instanceof LogicalFilter && child.bound() && child - .child(0) instanceof LogicalRelation) { + if ((child instanceof LogicalFilter) && child.bound()) { upperFilter = (LogicalFilter) child; + if (child.child(0) instanceof LogicalRelation) { + child = child.child(0); + } else if (child.child(0) instanceof LogicalAggregate + && child.child(0).child(0) instanceof LogicalRelation) { + upAgg = child.child(0); + child = child.child(0).child(0); + } + } + if ((child instanceof LogicalAggregate) + && child.bound() && child.child(0) instanceof LogicalRelation) { + upAgg = child; child = child.child(0); } if (!(child instanceof LogicalRelation) @@ -76,16 +88,17 @@ public class CheckPolicy implements AnalysisRuleFactory { RelatedPolicy relatedPolicy = checkPolicy.findPolicy(relation, ctx.cascadesContext); relatedPolicy.rowPolicyFilter.ifPresent(expression -> combineFilter.addAll( ExpressionUtils.extractConjunctionToSet(expression))); - Plan result = relation; + Plan result = upAgg != null ? upAgg.withChildren(relation) : relation; if (upperFilter != null) { combineFilter.addAll(upperFilter.getConjuncts()); } if (!combineFilter.isEmpty()) { - result = new LogicalFilter<>(combineFilter, relation); + result = new LogicalFilter<>(combineFilter, result); } if (relatedPolicy.dataMaskProjects.isPresent()) { result = new LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result); } + return result; }) ) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java index b14834fd321..67115e67687 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/BindRelationTest.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.rules.RulePromise; import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanRewriter; @@ -54,6 +55,12 @@ class BindRelationTest extends TestWithFeService implements GeneratedPlanPattern + ")ENGINE=OLAP\n" + "DISTRIBUTED BY HASH(`a`) BUCKETS 3\n" + "PROPERTIES (\"replication_num\"= \"1\");"); + createTable("CREATE TABLE db1.tagg ( \n" + + " \ta INT,\n" + + " \tb INT SUM\n" + + ")ENGINE=OLAP AGGREGATE KEY(a)\n " + + "DISTRIBUTED BY random BUCKETS 3\n" + + "PROPERTIES (\"replication_num\"= \"1\");"); connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); } @@ -125,6 +132,22 @@ class BindRelationTest extends TestWithFeService implements GeneratedPlanPattern ); } + @Test + void bindRandomAggTable() { + connectContext.setDatabase(DEFAULT_CLUSTER_PREFIX + DB1); + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), ImmutableList.of("tagg")), + connectContext, new BindRelation()); + + Assertions.assertTrue(plan instanceof LogicalAggregate); + Assertions.assertEquals( + ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1, "tagg"), + plan.getOutput().get(0).getQualifier()); + Assertions.assertEquals( + ImmutableList.of("internal", DEFAULT_CLUSTER_PREFIX + DB1, "tagg"), + plan.getOutput().get(1).getQualifier()); + } + @Override public RulePromise defaultPromise() { return RulePromise.REWRITE; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java index 196d99037e2..b807bbbbc7a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/CheckRowPolicyTest.java @@ -34,6 +34,9 @@ import org.apache.doris.catalog.OlapTable; import org.apache.doris.catalog.PartitionInfo; import org.apache.doris.catalog.Type; import org.apache.doris.common.FeConstants; +import org.apache.doris.mysql.privilege.AccessControllerManager; +import org.apache.doris.mysql.privilege.DataMaskPolicy; +import org.apache.doris.nereids.analyzer.UnboundRelation; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; @@ -41,6 +44,7 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; import org.apache.doris.nereids.util.PlanRewriter; import org.apache.doris.thrift.TStorageType; @@ -48,17 +52,22 @@ import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import mockit.Mock; +import mockit.MockUp; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.Arrays; import java.util.List; +import java.util.Optional; public class CheckRowPolicyTest extends TestWithFeService { private static String dbName = "check_row_policy"; private static String fullDbName = "" + dbName; private static String tableName = "table1"; + + private static String tableNameRanddomDist = "tableRandomDist"; private static String userName = "user1"; private static String policyName = "policy1"; @@ -76,6 +85,10 @@ public class CheckRowPolicyTest extends TestWithFeService { + tableName + " (k1 int, k2 int) distributed by hash(k1) buckets 1" + " properties(\"replication_num\" = \"1\");"); + createTable("create table " + + tableNameRanddomDist + + " (k1 int, k2 int) AGGREGATE KEY(k1, k2) distributed by random buckets 1" + + " properties(\"replication_num\" = \"1\");"); Database db = Env.getCurrentInternalCatalog().getDbOrMetaException(fullDbName); long tableId = db.getTableOrMetaException("table1").getId(); olapTable.setId(tableId); @@ -85,6 +98,7 @@ public class CheckRowPolicyTest extends TestWithFeService { 0, 0, (short) 0, TStorageType.COLUMN, KeysType.PRIMARY_KEYS); + // create user UserIdentity user = new UserIdentity(userName, "%"); user.analyze(); @@ -98,6 +112,27 @@ public class CheckRowPolicyTest extends TestWithFeService { Analyzer analyzer = new Analyzer(connectContext.getEnv(), connectContext); grantStmt.analyze(analyzer); Env.getCurrentEnv().getAuth().grant(grantStmt); + + new MockUp<AccessControllerManager>() { + @Mock + public Optional<DataMaskPolicy> evalDataMaskPolicy(UserIdentity currentUser, String ctl, + String db, String tbl, String col) { + return tbl.equalsIgnoreCase(tableNameRanddomDist) + ? Optional.of(new DataMaskPolicy() { + @Override + public String getMaskTypeDef() { + return String.format("concat(%s, '_****_', %s)", col, col); + } + + @Override + public String getPolicyIdent() { + return String.format("custom policy: concat(%s, '_****_', %s)", col, + col); + } + }) + : Optional.empty(); + } + }; } @Test @@ -115,6 +150,24 @@ public class CheckRowPolicyTest extends TestWithFeService { Assertions.assertEquals(plan, relation); } + @Test + public void checkUserRandomDist() throws AnalysisException, org.apache.doris.common.AnalysisException { + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), + ImmutableList.of(tableNameRanddomDist)), connectContext, new BindRelation()); + LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan); + + useUser("root"); + Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + Assertions.assertEquals(plan, rewrittenPlan); + + useUser("notFound"); + rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + Assertions.assertEquals(plan, rewrittenPlan.child(0)); + } + @Test public void checkNoPolicy() throws org.apache.doris.common.AnalysisException { useUser(userName); @@ -125,6 +178,18 @@ public class CheckRowPolicyTest extends TestWithFeService { Assertions.assertEquals(plan, relation); } + @Test + public void checkNoPolicyRandomDist() throws org.apache.doris.common.AnalysisException { + useUser(userName); + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), + ImmutableList.of(tableNameRanddomDist)), connectContext, new BindRelation()); + LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan); + Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + Assertions.assertEquals(plan, rewrittenPlan.child(0)); + } + @Test public void checkOnePolicy() throws Exception { useUser(userName); @@ -152,4 +217,36 @@ public class CheckRowPolicyTest extends TestWithFeService { + " ON " + tableName); } + + @Test + public void checkOnePolicyRandomDist() throws Exception { + useUser(userName); + connectContext.getState().setIsQuery(true); + Plan plan = PlanRewriter.bottomUpRewrite(new UnboundRelation(StatementScopeIdGenerator.newRelationId(), + ImmutableList.of(tableNameRanddomDist)), connectContext, new BindRelation()); + + LogicalCheckPolicy checkPolicy = new LogicalCheckPolicy(plan); + connectContext.getSessionVariable().setEnableNereidsPlanner(true); + createPolicy("CREATE ROW POLICY " + + policyName + + " ON " + + tableNameRanddomDist + + " AS PERMISSIVE TO " + + userName + + " USING (k1 = 1)"); + Plan rewrittenPlan = PlanRewriter.bottomUpRewrite(checkPolicy, connectContext, new CheckPolicy(), + new BindExpression()); + + Assertions.assertTrue(rewrittenPlan instanceof LogicalProject + && rewrittenPlan.child(0) instanceof LogicalFilter); + LogicalFilter filter = (LogicalFilter) rewrittenPlan.child(0); + Assertions.assertEquals(filter.child(), plan); + Assertions.assertTrue(ImmutableList.copyOf(filter.getConjuncts()).get(0) instanceof EqualTo); + Assertions.assertTrue(filter.getConjuncts().toString().contains("k1#0 = 1")); + + dropPolicy("DROP ROW POLICY " + + policyName + + " ON " + + tableNameRanddomDist); + } } diff --git a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out index c03e72c8f9e..eb099225960 100644 --- a/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out +++ b/regression-test/data/query_p0/aggregate/select_random_distributed_tbl.out @@ -217,13 +217,25 @@ -- !sql_17 -- 1 +3 -- !sql_18 -- 1 +3 -- !sql_19 -- -1 +999999999999999.99 +1999999999999999.98 -- !sql_20 -- 1 +3 + +-- !sql_21 -- +1 +3 + +-- !sql_22 -- +999999999999999.99 +1999999999999999.98 diff --git a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy index c818454c261..5c99a0a4aa0 100644 --- a/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy +++ b/regression-test/suites/query_p0/aggregate/select_random_distributed_tbl.groovy @@ -123,7 +123,8 @@ suite("select_random_distributed_tbl") { // test all keys are NOT NULL for AGG table sql "drop table if exists random_distributed_tbl_test_2;" sql """ CREATE TABLE random_distributed_tbl_test_2 ( - `k1` LARGEINT NOT NULL + `k1` LARGEINT NOT NULL, + `k2` DECIMAL(18, 2) SUM NOT NULL ) ENGINE=OLAP AGGREGATE KEY(`k1`) COMMENT 'OLAP' @@ -133,17 +134,19 @@ suite("select_random_distributed_tbl") { ); """ - sql """ insert into random_distributed_tbl_test_2 values(1); """ - sql """ insert into random_distributed_tbl_test_2 values(1); """ - sql """ insert into random_distributed_tbl_test_2 values(1); """ + sql """ insert into random_distributed_tbl_test_2 values(1, 999999999999999.99); """ + sql """ insert into random_distributed_tbl_test_2 values(1, 999999999999999.99); """ + sql """ insert into random_distributed_tbl_test_2 values(3, 999999999999999.99); """ sql "set enable_nereids_planner = false;" - qt_sql_17 "select k1 from random_distributed_tbl_test_2;" - qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2;" + qt_sql_17 "select k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_18 "select distinct k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_19 "select k2 from random_distributed_tbl_test_2 order by k2;" sql "set enable_nereids_planner = true;" - qt_sql_19 "select k1 from random_distributed_tbl_test_2;" - qt_sql_20 "select distinct k1 from random_distributed_tbl_test_2;" + qt_sql_20 "select k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_21 "select distinct k1 from random_distributed_tbl_test_2 order by k1;" + qt_sql_22 "select k2 from random_distributed_tbl_test_2 order by k2;" sql "drop table random_distributed_tbl_test_2;" } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org