This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new d411cabd9 fix(rust/driver/datafusion): using datafusion driver in
async runtime (#3712)
d411cabd9 is described below
commit d411cabd9e34a534e13a06ea4f533ff9a326869e
Author: Pavel Agafonov <[email protected]>
AuthorDate: Fri Nov 14 09:29:08 2025 +0300
fix(rust/driver/datafusion): using datafusion driver in async runtime
(#3712)
Closes #3711
This is not an ideal solution, but we need to think about sync/async
ergonomics in the future.
At the moment, there is a problem: the inability to use this approach in
single-thread runtime.
---------
Signed-off-by: if0ne <[email protected]>
Signed-off-by: Pavel Agafonov <[email protected]>
---
rust/driver/datafusion/src/lib.rs | 59 +++++++++++++++++++------
rust/driver/datafusion/tests/test_datafusion.rs | 26 ++++++++---
2 files changed, 65 insertions(+), 20 deletions(-)
diff --git a/rust/driver/datafusion/src/lib.rs
b/rust/driver/datafusion/src/lib.rs
index ef61f4d24..c08ba28c7 100644
--- a/rust/driver/datafusion/src/lib.rs
+++ b/rust/driver/datafusion/src/lib.rs
@@ -25,9 +25,9 @@ use
datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use datafusion_substrait::substrait::proto::Plan;
use prost::Message;
use std::fmt::Debug;
+use std::future::Future;
use std::sync::Arc;
use std::vec::IntoIter;
-use tokio::runtime::Runtime;
use arrow_array::builder::{
BooleanBuilder, Int32Builder, Int64Builder, ListBuilder, MapBuilder,
MapFieldNames,
@@ -48,6 +48,31 @@ use adbc_core::{
schemas, Connection, Database, Driver, Optionable, Statement,
};
+pub enum Runtime {
+ Handle(tokio::runtime::Handle),
+ Tokio(tokio::runtime::Runtime),
+}
+
+impl Runtime {
+ pub fn new(handle: Option<tokio::runtime::Handle>) ->
std::io::Result<Self> {
+ if let Some(handle) = handle {
+ Ok(Self::Handle(handle))
+ } else {
+ let runtime = tokio::runtime::Builder::new_multi_thread()
+ .enable_all()
+ .build()?;
+ Ok(Self::Tokio(runtime))
+ }
+ }
+
+ pub fn block_on<F: Future>(&self, future: F) -> F::Output {
+ match self {
+ Runtime::Handle(handle) => tokio::task::block_in_place(||
handle.block_on(future)),
+ Runtime::Tokio(runtime) => runtime.block_on(future),
+ }
+ }
+}
+
#[derive(Debug)]
pub struct SingleBatchReader {
batch: Option<RecordBatch>,
@@ -109,13 +134,23 @@ impl RecordBatchReader for DataFusionReader {
}
#[derive(Default)]
-pub struct DataFusionDriver {}
+pub struct DataFusionDriver {
+ handle: Option<tokio::runtime::Handle>,
+}
+
+impl DataFusionDriver {
+ pub fn new(handle: Option<tokio::runtime::Handle>) -> Self {
+ Self { handle }
+ }
+}
impl Driver for DataFusionDriver {
type DatabaseType = DataFusionDatabase;
fn new_database(&mut self) -> Result<Self::DatabaseType> {
- Ok(Self::DatabaseType {})
+ Ok(Self::DatabaseType {
+ handle: self.handle.clone(),
+ })
}
fn new_database_with_opts(
@@ -127,7 +162,9 @@ impl Driver for DataFusionDriver {
),
>,
) -> adbc_core::error::Result<Self::DatabaseType> {
- let mut database = Self::DatabaseType {};
+ let mut database = Self::DatabaseType {
+ handle: self.handle.clone(),
+ };
for (key, value) in opts {
database.set_option(key, value)?;
}
@@ -135,7 +172,9 @@ impl Driver for DataFusionDriver {
}
}
-pub struct DataFusionDatabase {}
+pub struct DataFusionDatabase {
+ handle: Option<tokio::runtime::Handle>,
+}
impl Optionable for DataFusionDatabase {
type Option = OptionDatabase;
@@ -186,10 +225,7 @@ impl Database for DataFusionDatabase {
fn new_connection(&self) -> Result<Self::ConnectionType> {
let ctx = SessionContext::new();
- let runtime = tokio::runtime::Builder::new_multi_thread()
- .enable_all()
- .build()
- .unwrap();
+ let runtime = Runtime::new(self.handle.clone()).unwrap();
Ok(DataFusionConnection {
runtime: Arc::new(runtime),
@@ -208,10 +244,7 @@ impl Database for DataFusionDatabase {
) -> adbc_core::error::Result<Self::ConnectionType> {
let ctx = SessionContext::new();
- let runtime = tokio::runtime::Builder::new_multi_thread()
- .enable_all()
- .build()
- .unwrap();
+ let runtime = Runtime::new(self.handle.clone()).unwrap();
let mut connection = DataFusionConnection {
runtime: Arc::new(runtime),
diff --git a/rust/driver/datafusion/tests/test_datafusion.rs
b/rust/driver/datafusion/tests/test_datafusion.rs
index 1ba496c9a..9123a480a 100644
--- a/rust/driver/datafusion/tests/test_datafusion.rs
+++ b/rust/driver/datafusion/tests/test_datafusion.rs
@@ -26,8 +26,8 @@ use
datafusion_substrait::logical_plan::producer::to_substrait_plan;
use datafusion_substrait::substrait::proto::Plan;
use prost::Message;
-fn get_connection() -> DataFusionConnection {
- let mut driver = DataFusionDriver::default();
+fn get_connection(handle: Option<tokio::runtime::Handle>) ->
DataFusionConnection {
+ let mut driver = DataFusionDriver::new(handle);
let database = driver.new_database().unwrap();
database.new_connection().unwrap()
}
@@ -80,7 +80,7 @@ fn execute_substrait(connection: &mut DataFusionConnection,
plan: Plan) -> Recor
#[test]
fn test_connection_options() {
- let mut connection = get_connection();
+ let mut connection = get_connection(None);
let current_catalog = connection
.get_option_string(OptionConnection::CurrentCatalog)
@@ -119,7 +119,7 @@ fn test_connection_options() {
#[test]
fn test_get_objects_database() {
- let mut connection = get_connection();
+ let mut connection = get_connection(None);
let objects = get_objects(&connection);
@@ -134,7 +134,7 @@ fn test_get_objects_database() {
#[test]
fn test_execute_sql() {
- let mut connection = get_connection();
+ let mut connection = get_connection(None);
execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS
datafusion.public.example (c1 INT, c2 VARCHAR) AS
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
@@ -146,7 +146,7 @@ fn test_execute_sql() {
#[test]
fn test_ingest() {
- let mut connection = get_connection();
+ let mut connection = get_connection(None);
execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS
datafusion.public.example (c1 INT, c2 VARCHAR) AS
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
@@ -172,7 +172,7 @@ fn test_ingest() {
#[test]
fn test_execute_substrait() {
- let mut connection = get_connection();
+ let mut connection = get_connection(None);
execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS
datafusion.public.example (c1 INT, c2 VARCHAR) AS
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
@@ -198,3 +198,15 @@ fn test_execute_substrait() {
assert_eq!(batch.num_rows(), 3);
assert_eq!(batch.num_columns(), 2);
}
+
+#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
+async fn test_running_in_async() {
+ let mut connection =
get_connection(Some(tokio::runtime::Handle::current()));
+
+ execute_update(&mut connection, "CREATE TABLE IF NOT EXISTS
datafusion.public.example (c1 INT, c2 VARCHAR) AS
VALUES(1,'HELLO'),(2,'DATAFUSION'),(3,'!')");
+
+ let batch = execute_sql_query(&mut connection, "SELECT * FROM
datafusion.public.example");
+
+ assert_eq!(batch.num_rows(), 3);
+ assert_eq!(batch.num_columns(), 2);
+}