This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 7d3a3fee6571210ac3e02ac218d06b2fa27256af Author: morrySnow <101034200+morrys...@users.noreply.github.com> AuthorDate: Thu Jan 18 14:25:34 2024 +0800 [fix](Nereids) update assignment column name should case insensitive (#30071) --- .../doris/nereids/rules/analysis/SlotBinder.java | 10 +++---- .../trees/plans/commands/UpdateCommand.java | 33 +++++++++++++++++++++- .../nereids_p0/update/update_unique_table.groovy | 14 ++++++++- 3 files changed, 50 insertions(+), 7 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java index 6f5f11b0a77..e25aa202627 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SlotBinder.java @@ -262,10 +262,6 @@ public class SlotBinder extends SubExprAnalyzer { return new BoundStar(slots); } - private boolean compareDbName(String unBoundDbName, String boundedDbName) { - return unBoundDbName.equalsIgnoreCase(boundedDbName); - } - private List<Slot> bindSlot(UnboundSlot unboundSlot, List<Slot> boundSlots) { return boundSlots.stream().distinct().filter(boundSlot -> { List<String> nameParts = unboundSlot.getNameParts(); @@ -305,7 +301,11 @@ public class SlotBinder extends SubExprAnalyzer { }).collect(Collectors.toList()); } - private boolean sameTableName(String boundSlot, String unboundSlot) { + public static boolean compareDbName(String boundedDbName, String unBoundDbName) { + return unBoundDbName.equalsIgnoreCase(boundedDbName); + } + + public static boolean sameTableName(String boundSlot, String unboundSlot) { if (GlobalVariable.lowerCaseTableNames != 1) { return boundSlot.equals(unboundSlot); } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateCommand.java index 6949b124e5d..0023b4bd59c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/UpdateCommand.java @@ -27,6 +27,7 @@ import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.analyzer.UnboundTableSink; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.rules.analysis.SlotBinder; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -101,9 +102,10 @@ public class UpdateCommand extends Command implements ForwardWithSync, Explainab public LogicalPlan completeQueryPlan(ConnectContext ctx, LogicalPlan logicalQuery) { checkTable(ctx); - Map<String, Expression> colNameToExpression = Maps.newHashMap(); + Map<String, Expression> colNameToExpression = Maps.newTreeMap(String.CASE_INSENSITIVE_ORDER); for (EqualTo equalTo : assignments) { List<String> nameParts = ((UnboundSlot) equalTo.left()).getNameParts(); + checkAssignmentColumn(ctx, nameParts); colNameToExpression.put(nameParts.get(nameParts.size() - 1), equalTo.right()); } List<NamedExpression> selectItems = Lists.newArrayList(); @@ -118,6 +120,7 @@ public class UpdateCommand extends Command implements ForwardWithSync, Explainab selectItems.add(expr instanceof UnboundSlot ? ((NamedExpression) expr) : new UnboundAlias(expr)); + colNameToExpression.remove(column.getName()); } else { if (column.hasOnUpdateDefaultValue()) { Expression defualtValueExpression = @@ -129,6 +132,10 @@ public class UpdateCommand extends Command implements ForwardWithSync, Explainab } } } + if (!colNameToExpression.isEmpty()) { + throw new AnalysisException("unknown column in assignment list: " + + String.join(", ", colNameToExpression.keySet())); + } logicalQuery = new LogicalProject<>(selectItems, logicalQuery); if (cte.isPresent()) { @@ -143,6 +150,30 @@ public class UpdateCommand extends Command implements ForwardWithSync, Explainab false, ImmutableList.of(), isPartialUpdate, DMLCommandType.UPDATE, logicalQuery); } + private void checkAssignmentColumn(ConnectContext ctx, List<String> columnNameParts) { + if (columnNameParts.size() <= 1) { + return; + } + String dbName = null; + String tableName = null; + if (columnNameParts.size() == 3) { + dbName = columnNameParts.get(0); + tableName = columnNameParts.get(1); + } else if (columnNameParts.size() == 2) { + tableName = columnNameParts.get(0); + } else { + throw new AnalysisException("column in assignment list is invalid, " + String.join(".", columnNameParts)); + } + if (dbName != null && this.tableAlias != null) { + throw new AnalysisException("column in assignment list is invalid, " + String.join(".", columnNameParts)); + } + List<String> tableQualifier = RelationUtil.getQualifierName(ctx, nameParts); + if (!SlotBinder.sameTableName(tableAlias == null ? tableQualifier.get(2) : tableAlias, tableName) + || (dbName != null && SlotBinder.compareDbName(tableQualifier.get(1), dbName))) { + throw new AnalysisException("column in assignment list is invalid, " + String.join(".", columnNameParts)); + } + } + private void checkTable(ConnectContext ctx) { if (ctx.getSessionVariable().isInDebugMode()) { throw new AnalysisException("Update is forbidden since current session is in debug mode." diff --git a/regression-test/suites/nereids_p0/update/update_unique_table.groovy b/regression-test/suites/nereids_p0/update/update_unique_table.groovy index 59ea06b10b8..8689f16f9e8 100644 --- a/regression-test/suites/nereids_p0/update/update_unique_table.groovy +++ b/regression-test/suites/nereids_p0/update/update_unique_table.groovy @@ -95,10 +95,22 @@ suite('update_unique_table') { sql ''' update t1 - set t1.c1 = t2.c1, t1.c3 = t2.c3 * 100 + set t1.C1 = t2.c1, t1.c3 = t2.c3 * 100 from t2 inner join t3 on t2.id = t3.id where t1.id = t2.id; ''' qt_sql 'select * from t1 order by id' + + test { + sql '''update t1 set t.c1 = 1 where t1.c1 = 1;''' + exception "" + } + + test { + sql '''update t1 t set t1.c1 = 1 where t1.c1 = 1;''' + exception "" + } + + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org