This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch fixing-java-client
in repository https://gitbox.apache.org/repos/asf/pinot.git

commit 05d71f27acb75e6b1ab045aa8ce8c380c4d17025
Author: Xiang Fu <[email protected]>
AuthorDate: Thu Jul 31 17:25:38 2025 -0700

    Fix CalciteSQLParser error for multi-stage queries with SET statements
    
    - Extract table name resolution logic to dedicated TableNameExtractor class
    - Implement proper handling of multi-stage queries using Calcite AST 
traversal
    - Support complex SQL features: CTEs, JOINs, subqueries, aliases, ORDER BY
    - Add comprehensive test suite with 20 test cases covering all scenarios
    - Resolve ClassCastException when parsing queries with SET statements
    
    Fixes #11823
---
 .../java/org/apache/pinot/client/Connection.java   |  28 +-
 .../apache/pinot/client/TableNameExtractor.java    | 410 +++++++++++++++++++++
 .../apache/pinot/client/grpc/GrpcConnection.java   |   4 +-
 .../pinot/client/TableNameExtractorTest.java       | 291 +++++++++++++++
 4 files changed, 705 insertions(+), 28 deletions(-)

diff --git 
a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java
 
b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java
index cc61f4591e..536fa909d4 100644
--- 
a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java
+++ 
b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java
@@ -21,12 +21,8 @@ package org.apache.pinot.client;
 import com.google.common.collect.Iterables;
 import java.util.List;
 import java.util.Properties;
