This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 8d1ba38 [fix] Move container stream scanning from Cython to C++ for
efficiency (#521)
8d1ba38 is described below
commit 8d1ba380c613b30b5910340ec009837729131db1
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Apr 2 11:28:42 2026 -0700
[fix] Move container stream scanning from Cython to C++ for efficiency
(#521)
When a packed FFI function receives containers (Array, List, Map, Dict)
containing tensors, we scan for the first non-CPU tensor to capture
stream context. Previously this scanning was done in Cython via repeated
packed-function calls (size, getitem, iterator create/advance per
element). This PR registers a single C++ function
ffi.ContainerFindFirstNonCPUDevice that does the recursive scan
natively.
---
python/tvm_ffi/cython/function.pxi | 203 ++++---------------------------------
src/ffi/container.cc | 59 +++++++++--
2 files changed, 74 insertions(+), 188 deletions(-)
diff --git a/python/tvm_ffi/cython/function.pxi
b/python/tvm_ffi/cython/function.pxi
index 65ba378..913c698 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -62,178 +62,6 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result):
return bytearray_to_bytes(&bytes)
-cdef inline bint _check_elem_for_stream(
- TVMFFIAny* elem_result,
- const DLPackExchangeAPI* api,
- TVMFFIPyCallContext* ctx
-) noexcept:
- """Check a single element for non-CPU tensor; set stream if found.
-
- Returns True if a non-CPU tensor was found and stream was set.
- Releases the element ref (for object types) in all cases.
- """
- cdef DLTensor* dltensor
- cdef void* stream = NULL
- cdef int32_t ti = elem_result.type_index
-
- if ti == kTVMFFITensor:
- dltensor =
TVMFFITensorGetDLTensorPtr(<TVMFFIObjectHandle>elem_result.v_obj)
- if dltensor.device.device_type != kDLCPU:
- ctx.device_type = dltensor.device.device_type
- ctx.device_id = dltensor.device.device_id
- api.current_work_stream(
- dltensor.device.device_type,
- dltensor.device.device_id,
- &stream)
- ctx.stream = <TVMFFIStreamHandle>stream
- TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
- return True
- TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
- elif (ti == kTVMFFIArray or ti == kTVMFFIList
- or ti == kTVMFFIMap or ti == kTVMFFIDict):
- _scan_container_for_stream(
- <TVMFFIObjectHandle>elem_result.v_obj, ti, api, ctx)
- TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
- if ctx.device_type != -1:
- return True
- elif ti >= kTVMFFIStaticObjectBegin:
- TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
- return False
-
-
-cdef inline void _scan_seq_for_stream(
- TVMFFIObjectHandle chandle,
- int32_t type_index,
- const DLPackExchangeAPI* api,
- TVMFFIPyCallContext* ctx
-) noexcept:
- """Scan an Array or List for the first non-CPU tensor."""
- cdef TVMFFIObjectHandle size_func_handle
- cdef TVMFFIObjectHandle getitem_func_handle
- cdef TVMFFIAny size_args[1]
- cdef TVMFFIAny size_result
- cdef TVMFFIAny getitem_args[2]
- cdef TVMFFIAny elem_result
- cdef int64_t n, i
-
- if type_index == kTVMFFIArray:
- size_func_handle = (<CObject>_FFI_ARRAY_SIZE).chandle
- getitem_func_handle = (<CObject>_FFI_ARRAY_GET_ITEM).chandle
- else:
- size_func_handle = (<CObject>_FFI_LIST_SIZE).chandle
- getitem_func_handle = (<CObject>_FFI_LIST_GET_ITEM).chandle
-
- size_args[0].type_index = type_index
- size_args[0].v_obj = <TVMFFIObject*>chandle
- size_result.type_index = kTVMFFINone
- size_result.v_int64 = 0
- if TVMFFIFunctionCall(size_func_handle, size_args, 1, &size_result) != 0:
- return
-
- n = size_result.v_int64
- if n == 0:
- return
-
- getitem_args[0].type_index = type_index
- getitem_args[0].v_obj = <TVMFFIObject*>chandle
-
- for i in range(n):
- getitem_args[1].type_index = kTVMFFIInt
- getitem_args[1].v_int64 = i
- elem_result.type_index = kTVMFFINone
- elem_result.v_int64 = 0
- if TVMFFIFunctionCall(getitem_func_handle, getitem_args, 2,
&elem_result) != 0:
- return
- if _check_elem_for_stream(&elem_result, api, ctx):
- return
-
-
-cdef inline void _scan_map_for_stream(
- TVMFFIObjectHandle chandle,
- int32_t type_index,
- const DLPackExchangeAPI* api,
- TVMFFIPyCallContext* ctx
-) noexcept:
- """Scan a Map or Dict's values for the first non-CPU tensor."""
- cdef TVMFFIObjectHandle size_func_handle
- cdef TVMFFIObjectHandle iter_func_handle
- cdef TVMFFIAny size_args[1]
- cdef TVMFFIAny size_result
- cdef TVMFFIAny iter_args[1]
- cdef TVMFFIAny iter_result
- cdef TVMFFIObjectHandle iter_handle = NULL
- cdef TVMFFIAny cmd[1]
- cdef TVMFFIAny val_result
- cdef TVMFFIAny advance_result
- cdef int64_t n, i
-
- if type_index == kTVMFFIMap:
- size_func_handle = (<CObject>_FFI_MAP_SIZE).chandle
- iter_func_handle = (<CObject>_FFI_MAP_FORWARD_ITER).chandle
- else:
- size_func_handle = (<CObject>_FFI_DICT_SIZE).chandle
- iter_func_handle = (<CObject>_FFI_DICT_FORWARD_ITER).chandle
-
- size_args[0].type_index = type_index
- size_args[0].v_obj = <TVMFFIObject*>chandle
- size_result.type_index = kTVMFFINone
- size_result.v_int64 = 0
- if TVMFFIFunctionCall(size_func_handle, size_args, 1, &size_result) != 0:
- return
-
- n = size_result.v_int64
- if n == 0:
- return
-
- # Get forward iterator
- iter_args[0].type_index = type_index
- iter_args[0].v_obj = <TVMFFIObject*>chandle
- iter_result.type_index = kTVMFFINone
- iter_result.v_int64 = 0
- if TVMFFIFunctionCall(iter_func_handle, iter_args, 1, &iter_result) != 0:
- return
- iter_handle = <TVMFFIObjectHandle>iter_result.v_obj
-
- for i in range(n):
- # Get value (command=1)
- cmd[0].type_index = kTVMFFIInt
- cmd[0].v_int64 = 1
- val_result.type_index = kTVMFFINone
- val_result.v_int64 = 0
- if TVMFFIFunctionCall(iter_handle, cmd, 1, &val_result) != 0:
- TVMFFIObjectDecRef(iter_handle)
- return
- if _check_elem_for_stream(&val_result, api, ctx):
- TVMFFIObjectDecRef(iter_handle)
- return
- # Advance (command=2), skip after last entry
- if i < n - 1:
- cmd[0].v_int64 = 2
- advance_result.type_index = kTVMFFINone
- advance_result.v_int64 = 0
- if TVMFFIFunctionCall(iter_handle, cmd, 1, &advance_result) != 0:
- TVMFFIObjectDecRef(iter_handle)
- return
-
- TVMFFIObjectDecRef(iter_handle)
-
-
-cdef inline void _scan_container_for_stream(
- TVMFFIObjectHandle chandle,
- int32_t type_index,
- const DLPackExchangeAPI* api,
- TVMFFIPyCallContext* ctx
-) noexcept:
- """Scan a container for the first non-CPU tensor to set stream context.
-
- Best-effort: silently returns on any FFI error (equivalent to no stream
set).
- """
- if type_index == kTVMFFIArray or type_index == kTVMFFIList:
- _scan_seq_for_stream(chandle, type_index, api, ctx)
- elif type_index == kTVMFFIMap or type_index == kTVMFFIDict:
- _scan_map_for_stream(chandle, type_index, api, ctx)
-
-
cdef inline object make_ret(TVMFFIAny result, const DLPackExchangeAPI*
c_ctx_dlpack_api = NULL):
"""convert result to return value."""
cdef int32_t type_index
@@ -332,6 +160,9 @@ cdef int TVMFFIPyArgSetterContainerObject_(
Propagates DLPack exchange API tag and scans for stream context.
"""
+ cdef TVMFFIAny scan_args[1]
+ cdef TVMFFIAny scan_result
+ cdef void* stream = NULL
out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
out.v_ptr = (<CObject>arg).chandle
cdef const DLPackExchangeAPI* api =
(<CContainerBase>arg)._dlpack_exchange_api
@@ -339,8 +170,22 @@ cdef int TVMFFIPyArgSetterContainerObject_(
if ctx.dlpack_c_exchange_api == NULL:
ctx.dlpack_c_exchange_api = api
if ctx.device_type == -1 and api.current_work_stream != NULL:
- _scan_container_for_stream(
- (<CObject>arg).chandle, out.type_index, api, ctx)
+ # Call C++ to find the first non-CPU tensor device in one shot.
+ scan_args[0].type_index = out.type_index
+ scan_args[0].v_obj = <TVMFFIObject*>(<CObject>arg).chandle
+ scan_result.type_index = kTVMFFINone
+ scan_result.v_int64 = 0
+ CHECK_CALL(TVMFFIFunctionCall(
+ (<CObject>_FFI_CONTAINER_FIND_FIRST_NON_CPU_DEVICE).chandle,
+ scan_args, 1, &scan_result))
+ if scan_result.type_index == kTVMFFIDevice and
scan_result.v_device.device_type != kDLCPU:
+ ctx.device_type = scan_result.v_device.device_type
+ ctx.device_id = scan_result.v_device.device_id
+ api.current_work_stream(
+ scan_result.v_device.device_type,
+ scan_result.v_device.device_id,
+ &stream)
+ ctx.stream = <TVMFFIStreamHandle>stream
return 0
@@ -1345,11 +1190,5 @@ cdef Function _OBJECT_FROM_JSON_GRAPH_STR =
_get_global_func("ffi.FromJSONGraphS
cdef Function _OBJECT_TO_JSON_GRAPH_STR =
_get_global_func("ffi.ToJSONGraphString", True)
cdef Function _CONSTRUCTOR_ARRAY = _get_global_func("ffi.Array", True)
cdef Function _CONSTRUCTOR_MAP = _get_global_func("ffi.Map", True)
-cdef Function _FFI_ARRAY_GET_ITEM = _get_global_func("ffi.ArrayGetItem", True)
-cdef Function _FFI_ARRAY_SIZE = _get_global_func("ffi.ArraySize", True)
-cdef Function _FFI_LIST_GET_ITEM = _get_global_func("ffi.ListGetItem", True)
-cdef Function _FFI_LIST_SIZE = _get_global_func("ffi.ListSize", True)
-cdef Function _FFI_MAP_SIZE = _get_global_func("ffi.MapSize", True)
-cdef Function _FFI_MAP_FORWARD_ITER =
_get_global_func("ffi.MapForwardIterFunctor", True)
-cdef Function _FFI_DICT_SIZE = _get_global_func("ffi.DictSize", True)
-cdef Function _FFI_DICT_FORWARD_ITER =
_get_global_func("ffi.DictForwardIterFunctor", True)
+cdef Function _FFI_CONTAINER_FIND_FIRST_NON_CPU_DEVICE = _get_global_func(
+ "ffi.ContainerFindFirstNonCPUDevice", True)
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index 354dcb8..fbb0e0d 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -23,6 +23,7 @@
#include <tvm/ffi/container/dict.h>
#include <tvm/ffi/container/list.h>
#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/tensor.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/registry.h>
@@ -31,6 +32,46 @@
namespace tvm {
namespace ffi {
+namespace {
+/*!
+ * \brief Recursively scan an Any element for the first non-CPU tensor device.
+ * \param elem The element to inspect.
+ * \param out Output device; written only when a non-CPU tensor is found.
+ * \return true if a non-CPU tensor was found.
+ */
+bool FindFirstNonCPUDevice(const Any& elem, DLDevice* out) {
+ switch (elem.type_index()) {
+ case TypeIndex::kTVMFFITensor: {
+ const auto* tensor = elem.as<TensorObj>();
+ if (tensor->device.device_type != kDLCPU) {
+ *out = tensor->device;
+ return true;
+ }
+ break;
+ }
+ case TypeIndex::kTVMFFIArray:
+ case TypeIndex::kTVMFFIList: {
+ const auto* seq = elem.as<SeqBaseObj>();
+ for (const auto& it : *seq) {
+ if (FindFirstNonCPUDevice(it, out)) return true;
+ }
+ break;
+ }
+ case TypeIndex::kTVMFFIMap:
+ case TypeIndex::kTVMFFIDict: {
+ const auto* map = elem.as<MapBaseObj>();
+ for (const auto& it : *map) {
+ if (FindFirstNonCPUDevice(it.second, out)) return true;
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return false;
+}
+} // namespace
+
// Favor struct outside function scope as MSVC may have bug for in fn scope
struct.
class MapForwardIterFunctor {
public:
@@ -184,12 +225,18 @@ TVM_FFI_STATIC_INIT_BLOCK() {
[](const ffi::DictObj* n) -> ffi::Function {
return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(),
n->end()));
})
- .def("ffi.DictGetItemOrMissing", [](const ffi::DictObj* n, const Any& k)
-> Any {
- try {
- return n->at(k);
- } catch (const tvm::ffi::Error& e) {
- return GetMissingObject();
- }
+ .def("ffi.DictGetItemOrMissing",
+ [](const ffi::DictObj* n, const Any& k) -> Any {
+ try {
+ return n->at(k);
+ } catch (const tvm::ffi::Error& e) {
+ return GetMissingObject();
+ }
+ })
+ .def("ffi.ContainerFindFirstNonCPUDevice", [](const Any& container) ->
DLDevice {
+ DLDevice result{kDLCPU, 0};
+ FindFirstNonCPUDevice(container, &result);
+ return result;
});
}
} // namespace ffi