This is an automated email from the ASF dual-hosted git repository. xiangfu pushed a commit to branch fixing-java-client in repository https://gitbox.apache.org/repos/asf/pinot.git
commit 05d71f27acb75e6b1ab045aa8ce8c380c4d17025 Author: Xiang Fu <[email protected]> AuthorDate: Thu Jul 31 17:25:38 2025 -0700 Fix CalciteSQLParser error for multi-stage queries with SET statements - Extract table name resolution logic to dedicated TableNameExtractor class - Implement proper handling of multi-stage queries using Calcite AST traversal - Support complex SQL features: CTEs, JOINs, subqueries, aliases, ORDER BY - Add comprehensive test suite with 20 test cases covering all scenarios - Resolve ClassCastException when parsing queries with SET statements Fixes #11823 --- .../java/org/apache/pinot/client/Connection.java | 28 +- .../apache/pinot/client/TableNameExtractor.java | 410 +++++++++++++++++++++ .../apache/pinot/client/grpc/GrpcConnection.java | 4 +- .../pinot/client/TableNameExtractorTest.java | 291 +++++++++++++++ 4 files changed, 705 insertions(+), 28 deletions(-) diff --git a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java index cc61f4591e..536fa909d4 100644 --- a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java +++ b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/Connection.java @@ -21,12 +21,8 @@ package org.apache.pinot.client; import com.google.common.collect.Iterables; import java.util.List; import java.util.Properties; -import java.util.Set; import java.util.concurrent.CompletableFuture; import javax.annotation.Nullable; -import org.apache.pinot.common.utils.request.RequestUtils; -import org.apache.pinot.sql.parsers.CalciteSqlParser; -import org.apache.pinot.sql.parsers.SqlNodeAndOptions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -108,7 +104,7 @@ public class Connection { public ResultSetGroup execute(@Nullable Iterable<String> tableNames, String query) throws PinotClientException { String[] resultTableNames = - (tableNames == null) ? resolveTableName(query) : Iterables.toArray(tableNames, String.class); + (tableNames == null) ? TableNameExtractor.resolveTableName(query) : Iterables.toArray(tableNames, String.class); String brokerHostPort = _brokerSelector.selectBroker(resultTableNames); if (brokerHostPort == null) { throw new PinotClientException("Could not find broker to query " + ((tableNames == null) ? "with no tables" @@ -157,7 +153,7 @@ public class Connection { public CompletableFuture<ResultSetGroup> executeAsync(@Nullable Iterable<String> tableNames, String query) throws PinotClientException { String[] resultTableNames = - (tableNames == null) ? resolveTableName(query) : Iterables.toArray(tableNames, String.class); + (tableNames == null) ? TableNameExtractor.resolveTableName(query) : Iterables.toArray(tableNames, String.class); String brokerHostPort = _brokerSelector.selectBroker(resultTableNames); if (brokerHostPort == null) { throw new PinotClientException("Could not find broker to query for statement: " + query); @@ -165,26 +161,6 @@ public class Connection { return _transport.executeQueryAsync(brokerHostPort, query).thenApply(ResultSetGroup::new); } - /** - * Returns the name of all the tables used in a sql query. - * - * @return name of all the tables used in a sql query. - */ - @Nullable - public static String[] resolveTableName(String query) { - try { - SqlNodeAndOptions sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(query); - Set<String> tableNames = - RequestUtils.getTableNames(CalciteSqlParser.compileSqlNodeToPinotQuery(sqlNodeAndOptions.getSqlNode())); - if (tableNames != null) { - return tableNames.toArray(new String[0]); - } - } catch (Exception e) { - LOGGER.error("Cannot parse table name from query: {}. Fallback to broker selector default.", query, e); - } - return null; - } - /** * Returns the list of brokers to which this connection can connect to. * diff --git a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java new file mode 100644 index 0000000000..3987139bd3 --- /dev/null +++ b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/TableNameExtractor.java @@ -0,0 +1,410 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.client; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import javax.annotation.Nullable; +import org.apache.calcite.sql.SqlBasicCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlJoin; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOrderBy; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlWith; +import org.apache.calcite.sql.SqlWithItem; +import org.apache.pinot.sql.parsers.CalciteSqlParser; +import org.apache.pinot.sql.parsers.SqlNodeAndOptions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.HashMap; +import java.util.Map; +import java.lang.reflect.Field; +/** + * Helper class to extract table names from Calcite SqlNode tree. + */ +public class TableNameExtractor { + private static final Logger LOGGER = LoggerFactory.getLogger(TableNameExtractor.class); + // Static map of reserved SQL keywords loaded from config file + private static final Map<String, Boolean> RESERVED_KEYWORDS = loadReservedKeywords(); + /** + * Returns the name of all the tables used in a sql query. + * + * @param query The SQL query string to analyze + * @return name of all the tables used in a sql query, or null if parsing fails + */ + @Nullable + public static String[] resolveTableName(String query) { + SqlNodeAndOptions sqlNodeAndOptions; + try { + sqlNodeAndOptions = CalciteSqlParser.compileToSqlNodeAndOptions(query); + } catch (Exception e) { + LOGGER.error("Cannot parse table name from query: {}. Fallback to broker selector default.", query, e); + return null; + } + try { + Set<String> tableNames = extractTableNamesFromMultiStageQuery(sqlNodeAndOptions.getSqlNode()); + if (tableNames != null) { + return tableNames.toArray(new String[0]); + } + } catch (Exception e) { + LOGGER.error("Cannot extract table name from query: {}. Fallback to broker selector default.", query, e); + } + return null; + } + /** + * Extracts table names from a multi-stage query using Calcite SQL AST traversal. + * + * @param sqlNode The root SqlNode of the parsed query + * @return Set of table names found in the query + */ + private static Set<String> extractTableNamesFromMultiStageQuery(SqlNode sqlNode) { + TableNameExtractor extractor = new TableNameExtractor(); + try { + extractor.extractTableNames(sqlNode); + return extractor.getTableNames(); + } catch (Exception e) { + LOGGER.debug("Failed to extract table names from multi-stage query", e); + return Collections.emptySet(); + } + } + private final Set<String> _tableNames = new HashSet<>(); + private final Set<String> _cteNames = new HashSet<>(); + private boolean _inFromClause = false; + public Set<String> getTableNames() { + return _tableNames; + } + public void extractTableNames(SqlNode node) { + if (node == null) { + return; + } + if (node instanceof SqlWith) { + visitWith((SqlWith) node); + } else if (node instanceof SqlOrderBy) { + visitOrderBy((SqlOrderBy) node); + } else if (node instanceof SqlWithItem) { + visitWithItem((SqlWithItem) node); + } else if (node instanceof SqlSelect) { + visitSelect((SqlSelect) node); + } else if (node instanceof SqlJoin) { + visitJoin((SqlJoin) node); + } else if (node instanceof SqlBasicCall) { + visitBasicCall((SqlBasicCall) node); + } else if (node instanceof SqlIdentifier) { + visitIdentifier((SqlIdentifier) node); + } else if (node instanceof SqlNodeList) { + visitNodeList((SqlNodeList) node); + } else { + // Handle unknown node types by trying to access operands + visitUnknownNode(node); + } + } + private void visitWith(SqlWith with) { + // Visit the WITH list (CTE definitions) + if (with.withList != null) { + visitNodeList(with.withList); + } + // Visit the main query body + if (with.body != null) { + extractTableNames(with.body); + } + } + private void visitOrderBy(SqlOrderBy orderBy) { + // Visit the main query - this is the most important part + if (orderBy.query != null) { + extractTableNames(orderBy.query); + } + // Visit ORDER BY expressions for potential subqueries + if (orderBy.orderList != null) { + // Don't set inFromClause=true for ORDER BY expressions + // as they typically contain column references, not table names + visitNodeList(orderBy.orderList); + } + // Visit OFFSET clause if it contains subqueries (rare but possible) + if (orderBy.offset != null) { + extractTableNames(orderBy.offset); + } + // Visit FETCH/LIMIT clause if it contains subqueries (rare but possible) + if (orderBy.fetch != null) { + extractTableNames(orderBy.fetch); + } + } + private void visitWithItem(SqlWithItem withItem) { + // Track the CTE name so we don't treat it as a table later + if (withItem.name != null) { + String cteName = withItem.name.getSimple(); + _cteNames.add(cteName); + } + // Extract table names from the CTE query definition, not the CTE alias + if (withItem.query != null) { + extractTableNames(withItem.query); + } + } + private void visitSelect(SqlSelect select) { + // Visit FROM clause - this is where we expect to find table names + if (select.getFrom() != null) { + _inFromClause = true; + extractTableNames(select.getFrom()); + _inFromClause = false; + } + // Visit other clauses for subqueries + if (select.getWhere() != null) { + extractTableNames(select.getWhere()); + } + if (select.getGroup() != null) { + visitNodeList(select.getGroup()); + } + if (select.getHaving() != null) { + extractTableNames(select.getHaving()); + } + if (select.getOrderList() != null) { + visitNodeList(select.getOrderList()); + } + if (select.getSelectList() != null) { + visitNodeList(select.getSelectList()); + } + } + private void visitJoin(SqlJoin join) { + // Visit both sides of the join - ensure they're processed as FROM clause items + boolean wasInFromClause = _inFromClause; + if (join.getLeft() != null) { + _inFromClause = true; + extractTableNames(join.getLeft()); + } + if (join.getRight() != null) { + _inFromClause = true; + extractTableNames(join.getRight()); + } + // Visit join condition but not as part of FROM clause context + // This handles potential subqueries in join conditions while avoiding + // incorrectly extracting column references as table names + if (join.getCondition() != null) { + _inFromClause = false; + extractTableNames(join.getCondition()); + } + // Restore original context + _inFromClause = wasInFromClause; + } + private void visitBasicCall(SqlBasicCall call) { + String operatorName = call.getOperator().getName().toUpperCase(); + if (operatorName.equals("AS")) { + // Handle table aliases like "tableA AS a" + // For AS operations, the first operand is the actual table name + if (call.getOperandList().size() > 0 && call.getOperandList().get(0) != null) { + extractTableNames(call.getOperandList().get(0)); + } + } else if (operatorName.equals("WITH")) { + // Handle CTE (Common Table Expression) + visitWithClause(call); + } else if (operatorName.equals("VALUES")) { + // Handle VALUES clause - usually doesn't contain table references + // Skip this to avoid false positives + } else { + // For other basic calls, visit all operands + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } + private void visitIdentifier(SqlIdentifier identifier) { + // Only extract table names when we're in a FROM clause + if (_inFromClause && identifier.names.size() >= 1) { + String tableName = identifier.names.get(identifier.names.size() - 1); + // Filter out SQL keywords, system identifiers, and CTE names + if (!isReservedKeyword(tableName) && !tableName.startsWith("$") && !_cteNames.contains(tableName)) { + _tableNames.add(tableName); + } + } + } + /** + * Visit a SqlNodeList by visiting each node in the list. + */ + private void visitNodeList(SqlNodeList nodeList) { + if (nodeList != null) { + for (SqlNode node : nodeList) { + if (node != null) { + extractTableNames(node); + } + } + } + } + /** + * Handle unknown node types by attempting to visit their operands. + */ + private void visitUnknownNode(SqlNode node) { + try { + // Try to get operands list using reflection or common methods + if (node.getKind() != null) { + switch (node.getKind().name()) { + case "WITH": + visitWithClause(node); + break; + case "ORDER_BY": + visitOrderByCall(node); + break; + default: + // For other unknown nodes, try to visit operands if they exist + visitNodeOperands(node); + break; + } + } else { + visitNodeOperands(node); + } + } catch (Exception e) { + // Ignore reflection errors and continue + } + } + /** + * Handle WITH clause (CTE - Common Table Expression). + */ + private void visitWithClause(SqlNode node) { + try { + // WITH clause typically has operands: [with_list, query] + if (node instanceof SqlBasicCall) { + SqlBasicCall withCall = (SqlBasicCall) node; + for (SqlNode operand : withCall.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + // Fallback to generic operand handling + visitNodeOperands(node); + } + } + /** + * Handle ORDER BY clause - this method is now replaced by visitOrderBy(SqlOrderBy). + * Keeping for backward compatibility with visitUnknownNode. + */ + private void visitOrderByCall(SqlNode node) { + try { + if (node instanceof SqlBasicCall) { + SqlBasicCall orderByCall = (SqlBasicCall) node; + // ORDER BY typically has [query, order_list] + for (SqlNode operand : orderByCall.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + visitNodeOperands(node); + } + } + /** + * Generic method to visit node operands when specific handling is not available. + */ + private void visitNodeOperands(SqlNode node) { + try { + // Try to access operands through common interface + if (node instanceof SqlBasicCall) { + SqlBasicCall call = (SqlBasicCall) node; + for (SqlNode operand : call.getOperandList()) { + if (operand != null) { + extractTableNames(operand); + } + } + } + } catch (Exception e) { + // Nothing more we can do + } + } + /** + * Check if the given name is a reserved SQL keyword that shouldn't be treated as a table name. + */ + private boolean isReservedKeyword(String name) { + if (name == null) { + return true; + } + String upperName = name.toUpperCase(); + return RESERVED_KEYWORDS.containsKey(upperName); + } + /** + * Load reserved SQL keywords from the SqlParserImplConstants. + * This method uses the generated constants from the parser to get all reserved keywords. + */ + private static Map<String, Boolean> loadReservedKeywords() { + Map<String, Boolean> reservedKeywords = new HashMap<>(); + try { + // Use reflection to access SqlParserImplConstants.tokenImage + Class<?> constantsClass = Class.forName("org.apache.pinot.sql.parsers.parser.SqlParserImplConstants"); + Field tokenImageField = constantsClass.getField("tokenImage"); + String[] tokenImage = (String[]) tokenImageField.get(null); + + // Process each token to extract reserved keywords + for (String token : tokenImage) { + // Skip tokens that are not keywords (like literals, operators, etc.) + if (token.startsWith("\"") && token.endsWith("\"") && !token.startsWith("\"<")) { + // Extract the keyword without quotes + String keyword = token.substring(1, token.length() - 1); + // Skip single character tokens and operators + if (keyword.length() > 1 && !isOperator(keyword)) { + reservedKeywords.put(keyword, true); + } + } + } + LOGGER.debug("Loaded {} reserved keywords from SqlParserImplConstants", reservedKeywords.size()); + } catch (Exception e) { + LOGGER.warn("Failed to load reserved keywords from SqlParserImplConstants, using fallback set", e); + // Fall back to essential reserved keywords + addFallbackReservedKeywords(reservedKeywords); + } + return Collections.unmodifiableMap(reservedKeywords); + } + + /** + * Check if a token is an operator (not a keyword). + */ + private static boolean isOperator(String token) { + // Common SQL operators that should not be treated as reserved keywords + String[] operators = { + "=", ">", "<", ">=", "<=", "<>", "!=", "+", "-", "*", "/", "%", "||", "->", "..", "(", ")", + "{", "}", "[", "]", ";", ".", ",", "?", ":", "|", "^", "$", "/*", "*/", "/*+" + }; + for (String op : operators) { + if (op.equals(token)) { + return true; + } + } + return false; + } + + /** + * Add fallback reserved keywords in case the constants loading fails. + */ + private static void addFallbackReservedKeywords(Map<String, Boolean> reservedKeywords) { + String[] fallbackKeywords = { + "SELECT", "FROM", "WHERE", "GROUP", "ORDER", "BY", "HAVING", "JOIN", "INNER", "LEFT", "RIGHT", "OUTER", + "ON", "AS", "AND", "OR", "NOT", "IN", "LIMIT", "OFFSET", "UNION", "ALL", "DISTINCT", "COUNT", + "SUM", "AVG", "MIN", "MAX", "CASE", "WHEN", "THEN", "ELSE", "END", "CREATE", "DROP", "ALTER", + "INSERT", "UPDATE", "DELETE", "TABLE", "INDEX", "VIEW", "SCHEMA", "DATABASE", "CASCADE", + "RESTRICT", "PRIMARY", "FOREIGN", "KEY", "CONSTRAINT", "UNIQUE", "NULL", "DEFAULT", + "CHECK", "REFERENCES", "SET", "VALUES", "WITH", "RECURSIVE", "EXISTS", "BETWEEN", + "LIKE", "IS", "TRUE", "FALSE", "UNKNOWN", "CAST", "CONVERT", "TRIM", "SUBSTRING", + "UPPER", "LOWER", "LENGTH", "CHAR_LENGTH", "POSITION", "EXTRACT" + }; + for (String keyword : fallbackKeywords) { + reservedKeywords.put(keyword, true); + } + } +} diff --git a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java index e30b49979e..ee1ccf6b3f 100644 --- a/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java +++ b/pinot-clients/pinot-java-client/src/main/java/org/apache/pinot/client/grpc/GrpcConnection.java @@ -32,10 +32,10 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import org.apache.pinot.client.BrokerResponse; import org.apache.pinot.client.BrokerSelector; -import org.apache.pinot.client.Connection; import org.apache.pinot.client.PinotClientException; import org.apache.pinot.client.ResultSetGroup; import org.apache.pinot.client.SimpleBrokerSelector; +import org.apache.pinot.client.TableNameExtractor; import org.apache.pinot.common.config.GrpcConfig; import org.apache.pinot.common.proto.Broker; import org.apache.pinot.common.utils.DataSchema; @@ -233,7 +233,7 @@ public class GrpcConnection implements AutoCloseable { */ public Iterator<Broker.BrokerResponse> executeWithIterator(String query, Map<String, String> metadata) throws PinotClientException { - String[] tableNames = Connection.resolveTableName(query); + String[] tableNames = TableNameExtractor.resolveTableName(query); String brokerHostPort = _brokerSelector.selectBroker(tableNames); if (brokerHostPort == null) { throw new PinotClientException("Could not find broker to query " + ((tableNames == null) ? "with no tables" diff --git a/pinot-clients/pinot-java-client/src/test/java/org/apache/pinot/client/TableNameExtractorTest.java b/pinot-clients/pinot-java-client/src/test/java/org/apache/pinot/client/TableNameExtractorTest.java new file mode 100644 index 0000000000..5319d78610 --- /dev/null +++ b/pinot-clients/pinot-java-client/src/test/java/org/apache/pinot/client/TableNameExtractorTest.java @@ -0,0 +1,291 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.client; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertTrue; + + +/** + * Tests for the TableNameExtractor class. + */ +public class TableNameExtractorTest { + + @Test + public void testResolveTableNameWithSingleQuery() { + // Test that single queries work correctly + String singleQuery = "SELECT * FROM myTable WHERE id > 100"; + + String[] tableNames = TableNameExtractor.resolveTableName(singleQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "myTable", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithSingleStatementAlias() { + String singleStatementQuery = "SELECT stats.* FROM airlineStats stats LIMIT 10"; + String[] tableNames = TableNameExtractor.resolveTableName(singleStatementQuery); + + assertNotNull(tableNames); + assertEquals(tableNames.length, 1); + assertEquals(tableNames[0], "airlineStats"); + } + + @Test + public void testResolveTableNameWithMultiStatementQuery() { + // Test the fix for issue #11823: CalciteSQLParser error with multi-statement queries + String multiStatementQuery = "SET useMultistageEngine=true;\nSELECT stats.* FROM airlineStats stats LIMIT 10"; + + // This should not throw a ClassCastException anymore + String[] tableNames = TableNameExtractor.resolveTableName(multiStatementQuery); + + // Should successfully resolve the table name + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "airlineStats", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithMultipleSetStatements() { + // Test with multiple SET statements + String multiSetQuery = "SET useMultistageEngine=true;\nSET timeoutMs=10000;\nSELECT * FROM testTable"; + + String[] tableNames = TableNameExtractor.resolveTableName(multiSetQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "testTable", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithMultipleSetStatementsAndJoin() { + String multiStatementQuery = "SET useMultistageEngine=true;\nSET maxRowsInJoin=1000;\n" + + "SELECT stats.* FROM airlineStats stats LIMIT 10"; + String[] tableNames = TableNameExtractor.resolveTableName(multiStatementQuery); + + assertNotNull(tableNames, "Table names should be resolved for queries with multiple SET statements"); + assertEquals(tableNames.length, 1); + assertEquals(tableNames[0], "airlineStats"); + } + + @Test + public void testResolveTableNameWithJoin() { + // Test with JOIN queries + String joinQuery = "SELECT * FROM table1 t1 JOIN table2 t2 ON t1.id = t2.id"; + + String[] tableNames = TableNameExtractor.resolveTableName(joinQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("table1"), "Should contain table1"); + assertTrue(Arrays.asList(tableNames).contains("table2"), "Should contain table2"); + } + + @Test + public void testResolveTableNameWithJoinQueryAndSetStatements() { + String joinQuery = "SET useMultistageEngine=true;\n" + + "SELECT a.col1, b.col2 FROM tableA a JOIN tableB b ON a.id = b.id"; + String[] tableNames = TableNameExtractor.resolveTableName(joinQuery); + + assertNotNull(tableNames, "Table names should be resolved for join queries with SET statements"); + assertEquals(tableNames.length, 2); + + Set<String> expectedTableNames = new HashSet<>(Arrays.asList("tableA", "tableB")); + Set<String> actualTableNames = new HashSet<>(Arrays.asList(tableNames)); + assertEquals(actualTableNames, expectedTableNames); + } + + @Test + public void testResolveTableNameWithExplicitAlias() { + // Test with explicit AS alias + String aliasQuery = "SELECT u.name FROM users AS u WHERE u.active = true"; + + String[] tableNames = TableNameExtractor.resolveTableName(aliasQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "users", "Should resolve the actual table name, not the alias"); + } + + @Test + public void testResolveTableNameWithImplicitAlias() { + // Test with implicit alias (no AS keyword) + String implicitAliasQuery = "SELECT o.id, u.name FROM orders o JOIN users u ON o.user_id = u.id"; + + String[] tableNames = TableNameExtractor.resolveTableName(implicitAliasQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + } + + @Test + public void testResolveTableNameWithCTE() { + // Test with Common Table Expression (CTE) + String cteQuery = "WITH active_users AS (SELECT * FROM users WHERE active = true) " + + "SELECT au.name FROM active_users au JOIN orders o ON au.id = o.user_id"; + + String[] tableNames = TableNameExtractor.resolveTableName(cteQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table from CTE"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + } + + @Test + public void testResolveTableNameWithNestedCTE() { + // Test with nested CTEs + String nestedCteQuery = "WITH user_orders AS (" + + " SELECT u.id, u.name, o.order_date " + + " FROM users u JOIN orders o ON u.id = o.user_id" + + "), recent_orders AS (" + + " SELECT * FROM user_orders WHERE order_date > '2023-01-01'" + + ") " + + "SELECT ro.name FROM recent_orders ro JOIN products p ON ro.id = p.user_id"; + + String[] tableNames = TableNameExtractor.resolveTableName(nestedCteQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 3, "Should resolve three tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("products"), "Should contain products table"); + } + + @Test + public void testResolveTableNameWithSubqueryAlias() { + // Test with subquery alias + String subqueryQuery = "SELECT t.name FROM (SELECT * FROM users WHERE active = true) AS t " + + "JOIN orders o ON t.id = o.user_id"; + + String[] tableNames = TableNameExtractor.resolveTableName(subqueryQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table from subquery"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + } + + @Test + public void testResolveTableNameWithComplexJoinAndAliases() { + // Test with multiple JOINs and various alias styles + String complexQuery = "SELECT u.name, o.total, p.title " + + "FROM users AS u " + + "INNER JOIN orders o ON u.id = o.user_id " + + "LEFT JOIN order_items oi ON o.id = oi.order_id " + + "RIGHT JOIN products AS p ON oi.product_id = p.id " + + "WHERE u.active = true"; + + String[] tableNames = TableNameExtractor.resolveTableName(complexQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 4, "Should resolve four tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("order_items"), "Should contain order_items table"); + assertTrue(Arrays.asList(tableNames).contains("products"), "Should contain products table"); + } + + @Test + public void testResolveTableNameWithJoinConditionSubquery() { + // Test with subquery in join condition + String joinSubqueryQuery = "SELECT u.name, o.total " + + "FROM users u " + + "JOIN orders o ON u.id = o.user_id " + + "AND o.id IN (SELECT order_id FROM order_items WHERE quantity > 5)"; + + String[] tableNames = TableNameExtractor.resolveTableName(joinSubqueryQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 3, "Should resolve three tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + assertTrue(Arrays.asList(tableNames).contains("order_items"), + "Should contain order_items table from subquery"); + } + + @Test + public void testResolveTableNameWithOrderBy() { + // Test with ORDER BY clause + String orderByQuery = "SELECT * FROM users ORDER BY name"; + String[] tableNames = TableNameExtractor.resolveTableName(orderByQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 1, "Should resolve exactly one table"); + assertEquals(tableNames[0], "users", "Should resolve the correct table name"); + } + + @Test + public void testResolveTableNameWithOrderBySubquery() { + // Test with subquery in ORDER BY clause (rare but possible) + String orderBySubqueryQuery = "SELECT * FROM users u ORDER BY " + + "(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id)"; + String[] tableNames = TableNameExtractor.resolveTableName(orderBySubqueryQuery); + + assertNotNull(tableNames, "Table names should not be null"); + assertEquals(tableNames.length, 2, "Should resolve two tables"); + assertTrue(Arrays.asList(tableNames).contains("users"), "Should contain users table"); + assertTrue(Arrays.asList(tableNames).contains("orders"), "Should contain orders table"); + } + + @Test + public void testResolveTableNameWithInvalidQuery() { + String invalidQuery = "INVALID SQL QUERY"; + String[] tableNames = TableNameExtractor.resolveTableName(invalidQuery); + + // Should return null when query cannot be parsed (fallback to default broker selector) + assertNull(tableNames); + } + + @Test + public void testResolveTableNameWithOnlySetStatements() { + String onlySetQuery = "SET useMultistageEngine=true;"; + String[] tableNames = TableNameExtractor.resolveTableName(onlySetQuery); + + // Should return null when there's no actual query statement + assertNull(tableNames); + } + + @Test + public void testResolveTableNameWithNullQuery() { + String[] tableNames = TableNameExtractor.resolveTableName(null); + + // Should return null when query is null + assertNull(tableNames); + } + + @Test + public void testResolveTableNameWithEmptyQuery() { + String[] tableNames = TableNameExtractor.resolveTableName(""); + + // Should return null when query is empty + assertNull(tableNames); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
