This is an automated email from the ASF dual-hosted git repository.

lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new c2fd239e4 fix(c/driver/postgresql): avoid crash if closing invalidated 
result (#2653)
c2fd239e4 is described below

commit c2fd239e4396e1a7013e74d23be7f8b84c0d5fe7
Author: David Li <[email protected]>
AuthorDate: Thu Mar 27 23:44:50 2025 -0400

    fix(c/driver/postgresql): avoid crash if closing invalidated result (#2653)
    
    The driver does not prevent you from closing a statement when there is
    still an open result set. Then closing the result set would crash. Avoid
    this by having the result set keep a weak pointer to the actual state.
    
    Fixes #2629.
---
 c/driver/postgresql/statement.cc                   | 91 +++++++++++++++-------
 c/driver/postgresql/statement.h                    | 12 ++-
 c/validation/adbc_validation.h                     |  2 +
 c/validation/adbc_validation_statement.cc          | 38 ++++++++-
 .../adbc_driver_manager/dbapi.py                   | 15 ++--
 5 files changed, 117 insertions(+), 41 deletions(-)

diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 11a01d5be..a83bca13f 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -30,6 +30,7 @@
 #include <limits>
 #include <memory>
 #include <string>
+#include <string_view>
 #include <utility>
 #include <vector>
 
@@ -219,55 +220,89 @@ void TupleReader::Release() {
   row_id_ = -1;
 }
 
+// Instead of directly exporting the TupleReader, which is tied to the
+// lifetime of the Statement, we export a weak_ptr reference instead.  That
+// way if the user accidentally closes the Statement before the
+// ArrowArrayStream, we can avoid a crash.
+// See https://github.com/apache/arrow-adbc/issues/2629
+struct ExportedTupleReader {
+  std::weak_ptr<TupleReader> self;
+};
+
 void TupleReader::ExportTo(struct ArrowArrayStream* stream) {
   stream->get_schema = &GetSchemaTrampoline;
   stream->get_next = &GetNextTrampoline;
   stream->get_last_error = &GetLastErrorTrampoline;
   stream->release = &ReleaseTrampoline;
-  stream->private_data = this;
+  stream->private_data = new ExportedTupleReader{weak_from_this()};
 }
 
-const struct AdbcError* TupleReader::ErrorFromArrayStream(struct 
ArrowArrayStream* stream,
+const struct AdbcError* TupleReader::ErrorFromArrayStream(struct 
ArrowArrayStream* self,
                                                           AdbcStatusCode* 
status) {
-  if (!stream->private_data || stream->release != &ReleaseTrampoline) {
+  if (!self->private_data || self->release != &ReleaseTrampoline) {
     return nullptr;
   }
 
-  TupleReader* reader = static_cast<TupleReader*>(stream->private_data);
-  if (status) {
-    *status = reader->status_;
+  auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+  auto maybe_reader = wrapper->self.lock();
+  if (maybe_reader) {
+    if (status) {
+      *status = maybe_reader->status_;
+    }
+    return &maybe_reader->error_;
   }
-  return &reader->error_;
+  return nullptr;
 }
 
 int TupleReader::GetSchemaTrampoline(struct ArrowArrayStream* self,
                                      struct ArrowSchema* out) {
   if (!self || !self->private_data) return EINVAL;
 
-  TupleReader* reader = static_cast<TupleReader*>(self->private_data);
-  return reader->GetSchema(out);
+  auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+  auto maybe_reader = wrapper->self.lock();
+  if (maybe_reader) {
+    return maybe_reader->GetSchema(out);
+  }
+  // statement was closed or reader was otherwise invalidated
+  return EINVAL;
 }
 
 int TupleReader::GetNextTrampoline(struct ArrowArrayStream* self,
                                    struct ArrowArray* out) {
   if (!self || !self->private_data) return EINVAL;
 
-  TupleReader* reader = static_cast<TupleReader*>(self->private_data);
-  return reader->GetNext(out);
+  auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+  auto maybe_reader = wrapper->self.lock();
+  if (maybe_reader) {
+    return maybe_reader->GetNext(out);
+  }
+  // statement was closed or reader was otherwise invalidated
+  return EINVAL;
 }
 
 const char* TupleReader::GetLastErrorTrampoline(struct ArrowArrayStream* self) 
{
   if (!self || !self->private_data) return nullptr;
+  constexpr std::string_view kReaderInvalidated =
+      "[libpq] Reader invalidated (statement or reader was closed)";
 
-  TupleReader* reader = static_cast<TupleReader*>(self->private_data);
-  return reader->last_error();
+  auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+  auto maybe_reader = wrapper->self.lock();
+  if (maybe_reader) {
+    return maybe_reader->last_error();
+  }
+  // statement was closed or reader was otherwise invalidated
+  return kReaderInvalidated.data();
 }
 
 void TupleReader::ReleaseTrampoline(struct ArrowArrayStream* self) {
   if (!self || !self->private_data) return;
 
-  TupleReader* reader = static_cast<TupleReader*>(self->private_data);
-  reader->Release();
+  auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+  auto maybe_reader = wrapper->self.lock();
+  if (maybe_reader) {
+    maybe_reader->Release();
+  }
+  delete wrapper;
   self->private_data = nullptr;
   self->release = nullptr;
 }
@@ -281,7 +316,7 @@ AdbcStatusCode PostgresStatement::New(struct 
AdbcConnection* connection,
   connection_ =
       
*reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
   type_resolver_ = connection_->type_resolver();
-  reader_.conn_ = connection_->conn();
+  ClearResult();
   return ADBC_STATUS_OK;
 }
 
@@ -514,24 +549,24 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct 
ArrowArrayStream* stream,
   }
 
   struct ArrowError na_error;
-  reader_.copy_reader_ = std::make_unique<PostgresCopyStreamReader>();
-  CHECK_NA(INTERNAL, reader_.copy_reader_->Init(root_type), error);
+  reader_->copy_reader_ = std::make_unique<PostgresCopyStreamReader>();
+  CHECK_NA(INTERNAL, reader_->copy_reader_->Init(root_type), error);
   CHECK_NA_DETAIL(INTERNAL,
-                  reader_.copy_reader_->InferOutputSchema(
+                  reader_->copy_reader_->InferOutputSchema(
                       std::string(connection_->VendorName()), &na_error),
                   &na_error, error);
 
-  CHECK_NA_DETAIL(INTERNAL, reader_.copy_reader_->InitFieldReaders(&na_error), 
&na_error,
+  CHECK_NA_DETAIL(INTERNAL, 
reader_->copy_reader_->InitFieldReaders(&na_error), &na_error,
                   error);
 
   // Execute the COPY query
   RAISE_STATUS(error, helper.ExecuteCopy());
 
   // We need the PQresult back for the reader
-  reader_.result_ = helper.ReleaseResult();
+  reader_->result_ = helper.ReleaseResult();
 
   // Export to stream
-  reader_.ExportTo(stream);
+  reader_->ExportTo(stream);
   if (rows_affected) *rows_affected = -1;
   return ADBC_STATUS_OK;
 }
@@ -674,7 +709,7 @@ AdbcStatusCode PostgresStatement::GetOption(const char* 
key, char* value, size_t
         break;
     }
   } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 
0) {
-    result = std::to_string(reader_.batch_size_hint_bytes_);
+    result = std::to_string(reader_->batch_size_hint_bytes_);
   } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) {
     if (UseCopy()) {
       result = "true";
@@ -710,7 +745,7 @@ AdbcStatusCode PostgresStatement::GetOptionInt(const char* 
key, int64_t* value,
                                                struct AdbcError* error) {
   std::string result;
   if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) {
-    *value = reader_.batch_size_hint_bytes_;
+    *value = reader_->batch_size_hint_bytes_;
     return ADBC_STATUS_OK;
   }
   SetError(error, "[libpq] Unknown statement option '%s'", key);
@@ -799,7 +834,7 @@ AdbcStatusCode PostgresStatement::SetOption(const char* 
key, const char* value,
       return ADBC_STATUS_INVALID_ARGUMENT;
     }
 
-    this->reader_.batch_size_hint_bytes_ = int_value;
+    this->batch_size_hint_bytes_ = this->reader_->batch_size_hint_bytes_ = 
int_value;
   } else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) {
     if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
       use_copy_ = true;
@@ -836,7 +871,7 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char* 
key, int64_t value,
       return ADBC_STATUS_INVALID_ARGUMENT;
     }
 
-    this->reader_.batch_size_hint_bytes_ = value;
+    this->batch_size_hint_bytes_ = this->reader_->batch_size_hint_bytes_ = 
value;
     return ADBC_STATUS_OK;
   }
   SetError(error, "[libpq] Unknown statement option '%s'", key);
@@ -845,7 +880,9 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char* 
key, int64_t value,
 
 void PostgresStatement::ClearResult() {
   // TODO: we may want to synchronize here for safety
-  reader_.Release();
+  if (reader_) reader_->Release();
+  reader_ = std::make_shared<TupleReader>(connection_->conn());
+  reader_->batch_size_hint_bytes_ = batch_size_hint_bytes_;
 }
 
 int PostgresStatement::UseCopy() {
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 60ada992b..a2c3f5e88 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -39,8 +39,10 @@ namespace adbcpq {
 class PostgresConnection;
 class PostgresStatement;
 
+constexpr static int64_t kDefaultBatchSizeHintBytes = 16777216;
+
 /// \brief An ArrowArrayStream that reads tuples from a PGresult.
-class TupleReader final {
+class TupleReader final : public std::enable_shared_from_this<TupleReader> {
  public:
   TupleReader(PGconn* conn)
       : status_(ADBC_STATUS_OK),
@@ -50,7 +52,7 @@ class TupleReader final {
         pgbuf_(nullptr),
         copy_reader_(nullptr),
         row_id_(-1),
-        batch_size_hint_bytes_(16777216),
+        batch_size_hint_bytes_(kDefaultBatchSizeHintBytes),
         is_finished_(false) {
     ArrowErrorInit(&na_error_);
     data_.data.as_char = nullptr;
@@ -98,7 +100,8 @@ class PostgresStatement {
         query_(),
         prepared_(false),
         use_copy_(-1),
-        reader_(nullptr) {
+        reader_(nullptr),
+        batch_size_hint_bytes_(kDefaultBatchSizeHintBytes) {
     std::memset(&bind_, 0, sizeof(bind_));
   }
 
@@ -170,7 +173,8 @@ class PostgresStatement {
     bool temporary = false;
   } ingest_;
 
-  TupleReader reader_;
+  std::shared_ptr<TupleReader> reader_;
+  int64_t batch_size_hint_bytes_;
 
   int UseCopy();
 };
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 427e39b2e..fad84137e 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -463,6 +463,7 @@ class StatementTest {
 
   void TestConcurrentStatements();
   void TestErrorCompatibility();
+  void TestResultIndependence();
   void TestResultInvalidation();
 
  protected:
@@ -579,6 +580,7 @@ void StatementTest::TestSqlIngestType(ArrowType type,
   TEST_F(FIXTURE, Transactions) { TestTransactions(); }                        
         \
   TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); }        
         \
   TEST_F(FIXTURE, ErrorCompatibility) { TestErrorCompatibility(); }            
         \
+  TEST_F(FIXTURE, ResultIndependence) { TestResultIndependence(); }            
         \
   TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); }
 
 }  // namespace adbc_validation
diff --git a/c/validation/adbc_validation_statement.cc 
b/c/validation/adbc_validation_statement.cc
index adb7aacb2..24765f163 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -51,8 +51,12 @@ void StatementTest::TearDownTest() {
   if (statement.private_data) {
     EXPECT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
   }
-  EXPECT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
-  EXPECT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+  if (connection.private_data) {
+    EXPECT_THAT(AdbcConnectionRelease(&connection, &error), 
IsOkStatus(&error));
+  }
+  if (database.private_data) {
+    EXPECT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+  }
   if (error.release) {
     error.release(&error);
   }
@@ -2839,6 +2843,35 @@ void StatementTest::TestErrorCompatibility() {
   error.release(&error);
 }
 
+void StatementTest::TestResultIndependence() {
+  // If we have a result reader, and we close the statement (and other
+  // resources), either the statement should error, or the reader should be
+  // closeable and should error on other operations
+
+  ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
+  ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
+              IsOkStatus(&error));
+
+  StreamReader reader;
+  ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+                                        &reader.rows_affected, &error),
+              IsOkStatus(&error));
+  ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+
+  auto status = AdbcStatementRelease(&statement, &error);
+  if (status != ADBC_STATUS_OK) {
+    // That's ok, this driver prevents closing the statement while readers are 
open
+    return;
+  }
+  ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
+  ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+
+  // Must not crash (but it's up to the driver whether it errors or succeeds)
+  std::ignore = reader.MaybeNext();
+  // Implicitly StreamReader calls release() on destruction, that should not
+  // crash either
+}
+
 void StatementTest::TestResultInvalidation() {
   // Start reading from a statement, then overwrite it
   ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), 
IsOkStatus(&error));
@@ -2860,4 +2893,5 @@ void StatementTest::TestResultInvalidation() {
   // First reader may fail, or may succeed but give no data
   reader1.MaybeNext();
 }
+
 }  // namespace adbc_validation
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py 
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 679cc1bff..acb0136cc 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -415,14 +415,13 @@ class Connection(_Closeable):
         reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
         table = _blocking_call(reader.read_all, (), {}, self._conn.cancel)
         info = table.to_pylist()
-        return dict(
-            {
-                _KNOWN_INFO_VALUES.get(row["info_name"], row["info_name"]): 
row[
-                    "info_value"
-                ]
-                for row in info
-            }
-        )
+        # try to help the type checker a bit here
+        result: Dict[Union[str, int], Any] = {}
+        for row in info:
+            info_name: int = row["info_name"]
+            key: Union[str, int] = _KNOWN_INFO_VALUES.get(info_name, info_name)
+            result[key] = row["info_value"]
+        return result
 
     def adbc_get_objects(
         self,

Reply via email to