This is an automated email from the ASF dual-hosted git repository.

eldenmoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 820b300300d branch-3.0: [Fix](ShortCircuit) fix prepared statement 
with partial arguments prepared #45371 (#45465)
820b300300d is described below

commit 820b300300d5f63213fd6aa926b107afa3cb68c1
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Dec 17 09:46:24 2024 +0800

    branch-3.0: [Fix](ShortCircuit) fix prepared statement with partial 
arguments prepared #45371 (#45465)
    
    Cherry-picked from #45371
    
    Co-authored-by: lihangyu <lihan...@selectdb.com>
---
 .../org/apache/doris/nereids/StatementContext.java |  6 ++
 .../nereids/rules/analysis/ExpressionAnalyzer.java | 21 +++++
 .../org/apache/doris/qe/PointQueryExecutor.java    | 40 +++++++---
 .../data/point_query_p0/test_point_query.out       | 30 +++++++
 .../suites/point_query_p0/test_point_query.groovy  | 92 +++++++++++++++++-----
 5 files changed, 157 insertions(+), 32 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
index 175b623467a..cd11b3228b9 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java
@@ -129,6 +129,8 @@ public class StatementContext implements Closeable {
     private final IdGenerator<PlaceholderId> placeHolderIdGenerator = 
PlaceholderId.createGenerator();
     // relation id to placeholders for prepared statement, ordered by 
placeholder id
     private final Map<PlaceholderId, Expression> idToPlaceholderRealExpr = new 
TreeMap<>();
+    // map placeholder id to comparison slot, which will used to replace 
conjuncts directly
+    private final Map<PlaceholderId, SlotReference> idToComparisonSlot = new 
TreeMap<>();
 
     // collect all hash join conditions to compute node connectivity in join 
graph
     private final List<Expression> joinFilters = new ArrayList<>();
@@ -367,6 +369,10 @@ public class StatementContext implements Closeable {
         return idToPlaceholderRealExpr;
     }
 
+    public Map<PlaceholderId, SlotReference> getIdToComparisonSlot() {
+        return idToComparisonSlot;
+    }
+
     public Map<CTEId, List<Pair<Multimap<Slot, Slot>, Group>>> 
getCteIdToConsumerGroup() {
         return cteIdToConsumerGroup;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
index 49789aa66e1..5ef3d0fbff3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java
@@ -24,6 +24,7 @@ import org.apache.doris.catalog.FunctionRegistry;
 import org.apache.doris.common.DdlException;
 import org.apache.doris.common.Pair;
 import org.apache.doris.common.util.Util;
+import org.apache.doris.mysql.MysqlCommand;
 import org.apache.doris.nereids.CascadesContext;
 import org.apache.doris.nereids.SqlCacheContext;
 import org.apache.doris.nereids.StatementContext;
@@ -75,6 +76,7 @@ import 
org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
 import 
org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
+import org.apache.doris.nereids.trees.plans.PlaceholderId;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
 import org.apache.doris.nereids.types.ArrayType;
@@ -531,10 +533,29 @@ public class ExpressionAnalyzer extends 
SubExprAnalyzer<ExpressionRewriteContext
         return visit(realExpr, context);
     }
 
+    // Register prepared statement placeholder id to related slot in 
comparison predicate.
+    // Used to replace expression in ShortCircuit plan
+    private void registerPlaceholderIdToSlot(ComparisonPredicate cp,
+                    ExpressionRewriteContext context, Expression left, 
Expression right) {
+        if (ConnectContext.get() != null
+                    && ConnectContext.get().getCommand() == 
MysqlCommand.COM_STMT_EXECUTE) {
+            // Used to replace expression in ShortCircuit plan
+            if (cp.right() instanceof Placeholder && left instanceof 
SlotReference) {
+                PlaceholderId id = ((Placeholder) 
cp.right()).getPlaceholderId();
+                
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, 
(SlotReference) left);
+            } else if (cp.left() instanceof Placeholder && right instanceof 
SlotReference) {
+                PlaceholderId id = ((Placeholder) 
cp.left()).getPlaceholderId();
+                
context.cascadesContext.getStatementContext().getIdToComparisonSlot().put(id, 
(SlotReference) right);
+            }
+        }
+    }
+
     @Override
     public Expression visitComparisonPredicate(ComparisonPredicate cp, 
ExpressionRewriteContext context) {
         Expression left = cp.left().accept(this, context);
         Expression right = cp.right().accept(this, context);
+        // Used to replace expression in ShortCircuit plan
+        registerPlaceholderIdToSlot(cp, context, left, right);
         cp = (ComparisonPredicate) cp.withChildren(left, right);
         return TypeCoercionUtils.processComparisonPredicate(cp);
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
index 9e4030b768b..b1bf3e227f0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/PointQueryExecutor.java
@@ -31,7 +31,9 @@ import org.apache.doris.common.UserException;
 import org.apache.doris.mysql.MysqlCommand;
 import org.apache.doris.nereids.StatementContext;
 import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.PlaceholderId;
 import org.apache.doris.planner.OlapScanNode;
 import org.apache.doris.proto.InternalService;
 import org.apache.doris.proto.InternalService.KeyTuple;
@@ -59,12 +61,12 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Map.Entry;
 import java.util.Set;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
-import java.util.stream.Collectors;
 
 public class PointQueryExecutor implements CoordInterface {
     private static final Logger LOG = 
LogManager.getLogger(PointQueryExecutor.class);
@@ -142,33 +144,45 @@ public class PointQueryExecutor implements CoordInterface 
{
         Preconditions.checkNotNull(preparedStmtCtx.shortCircuitQueryContext);
         ShortCircuitQueryContext shortCircuitQueryContext = 
preparedStmtCtx.shortCircuitQueryContext.get();
         // update conjuncts
-        List<Expr> conjunctVals = 
statementContext.getIdToPlaceholderRealExpr().values().stream().map(
-                        expression -> (
-                                (Literal) expression).toLegacyLiteral())
-                .collect(Collectors.toList());
-        if (conjunctVals.size() != preparedStmtCtx.command.placeholderCount()) 
{
+        Map<String, Expr> colNameToConjunct = Maps.newHashMap();
+        for (Entry<PlaceholderId, SlotReference> entry : 
statementContext.getIdToComparisonSlot().entrySet()) {
+            String colName = entry.getValue().getColumn().get().getName();
+            Expr conjunctVal = ((Literal)  
statementContext.getIdToPlaceholderRealExpr()
+                    .get(entry.getKey())).toLegacyLiteral();
+            colNameToConjunct.put(colName, conjunctVal);
+        }
+        if (colNameToConjunct.size() != 
preparedStmtCtx.command.placeholderCount()) {
             throw new AnalysisException("Mismatched conjuncts values size with 
prepared"
                     + "statement parameters size, expected "
                     + preparedStmtCtx.command.placeholderCount()
-                    + ", but meet " + conjunctVals.size());
+                    + ", but meet " + colNameToConjunct.size());
         }
-        updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, 
conjunctVals);
+        updateScanNodeConjuncts(shortCircuitQueryContext.scanNode, 
colNameToConjunct);
         // short circuit plan and execution
         executor.executeAndSendResult(false, false,
                 shortCircuitQueryContext.analzyedQuery, executor.getContext()
                         .getMysqlChannel(), null, null);
     }
 
-    private static void updateScanNodeConjuncts(OlapScanNode scanNode, 
List<Expr> conjunctVals) {
-        for (int i = 0; i < conjunctVals.size(); ++i) {
-            BinaryPredicate binaryPredicate = (BinaryPredicate) 
scanNode.getConjuncts().get(i);
+    private static void updateScanNodeConjuncts(OlapScanNode scanNode,
+                Map<String, Expr> colNameToConjunct) {
+        for (Expr conjunct : scanNode.getConjuncts()) {
+            BinaryPredicate binaryPredicate = (BinaryPredicate) conjunct;
+            SlotRef slot = null;
+            int updateChildIdx = 0;
             if (binaryPredicate.getChild(0) instanceof LiteralExpr) {
-                binaryPredicate.setChild(0, conjunctVals.get(i));
+                slot = (SlotRef) binaryPredicate.getChildWithoutCast(1);
             } else if (binaryPredicate.getChild(1) instanceof LiteralExpr) {
-                binaryPredicate.setChild(1, conjunctVals.get(i));
+                slot = (SlotRef) binaryPredicate.getChildWithoutCast(0);
+                updateChildIdx = 1;
             } else {
                 Preconditions.checkState(false, "Should contains literal in " 
+ binaryPredicate.toSqlImpl());
             }
+            // not a placeholder to replace
+            if (!colNameToConjunct.containsKey(slot.getColumnName())) {
+                continue;
+            }
+            binaryPredicate.setChild(updateChildIdx, 
colNameToConjunct.get(slot.getColumnName()));
         }
     }
 
diff --git a/regression-test/data/point_query_p0/test_point_query.out 
b/regression-test/data/point_query_p0/test_point_query.out
index 1cc4142e39f..55c79757820 100644
--- a/regression-test/data/point_query_p0/test_point_query.out
+++ b/regression-test/data/point_query_p0/test_point_query.out
@@ -160,3 +160,33 @@
 -- !sql --
 -10    20      aabc    update val
 
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
+-- !point_select --
+user_guid      feature sk      feature_value   2021-01-01T00:00
+
diff --git a/regression-test/suites/point_query_p0/test_point_query.groovy 
b/regression-test/suites/point_query_p0/test_point_query.groovy
index f84012a8fd7..0ea879956e3 100644
--- a/regression-test/suites/point_query_p0/test_point_query.groovy
+++ b/regression-test/suites/point_query_p0/test_point_query.groovy
@@ -27,32 +27,30 @@ suite("test_point_query", "nonConcurrent") {
             logger.info("update config: code=" + code + ", out=" + out + ", 
err=" + err)
         }
     }
+    def user = context.config.jdbcUser
+    def password = context.config.jdbcPassword
+    def realDb = "regression_test_serving_p0"
+    // Parse url
+    String jdbcUrl = context.config.jdbcUrl
+    String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
+    def sql_ip = urlWithoutSchema.substring(0, urlWithoutSchema.indexOf(":"))
+    def sql_port
+    if (urlWithoutSchema.indexOf("/") >= 0) {
+        // e.g: jdbc:mysql://locahost:8080/?a=b
+        sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 
1, urlWithoutSchema.indexOf("/"))
+    } else {
+        // e.g: jdbc:mysql://locahost:8080
+        sql_port = urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 
1)
+    }
+    // set server side prepared statement url
+    def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + realDb 
+ "?&useServerPrepStmts=true"
     try {
         set_be_config.call("disable_storage_row_cache", "false")
-        // nereids do not support point query now
         sql "set global enable_fallback_to_original_planner = false"
         sql """set global enable_nereids_planner=true"""
-        def user = context.config.jdbcUser
-        def password = context.config.jdbcPassword
-        def realDb = "regression_test_serving_p0"
         def tableName = realDb + ".tbl_point_query"
         sql "CREATE DATABASE IF NOT EXISTS ${realDb}"
 
-        // Parse url
-        String jdbcUrl = context.config.jdbcUrl
-        String urlWithoutSchema = jdbcUrl.substring(jdbcUrl.indexOf("://") + 3)
-        def sql_ip = urlWithoutSchema.substring(0, 
urlWithoutSchema.indexOf(":"))
-        def sql_port
-        if (urlWithoutSchema.indexOf("/") >= 0) {
-            // e.g: jdbc:mysql://locahost:8080/?a=b
-            sql_port = 
urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1, 
urlWithoutSchema.indexOf("/"))
-        } else {
-            // e.g: jdbc:mysql://locahost:8080
-            sql_port = 
urlWithoutSchema.substring(urlWithoutSchema.indexOf(":") + 1)
-        }
-        // set server side prepared statement url
-        def prepare_url = "jdbc:mysql://" + sql_ip + ":" + sql_port + "/" + 
realDb + "?&useServerPrepStmts=true"
-
         def generateString = {len ->
             def str = ""
             for (int i = 0; i < len; i++) {
@@ -330,4 +328,60 @@ suite("test_point_query", "nonConcurrent") {
     qt_sql "select * from table_3821461 where col1 = 10 and col2 = 20 and loc3 
= 'aabc';"
     sql "update table_3821461 set value = 'update value' where col1 = -10 or 
col1 = 20;"
     qt_sql """select * from table_3821461 where col1 = -10 and col2 = 20 and 
loc3 = 'aabc'"""
+
+    sql "DROP TABLE IF EXISTS test_partial_prepared_statement"
+    sql """
+        CREATE TABLE `test_partial_prepared_statement` (
+          `user_guid` varchar(64) NOT NULL,
+          `feature` varchar(256) NOT NULL,
+          `sk` varchar(256) NOT NULL,
+          `feature_value` text NULL,
+          `data_time` datetime NOT NULL
+        ) ENGINE=OLAP
+        UNIQUE KEY(`user_guid`, `feature`, `sk`)
+        DISTRIBUTED BY HASH(`user_guid`) BUCKETS 32
+        PROPERTIES (
+        "enable_unique_key_merge_on_write" = "true",
+        "light_schema_change" = "true",
+        "function_column.sequence_col" = "data_time",
+        "store_row_column" = "true",
+        "replication_num" = "1",
+        "row_store_page_size" = "16384"
+        );
+    """
+    sql "insert into test_partial_prepared_statement values ('user_guid', 
'feature', 'sk','feature_value', '2021-01-01 00:00:00')"
+    def result2 = connect(user, password, prepare_url) {
+        def partial_prepared_stmt = prepareStatement "select /*+ 
SET_VAR(enable_nereids_planner=true) */ * from 
regression_test_point_query_p0.test_partial_prepared_statement where sk = 'sk' 
and user_guid = 'user_guid' and  feature = ? "
+        assertEquals(partial_prepared_stmt.class, 
com.mysql.cj.jdbc.ServerPreparedStatement);
+        partial_prepared_stmt.setString(1, "feature")
+        qe_point_select partial_prepared_stmt
+        qe_point_select partial_prepared_stmt
+
+        partial_prepared_stmt = prepareStatement "select /*+ 
SET_VAR(enable_nereids_planner=true) */ * from 
regression_test_point_query_p0.test_partial_prepared_statement where user_guid 
= ? and  feature = 'feature' and sk = ?"
+        assertEquals(partial_prepared_stmt.class, 
com.mysql.cj.jdbc.ServerPreparedStatement);
+        partial_prepared_stmt.setString(1, "user_guid")
+        partial_prepared_stmt.setString(2, "sk")
+        qe_point_select partial_prepared_stmt
+        qe_point_select partial_prepared_stmt
+
+        partial_prepared_stmt = prepareStatement "select /*+ 
SET_VAR(enable_nereids_planner=true) */ * from 
regression_test_point_query_p0.test_partial_prepared_statement where ? = 
user_guid and sk = 'sk'  and  feature = 'feature' "
+        assertEquals(partial_prepared_stmt.class, 
com.mysql.cj.jdbc.ServerPreparedStatement);
+        partial_prepared_stmt.setString(1, "user_guid")
+        qe_point_select partial_prepared_stmt
+        qe_point_select partial_prepared_stmt
+
+        partial_prepared_stmt = prepareStatement "select /*+ 
SET_VAR(enable_nereids_planner=true) */ * from 
regression_test_point_query_p0.test_partial_prepared_statement where ? = 
user_guid and sk = 'sk'  and  feature = ? "
+        assertEquals(partial_prepared_stmt.class, 
com.mysql.cj.jdbc.ServerPreparedStatement);
+        partial_prepared_stmt.setString(1, "user_guid")
+        partial_prepared_stmt.setString(2, "feature")
+        qe_point_select partial_prepared_stmt
+        qe_point_select partial_prepared_stmt
+
+        partial_prepared_stmt = prepareStatement "select /*+ 
SET_VAR(enable_nereids_planner=true) */ * from 
regression_test_point_query_p0.test_partial_prepared_statement where  sk = ? 
and  feature = ? and 'user_guid' = user_guid"
+        assertEquals(partial_prepared_stmt.class, 
com.mysql.cj.jdbc.ServerPreparedStatement);
+        partial_prepared_stmt.setString(1, "sk")
+        partial_prepared_stmt.setString(2, "feature")
+        qe_point_select partial_prepared_stmt
+        qe_point_select partial_prepared_stmt
+    }
 } 
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to