qzsee commented on code in PR #12151: URL: https://github.com/apache/doris/pull/12151#discussion_r963205692
########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/FoldConstantRule.java: ########## @@ -0,0 +1,440 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rewrite.rules; + +import org.apache.doris.analysis.BetweenPredicate; +import org.apache.doris.analysis.CastExpr; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.ExprId; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.IdGenerator; +import org.apache.doris.common.LoadException; +import org.apache.doris.common.util.TimeUtils; +import org.apache.doris.common.util.VectorizedUtil; +import org.apache.doris.nereids.glue.translator.ExpressionTranslator; +import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.CaseWhen; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.ExpressionEvaluator; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; +import org.apache.doris.nereids.trees.expressions.Like; +import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.NullSafeEqual; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; +import org.apache.doris.nereids.trees.expressions.WhenClause; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.proto.InternalService; +import org.apache.doris.proto.InternalService.PConstantExprResult; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.rpc.BackendServiceProxy; +import org.apache.doris.system.Backend; +import org.apache.doris.thrift.TExpr; +import org.apache.doris.thrift.TFoldConstantParams; +import org.apache.doris.thrift.TNetworkAddress; +import org.apache.doris.thrift.TPrimitiveType; +import org.apache.doris.thrift.TQueryGlobals; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; + +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +/** + * Constant evaluation of an expression. + */ +public class FoldConstantRule extends AbstractExpressionRewriteRule { + public static final FoldConstantRule INSTANCE = new FoldConstantRule(); + private static final Logger LOG = LogManager.getLogger(FoldConstantRule.class); + private static final DateFormat DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + private final IdGenerator<ExprId> idGenerator = ExprId.createGenerator(); + + private static boolean isAllLiteral(Expression... children) { + return Arrays.stream(children).allMatch(c -> c instanceof Literal); + } + + private static boolean isAllLiteral(List<Expression> children) { + return children.stream().allMatch(c -> c instanceof Literal); + } + + private static boolean hasNull(List<Expression> children) { + return children.stream().anyMatch(c -> c instanceof NullLiteral); + } + + private static boolean hasNull(Expression... children) { + return Arrays.stream(children).anyMatch(c -> c instanceof NullLiteral); + } + + @Override + public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) { + Expression expression = expr.accept(this, ctx); + if (ctx.connectContext != null && ctx.connectContext.getSessionVariable().isEnableFoldConstantByBe()) { + return foldByBe(expression, ctx); + } + return expression; + } + + @Override + public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) { + Expression left = rewrite(cp.left(), context); + Expression right = rewrite(cp.right(), context); + + if (!(cp instanceof NullSafeEqual) && hasNull(left, right)) { + return Literal.of(null); + } + + if (isAllLiteral(left, right)) { + Literal l = (Literal) left; + Literal r = (Literal) right; + if (cp instanceof EqualTo) { + return BooleanLiteral.of(l.compareTo(r) == 0); + } + if (cp instanceof GreaterThan) { + return BooleanLiteral.of(l.compareTo(r) > 0); + } + if (cp instanceof GreaterThanEqual) { + return BooleanLiteral.of(l.compareTo(r) >= 0); + } + if (cp instanceof LessThan) { + return BooleanLiteral.of(l.compareTo(r) < 0); + } + if (cp instanceof LessThanEqual) { + return BooleanLiteral.of(l.compareTo(r) <= 0); + } + if (cp instanceof NullSafeEqual) { + if (l.isNull() && r.isNull()) { + return BooleanLiteral.TRUE; + } else if (!l.isNull() && !r.isNull()) { + return BooleanLiteral.of(l.compareTo(r) == 0); + } else { + return BooleanLiteral.FALSE; + } + } + } + return cp.withChildren(left, right); + } + + @Override + public Expression visitNot(Not not, ExpressionRewriteContext context) { + Expression child = rewrite(not.child(), context); + if (child instanceof NullLiteral) { + return Literal.of(null); + } + if (child.isLiteral()) { + if (child instanceof BooleanLiteral) { + BooleanLiteral c = (BooleanLiteral) child; + return BooleanLiteral.of(!c.getValue()); + } + } + return not.withChildren(child); + } + + @Override + public Expression visitCompoundPredicate(CompoundPredicate compoundPredicate, ExpressionRewriteContext context) { + Expression left = rewrite(compoundPredicate.left(), context); + Expression right = rewrite(compoundPredicate.right(), context); + if (left instanceof NullLiteral && right instanceof NullLiteral) { + return Literal.of(null); + } + if (left instanceof Literal || right instanceof Literal) { + if (compoundPredicate instanceof And) { + if (left.equals(BooleanLiteral.FALSE) || right.equals(BooleanLiteral.FALSE)) { + return BooleanLiteral.FALSE; + } + if (left.equals(BooleanLiteral.TRUE)) { + return right; + } + if (right.equals(BooleanLiteral.TRUE)) { + return left; + } + } + if (compoundPredicate instanceof Or) { + if (left.equals(BooleanLiteral.TRUE) || right.equals(BooleanLiteral.TRUE)) { + return BooleanLiteral.TRUE; + } + if (left.equals(BooleanLiteral.FALSE)) { + return right; + } + if (right.equals(BooleanLiteral.FALSE)) { + return left; + } + } + } + return compoundPredicate.withChildren(left, right); + } + + @Override + public Expression visitLike(Like like, ExpressionRewriteContext context) { + Expression left = rewrite(like.left(), context); + if (left instanceof NullLiteral) { + return Literal.of(null); + } + return like.withChildren(left, like.right()); + } + + @Override + public Expression visitCast(Cast cast, ExpressionRewriteContext context) { + if (hasNull(cast.children())) { + return Literal.of(null); + } + Expression child = rewrite(cast.child(), context); + if (child.isLiteral()) { + return child.castTo(cast.getDataType()); + } + return cast; + } + + @Override + public Expression visitBoundFunction(BoundFunction boundFunction, ExpressionRewriteContext context) { + List<Expression> newArgs = boundFunction.getArguments().stream().map(arg -> rewrite(arg, context)) + .collect(Collectors.toList()); + if (isAllLiteral(newArgs)) { + return ExpressionEvaluator.INSTANCE.eval(boundFunction.withChildren(newArgs)); + } + return boundFunction.withChildren(newArgs); + } + + @Override + public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, ExpressionRewriteContext context) { + Expression left = rewrite(binaryArithmetic.left(), context); + Expression right = rewrite(binaryArithmetic.right(), context); + if (left instanceof NullLiteral || right instanceof NullLiteral) { + return Literal.of(null); + } + if (isAllLiteral(left, right)) { + return ExpressionEvaluator.INSTANCE.eval(binaryArithmetic.withChildren(left, right)); + } + return binaryArithmetic.withChildren(left, right); + } + + @Override + public Expression visitCaseWhen(CaseWhen caseWhen, ExpressionRewriteContext context) { + Expression newDefault = null; + boolean foundNewDefault = false; + + List<WhenClause> whenClauses = new ArrayList<>(); + for (WhenClause whenClause : caseWhen.getWhenClauses()) { + Expression whenOperand = rewrite(whenClause.getOperand(), context); + + if (!(whenOperand instanceof Literal)) { + whenClauses.add(new WhenClause(whenOperand, rewrite(whenClause.getResult(), context))); + } else if (BooleanLiteral.TRUE.equals(whenOperand)) { + foundNewDefault = true; + newDefault = rewrite(whenClause.getResult(), context); + break; + } + } + + Expression defaultResult; + if (foundNewDefault) { + defaultResult = newDefault; + } else { + defaultResult = rewrite(caseWhen.getDefaultValue().orElse(Literal.of(null)), context); + } + + if (whenClauses.isEmpty()) { + return defaultResult; + } + return new CaseWhen(whenClauses, defaultResult); + } + + @Override + public Expression visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { + List<Expression> newChildren = inPredicate.children().stream().map(c -> rewrite(c, context)) + .collect(Collectors.toList()); + if (newChildren.get(0) instanceof NullLiteral) { + return Literal.of(null); + } + if (isAllLiteral(newChildren)) { + Literal c0 = (Literal) newChildren.get(0); + for (int i = 1; i < newChildren.size(); i++) { + if (c0.compareTo((Literal) newChildren.get(i)) == 0) { + return Literal.of(true); + } + } + return Literal.of(false); + } + return inPredicate.withChildren(newChildren); + } + + @Override + public Expression visitIsNull(IsNull isNull, ExpressionRewriteContext context) { + Expression child = rewrite(isNull.child(), context); + if (child instanceof NullLiteral) { + return Literal.of(true); + } else if (child instanceof Literal) { + return Literal.of(false); + } + return isNull.withChildren(child); + } + + @Override + public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, ExpressionRewriteContext context) { + Expression left = rewrite(arithmetic.child(0), context); + Expression right = rewrite(arithmetic.child(1), context); + if (isAllLiteral(left, right)) { + return ExpressionEvaluator.INSTANCE.eval(arithmetic.withChildren(left, right)); + } + return arithmetic.withChildren(left, right); + } + + private Expression foldByBe(Expression root, ExpressionRewriteContext context) { + if (root.isConstant() && !root.isLiteral()) { + Expr expr = ExpressionTranslator.INSTANCE.translate(root, null); + assignId(expr); + Map<String, Expr> ori = new HashMap<>(); + ori.put(expr.getId().toString(), expr); + + Map<String, Map<String, TExpr>> paramMap = new HashMap<>(); + Map<String, TExpr> constMap = new HashMap<>(); + + collectConst(expr, constMap); + paramMap.put("0", constMap); Review Comment: This is just a token, so it doesn't make sense, but it does get confusing, so use the ID of the current expression instead -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org