-import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import javax.annotation.Nullable;
-import org.apache.pinot.common.utils.request.RequestUtils;
-import org.apache.pinot.sql.parsers.CalciteSqlParser;
-import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -108,7 +104,7 @@ public class Connection {
   public ResultSetGroup execute(@Nullable Iterable<String> tableNames, String 
query)
       throws PinotClientException {
     String[] resultTableNames =
-        (tableNames == null) ? resolveTableName(query) : 
Iterables.toArray(tableNames, String.class);
+        (tableNames == null) ? TableNameExtractor.resolveTableName(query) : 
Iterables.toArray(tableNames, String.class);
     String brokerHostPort = _brokerSelector.selectBroker(resultTableNames);
     if (brokerHostPort == null) {
       throw new PinotClientException("Could not find broker to query " + 
((tableNames == null) ? "with no tables"
@@ -157,7 +153,7 @@ public class Connection {
   public CompletableFuture<ResultSetGroup> executeAsync(@Nullable 
Iterable<String> tableNames, String query)
       throws PinotClientException {
     String[] resultTableNames =
-        (tableNames == null) ? resolveTableName(query) : 
Iterables.toArray(tableNames, String.class);
+        (tableNames == null) ? TableNameExtractor.resolveTableName(query) : 
Iterables.toArray(tableNames, String.class);
     String brokerHostPort = _brokerSelector.selectBroker(resultTableNames);
     if (brokerHostPort == null) {
       throw new PinotClientException("Could not find broker to query for 
statement: " + query);
@@ -165,26 +161,6 @@ public class Connection {
     return _transport.executeQueryAsync(brokerHostPort, 
query).thenApply(ResultSetGroup::new);
   }
 
-  /**
-   * Returns the name of all the tables used in a sql query.
-   *
-   * @return name of all the tables used in a sql query.
-   */
-  @Nullable
-  public static String[] resolveTableName(String query) {
-    try {
-      SqlNodeAndOptions sqlNodeAndOptions = 
CalciteSqlParser.compileToSqlNodeAndOptions(query);
-      Set<String> tableNames =
-          
RequestUtils.getTableNames(CalciteSqlParser.compileSqlNodeToPinotQuery(sqlNodeAndOptions.getSqlNode()));
-      if (tableNames != null) {
-        return tableNames.toArray(new String[0]);
-      }
-    } catch (Exception e) {
-      LOGGER.error("Cannot parse table name from query: {}. Fallback to broker 
selector default.", query, e);
-    }
-    return null;
-  }
-
   /**
    * Returns the list of brokers to which this connection can connect to.
    *
diff --git 
a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java
 
b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java
new file mode 100644
index 0000000000..3987139bd3
--- /dev/null
+++ 
b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java
@@ -0,0 +1,410 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.client;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+import javax.annotation.Nullable;
+import org.apache.calcite.sql.SqlBasicCall;
+import org.apache.calcite.sql.SqlIdentifier;
+import org.apache.calcite.sql.SqlJoin;
+import org.apache.calcite.sql.SqlNode;
+import org.apache.calcite.sql.SqlNodeList;
+import org.apache.calcite.sql.SqlOrderBy;
+import org.apache.calcite.sql.SqlSelect;
+import org.apache.calcite.sql.SqlWith;
+import org.apache.calcite.sql.SqlWithItem;
+import org.apache.pinot.sql.parsers.CalciteSqlParser;
+import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import java.util.HashMap;
+import java.util.Map;
+import java.lang.reflect.Field;
+/**
+ * Helper class to extract table names from Calcite SqlNode tree.
+ */
+public class TableNameExtractor {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(TableNameExtractor.class);
+  // Static map of reserved SQL keywords loaded from config file
+  private static final Map<String, Boolean> RESERVED_KEYWORDS = 
loadReservedKeywords();
+  /**
+   * Returns the name of all the tables used in a sql query.
+   *
+   * @param query The SQL query string to analyze
+   * @return name of all the tables used in a sql query, or null if parsing 
fails
+   */
+  @Nullable
+  public static String[] resolveTableName(String query) {
+    SqlNodeAndOptions sqlNodeAndOptions;
+    try {
+      sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(query);
+    } catch (Exception e) {
+      LOGGER.error("Cannot parse table name from query: {}. Fallback to broker 
selector default.", query, e);
+      return null;
+    }
+    try {
+      Set<String> tableNames = 
extractTableNamesFromMultiStageQuery(sqlNodeAndOptions.getSqlNode());
+      if (tableNames != null) {
+        return tableNames.toArray(new String[0]);
+      }
+    } catch (Exception e) {
+      LOGGER.error("Cannot extract table name from query: {}. Fallback to 
broker selector default.", query, e);
+    }
+    return null;
+  }
+  /**
+   * Extracts table names from a multi-stage query using Calcite SQL AST 
traversal.
+   *
+   * @param sqlNode The root SqlNode of the parsed query
+   * @return Set of table names found in the query
+   */
+  private static Set<String> extractTableNamesFromMultiStageQuery(SqlNode 
sqlNode) {
+    TableNameExtractor extractor = new TableNameExtractor();
+    try {
+      extractor.extractTableNames(sqlNode);
+      return extractor.getTableNames();
+    } catch (Exception e) {
+      LOGGER.debug("Failed to extract table names from multi-stage query", e);
+      return Collections.emptySet();
+    }
+  }
+  private final Set<String> _tableNames = new HashSet<>();
+  private final Set<String> _cteNames = new HashSet<>();
+  private boolean _inFromClause = false;
+  public Set<String> getTableNames() {
+    return _tableNames;
+  }
+  public void extractTableNames(SqlNode node) {
+    if (node == null) {
+      return;
+    }
+    if (node instanceof SqlWith) {
+      visitWith((SqlWith) node);
+    } else if (node instanceof SqlOrderBy) {
+      visitOrderBy((SqlOrderBy) node);
+    } else if (node instanceof SqlWithItem) {
+      visitWithItem((SqlWithItem) node);
+    } else if (node instanceof SqlSelect) {
+      visitSelect((SqlSelect) node);
+    } else if (node instanceof SqlJoin) {
+      visitJoin((SqlJoin) node);
+    } else if (node instanceof SqlBasicCall) {
+      visitBasicCall((SqlBasicCall) node);
+    } else if (node instanceof SqlIdentifier) {
+      visitIdentifier((SqlIdentifier) node);
+    } else if (node instanceof SqlNodeList) {
+      visitNodeList((SqlNodeList) node);
+    } else {
+      // Handle unknown node types by trying to access operands
+      visitUnknownNode(node);
+    }
+  }
+  private void visitWith(SqlWith with) {
+    // Visit the WITH list (CTE definitions)
+    if (with.withList != null) {
+      visitNodeList(with.withList);
+    }
+    // Visit the main query body
+    if (with.body != null) {
+      extractTableNames(with.body);
+    }
+  }
+  private void visitOrderBy(SqlOrderBy orderBy) {
+    // Visit the main query - this is the most important part
+    if (orderBy.query != null) {
+      extractTableNames(orderBy.query);
+    }
+    // Visit ORDER BY expressions for potential subqueries
+    if (orderBy.orderList != null) {
+      // Don't set inFromClause=true for ORDER BY expressions
+      // as they typically contain column references, not table names
+      visitNodeList(orderBy.orderList);
+    }
+    // Visit OFFSET clause if it contains subqueries (rare but possible)
+    if (orderBy.offset != null) {
+      extractTableNames(orderBy.offset);
+    }
+    // Visit FETCH/LIMIT clause if it contains subqueries (rare but possible)
+    if (orderBy.fetch != null) {
+      extractTableNames(orderBy.fetch);
+    }
+  }
+  private void visitWithItem(SqlWithItem withItem) {
+    // Track the CTE name so we don't treat it as a table later
+    if (withItem.name != null) {
+      String cteName = withItem.name.getSimple();
+      _cteNames.add(cteName);
+    }
+    // Extract table names from the CTE query definition, not the CTE alias
+    if (withItem.query != null) {
+      extractTableNames(withItem.query);
+    }
+  }
+  private void visitSelect(SqlSelect select) {
+    // Visit FROM clause - this is where we expect to find table names
+    if (select.getFrom() != null) {
+      _inFromClause = true;
+      extractTableNames(select.getFrom());
+      _inFromClause = false;
+    }
+    // Visit other clauses for subqueries
+    if (select.getWhere() != null) {
+      extractTableNames(select.getWhere());
+    }
+    if (select.getGroup() != null) {
+      visitNodeList(select.getGroup());
+    }
+    if (select.getHaving() != null) {
+      extractTableNames(select.getHaving());
+    }
+    if (select.getOrderList() != null) {
+      visitNodeList(select.getOrderList());
+    }
+    if (select.getSelectList() != null) {
+      visitNodeList(select.getSelectList());
+    }
+  }
+  private void visitJoin(SqlJoin join) {
+    // Visit both sides of the join - ensure they're processed as FROM clause 
items
+    boolean wasInFromClause = _inFromClause;
+    if (join.getLeft() != null) {
+      _inFromClause = true;
+      extractTableNames(join.getLeft());
+    }
+    if (join.getRight() != null) {
+      _inFromClause = true;
+      extractTableNames(join.getRight());
+    }
+    // Visit join condition but not as part of FROM clause context
+    // This handles potential subqueries in join conditions while avoiding
+    // incorrectly extracting column references as table names
+    if (join.getCondition() != null) {
+      _inFromClause = false;
+      extractTableNames(join.getCondition());
+    }
+    // Restore original context
+    _inFromClause = wasInFromClause;
+  }
+  private void visitBasicCall(SqlBasicCall call) {
+    String operatorName = call.getOperator().getName().toUpperCase();
+    if (operatorName.equals("AS")) {
+      // Handle table aliases like "tableA AS a"
+      // For AS operations, the first operand is the actual table name
+      if (call.getOperandList().size() > 0 && call.getOperandList().get(0) != 
null) {
+        extractTableNames(call.getOperandList().get(0));
+      }
+    } else if (operatorName.equals("WITH")) {
+      // Handle CTE (Common Table Expression)
+      visitWithClause(call);
+    } else if (operatorName.equals("VALUES")) {
+      // Handle VALUES clause - usually doesn't contain table references
+      // Skip this to avoid false positives
+    } else {
+      // For other basic calls, visit all operands
+      for (SqlNode operand : call.getOperandList()) {
+        if (operand != null) {
+          extractTableNames(operand);
+        }
+      }
+    }
+  }
+  private void visitIdentifier(SqlIdentifier identifier) {
+    // Only extract table names when we're in a FROM clause
+    if (_inFromClause && identifier.names.size() >= 1) {
+      String tableName = identifier.names.get(identifier.names.size() - 1);
+      // Filter out SQL keywords, system identifiers, and CTE names
+      if (!isReservedKeyword(tableName) && !tableName.startsWith("$") && 
!_cteNames.contains(tableName)) {
+        _tableNames.add(tableName);
+      }
+    }
+  }
+  /**
+   * Visit a SqlNodeList by visiting each node in the list.
+   */
+  private void visitNodeList(SqlNodeList nodeList) {
+    if (nodeList != null) {
+      for (SqlNode node : nodeList) {
+        if (node != null) {
+          extractTableNames(node);
+        }
+      }
+    }
+  }
+  /**
+   * Handle unknown node types by attempting to visit their operands.
+   */
+  private void visitUnknownNode(SqlNode node) {
+    try {
+      // Try to get operands list using reflection or common methods
+      if (node.getKind() != null) {
+        switch (node.getKind().name()) {
+          case "WITH":
+            visitWithClause(node);
+            break;
+          case "ORDER_BY":
+            visitOrderByCall(node);
+            break;
+          default:
+            // For other unknown nodes, try to visit operands if they exist
+            visitNodeOperands(node);
+            break;
+        }
+      } else {
+        visitNodeOperands(node);
+      }
+    } catch (Exception e) {
+      // Ignore reflection errors and continue
+    }
+  }
+  /**
+   * Handle WITH clause (CTE - Common Table Expression).
+   */
+  private void visitWithClause(SqlNode node) {
+    try {
+      // WITH clause typically has operands: [with_list, query]
+      if (node instanceof SqlBasicCall) {
+        SqlBasicCall withCall = (SqlBasicCall) node;
+        for (SqlNode operand : withCall.getOperandList()) {
+          if (operand != null) {
+            extractTableNames(operand);
+          }
+        }
+      }
+    } catch (Exception e) {
+      // Fallback to generic operand handling
+      visitNodeOperands(node);
+    }
+  }
+  /**
+   * Handle ORDER BY clause - this method is now replaced by 
visitOrderBy(SqlOrderBy).
+   * Keeping for backward compatibility with visitUnknownNode.
+   */
+  private void visitOrderByCall(SqlNode node) {
+    try {
+      if (node instanceof SqlBasicCall) {
+        SqlBasicCall orderByCall = (SqlBasicCall) node;
+        // ORDER BY typically has [query, order_list]
+        for (SqlNode operand : orderByCall.getOperandList()) {
+          if (operand != null) {
+            extractTableNames(operand);
+          }
+        }
+      }
+    } catch (Exception e) {
+      visitNodeOperands(node);
+    }
+  }
+  /**
+   * Generic method to visit node operands when specific handling is not 
available.
+   */
+  private void visitNodeOperands(SqlNode node) {
+    try {
+      // Try to access operands through common interface
+      if (node instanceof SqlBasicCall) {
+        SqlBasicCall call = (SqlBasicCall) node;
+        for (SqlNode operand : call.getOperandList()) {
+          if (operand != null) {
+            extractTableNames(operand);
+          }
+        }
+      }
+    } catch (Exception e) {
+      // Nothing more we can do
+    }
+  }
+  /**
+   * Check if the given name is a reserved SQL keyword that shouldn't be 
treated as a table name.
+   */
+  private boolean isReservedKeyword(String name) {
+    if (name == null) {
+      return true;
+    }
+    String upperName = name.toUpperCase();
+    return RESERVED_KEYWORDS.containsKey(upperName);
+  }
+  /**
+   * Load reserved SQL keywords from the SqlParserImplConstants.
+   * This method uses the generated constants from the parser to get all 
reserved keywords.
+   */
+  private static Map<String, Boolean> loadReservedKeywords() {
+    Map<String, Boolean> reservedKeywords = new HashMap<>();
+    try {
+      // Use reflection to access SqlParserImplConstants.tokenImage
+      Class<?> constantsClass = 
Class.forName("org.apache.pinot.sql.parsers.parser.SqlParserImplConstants");
+      Field tokenImageField = constantsClass.getField("tokenImage");
+      String[] tokenImage = (String[]) tokenImageField.get(null);
+
+      // Process each token to extract reserved keywords
+      for (String token : tokenImage) {
+        // Skip tokens that are not keywords (like literals, operators, etc.)
+        if (token.startsWith("\"") && token.endsWith("\"") && 
!token.startsWith("\"<")) {
+          // Extract the keyword without quotes
+          String keyword = token.substring(1, token.length() - 1);
+          // Skip single character tokens and operators
+          if (keyword.length() > 1 && !isOperator(keyword)) {
+            reservedKeywords.put(keyword, true);
+          }
+        }
+      }
+      LOGGER.debug("Loaded {} reserved keywords from SqlParserImplConstants", 
reservedKeywords.size());
+    } catch (Exception e) {
+      LOGGER.warn("Failed to load reserved keywords from 
SqlParserImplConstants, using fallback set", e);
+      // Fall back to essential reserved keywords
+      addFallbackReservedKeywords(reservedKeywords);
+    }
+    return Collections.unmodifiableMap(reservedKeywords);
+  }
+
+  /**
+   * Check if a token is an operator (not a keyword).
+   */
+  private static boolean isOperator(String token) {
+    // Common SQL operators that should not be treated as reserved keywords
+    String[] operators = {
+        "=", ">", "<", ">=", "<=", "<>", "!=", "+", "-", "*", "/", "%", "||", 
"->", "..", "(", ")",
+        "{", "}", "[", "]", ";", ".", ",", "?", ":", "|", "^", "$", "/*", 
"*/", "/*+"
+    };
+    for (String op : operators) {
+      if (op.equals(token)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  /**
+   * Add fallback reserved keywords in case the constants loading fails.
+   */
+  private static void addFallbackReservedKeywords(Map<String, Boolean> 
reservedKeywords) {
+    String[] fallbackKeywords = {
+        "SELECT", "FROM", "WHERE", "GROUP", "ORDER", "BY", "HAVING", "JOIN", 
"INNER", "LEFT", "RIGHT", "OUTER",
+        "ON", "AS", "AND", "OR", "NOT", "IN", "LIMIT", "OFFSET", "UNION", 
"ALL", "DISTINCT", "COUNT",
+        "SUM", "AVG", "MIN", "MAX", "CASE", "WHEN", "THEN", "ELSE", "END", 
"CREATE", "DROP", "ALTER",
+        "INSERT", "UPDATE", "DELETE", "TABLE", "INDEX", "VIEW", "SCHEMA", 
"DATABASE", "CASCADE",
+        "RESTRICT", "PRIMARY", "FOREIGN", "KEY", "CONSTRAINT", "UNIQUE", 
"NULL", "DEFAULT",
+        "CHECK", "REFERENCES", "SET", "VALUES", "WITH", "RECURSIVE", "EXISTS", 
"BETWEEN",
+        "LIKE", "IS", "TRUE", "FALSE", "UNKNOWN", "CAST", "CONVERT", "TRIM", 
"SUBSTRING",
+        "UPPER", "LOWER", "LENGTH", "CHAR_LENGTH", "POSITION", "EXTRACT"
+    };
+    for (String keyword : fallbackKeywords) {
+      reservedKeywords.put(keyword, true);
+    }
+  }
+}
diff --git 
a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java
 
b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java
index e30b49979e..ee1ccf6b3f 100644
--- 
a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java
+++ 
b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java
@@ -32,10 +32,10 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import org.apache.pinot.client.BrokerResponse;
 import org.apache.pinot.client.BrokerSelector;
-import org.apache.pinot.client.Connection;
 import org.apache.pinot.client.PinotClientException;
 import org.apache.pinot.client.ResultSetGroup;
 import org.apache.pinot.client.SimpleBrokerSelector;
+import org.apache.pinot.client.TableNameExtractor;
 import org.apache.pinot.common.config.GrpcConfig;
 import org.apache.pinot.common.proto.Broker;
 import org.apache.pinot.common.utils.DataSchema;
@@ -233,7 +233,7 @@ public class GrpcConnection implements AutoCloseable {
    */
   public Iterator<Broker.BrokerResponse> executeWithIterator(String query, 
Map<String, String> metadata)
       throws PinotClientException {
-    String[] tableNames = Connection.resolveTableName(query);
+    String[] tableNames = TableNameExtractor.resolveTableName(query);
     String brokerHostPort = _brokerSelector.selectBroker(tableNames);
     if (brokerHostPort == null) {
       throw new PinotClientException("Could not find broker to query " + 
((tableNames == null) ? "with no tables"
diff --git 
a/pinot-clients/pinot-java-client/src/test/java/org/apache/pinot/client/TableNameExtractorTest.java
 
b/pinot-clients/pinot-java-client/src/test/java/org/apache/pinot/client/TableNameExtractorTest.java
new file mode 100644
index 0000000000..5319d78610
--- /dev/null
+++ 
b/pinot-clients/pinot-java-client/src/test/java/org/apache/pinot/client/TableNameExtractorTest.java
@@ -0,0 +1,291 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.client;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertNull;
+import static org.testng.Assert.assertTrue;
+
+
+/**
+ * Tests for the TableNameExtractor class.
+ */
+public class TableNameExtractorTest {
+
+  @Test
+  public void testResolveTableNameWithSingleQuery() {
+    // Test that single queries work correctly
+    String singleQuery = "SELECT * FROM myTable WHERE id > 100";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(singleQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 1, "Should resolve exactly one table");
+    assertEquals(tableNames[0], "myTable", "Should resolve the correct table 
name");
+  }
+
+  @Test
+  public void testResolveTableNameWithSingleStatementAlias() {
+    String singleStatementQuery = "SELECT stats.* FROM airlineStats stats 
LIMIT 10";
+    String[] tableNames = 
TableNameExtractor.resolveTableName(singleStatementQuery);
+
+    assertNotNull(tableNames);
+    assertEquals(tableNames.length, 1);
+    assertEquals(tableNames[0], "airlineStats");
+  }
+
+  @Test
+  public void testResolveTableNameWithMultiStatementQuery() {
+    // Test the fix for issue #11823: CalciteSQLParser error with 
multi-statement queries
+    String multiStatementQuery = "SET useMultistageEngine=true;\nSELECT 
stats.* FROM airlineStats stats LIMIT 10";
+
+    // This should not throw a ClassCastException anymore
+    String[] tableNames = 
TableNameExtractor.resolveTableName(multiStatementQuery);
+
+    // Should successfully resolve the table name
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 1, "Should resolve exactly one table");
+    assertEquals(tableNames[0], "airlineStats", "Should resolve the correct 
table name");
+  }
+
+  @Test
+  public void testResolveTableNameWithMultipleSetStatements() {
+    // Test with multiple SET statements
+    String multiSetQuery = "SET useMultistageEngine=true;\nSET 
timeoutMs=10000;\nSELECT * FROM testTable";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(multiSetQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 1, "Should resolve exactly one table");
+    assertEquals(tableNames[0], "testTable", "Should resolve the correct table 
name");
+  }
+
+  @Test
+  public void testResolveTableNameWithMultipleSetStatementsAndJoin() {
+    String multiStatementQuery = "SET useMultistageEngine=true;\nSET 
maxRowsInJoin=1000;\n"
+        + "SELECT stats.* FROM airlineStats stats LIMIT 10";
+    String[] tableNames = 
TableNameExtractor.resolveTableName(multiStatementQuery);
+
+    assertNotNull(tableNames, "Table names should be resolved for queries with 
multiple SET statements");
+    assertEquals(tableNames.length, 1);
+    assertEquals(tableNames[0], "airlineStats");
+  }
+
+  @Test
+  public void testResolveTableNameWithJoin() {
+    // Test with JOIN queries
+    String joinQuery = "SELECT * FROM table1 t1 JOIN table2 t2 ON t1.id = 
t2.id";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(joinQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 2, "Should resolve two tables");
+    assertTrue(Arrays.asList(tableNames).contains("table1"), "Should contain 
table1");
+    assertTrue(Arrays.asList(tableNames).contains("table2"), "Should contain 
table2");
+  }
+
+  @Test
+  public void testResolveTableNameWithJoinQueryAndSetStatements() {
+    String joinQuery = "SET useMultistageEngine=true;\n"
+        + "SELECT a.col1, b.col2 FROM tableA a JOIN tableB b ON a.id = b.id";
+    String[] tableNames = TableNameExtractor.resolveTableName(joinQuery);
+
+    assertNotNull(tableNames, "Table names should be resolved for join queries 
with SET statements");
+    assertEquals(tableNames.length, 2);
+
+    Set<String> expectedTableNames = new HashSet<>(Arrays.asList("tableA", 
"tableB"));
+    Set<String> actualTableNames = new HashSet<>(Arrays.asList(tableNames));
+    assertEquals(actualTableNames, expectedTableNames);
+  }
+
+  @Test
+  public void testResolveTableNameWithExplicitAlias() {
+    // Test with explicit AS alias
+    String aliasQuery = "SELECT u.name FROM users AS u WHERE u.active = true";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(aliasQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 1, "Should resolve exactly one table");
+    assertEquals(tableNames[0], "users", "Should resolve the actual table 
name, not the alias");
+  }
+
+  @Test
+  public void testResolveTableNameWithImplicitAlias() {
+    // Test with implicit alias (no AS keyword)
+    String implicitAliasQuery = "SELECT o.id, u.name FROM orders o JOIN users 
u ON o.user_id = u.id";
+
+    String[] tableNames = 
TableNameExtractor.resolveTableName(implicitAliasQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 2, "Should resolve two tables");
+    assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain 
orders table");
+    assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain 
users table");
+  }
+
+  @Test
+  public void testResolveTableNameWithCTE() {
+    // Test with Common Table Expression (CTE)
+    String cteQuery = "WITH active_users AS (SELECT * FROM users WHERE active 
= true) "
+        + "SELECT au.name FROM active_users au JOIN orders o ON au.id = 
o.user_id";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(cteQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 2, "Should resolve two tables");
+    assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain 
users table from CTE");
+    assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain 
orders table");
+  }
+
+  @Test
+  public void testResolveTableNameWithNestedCTE() {
+    // Test with nested CTEs
+    String nestedCteQuery = "WITH user_orders AS ("
+        + "  SELECT u.id, u.name, o.order_date "
+        + "  FROM users u JOIN orders o ON u.id = o.user_id"
+        + "), recent_orders AS ("
+        + "  SELECT * FROM user_orders WHERE order_date > '2023-01-01'"
+        + ") "
+        + "SELECT ro.name FROM recent_orders ro JOIN products p ON ro.id = 
p.user_id";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(nestedCteQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 3, "Should resolve three tables");
+    assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain 
users table");
+    assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain 
orders table");
+    assertTrue(Arrays.asList(tableNames).contains("products"), "Should contain 
products table");
+  }
+
+  @Test
+  public void testResolveTableNameWithSubqueryAlias() {
+    // Test with subquery alias
+    String subqueryQuery = "SELECT t.name FROM (SELECT * FROM users WHERE 
active = true) AS t "
+        + "JOIN orders o ON t.id = o.user_id";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(subqueryQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 2, "Should resolve two tables");
+    assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain 
users table from subquery");
+    assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain 
orders table");
+  }
+
+  @Test
+  public void testResolveTableNameWithComplexJoinAndAliases() {
+    // Test with multiple JOINs and various alias styles
+    String complexQuery = "SELECT u.name, o.total, p.title "
+        + "FROM users AS u "
+        + "INNER JOIN orders o ON u.id = o.user_id "
+        + "LEFT JOIN order_items oi ON o.id = oi.order_id "
+        + "RIGHT JOIN products AS p ON oi.product_id = p.id "
+        + "WHERE u.active = true";
+
+    String[] tableNames = TableNameExtractor.resolveTableName(complexQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 4, "Should resolve four tables");
+    assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain 
users table");
+    assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain 
orders table");
+    assertTrue(Arrays.asList(tableNames).contains("order_items"), "Should 
contain order_items table");
+    assertTrue(Arrays.asList(tableNames).contains("products"), "Should contain 
products table");
+  }
+
+  @Test
+  public void testResolveTableNameWithJoinConditionSubquery() {
+    // Test with subquery in join condition
+    String joinSubqueryQuery = "SELECT u.name, o.total "
+        + "FROM users u "
+        + "JOIN orders o ON u.id = o.user_id "
+        + "AND o.id IN (SELECT order_id FROM order_items WHERE quantity > 5)";
+
+    String[] tableNames = 
TableNameExtractor.resolveTableName(joinSubqueryQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 3, "Should resolve three tables");
+    assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain 
users table");
+    assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain 
orders table");
+    assertTrue(Arrays.asList(tableNames).contains("order_items"),
+        "Should contain order_items table from subquery");
+  }
+
+  @Test
+  public void testResolveTableNameWithOrderBy() {
+    // Test with ORDER BY clause
+    String orderByQuery = "SELECT * FROM users ORDER BY name";
+    String[] tableNames = TableNameExtractor.resolveTableName(orderByQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 1, "Should resolve exactly one table");
+    assertEquals(tableNames[0], "users", "Should resolve the correct table 
name");
+  }
+
+  @Test
+  public void testResolveTableNameWithOrderBySubquery() {
+    // Test with subquery in ORDER BY clause (rare but possible)
+    String orderBySubqueryQuery = "SELECT * FROM users u ORDER BY "
+        + "(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id)";
+    String[] tableNames = 
TableNameExtractor.resolveTableName(orderBySubqueryQuery);
+
+    assertNotNull(tableNames, "Table names should not be null");
+    assertEquals(tableNames.length, 2, "Should resolve two tables");
+    assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain 
users table");
+    assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain 
orders table");
+  }
+
+  @Test
+  public void testResolveTableNameWithInvalidQuery() {
+    String invalidQuery = "INVALID SQL QUERY";
+    String[] tableNames = TableNameExtractor.resolveTableName(invalidQuery);
+
+    // Should return null when query cannot be parsed (fallback to default 
broker selector)
+    assertNull(tableNames);
+  }
+
+  @Test
+  public void testResolveTableNameWithOnlySetStatements() {
+    String onlySetQuery = "SET useMultistageEngine=true;";
+    String[] tableNames = TableNameExtractor.resolveTableName(onlySetQuery);
+
+    // Should return null when there's no actual query statement
+    assertNull(tableNames);
+  }
+
+  @Test
+  public void testResolveTableNameWithNullQuery() {
+    String[] tableNames = TableNameExtractor.resolveTableName(null);
+
+    // Should return null when query is null
+    assertNull(tableNames);
+  }
+
+  @Test
+  public void testResolveTableNameWithEmptyQuery() {
+    String[] tableNames = TableNameExtractor.resolveTableName("");
+
+    // Should return null when query is empty
+    assertNull(tableNames);
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to