This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new c2fd239e4 fix(c/driver/postgresql): avoid crash if closing invalidated
result (#2653)
c2fd239e4 is described below
commit c2fd239e4396e1a7013e74d23be7f8b84c0d5fe7
Author: David Li <[email protected]>
AuthorDate: Thu Mar 27 23:44:50 2025 -0400
fix(c/driver/postgresql): avoid crash if closing invalidated result (#2653)
The driver does not prevent you from closing a statement when there is
still an open result set. Then closing the result set would crash. Avoid
this by having the result set keep a weak pointer to the actual state.
Fixes #2629.
---
c/driver/postgresql/statement.cc | 91 +++++++++++++++-------
c/driver/postgresql/statement.h | 12 ++-
c/validation/adbc_validation.h | 2 +
c/validation/adbc_validation_statement.cc | 38 ++++++++-
.../adbc_driver_manager/dbapi.py | 15 ++--
5 files changed, 117 insertions(+), 41 deletions(-)
diff --git a/c/driver/postgresql/statement.cc b/c/driver/postgresql/statement.cc
index 11a01d5be..a83bca13f 100644
--- a/c/driver/postgresql/statement.cc
+++ b/c/driver/postgresql/statement.cc
@@ -30,6 +30,7 @@
#include <limits>
#include <memory>
#include <string>
+#include <string_view>
#include <utility>
#include <vector>
@@ -219,55 +220,89 @@ void TupleReader::Release() {
row_id_ = -1;
}
+// Instead of directly exporting the TupleReader, which is tied to the
+// lifetime of the Statement, we export a weak_ptr reference instead. That
+// way if the user accidentally closes the Statement before the
+// ArrowArrayStream, we can avoid a crash.
+// See https://github.com/apache/arrow-adbc/issues/2629
+struct ExportedTupleReader {
+ std::weak_ptr<TupleReader> self;
+};
+
void TupleReader::ExportTo(struct ArrowArrayStream* stream) {
stream->get_schema = &GetSchemaTrampoline;
stream->get_next = &GetNextTrampoline;
stream->get_last_error = &GetLastErrorTrampoline;
stream->release = &ReleaseTrampoline;
- stream->private_data = this;
+ stream->private_data = new ExportedTupleReader{weak_from_this()};
}
-const struct AdbcError* TupleReader::ErrorFromArrayStream(struct
ArrowArrayStream* stream,
+const struct AdbcError* TupleReader::ErrorFromArrayStream(struct
ArrowArrayStream* self,
AdbcStatusCode*
status) {
- if (!stream->private_data || stream->release != &ReleaseTrampoline) {
+ if (!self->private_data || self->release != &ReleaseTrampoline) {
return nullptr;
}
- TupleReader* reader = static_cast<TupleReader*>(stream->private_data);
- if (status) {
- *status = reader->status_;
+ auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+ auto maybe_reader = wrapper->self.lock();
+ if (maybe_reader) {
+ if (status) {
+ *status = maybe_reader->status_;
+ }
+ return &maybe_reader->error_;
}
- return &reader->error_;
+ return nullptr;
}
int TupleReader::GetSchemaTrampoline(struct ArrowArrayStream* self,
struct ArrowSchema* out) {
if (!self || !self->private_data) return EINVAL;
- TupleReader* reader = static_cast<TupleReader*>(self->private_data);
- return reader->GetSchema(out);
+ auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+ auto maybe_reader = wrapper->self.lock();
+ if (maybe_reader) {
+ return maybe_reader->GetSchema(out);
+ }
+ // statement was closed or reader was otherwise invalidated
+ return EINVAL;
}
int TupleReader::GetNextTrampoline(struct ArrowArrayStream* self,
struct ArrowArray* out) {
if (!self || !self->private_data) return EINVAL;
- TupleReader* reader = static_cast<TupleReader*>(self->private_data);
- return reader->GetNext(out);
+ auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+ auto maybe_reader = wrapper->self.lock();
+ if (maybe_reader) {
+ return maybe_reader->GetNext(out);
+ }
+ // statement was closed or reader was otherwise invalidated
+ return EINVAL;
}
const char* TupleReader::GetLastErrorTrampoline(struct ArrowArrayStream* self)
{
if (!self || !self->private_data) return nullptr;
+ constexpr std::string_view kReaderInvalidated =
+ "[libpq] Reader invalidated (statement or reader was closed)";
- TupleReader* reader = static_cast<TupleReader*>(self->private_data);
- return reader->last_error();
+ auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+ auto maybe_reader = wrapper->self.lock();
+ if (maybe_reader) {
+ return maybe_reader->last_error();
+ }
+ // statement was closed or reader was otherwise invalidated
+ return kReaderInvalidated.data();
}
void TupleReader::ReleaseTrampoline(struct ArrowArrayStream* self) {
if (!self || !self->private_data) return;
- TupleReader* reader = static_cast<TupleReader*>(self->private_data);
- reader->Release();
+ auto* wrapper = static_cast<ExportedTupleReader*>(self->private_data);
+ auto maybe_reader = wrapper->self.lock();
+ if (maybe_reader) {
+ maybe_reader->Release();
+ }
+ delete wrapper;
self->private_data = nullptr;
self->release = nullptr;
}
@@ -281,7 +316,7 @@ AdbcStatusCode PostgresStatement::New(struct
AdbcConnection* connection,
connection_ =
*reinterpret_cast<std::shared_ptr<PostgresConnection>*>(connection->private_data);
type_resolver_ = connection_->type_resolver();
- reader_.conn_ = connection_->conn();
+ ClearResult();
return ADBC_STATUS_OK;
}
@@ -514,24 +549,24 @@ AdbcStatusCode PostgresStatement::ExecuteQuery(struct
ArrowArrayStream* stream,
}
struct ArrowError na_error;
- reader_.copy_reader_ = std::make_unique<PostgresCopyStreamReader>();
- CHECK_NA(INTERNAL, reader_.copy_reader_->Init(root_type), error);
+ reader_->copy_reader_ = std::make_unique<PostgresCopyStreamReader>();
+ CHECK_NA(INTERNAL, reader_->copy_reader_->Init(root_type), error);
CHECK_NA_DETAIL(INTERNAL,
- reader_.copy_reader_->InferOutputSchema(
+ reader_->copy_reader_->InferOutputSchema(
std::string(connection_->VendorName()), &na_error),
&na_error, error);
- CHECK_NA_DETAIL(INTERNAL, reader_.copy_reader_->InitFieldReaders(&na_error),
&na_error,
+ CHECK_NA_DETAIL(INTERNAL,
reader_->copy_reader_->InitFieldReaders(&na_error), &na_error,
error);
// Execute the COPY query
RAISE_STATUS(error, helper.ExecuteCopy());
// We need the PQresult back for the reader
- reader_.result_ = helper.ReleaseResult();
+ reader_->result_ = helper.ReleaseResult();
// Export to stream
- reader_.ExportTo(stream);
+ reader_->ExportTo(stream);
if (rows_affected) *rows_affected = -1;
return ADBC_STATUS_OK;
}
@@ -674,7 +709,7 @@ AdbcStatusCode PostgresStatement::GetOption(const char*
key, char* value, size_t
break;
}
} else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) ==
0) {
- result = std::to_string(reader_.batch_size_hint_bytes_);
+ result = std::to_string(reader_->batch_size_hint_bytes_);
} else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) {
if (UseCopy()) {
result = "true";
@@ -710,7 +745,7 @@ AdbcStatusCode PostgresStatement::GetOptionInt(const char*
key, int64_t* value,
struct AdbcError* error) {
std::string result;
if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_BATCH_SIZE_HINT_BYTES) == 0) {
- *value = reader_.batch_size_hint_bytes_;
+ *value = reader_->batch_size_hint_bytes_;
return ADBC_STATUS_OK;
}
SetError(error, "[libpq] Unknown statement option '%s'", key);
@@ -799,7 +834,7 @@ AdbcStatusCode PostgresStatement::SetOption(const char*
key, const char* value,
return ADBC_STATUS_INVALID_ARGUMENT;
}
- this->reader_.batch_size_hint_bytes_ = int_value;
+ this->batch_size_hint_bytes_ = this->reader_->batch_size_hint_bytes_ =
int_value;
} else if (std::strcmp(key, ADBC_POSTGRESQL_OPTION_USE_COPY) == 0) {
if (std::strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
use_copy_ = true;
@@ -836,7 +871,7 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char*
key, int64_t value,
return ADBC_STATUS_INVALID_ARGUMENT;
}
- this->reader_.batch_size_hint_bytes_ = value;
+ this->batch_size_hint_bytes_ = this->reader_->batch_size_hint_bytes_ =
value;
return ADBC_STATUS_OK;
}
SetError(error, "[libpq] Unknown statement option '%s'", key);
@@ -845,7 +880,9 @@ AdbcStatusCode PostgresStatement::SetOptionInt(const char*
key, int64_t value,
void PostgresStatement::ClearResult() {
// TODO: we may want to synchronize here for safety
- reader_.Release();
+ if (reader_) reader_->Release();
+ reader_ = std::make_shared<TupleReader>(connection_->conn());
+ reader_->batch_size_hint_bytes_ = batch_size_hint_bytes_;
}
int PostgresStatement::UseCopy() {
diff --git a/c/driver/postgresql/statement.h b/c/driver/postgresql/statement.h
index 60ada992b..a2c3f5e88 100644
--- a/c/driver/postgresql/statement.h
+++ b/c/driver/postgresql/statement.h
@@ -39,8 +39,10 @@ namespace adbcpq {
class PostgresConnection;
class PostgresStatement;
+constexpr static int64_t kDefaultBatchSizeHintBytes = 16777216;
+
/// \brief An ArrowArrayStream that reads tuples from a PGresult.
-class TupleReader final {
+class TupleReader final : public std::enable_shared_from_this<TupleReader> {
public:
TupleReader(PGconn* conn)
: status_(ADBC_STATUS_OK),
@@ -50,7 +52,7 @@ class TupleReader final {
pgbuf_(nullptr),
copy_reader_(nullptr),
row_id_(-1),
- batch_size_hint_bytes_(16777216),
+ batch_size_hint_bytes_(kDefaultBatchSizeHintBytes),
is_finished_(false) {
ArrowErrorInit(&na_error_);
data_.data.as_char = nullptr;
@@ -98,7 +100,8 @@ class PostgresStatement {
query_(),
prepared_(false),
use_copy_(-1),
- reader_(nullptr) {
+ reader_(nullptr),
+ batch_size_hint_bytes_(kDefaultBatchSizeHintBytes) {
std::memset(&bind_, 0, sizeof(bind_));
}
@@ -170,7 +173,8 @@ class PostgresStatement {
bool temporary = false;
} ingest_;
- TupleReader reader_;
+ std::shared_ptr<TupleReader> reader_;
+ int64_t batch_size_hint_bytes_;
int UseCopy();
};
diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h
index 427e39b2e..fad84137e 100644
--- a/c/validation/adbc_validation.h
+++ b/c/validation/adbc_validation.h
@@ -463,6 +463,7 @@ class StatementTest {
void TestConcurrentStatements();
void TestErrorCompatibility();
+ void TestResultIndependence();
void TestResultInvalidation();
protected:
@@ -579,6 +580,7 @@ void StatementTest::TestSqlIngestType(ArrowType type,
TEST_F(FIXTURE, Transactions) { TestTransactions(); }
\
TEST_F(FIXTURE, ConcurrentStatements) { TestConcurrentStatements(); }
\
TEST_F(FIXTURE, ErrorCompatibility) { TestErrorCompatibility(); }
\
+ TEST_F(FIXTURE, ResultIndependence) { TestResultIndependence(); }
\
TEST_F(FIXTURE, ResultInvalidation) { TestResultInvalidation(); }
} // namespace adbc_validation
diff --git a/c/validation/adbc_validation_statement.cc
b/c/validation/adbc_validation_statement.cc
index adb7aacb2..24765f163 100644
--- a/c/validation/adbc_validation_statement.cc
+++ b/c/validation/adbc_validation_statement.cc
@@ -51,8 +51,12 @@ void StatementTest::TearDownTest() {
if (statement.private_data) {
EXPECT_THAT(AdbcStatementRelease(&statement, &error), IsOkStatus(&error));
}
- EXPECT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
- EXPECT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+ if (connection.private_data) {
+ EXPECT_THAT(AdbcConnectionRelease(&connection, &error),
IsOkStatus(&error));
+ }
+ if (database.private_data) {
+ EXPECT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+ }
if (error.release) {
error.release(&error);
}
@@ -2839,6 +2843,35 @@ void StatementTest::TestErrorCompatibility() {
error.release(&error);
}
+void StatementTest::TestResultIndependence() {
+ // If we have a result reader, and we close the statement (and other
+ // resources), either the statement should error, or the reader should be
+ // closeable and should error on other operations
+
+ ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
+ ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT 42", &error),
+ IsOkStatus(&error));
+
+ StreamReader reader;
+ ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
+ &reader.rows_affected, &error),
+ IsOkStatus(&error));
+ ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
+
+ auto status = AdbcStatementRelease(&statement, &error);
+ if (status != ADBC_STATUS_OK) {
+ // That's ok, this driver prevents closing the statement while readers are
open
+ return;
+ }
+ ASSERT_THAT(AdbcConnectionRelease(&connection, &error), IsOkStatus(&error));
+ ASSERT_THAT(AdbcDatabaseRelease(&database, &error), IsOkStatus(&error));
+
+ // Must not crash (but it's up to the driver whether it errors or succeeds)
+ std::ignore = reader.MaybeNext();
+ // Implicitly StreamReader calls release() on destruction, that should not
+ // crash either
+}
+
void StatementTest::TestResultInvalidation() {
// Start reading from a statement, then overwrite it
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
IsOkStatus(&error));
@@ -2860,4 +2893,5 @@ void StatementTest::TestResultInvalidation() {
// First reader may fail, or may succeed but give no data
reader1.MaybeNext();
}
+
} // namespace adbc_validation
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 679cc1bff..acb0136cc 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -415,14 +415,13 @@ class Connection(_Closeable):
reader = pyarrow.RecordBatchReader._import_from_c(handle.address)
table = _blocking_call(reader.read_all, (), {}, self._conn.cancel)
info = table.to_pylist()
- return dict(
- {
- _KNOWN_INFO_VALUES.get(row["info_name"], row["info_name"]):
row[
- "info_value"
- ]
- for row in info
- }
- )
+ # try to help the type checker a bit here
+ result: Dict[Union[str, int], Any] = {}
+ for row in info:
+ info_name: int = row["info_name"]
+ key: Union[str, int] = _KNOWN_INFO_VALUES.get(info_name, info_name)
+ result[key] = row["info_value"]
+ return result
def adbc_get_objects(
self,