This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new b2e8f2505b GH-47711: [C++][FlightRPC] Enable ODBC query execution 
(#48032)
b2e8f2505b is described below

commit b2e8f2505ba3eafe65a78ece6ae87fa7d0c1c133
Author: Alina (Xi) Li <[email protected]>
AuthorDate: Thu Dec 4 01:29:34 2025 -0800

    GH-47711: [C++][FlightRPC] Enable ODBC query execution (#48032)
    
    ### Rationale for this change
    Enable query execution in ODBC.
    
    ### What changes are included in this PR?
    - Extract SQLExecDirect, SQLExecute, SQLPrepare implementation & tests
    ### Are these changes tested?
    - Tested on local MSVC
    
    ### Are there any user-facing changes?
    N/A
    
    * GitHub Issue: #47711
    
    Authored-by: Alina (Xi) Li <[email protected]>
    Signed-off-by: David Li <[email protected]>
---
 cpp/src/arrow/flight/sql/odbc/odbc_api.cc          | 39 ++++++++--
 .../sql/odbc/odbc_impl/flight_sql_statement.cc     |  5 +-
 .../sql/odbc/odbc_impl/flight_sql_statement.h      |  1 +
 .../arrow/flight/sql/odbc/tests/statement_test.cc  | 85 +++++++++++++++++++++-
 4 files changed, 119 insertions(+), 11 deletions(-)

diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc 
b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc
index dee5f934fb..76d0024680 100644
--- a/cpp/src/arrow/flight/sql/odbc/odbc_api.cc
+++ b/cpp/src/arrow/flight/sql/odbc/odbc_api.cc
@@ -1005,22 +1005,49 @@ SQLRETURN SQLExecDirect(SQLHSTMT stmt, SQLWCHAR* 
query_text, SQLINTEGER text_len
   ARROW_LOG(DEBUG) << "SQLExecDirectW called with stmt: " << stmt
                    << ", query_text: " << static_cast<const void*>(query_text)
                    << ", text_length: " << text_length;
-  // GH-47711 TODO: Implement SQLExecDirect
-  return SQL_INVALID_HANDLE;
+
+  using ODBC::ODBCStatement;
+  // The driver is built to handle SELECT statements only.
+  return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
+    ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
+    std::string query = ODBC::SqlWcharToString(query_text, text_length);
+
+    statement->Prepare(query);
+    statement->ExecutePrepared();
+
+    return SQL_SUCCESS;
+  });
 }
 
 SQLRETURN SQLPrepare(SQLHSTMT stmt, SQLWCHAR* query_text, SQLINTEGER 
text_length) {
   ARROW_LOG(DEBUG) << "SQLPrepareW called with stmt: " << stmt
                    << ", query_text: " << static_cast<const void*>(query_text)
                    << ", text_length: " << text_length;
-  // GH-47712 TODO: Implement SQLPrepare
-  return SQL_INVALID_HANDLE;
+
+  using ODBC::ODBCStatement;
+  // The driver is built to handle SELECT statements only.
+  return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
+    ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
+    std::string query = ODBC::SqlWcharToString(query_text, text_length);
+
+    statement->Prepare(query);
+
+    return SQL_SUCCESS;
+  });
 }
 
 SQLRETURN SQLExecute(SQLHSTMT stmt) {
   ARROW_LOG(DEBUG) << "SQLExecute called with stmt: " << stmt;
-  // GH-47712 TODO: Implement SQLExecute
-  return SQL_INVALID_HANDLE;
+
+  using ODBC::ODBCStatement;
+  // The driver is built to handle SELECT statements only.
+  return ODBCStatement::ExecuteWithDiagnostics(stmt, SQL_ERROR, [=]() {
+    ODBCStatement* statement = reinterpret_cast<ODBCStatement*>(stmt);
+
+    statement->ExecutePrepared();
+
+    return SQL_SUCCESS;
+  });
 }
 
 SQLRETURN SQLFetch(SQLHSTMT stmt) {
diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc 
b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc
index 785a04c7b0..f6c6da860d 100644
--- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc
+++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.cc
@@ -69,6 +69,10 @@ FlightSqlStatement::FlightSqlStatement(const Diagnostics& 
diagnostics,
   call_options_.timeout = TimeoutDuration{-1};
 }
 
+FlightSqlStatement::~FlightSqlStatement() {
+  ClosePreparedStatementIfAny(prepared_statement_, call_options_);
+}
+
 bool FlightSqlStatement::SetAttribute(StatementAttributeId attribute,
                                       const Attribute& value) {
   switch (attribute) {
@@ -119,7 +123,6 @@ bool FlightSqlStatement::ExecutePrepared() {
 
   Result<std::shared_ptr<FlightInfo>> result =
       prepared_statement_->Execute(call_options_);
-
   ThrowIfNotOK(result.status());
 
   current_result_set_ = std::make_shared<FlightSqlResultSet>(
diff --git a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h 
b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h
index 3593b2f774..d61f8ef378 100644
--- a/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h
+++ b/cpp/src/arrow/flight/sql/odbc/odbc_impl/flight_sql_statement.h
@@ -49,6 +49,7 @@ class FlightSqlStatement : public Statement {
   FlightSqlStatement(const Diagnostics& diagnostics, FlightSqlClient& 
sql_client,
                      FlightClientOptions client_options, FlightCallOptions 
call_options,
                      const MetadataSettings& metadata_settings);
+  ~FlightSqlStatement();
 
   bool SetAttribute(StatementAttributeId attribute, const Attribute& value) 
override;
 
diff --git a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc 
b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc
index 9d6d42c4a1..a83855c218 100644
--- a/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc
+++ b/cpp/src/arrow/flight/sql/odbc/tests/statement_test.cc
@@ -37,9 +37,86 @@ class StatementRemoteTest : public 
FlightSQLODBCRemoteTestBase {};
 using TestTypes = ::testing::Types<StatementMockTest, StatementRemoteTest>;
 TYPED_TEST_SUITE(StatementTest, TestTypes);
 
+TYPED_TEST(StatementTest, TestSQLExecDirectSimpleQuery) {
+  std::wstring wsql = L"SELECT 1;";
+  std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
+
+  ASSERT_EQ(SQL_SUCCESS,
+            SQLExecDirect(this->stmt, &sql0[0], 
static_cast<SQLINTEGER>(sql0.size())));
+
+  // GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
+  /*
+  ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
+
+  SQLINTEGER val;
+
+  ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
+  // Verify 1 is returned
+  EXPECT_EQ(1, val);
+
+  ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
+
+  ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
+  // Invalid cursor state
+  VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
+  */
+}
+
+TYPED_TEST(StatementTest, TestSQLExecDirectInvalidQuery) {
+  std::wstring wsql = L"SELECT;";
+  std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
+
+  ASSERT_EQ(SQL_ERROR,
+            SQLExecDirect(this->stmt, &sql0[0], 
static_cast<SQLINTEGER>(sql0.size())));
+  // ODBC provides generic error code HY000 to all statement errors
+  VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
+}
+
+TYPED_TEST(StatementTest, TestSQLExecuteSimpleQuery) {
+  std::wstring wsql = L"SELECT 1;";
+  std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
+
+  ASSERT_EQ(SQL_SUCCESS,
+            SQLPrepare(this->stmt, &sql0[0], 
static_cast<SQLINTEGER>(sql0.size())));
+
+  ASSERT_EQ(SQL_SUCCESS, SQLExecute(this->stmt));
+
+  // GH-47713 TODO: Uncomment call to SQLFetch SQLGetData after implementation
+  /*
+  // Fetch data
+  ASSERT_EQ(SQL_SUCCESS, SQLFetch(this->stmt));
+
+  SQLINTEGER val;
+  ASSERT_EQ(SQL_SUCCESS, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
+
+  // Verify 1 is returned
+  EXPECT_EQ(1, val);
+
+  ASSERT_EQ(SQL_NO_DATA, SQLFetch(this->stmt));
+
+  ASSERT_EQ(SQL_ERROR, SQLGetData(this->stmt, 1, SQL_C_LONG, &val, 0, 0));
+  // Invalid cursor state
+  VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorState24000);
+  */
+}
+
+TYPED_TEST(StatementTest, TestSQLPrepareInvalidQuery) {
+  std::wstring wsql = L"SELECT;";
+  std::vector<SQLWCHAR> sql0(wsql.begin(), wsql.end());
+
+  ASSERT_EQ(SQL_ERROR,
+            SQLPrepare(this->stmt, &sql0[0], 
static_cast<SQLINTEGER>(sql0.size())));
+  // ODBC provides generic error code HY000 to all statement errors
+  VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY000);
+
+  ASSERT_EQ(SQL_ERROR, SQLExecute(this->stmt));
+  // Verify function sequence error state is returned
+  VerifyOdbcErrorState(SQL_HANDLE_STMT, this->stmt, kErrorStateHY010);
+}
+
 TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsInputString) {
   SQLWCHAR buf[1024];
-  SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
+  SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
   SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
   SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
   SQLINTEGER output_char_len = 0;
@@ -58,7 +135,7 @@ TYPED_TEST(StatementTest, 
TestSQLNativeSqlReturnsInputString) {
 
 TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsNTSInputString) {
   SQLWCHAR buf[1024];
-  SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
+  SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
   SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
   SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
   SQLINTEGER output_char_len = 0;
@@ -95,7 +172,7 @@ TYPED_TEST(StatementTest, 
TestSQLNativeSqlReturnsInputStringLength) {
 TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsTruncatedString) {
   const SQLINTEGER small_buf_size_in_char = 11;
   SQLWCHAR small_buf[small_buf_size_in_char];
-  SQLINTEGER small_buf_char_len = sizeof(small_buf) / ODBC::GetSqlWCharSize();
+  SQLINTEGER small_buf_char_len = sizeof(small_buf) / GetSqlWCharSize();
   SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
   SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
   SQLINTEGER output_char_len = 0;
@@ -122,7 +199,7 @@ TYPED_TEST(StatementTest, 
TestSQLNativeSqlReturnsTruncatedString) {
 
 TYPED_TEST(StatementTest, TestSQLNativeSqlReturnsErrorOnBadInputs) {
   SQLWCHAR buf[1024];
-  SQLINTEGER buf_char_len = sizeof(buf) / ODBC::GetSqlWCharSize();
+  SQLINTEGER buf_char_len = sizeof(buf) / GetSqlWCharSize();
   SQLWCHAR input_str[] = L"SELECT * FROM mytable WHERE id == 1";
   SQLINTEGER input_char_len = static_cast<SQLINTEGER>(wcslen(input_str));
   SQLINTEGER output_char_len = 0;

Reply via email to