paleolimbot commented on code in PR #586: URL: https://github.com/apache/sedona-db/pull/586#discussion_r2817584548
########## c/sedona-libgpuspatial/src/predicate.rs: ########## @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::os::raw::c_uint; + +#[derive(Debug, PartialEq, Copy, Clone)] +pub enum GpuSpatialRelationPredicate { + Equals, + Disjoint, + Touches, + Contains, + Covers, + Intersects, + Within, + CoveredBy, +} + +#[allow(dead_code)] // not used if the GPU feature is disabled Review Comment: ```suggestion #[cfg(not(gpu_available))] // not used if the GPU feature is disabled ``` ########## c/sedona-libgpuspatial/src/lib.rs: ########## @@ -15,6 +15,305 @@ // specific language governing permissions and limitations // under the License. -// Module declarations +use arrow_schema::DataType; +use geo_types::Rect; + +mod error; #[cfg(gpu_available)] +mod libgpuspatial; mod libgpuspatial_glue_bindgen; +mod options; +mod predicate; + +pub use error::GpuSpatialError; +pub use options::GpuSpatialOptions; +pub use predicate::GpuSpatialRelationPredicate; +pub use sys::{GpuSpatialIndex, GpuSpatialRefiner}; + +#[cfg(gpu_available)] +mod sys { + use super::libgpuspatial; + use super::*; + use libgpuspatial::GpuSpatialRuntimeWrapper; + use std::sync::{Arc, Mutex}; + + pub type Result<T> = std::result::Result<T, GpuSpatialError>; + + // Global Runtime State + unsafe impl Send for GpuSpatialRuntimeWrapper {} + unsafe impl Sync for GpuSpatialRuntimeWrapper {} Review Comment: These should go right below the struct definition for `GpuSpatialRuntimeWrapper` (makes it easier to spot members or functions of the struct that might not be Send or Sync) ########## c/sedona-libgpuspatial/libgpuspatial/src/gpuspatial_c.cc: ########## @@ -194,12 +192,25 @@ struct GpuSpatialIndexFloat2DExporter { } static int CProbe(self_t* self, SedonaSpatialIndexContext* context, const float* buf, - uint32_t n_rects) { - return SafeExecute(static_cast<context_t*>(context->private_data), [&] { + uint32_t n_rects, + void (*callback)(const uint32_t* build_indices, + const uint32_t* probe_indices, uint32_t length, + void* user_data), + void* user_data) { + // Do not use SafeExecute because this method is thread-safe and we don't want to set + // last_error for the whole index if one thread encounters an error + try { auto* rects = reinterpret_cast<const spatial_index_t::box_t*>(buf); auto& buff = static_cast<context_t*>(context->private_data)->payload; use_index(self).Probe(rects, n_rects, &buff.build_indices, &buff.probe_indices); - }); + callback(buff.build_indices.data(), buff.probe_indices.data(), + buff.build_indices.size(), user_data); + return 0; + } catch (const std::exception& e) { // user should call context_get_last_error + return EINVAL; + } catch (...) { + return EINVAL; + } Review Comment: You should be able to set the context error string here safely, correct? If `use_index(self).Probe(...)` already does this, you can add a comment here. If the process of probing never returns a meaningful error message, you can set the context error message to "Unknown error" or something. ########## c/sedona-libgpuspatial/src/lib.rs: ########## @@ -15,6 +15,305 @@ // specific language governing permissions and limitations // under the License. -// Module declarations +use arrow_schema::DataType; +use geo_types::Rect; + +mod error; #[cfg(gpu_available)] +mod libgpuspatial; mod libgpuspatial_glue_bindgen; +mod options; +mod predicate; + +pub use error::GpuSpatialError; +pub use options::GpuSpatialOptions; +pub use predicate::GpuSpatialRelationPredicate; +pub use sys::{GpuSpatialIndex, GpuSpatialRefiner}; + +#[cfg(gpu_available)] +mod sys { + use super::libgpuspatial; + use super::*; + use libgpuspatial::GpuSpatialRuntimeWrapper; + use std::sync::{Arc, Mutex}; + + pub type Result<T> = std::result::Result<T, GpuSpatialError>; + + // Global Runtime State + unsafe impl Send for GpuSpatialRuntimeWrapper {} + unsafe impl Sync for GpuSpatialRuntimeWrapper {} + + static GLOBAL_GPUSPATIAL_RUNTIME: Mutex<Option<Arc<GpuSpatialRuntimeWrapper>>> = + Mutex::new(None); + /// Handles initialization of the GPU runtime. + pub struct SpatialContext { + runtime: Arc<GpuSpatialRuntimeWrapper>, + } + + impl SpatialContext { + pub fn try_new(options: &GpuSpatialOptions) -> Result<Self> { + // Lock the mutex globally + let mut guard = GLOBAL_GPUSPATIAL_RUNTIME + .lock() + .map_err(|_| GpuSpatialError::Init("Global mutex poisoned".into()))?; + + // Check if it already exists + if let Some(existing_runtime) = guard.as_ref() { + if existing_runtime.device_id != options.device_id { + return Err(GpuSpatialError::Init(format!( + "Runtime conflict: Initialized on Device {}, requested Device {}.", + existing_runtime.device_id, options.device_id + ))); + } + // Return the existing one + return Ok(Self { + runtime: existing_runtime.clone(), + }); + } Review Comment: I think a `OnceLock<Arc<GpuSpatialRuntimeWrapper>>` is the Rusty way to do this. In action: https://github.com/apache/sedona-db/blob/7eac903e08c84ae9aa6396d2ff1e8101d1e6b8df/rust/sedona-functions/src/st_setsrid.rs#L434-L448 ########## c/sedona-libgpuspatial/src/lib.rs: ########## @@ -15,6 +15,305 @@ // specific language governing permissions and limitations // under the License. -// Module declarations +use arrow_schema::DataType; +use geo_types::Rect; + +mod error; #[cfg(gpu_available)] +mod libgpuspatial; mod libgpuspatial_glue_bindgen; +mod options; +mod predicate; + +pub use error::GpuSpatialError; +pub use options::GpuSpatialOptions; +pub use predicate::GpuSpatialRelationPredicate; +pub use sys::{GpuSpatialIndex, GpuSpatialRefiner}; + +#[cfg(gpu_available)] +mod sys { + use super::libgpuspatial; + use super::*; + use libgpuspatial::GpuSpatialRuntimeWrapper; + use std::sync::{Arc, Mutex}; + + pub type Result<T> = std::result::Result<T, GpuSpatialError>; + + // Global Runtime State + unsafe impl Send for GpuSpatialRuntimeWrapper {} + unsafe impl Sync for GpuSpatialRuntimeWrapper {} + + static GLOBAL_GPUSPATIAL_RUNTIME: Mutex<Option<Arc<GpuSpatialRuntimeWrapper>>> = + Mutex::new(None); + /// Handles initialization of the GPU runtime. + pub struct SpatialContext { + runtime: Arc<GpuSpatialRuntimeWrapper>, + } + + impl SpatialContext { + pub fn try_new(options: &GpuSpatialOptions) -> Result<Self> { + // Lock the mutex globally + let mut guard = GLOBAL_GPUSPATIAL_RUNTIME + .lock() + .map_err(|_| GpuSpatialError::Init("Global mutex poisoned".into()))?; + + // Check if it already exists + if let Some(existing_runtime) = guard.as_ref() { + if existing_runtime.device_id != options.device_id { + return Err(GpuSpatialError::Init(format!( + "Runtime conflict: Initialized on Device {}, requested Device {}.", + existing_runtime.device_id, options.device_id + ))); + } + // Return the existing one + return Ok(Self { + runtime: existing_runtime.clone(), + }); + } + + let out_path = std::path::PathBuf::from(env!("OUT_DIR")); + let ptx_root = out_path.join("share/gpuspatial/shaders"); + let ptx_root_str = ptx_root + .to_str() + .ok_or_else(|| GpuSpatialError::Init("Invalid PTX path".to_string()))?; + + let wrapper = libgpuspatial::GpuSpatialRuntimeWrapper::try_new( + options.device_id, + ptx_root_str, + options.cuda_use_memory_pool, + options.cuda_memory_pool_init_percent, + )?; + + let arc_wrapper = Arc::new(wrapper); + + *guard = Some(arc_wrapper.clone()); + + Ok(Self { + runtime: arc_wrapper, + }) + } + } + + pub struct GpuSpatialIndex { + inner: libgpuspatial::FloatIndex2D, + } + + impl GpuSpatialIndex { Review Comment: I think eventually we will want to collapse all of this wrapper code into the `FloatIndex2D` but no need to do that now. One good reason to do this is that it means you only have to document it once. If we keep this wrapper for this PR, we still need to document both (although you can link to the inner documentation from the documentation for this struct). ########## c/sedona-libgpuspatial/src/libgpuspatial.rs: ########## @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::GpuSpatialError; +#[cfg(gpu_available)] +use crate::libgpuspatial_glue_bindgen::*; +use crate::predicate::GpuSpatialRelationPredicate; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::DataType; +use std::cell::UnsafeCell; +use std::convert::TryFrom; +use std::ffi::{c_void, CStr, CString}; +use std::os::raw::c_char; +use std::sync::Arc; + +pub struct GpuSpatialRuntimeWrapper { + runtime: UnsafeCell<GpuSpatialRuntime>, + /// Store which device the runtime is created on + pub device_id: i32, +} + +impl GpuSpatialRuntimeWrapper { + pub fn try_new( + device_id: i32, + ptx_root: &str, + use_cuda_memory_pool: bool, + cuda_memory_pool_init_precent: i32, + ) -> Result<GpuSpatialRuntimeWrapper, GpuSpatialError> { + let mut runtime = GpuSpatialRuntime { + init: None, + release: None, + get_last_error: None, + private_data: std::ptr::null_mut(), + }; + + unsafe { + GpuSpatialRuntimeCreate(&mut runtime); + } + + if let Some(init_fn) = runtime.init { + let c_ptx_root = CString::new(ptx_root).map_err(|_| { + GpuSpatialError::Init("Failed to convert ptx_root to CString".into()) + })?; + + let mut config = GpuSpatialRuntimeConfig { + device_id, + ptx_root: c_ptx_root.as_ptr(), + use_cuda_memory_pool, + cuda_memory_pool_init_precent, + }; + + unsafe { + let get_last_error = runtime.get_last_error; + let runtime_ptr = &mut runtime as *mut GpuSpatialRuntime; + + check_ffi_call( + move || init_fn(runtime_ptr as *mut _, &mut config), + get_last_error, + runtime_ptr, + GpuSpatialError::Init, + )?; + } + } else { + return Err(GpuSpatialError::Init("init function is None".to_string())); + } + + Ok(GpuSpatialRuntimeWrapper { + runtime: UnsafeCell::new(runtime), + device_id, + }) + } +} + +impl Drop for GpuSpatialRuntimeWrapper { + fn drop(&mut self) { + let runtime = self.runtime.get_mut(); + let release_fn = runtime.release.expect("release function is None"); + unsafe { + release_fn(runtime as *mut _); + } + } +} + +/// Internal wrapper that manages the lifecycle of the C `SedonaFloatIndex2D` struct. +/// It is wrapped in an `Arc` by the public structs to ensure thread safety. +struct FloatIndex2DWrapper { + index: SedonaFloatIndex2D, + // Keep a reference to the RT engine to ensure it lives as long as the index + _runtime: Arc<GpuSpatialRuntimeWrapper>, +} + +impl Drop for FloatIndex2DWrapper { + fn drop(&mut self) { + let release_fn = self.index.release.expect("release function is None"); + unsafe { + release_fn(&mut self.index as *mut _); + } + } +} + +pub struct FloatIndex2D { + inner: FloatIndex2DWrapper, +} + Review Comment: ```suggestion unsafe impl Send for FloatIndex2D; unsafe impl Sync for FloatIndex2D; ``` ########## c/sedona-libgpuspatial/libgpuspatial/include/gpuspatial/gpuspatial_c.h: ########## @@ -84,30 +84,23 @@ struct SedonaFloatIndex2D { * @return 0 on success, non-zero on failure */ int (*finish_building)(struct SedonaFloatIndex2D* self); + /** * Probe the spatial index with the given rectangles, each rectangle is represented by 4 * floats: [min_x, min_y, max_x, max_y] Points can also be probed by providing [x, y, x, * y] but points and rectangles cannot be mixed in one Probe call. The results of the - * probe will be stored in the context. + * probe will be stored in the context. The callback function will be called for each + * batch of results, with the build and probe indices of the candidate pairs in the + * batch. The user_data pointer will be passed to the callback * * @return 0 on success, non-zero on failure */ int (*probe)(struct SedonaFloatIndex2D* self, struct SedonaSpatialIndexContext* context, - const float* buf, uint32_t n_rects); - /** Get the build indices buffer from the context - * - * @return A pointer to the buffer and its length - */ - void (*get_build_indices_buffer)(struct SedonaSpatialIndexContext* context, - uint32_t** build_indices, - uint32_t* build_indices_length); - /** Get the probe indices buffer from the context - * - * @return A pointer to the buffer and its length - */ - void (*get_probe_indices_buffer)(struct SedonaSpatialIndexContext* context, - uint32_t** probe_indices, - uint32_t* probe_indices_length); + const float* buf, uint32_t n_rects, + void (*callback)(const uint32_t* build_indices, + const uint32_t* probe_indices, uint32_t length, + void* user_data), Review Comment: We should let the callback return an error code. Right now you only visit exactly one pre-allocated buffer but hopefully we can split that up into multiple calls for huge probe results and returning an error code will allow that process to be cancelled. ```suggestion int (*callback)(const uint32_t* build_indices, const uint32_t* probe_indices, uint32_t length, void* user_data), ``` ########## c/sedona-libgpuspatial/src/libgpuspatial.rs: ########## @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::GpuSpatialError; +#[cfg(gpu_available)] +use crate::libgpuspatial_glue_bindgen::*; +use crate::predicate::GpuSpatialRelationPredicate; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::DataType; +use std::cell::UnsafeCell; +use std::convert::TryFrom; +use std::ffi::{c_void, CStr, CString}; +use std::os::raw::c_char; +use std::sync::Arc; + +pub struct GpuSpatialRuntimeWrapper { + runtime: UnsafeCell<GpuSpatialRuntime>, + /// Store which device the runtime is created on + pub device_id: i32, +} + +impl GpuSpatialRuntimeWrapper { + pub fn try_new( + device_id: i32, + ptx_root: &str, + use_cuda_memory_pool: bool, + cuda_memory_pool_init_precent: i32, + ) -> Result<GpuSpatialRuntimeWrapper, GpuSpatialError> { + let mut runtime = GpuSpatialRuntime { + init: None, + release: None, + get_last_error: None, + private_data: std::ptr::null_mut(), + }; + + unsafe { + GpuSpatialRuntimeCreate(&mut runtime); + } + + if let Some(init_fn) = runtime.init { + let c_ptx_root = CString::new(ptx_root).map_err(|_| { + GpuSpatialError::Init("Failed to convert ptx_root to CString".into()) + })?; + + let mut config = GpuSpatialRuntimeConfig { + device_id, + ptx_root: c_ptx_root.as_ptr(), + use_cuda_memory_pool, + cuda_memory_pool_init_precent, + }; + + unsafe { + let get_last_error = runtime.get_last_error; + let runtime_ptr = &mut runtime as *mut GpuSpatialRuntime; + + check_ffi_call( + move || init_fn(runtime_ptr as *mut _, &mut config), + get_last_error, + runtime_ptr, + GpuSpatialError::Init, + )?; + } + } else { + return Err(GpuSpatialError::Init("init function is None".to_string())); + } + + Ok(GpuSpatialRuntimeWrapper { + runtime: UnsafeCell::new(runtime), + device_id, + }) + } +} + +impl Drop for GpuSpatialRuntimeWrapper { + fn drop(&mut self) { + let runtime = self.runtime.get_mut(); + let release_fn = runtime.release.expect("release function is None"); + unsafe { + release_fn(runtime as *mut _); + } + } +} + +/// Internal wrapper that manages the lifecycle of the C `SedonaFloatIndex2D` struct. +/// It is wrapped in an `Arc` by the public structs to ensure thread safety. +struct FloatIndex2DWrapper { + index: SedonaFloatIndex2D, + // Keep a reference to the RT engine to ensure it lives as long as the index + _runtime: Arc<GpuSpatialRuntimeWrapper>, +} + +impl Drop for FloatIndex2DWrapper { + fn drop(&mut self) { + let release_fn = self.index.release.expect("release function is None"); + unsafe { + release_fn(&mut self.index as *mut _); + } + } +} + +pub struct FloatIndex2D { + inner: FloatIndex2DWrapper, +} + +impl FloatIndex2D { + pub fn try_new( + runtime: Arc<GpuSpatialRuntimeWrapper>, + concurrency: u32, + ) -> Result<Self, GpuSpatialError> { + let mut index = SedonaFloatIndex2D { + clear: None, + create_context: None, + destroy_context: None, + push_build: None, + finish_building: None, + probe: None, + get_last_error: None, + context_get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + }; + + let config = GpuSpatialIndexConfig { + runtime: runtime.runtime.get(), + concurrency, + }; + + unsafe { + if GpuSpatialIndexFloat2DCreate(&mut index, &config) != 0 { + return Err(GpuSpatialError::Init("Index Create failed".into())); + } + } + + Ok(Self { + inner: FloatIndex2DWrapper { + index, + _runtime: runtime.clone(), + }, + }) + } + + pub fn clear(&mut self) { + if let Some(clear_fn) = self.inner.index.clear { + unsafe { + clear_fn(&mut self.inner.index as *mut _); + } + } + } + + pub fn push_build(&mut self, buf: *const f32, n_rects: u32) -> Result<(), GpuSpatialError> { + let push_fn = + self.inner.index.push_build.ok_or_else(|| { + GpuSpatialError::PushBuild("push_build function is None".to_string()) + })?; + let get_last_error = self.inner.index.get_last_error; + let index_ptr = &mut self.inner.index as *mut SedonaFloatIndex2D; + + unsafe { + check_ffi_call( + move || push_fn(index_ptr, buf, n_rects), + get_last_error, + index_ptr, + GpuSpatialError::PushBuild, + ) + } + } + + pub fn finish_building(&mut self) -> Result<(), GpuSpatialError> { + let finish_fn = self + .inner + .index + .finish_building + .ok_or_else(|| GpuSpatialError::FinishBuild("finish_building missing".into()))?; + let get_last_error = self.inner.index.get_last_error; + let index_ptr = &mut self.inner.index as *mut SedonaFloatIndex2D; + + unsafe { + check_ffi_call( + move || finish_fn(&mut self.inner.index), + get_last_error, + index_ptr, + GpuSpatialError::FinishBuild, + ) + } + } + + pub fn probe( + &self, + buf: *const f32, + n_rects: u32, + ) -> Result<(Vec<u32>, Vec<u32>), GpuSpatialError> { + let probe_fn = self + .inner + .index + .probe + .ok_or_else(|| GpuSpatialError::Probe("probe function is None".into()))?; + let create_context_fn = self.inner.index.create_context; + let destroy_context_fn = self.inner.index.destroy_context; + let context_err_fn = self.inner.index.context_get_last_error; + let index_ptr = &self.inner.index as *const _ as *mut SedonaFloatIndex2D; + + let mut ctx = SedonaSpatialIndexContext { + private_data: std::ptr::null_mut(), + }; + let mut state = ProbeState { + results: (Vec::new(), Vec::new()), + error: None, + }; + + unsafe { + if let Some(create_ctx) = create_context_fn { + create_ctx(&mut ctx); + } + + let status = probe_fn( + index_ptr, + &mut ctx, + buf, + n_rects, + Some(probe_callback_wrapper), + &mut state as *mut _ as *mut c_void, + ); + + if status != 0 { + let error_string = if let Some(get_ctx_err) = context_err_fn { + CStr::from_ptr(get_ctx_err(&mut ctx)) + .to_string_lossy() + .into_owned() + } else { + "Unknown context error during probe".to_string() + }; + + if let Some(destroy_ctx) = destroy_context_fn { + destroy_ctx(&mut ctx); + } + return Err(GpuSpatialError::Probe(error_string)); + } + + if let Some(callback_error) = state.error { + if let Some(destroy_ctx) = destroy_context_fn { + destroy_ctx(&mut ctx); + } + return Err(callback_error); + } + + if let Some(destroy_ctx) = destroy_context_fn { + destroy_ctx(&mut ctx); + } + } + + Ok(state.results) + } +} + +struct RefinerWrapper { + refiner: SedonaSpatialRefiner, + _runtime: Arc<GpuSpatialRuntimeWrapper>, +} + +impl Drop for RefinerWrapper { + fn drop(&mut self) { + let release_fn = self.refiner.release.expect("release function is None"); + unsafe { + release_fn(&mut self.refiner as *mut _); + } + } +} +pub struct Refiner { + inner: RefinerWrapper, +} + Review Comment: ```suggestion unsafe impl Send for Refiner; unsafe impl Sync for Refiner; ``` ########## c/sedona-libgpuspatial/libgpuspatial/src/gpuspatial_c.cc: ########## @@ -194,12 +192,25 @@ struct GpuSpatialIndexFloat2DExporter { } static int CProbe(self_t* self, SedonaSpatialIndexContext* context, const float* buf, - uint32_t n_rects) { - return SafeExecute(static_cast<context_t*>(context->private_data), [&] { + uint32_t n_rects, + void (*callback)(const uint32_t* build_indices, + const uint32_t* probe_indices, uint32_t length, + void* user_data), Review Comment: This callback should return an error code (this will let long-running probe operations be cancelled if or when you split up the output of huge probe results). ```suggestion int (*callback)(const uint32_t* build_indices, const uint32_t* probe_indices, uint32_t length, void* user_data), ``` ########## c/sedona-libgpuspatial/libgpuspatial/src/gpuspatial_c.cc: ########## @@ -194,12 +192,25 @@ struct GpuSpatialIndexFloat2DExporter { } static int CProbe(self_t* self, SedonaSpatialIndexContext* context, const float* buf, - uint32_t n_rects) { - return SafeExecute(static_cast<context_t*>(context->private_data), [&] { + uint32_t n_rects, + void (*callback)(const uint32_t* build_indices, + const uint32_t* probe_indices, uint32_t length, + void* user_data), + void* user_data) { + // Do not use SafeExecute because this method is thread-safe and we don't want to set + // last_error for the whole index if one thread encounters an error + try { auto* rects = reinterpret_cast<const spatial_index_t::box_t*>(buf); auto& buff = static_cast<context_t*>(context->private_data)->payload; use_index(self).Probe(rects, n_rects, &buff.build_indices, &buff.probe_indices); - }); + callback(buff.build_indices.data(), buff.probe_indices.data(), + buff.build_indices.size(), user_data); + return 0; Review Comment: ```suggestion return callback(buff.build_indices.data(), buff.probe_indices.data(), buff.build_indices.size(), user_data); ``` ########## c/sedona-libgpuspatial/src/libgpuspatial.rs: ########## @@ -0,0 +1,502 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::GpuSpatialError; +#[cfg(gpu_available)] +use crate::libgpuspatial_glue_bindgen::*; +use crate::predicate::GpuSpatialRelationPredicate; +use arrow_array::{Array, ArrayRef}; +use arrow_schema::ffi::FFI_ArrowSchema; +use arrow_schema::DataType; +use std::cell::UnsafeCell; +use std::convert::TryFrom; +use std::ffi::{c_void, CStr, CString}; +use std::os::raw::c_char; +use std::sync::Arc; + +pub struct GpuSpatialRuntimeWrapper { + runtime: UnsafeCell<GpuSpatialRuntime>, + /// Store which device the runtime is created on + pub device_id: i32, +} + +impl GpuSpatialRuntimeWrapper { + pub fn try_new( + device_id: i32, + ptx_root: &str, + use_cuda_memory_pool: bool, + cuda_memory_pool_init_precent: i32, + ) -> Result<GpuSpatialRuntimeWrapper, GpuSpatialError> { + let mut runtime = GpuSpatialRuntime { + init: None, + release: None, + get_last_error: None, + private_data: std::ptr::null_mut(), + }; + + unsafe { + GpuSpatialRuntimeCreate(&mut runtime); + } + + if let Some(init_fn) = runtime.init { + let c_ptx_root = CString::new(ptx_root).map_err(|_| { + GpuSpatialError::Init("Failed to convert ptx_root to CString".into()) + })?; + + let mut config = GpuSpatialRuntimeConfig { + device_id, + ptx_root: c_ptx_root.as_ptr(), + use_cuda_memory_pool, + cuda_memory_pool_init_precent, + }; + + unsafe { + let get_last_error = runtime.get_last_error; + let runtime_ptr = &mut runtime as *mut GpuSpatialRuntime; + + check_ffi_call( + move || init_fn(runtime_ptr as *mut _, &mut config), + get_last_error, + runtime_ptr, + GpuSpatialError::Init, + )?; + } + } else { + return Err(GpuSpatialError::Init("init function is None".to_string())); + } + + Ok(GpuSpatialRuntimeWrapper { + runtime: UnsafeCell::new(runtime), + device_id, + }) + } +} + +impl Drop for GpuSpatialRuntimeWrapper { + fn drop(&mut self) { + let runtime = self.runtime.get_mut(); + let release_fn = runtime.release.expect("release function is None"); + unsafe { + release_fn(runtime as *mut _); + } + } +} + +/// Internal wrapper that manages the lifecycle of the C `SedonaFloatIndex2D` struct. +/// It is wrapped in an `Arc` by the public structs to ensure thread safety. +struct FloatIndex2DWrapper { + index: SedonaFloatIndex2D, + // Keep a reference to the RT engine to ensure it lives as long as the index + _runtime: Arc<GpuSpatialRuntimeWrapper>, +} + +impl Drop for FloatIndex2DWrapper { + fn drop(&mut self) { + let release_fn = self.index.release.expect("release function is None"); + unsafe { + release_fn(&mut self.index as *mut _); + } + } +} + +pub struct FloatIndex2D { + inner: FloatIndex2DWrapper, +} + +impl FloatIndex2D { + pub fn try_new( + runtime: Arc<GpuSpatialRuntimeWrapper>, + concurrency: u32, + ) -> Result<Self, GpuSpatialError> { + let mut index = SedonaFloatIndex2D { + clear: None, + create_context: None, + destroy_context: None, + push_build: None, + finish_building: None, + probe: None, + get_last_error: None, + context_get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + }; + + let config = GpuSpatialIndexConfig { + runtime: runtime.runtime.get(), + concurrency, + }; + + unsafe { + if GpuSpatialIndexFloat2DCreate(&mut index, &config) != 0 { + return Err(GpuSpatialError::Init("Index Create failed".into())); + } + } + + Ok(Self { + inner: FloatIndex2DWrapper { + index, + _runtime: runtime.clone(), + }, + }) + } + + pub fn clear(&mut self) { + if let Some(clear_fn) = self.inner.index.clear { + unsafe { + clear_fn(&mut self.inner.index as *mut _); + } + } + } + + pub fn push_build(&mut self, buf: *const f32, n_rects: u32) -> Result<(), GpuSpatialError> { + let push_fn = + self.inner.index.push_build.ok_or_else(|| { + GpuSpatialError::PushBuild("push_build function is None".to_string()) + })?; + let get_last_error = self.inner.index.get_last_error; + let index_ptr = &mut self.inner.index as *mut SedonaFloatIndex2D; + + unsafe { + check_ffi_call( + move || push_fn(index_ptr, buf, n_rects), + get_last_error, + index_ptr, + GpuSpatialError::PushBuild, + ) + } + } + + pub fn finish_building(&mut self) -> Result<(), GpuSpatialError> { + let finish_fn = self + .inner + .index + .finish_building + .ok_or_else(|| GpuSpatialError::FinishBuild("finish_building missing".into()))?; + let get_last_error = self.inner.index.get_last_error; + let index_ptr = &mut self.inner.index as *mut SedonaFloatIndex2D; + + unsafe { + check_ffi_call( + move || finish_fn(&mut self.inner.index), + get_last_error, + index_ptr, + GpuSpatialError::FinishBuild, + ) + } + } + + pub fn probe( + &self, + buf: *const f32, + n_rects: u32, + ) -> Result<(Vec<u32>, Vec<u32>), GpuSpatialError> { + let probe_fn = self + .inner + .index + .probe + .ok_or_else(|| GpuSpatialError::Probe("probe function is None".into()))?; + let create_context_fn = self.inner.index.create_context; + let destroy_context_fn = self.inner.index.destroy_context; + let context_err_fn = self.inner.index.context_get_last_error; + let index_ptr = &self.inner.index as *const _ as *mut SedonaFloatIndex2D; + + let mut ctx = SedonaSpatialIndexContext { + private_data: std::ptr::null_mut(), + }; + let mut state = ProbeState { + results: (Vec::new(), Vec::new()), + error: None, + }; + + unsafe { + if let Some(create_ctx) = create_context_fn { + create_ctx(&mut ctx); + } + + let status = probe_fn( + index_ptr, + &mut ctx, + buf, + n_rects, + Some(probe_callback_wrapper), + &mut state as *mut _ as *mut c_void, + ); + + if status != 0 { + let error_string = if let Some(get_ctx_err) = context_err_fn { + CStr::from_ptr(get_ctx_err(&mut ctx)) + .to_string_lossy() + .into_owned() + } else { + "Unknown context error during probe".to_string() + }; + + if let Some(destroy_ctx) = destroy_context_fn { + destroy_ctx(&mut ctx); + } + return Err(GpuSpatialError::Probe(error_string)); + } + + if let Some(callback_error) = state.error { + if let Some(destroy_ctx) = destroy_context_fn { + destroy_ctx(&mut ctx); + } + return Err(callback_error); + } + + if let Some(destroy_ctx) = destroy_context_fn { + destroy_ctx(&mut ctx); + } + } + + Ok(state.results) + } +} + +struct RefinerWrapper { + refiner: SedonaSpatialRefiner, + _runtime: Arc<GpuSpatialRuntimeWrapper>, +} + +impl Drop for RefinerWrapper { + fn drop(&mut self) { + let release_fn = self.refiner.release.expect("release function is None"); + unsafe { + release_fn(&mut self.refiner as *mut _); + } + } +} +pub struct Refiner { + inner: RefinerWrapper, +} + +impl Refiner { + pub fn try_new( + runtime: Arc<GpuSpatialRuntimeWrapper>, + concurrency: u32, + compress_bvh: bool, + pipeline_batches: u32, + ) -> Result<Self, GpuSpatialError> { + let mut refiner = SedonaSpatialRefiner { + clear: None, + init_schema: None, + push_build: None, + finish_building: None, + refine: None, + get_last_error: None, + release: None, + private_data: std::ptr::null_mut(), + }; + + let config = GpuSpatialRefinerConfig { + runtime: runtime.runtime.get(), + concurrency, + compress_bvh, + pipeline_batches, + }; + + unsafe { + GpuSpatialRefinerCreate(&mut refiner, &config); + } + + Ok(Self { + inner: RefinerWrapper { + refiner, + _runtime: runtime.clone(), + }, + }) + } + + pub fn init_schema( + &mut self, + build_dt: &DataType, + probe_dt: &DataType, + ) -> Result<(), GpuSpatialError> { + let build_ffi = FFI_ArrowSchema::try_from(build_dt)?; + let probe_ffi = FFI_ArrowSchema::try_from(probe_dt)?; + let init_fn = self.inner.refiner.init_schema.unwrap(); + let get_last_error = self.inner.refiner.get_last_error; + let refiner_ptr = &mut self.inner.refiner as *mut SedonaSpatialRefiner; + + unsafe { + check_ffi_call( + || { + init_fn( + &mut self.inner.refiner, + &build_ffi as *const _ as *const _, + &probe_ffi as *const _ as *const _, + ) + }, + get_last_error, + refiner_ptr, + GpuSpatialError::Init, + ) + } + } + + pub fn push_build(&mut self, array: &ArrayRef) -> Result<(), GpuSpatialError> { + let (ffi_array, _) = arrow_array::ffi::to_ffi(&array.to_data())?; + let push_fn = self.inner.refiner.push_build.unwrap(); + let get_last_error = self.inner.refiner.get_last_error; + let refiner_ptr = &mut self.inner.refiner as *mut SedonaSpatialRefiner; + + unsafe { + check_ffi_call( + || push_fn(&mut self.inner.refiner, &ffi_array as *const _ as *const _), + get_last_error, + refiner_ptr, + GpuSpatialError::PushBuild, + ) + } + } + + pub fn clear(&mut self) { + if let Some(clear_fn) = self.inner.refiner.clear { + unsafe { + clear_fn(&mut self.inner.refiner as *mut _); + } + } + } + + pub fn finish_building(&mut self) -> Result<(), GpuSpatialError> { + let finish_fn = self.inner.refiner.finish_building.unwrap(); + let get_last_error = self.inner.refiner.get_last_error; + let refiner_ptr = &mut self.inner.refiner as *mut SedonaSpatialRefiner; + + unsafe { + check_ffi_call( + || finish_fn(&mut self.inner.refiner), + get_last_error, + refiner_ptr, + GpuSpatialError::FinishBuild, + ) + } + } + + pub fn refine( + &self, + array: &ArrayRef, + predicate: GpuSpatialRelationPredicate, + build_indices: &mut Vec<u32>, + probe_indices: &mut Vec<u32>, + ) -> Result<(), GpuSpatialError> { + let (ffi_array, _) = arrow_array::ffi::to_ffi(&array.to_data())?; + let refine_fn = self.inner.refiner.refine.unwrap(); + let mut new_len: u32 = 0; + + unsafe { + check_ffi_call( + || { + refine_fn( + &self.inner.refiner as *const _ as *mut _, + &ffi_array as *const _ as *mut _, + predicate.as_c_uint(), + build_indices.as_mut_ptr(), + probe_indices.as_mut_ptr(), + build_indices.len() as u32, + &mut new_len, + ) + }, + self.inner.refiner.get_last_error, + &self.inner.refiner as *const _ as *mut _, + GpuSpatialError::Refine, + )?; + } + build_indices.truncate(new_len as usize); + probe_indices.truncate(new_len as usize); + Ok(()) + } +} + +// ---------------------------------------------------------------------- +// Helper Functions +// ---------------------------------------------------------------------- + +// Define the exact signature of the C error-getting function +type ErrorFn<T> = unsafe extern "C" fn(*mut T) -> *const c_char; +struct ProbeState { + results: (Vec<u32>, Vec<u32>), + error: Option<GpuSpatialError>, +} +/// Helper to handle the common pattern of calling a C function returning an int status, +/// checking if it failed, and retrieving the error message if so. +unsafe fn check_ffi_call<T, F, ErrMap>( + call_fn: F, + get_error_fn: Option<ErrorFn<T>>, + obj_ptr: *mut T, + err_mapper: ErrMap, +) -> Result<(), GpuSpatialError> +where + F: FnOnce() -> i32, + ErrMap: FnOnce(String) -> GpuSpatialError, +{ + if call_fn() != 0 { + let error_string = if let Some(get_err) = get_error_fn { + let err_ptr = get_err(obj_ptr); + if !err_ptr.is_null() { + CStr::from_ptr(err_ptr).to_string_lossy().into_owned() + } else { + "Unknown error (null error message)".to_string() + } + } else { + "Unknown error (get_last_error not available)".to_string() + }; + + return Err(err_mapper(error_string)); + } + Ok(()) +} + +unsafe extern "C" fn probe_callback_wrapper( + build_indices: *const u32, + probe_indices: *const u32, + length: u32, + user_data: *mut c_void, +) { Review Comment: Not for this PR, but we should expose the visitor interface via Rust because it is so much more memory efficient if there are any polygons whose bbox is not selective. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
