This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 8d1ba38  [fix] Move container stream scanning from Cython to C++ for 
efficiency (#521)
8d1ba38 is described below

commit 8d1ba380c613b30b5910340ec009837729131db1
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Apr 2 11:28:42 2026 -0700

    [fix] Move container stream scanning from Cython to C++ for efficiency 
(#521)
    
    When a packed FFI function receives containers (Array, List, Map, Dict)
    containing tensors, we scan for the first non-CPU tensor to capture
    stream context. Previously this scanning was done in Cython via repeated
    packed-function calls (size, getitem, iterator create/advance per
    element). This PR registers a single C++ function
    ffi.ContainerFindFirstNonCPUDevice that does the recursive scan
    natively.
---
 python/tvm_ffi/cython/function.pxi | 203 ++++---------------------------------
 src/ffi/container.cc               |  59 +++++++++--
 2 files changed, 74 insertions(+), 188 deletions(-)

diff --git a/python/tvm_ffi/cython/function.pxi 
b/python/tvm_ffi/cython/function.pxi
index 65ba378..913c698 100644
--- a/python/tvm_ffi/cython/function.pxi
+++ b/python/tvm_ffi/cython/function.pxi
@@ -62,178 +62,6 @@ cdef inline object make_ret_small_bytes(TVMFFIAny result):
     return bytearray_to_bytes(&bytes)
 
 
-cdef inline bint _check_elem_for_stream(
-    TVMFFIAny* elem_result,
-    const DLPackExchangeAPI* api,
-    TVMFFIPyCallContext* ctx
-) noexcept:
-    """Check a single element for non-CPU tensor; set stream if found.
-
-    Returns True if a non-CPU tensor was found and stream was set.
-    Releases the element ref (for object types) in all cases.
-    """
-    cdef DLTensor* dltensor
-    cdef void* stream = NULL
-    cdef int32_t ti = elem_result.type_index
-
-    if ti == kTVMFFITensor:
-        dltensor = 
TVMFFITensorGetDLTensorPtr(<TVMFFIObjectHandle>elem_result.v_obj)
-        if dltensor.device.device_type != kDLCPU:
-            ctx.device_type = dltensor.device.device_type
-            ctx.device_id = dltensor.device.device_id
-            api.current_work_stream(
-                dltensor.device.device_type,
-                dltensor.device.device_id,
-                &stream)
-            ctx.stream = <TVMFFIStreamHandle>stream
-            TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
-            return True
-        TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
-    elif (ti == kTVMFFIArray or ti == kTVMFFIList
-          or ti == kTVMFFIMap or ti == kTVMFFIDict):
-        _scan_container_for_stream(
-            <TVMFFIObjectHandle>elem_result.v_obj, ti, api, ctx)
-        TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
-        if ctx.device_type != -1:
-            return True
-    elif ti >= kTVMFFIStaticObjectBegin:
-        TVMFFIObjectDecRef(<TVMFFIObjectHandle>elem_result.v_obj)
-    return False
-
-
-cdef inline void _scan_seq_for_stream(
-    TVMFFIObjectHandle chandle,
-    int32_t type_index,
-    const DLPackExchangeAPI* api,
-    TVMFFIPyCallContext* ctx
-) noexcept:
-    """Scan an Array or List for the first non-CPU tensor."""
-    cdef TVMFFIObjectHandle size_func_handle
-    cdef TVMFFIObjectHandle getitem_func_handle
-    cdef TVMFFIAny size_args[1]
-    cdef TVMFFIAny size_result
-    cdef TVMFFIAny getitem_args[2]
-    cdef TVMFFIAny elem_result
-    cdef int64_t n, i
-
-    if type_index == kTVMFFIArray:
-        size_func_handle = (<CObject>_FFI_ARRAY_SIZE).chandle
-        getitem_func_handle = (<CObject>_FFI_ARRAY_GET_ITEM).chandle
-    else:
-        size_func_handle = (<CObject>_FFI_LIST_SIZE).chandle
-        getitem_func_handle = (<CObject>_FFI_LIST_GET_ITEM).chandle
-
-    size_args[0].type_index = type_index
-    size_args[0].v_obj = <TVMFFIObject*>chandle
-    size_result.type_index = kTVMFFINone
-    size_result.v_int64 = 0
-    if TVMFFIFunctionCall(size_func_handle, size_args, 1, &size_result) != 0:
-        return
-
-    n = size_result.v_int64
-    if n == 0:
-        return
-
-    getitem_args[0].type_index = type_index
-    getitem_args[0].v_obj = <TVMFFIObject*>chandle
-
-    for i in range(n):
-        getitem_args[1].type_index = kTVMFFIInt
-        getitem_args[1].v_int64 = i
-        elem_result.type_index = kTVMFFINone
-        elem_result.v_int64 = 0
-        if TVMFFIFunctionCall(getitem_func_handle, getitem_args, 2, 
&elem_result) != 0:
-            return
-        if _check_elem_for_stream(&elem_result, api, ctx):
-            return
-
-
-cdef inline void _scan_map_for_stream(
-    TVMFFIObjectHandle chandle,
-    int32_t type_index,
-    const DLPackExchangeAPI* api,
-    TVMFFIPyCallContext* ctx
-) noexcept:
-    """Scan a Map or Dict's values for the first non-CPU tensor."""
-    cdef TVMFFIObjectHandle size_func_handle
-    cdef TVMFFIObjectHandle iter_func_handle
-    cdef TVMFFIAny size_args[1]
-    cdef TVMFFIAny size_result
-    cdef TVMFFIAny iter_args[1]
-    cdef TVMFFIAny iter_result
-    cdef TVMFFIObjectHandle iter_handle = NULL
-    cdef TVMFFIAny cmd[1]
-    cdef TVMFFIAny val_result
-    cdef TVMFFIAny advance_result
-    cdef int64_t n, i
-
-    if type_index == kTVMFFIMap:
-        size_func_handle = (<CObject>_FFI_MAP_SIZE).chandle
-        iter_func_handle = (<CObject>_FFI_MAP_FORWARD_ITER).chandle
-    else:
-        size_func_handle = (<CObject>_FFI_DICT_SIZE).chandle
-        iter_func_handle = (<CObject>_FFI_DICT_FORWARD_ITER).chandle
-
-    size_args[0].type_index = type_index
-    size_args[0].v_obj = <TVMFFIObject*>chandle
-    size_result.type_index = kTVMFFINone
-    size_result.v_int64 = 0
-    if TVMFFIFunctionCall(size_func_handle, size_args, 1, &size_result) != 0:
-        return
-
-    n = size_result.v_int64
-    if n == 0:
-        return
-
-    # Get forward iterator
-    iter_args[0].type_index = type_index
-    iter_args[0].v_obj = <TVMFFIObject*>chandle
-    iter_result.type_index = kTVMFFINone
-    iter_result.v_int64 = 0
-    if TVMFFIFunctionCall(iter_func_handle, iter_args, 1, &iter_result) != 0:
-        return
-    iter_handle = <TVMFFIObjectHandle>iter_result.v_obj
-
-    for i in range(n):
-        # Get value (command=1)
-        cmd[0].type_index = kTVMFFIInt
-        cmd[0].v_int64 = 1
-        val_result.type_index = kTVMFFINone
-        val_result.v_int64 = 0
-        if TVMFFIFunctionCall(iter_handle, cmd, 1, &val_result) != 0:
-            TVMFFIObjectDecRef(iter_handle)
-            return
-        if _check_elem_for_stream(&val_result, api, ctx):
-            TVMFFIObjectDecRef(iter_handle)
-            return
-        # Advance (command=2), skip after last entry
-        if i < n - 1:
-            cmd[0].v_int64 = 2
-            advance_result.type_index = kTVMFFINone
-            advance_result.v_int64 = 0
-            if TVMFFIFunctionCall(iter_handle, cmd, 1, &advance_result) != 0:
-                TVMFFIObjectDecRef(iter_handle)
-                return
-
-    TVMFFIObjectDecRef(iter_handle)
-
-
-cdef inline void _scan_container_for_stream(
-    TVMFFIObjectHandle chandle,
-    int32_t type_index,
-    const DLPackExchangeAPI* api,
-    TVMFFIPyCallContext* ctx
-) noexcept:
-    """Scan a container for the first non-CPU tensor to set stream context.
-
-    Best-effort: silently returns on any FFI error (equivalent to no stream 
set).
-    """
-    if type_index == kTVMFFIArray or type_index == kTVMFFIList:
-        _scan_seq_for_stream(chandle, type_index, api, ctx)
-    elif type_index == kTVMFFIMap or type_index == kTVMFFIDict:
-        _scan_map_for_stream(chandle, type_index, api, ctx)
-
-
 cdef inline object make_ret(TVMFFIAny result, const DLPackExchangeAPI* 
c_ctx_dlpack_api = NULL):
     """convert result to return value."""
     cdef int32_t type_index
@@ -332,6 +160,9 @@ cdef int TVMFFIPyArgSetterContainerObject_(
 
     Propagates DLPack exchange API tag and scans for stream context.
     """
+    cdef TVMFFIAny scan_args[1]
+    cdef TVMFFIAny scan_result
+    cdef void* stream = NULL
     out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
     out.v_ptr = (<CObject>arg).chandle
     cdef const DLPackExchangeAPI* api = 
(<CContainerBase>arg)._dlpack_exchange_api
@@ -339,8 +170,22 @@ cdef int TVMFFIPyArgSetterContainerObject_(
         if ctx.dlpack_c_exchange_api == NULL:
             ctx.dlpack_c_exchange_api = api
         if ctx.device_type == -1 and api.current_work_stream != NULL:
-            _scan_container_for_stream(
-                (<CObject>arg).chandle, out.type_index, api, ctx)
+            # Call C++ to find the first non-CPU tensor device in one shot.
+            scan_args[0].type_index = out.type_index
+            scan_args[0].v_obj = <TVMFFIObject*>(<CObject>arg).chandle
+            scan_result.type_index = kTVMFFINone
+            scan_result.v_int64 = 0
+            CHECK_CALL(TVMFFIFunctionCall(
+                (<CObject>_FFI_CONTAINER_FIND_FIRST_NON_CPU_DEVICE).chandle,
+                scan_args, 1, &scan_result))
+            if scan_result.type_index == kTVMFFIDevice and 
scan_result.v_device.device_type != kDLCPU:
+                ctx.device_type = scan_result.v_device.device_type
+                ctx.device_id = scan_result.v_device.device_id
+                api.current_work_stream(
+                    scan_result.v_device.device_type,
+                    scan_result.v_device.device_id,
+                    &stream)
+                ctx.stream = <TVMFFIStreamHandle>stream
     return 0
 
 
@@ -1345,11 +1190,5 @@ cdef Function _OBJECT_FROM_JSON_GRAPH_STR = 
_get_global_func("ffi.FromJSONGraphS
 cdef Function _OBJECT_TO_JSON_GRAPH_STR = 
_get_global_func("ffi.ToJSONGraphString", True)
 cdef Function _CONSTRUCTOR_ARRAY = _get_global_func("ffi.Array", True)
 cdef Function _CONSTRUCTOR_MAP = _get_global_func("ffi.Map", True)
-cdef Function _FFI_ARRAY_GET_ITEM = _get_global_func("ffi.ArrayGetItem", True)
-cdef Function _FFI_ARRAY_SIZE = _get_global_func("ffi.ArraySize", True)
-cdef Function _FFI_LIST_GET_ITEM = _get_global_func("ffi.ListGetItem", True)
-cdef Function _FFI_LIST_SIZE = _get_global_func("ffi.ListSize", True)
-cdef Function _FFI_MAP_SIZE = _get_global_func("ffi.MapSize", True)
-cdef Function _FFI_MAP_FORWARD_ITER = 
_get_global_func("ffi.MapForwardIterFunctor", True)
-cdef Function _FFI_DICT_SIZE = _get_global_func("ffi.DictSize", True)
-cdef Function _FFI_DICT_FORWARD_ITER = 
_get_global_func("ffi.DictForwardIterFunctor", True)
+cdef Function _FFI_CONTAINER_FIND_FIRST_NON_CPU_DEVICE = _get_global_func(
+    "ffi.ContainerFindFirstNonCPUDevice", True)
diff --git a/src/ffi/container.cc b/src/ffi/container.cc
index 354dcb8..fbb0e0d 100644
--- a/src/ffi/container.cc
+++ b/src/ffi/container.cc
@@ -23,6 +23,7 @@
 #include <tvm/ffi/container/dict.h>
 #include <tvm/ffi/container/list.h>
 #include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/tensor.h>
 #include <tvm/ffi/function.h>
 #include <tvm/ffi/reflection/registry.h>
 
@@ -31,6 +32,46 @@
 namespace tvm {
 namespace ffi {
 
+namespace {
+/*!
+ * \brief Recursively scan an Any element for the first non-CPU tensor device.
+ * \param elem The element to inspect.
+ * \param out Output device; written only when a non-CPU tensor is found.
+ * \return true if a non-CPU tensor was found.
+ */
+bool FindFirstNonCPUDevice(const Any& elem, DLDevice* out) {
+  switch (elem.type_index()) {
+    case TypeIndex::kTVMFFITensor: {
+      const auto* tensor = elem.as<TensorObj>();
+      if (tensor->device.device_type != kDLCPU) {
+        *out = tensor->device;
+        return true;
+      }
+      break;
+    }
+    case TypeIndex::kTVMFFIArray:
+    case TypeIndex::kTVMFFIList: {
+      const auto* seq = elem.as<SeqBaseObj>();
+      for (const auto& it : *seq) {
+        if (FindFirstNonCPUDevice(it, out)) return true;
+      }
+      break;
+    }
+    case TypeIndex::kTVMFFIMap:
+    case TypeIndex::kTVMFFIDict: {
+      const auto* map = elem.as<MapBaseObj>();
+      for (const auto& it : *map) {
+        if (FindFirstNonCPUDevice(it.second, out)) return true;
+      }
+      break;
+    }
+    default:
+      break;
+  }
+  return false;
+}
+}  // namespace
+
 // Favor struct outside function scope as MSVC may have bug for in fn scope 
struct.
 class MapForwardIterFunctor {
  public:
@@ -184,12 +225,18 @@ TVM_FFI_STATIC_INIT_BLOCK() {
            [](const ffi::DictObj* n) -> ffi::Function {
              return ffi::Function::FromTyped(MapForwardIterFunctor(n->begin(), 
n->end()));
            })
-      .def("ffi.DictGetItemOrMissing", [](const ffi::DictObj* n, const Any& k) 
-> Any {
-        try {
-          return n->at(k);
-        } catch (const tvm::ffi::Error& e) {
-          return GetMissingObject();
-        }
+      .def("ffi.DictGetItemOrMissing",
+           [](const ffi::DictObj* n, const Any& k) -> Any {
+             try {
+               return n->at(k);
+             } catch (const tvm::ffi::Error& e) {
+               return GetMissingObject();
+             }
+           })
+      .def("ffi.ContainerFindFirstNonCPUDevice", [](const Any& container) -> 
DLDevice {
+        DLDevice result{kDLCPU, 0};
+        FindFirstNonCPUDevice(container, &result);
+        return result;
       });
 }
 }  // namespace ffi

Reply via email to