This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new dbf64e88 [FEAT] Add map_dataclass_to_tuple to make_kwargs_wrapper 
(#515)
dbf64e88 is described below

commit dbf64e88c5f171da493ea887acc7ae5887101676
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Mar 26 20:47:21 2026 -0400

    [FEAT] Add map_dataclass_to_tuple to make_kwargs_wrapper (#515)
    
    Add a `map_dataclass_to_tuple` parameter to `make_kwargs_wrapper` and
    `make_kwargs_wrapper_from_signature` that accepts a list of argument
    names whose dataclass values should be converted to tuples (via
    `dataclasses.astuple`) before being passed to the target function.
    
    This enables callers to pass dataclass instances to FFI functions that
    expect flattened tuple arguments, matching the calling convention. The
    conversion is injected into the generated wrapper code, so there is no
    overhead when the feature is not used.
---
 pyproject.toml                            |   4 +-
 python/tvm_ffi/utils/kwargs_wrapper.py    | 188 +++++++++++++++++-------------
 tests/python/utils/test_kwargs_wrapper.py |  82 +++++++++++++
 3 files changed, 193 insertions(+), 81 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index d5da7860..a5d65521 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -75,7 +75,9 @@ docs = [
   "nbconvert",
   "nbsphinx",
   "nbstripout",
-  "sphinx",
+  # TODO: unpin when sphinx-toolbox fixes `from sphinx.ext.autodoc import 
logger` (removed in Sphinx 9)
+  # Blocked by: sphinx-toolbox/utils.py imports logger from sphinx.ext.autodoc 
(broken as of v4.1.2)
+  "sphinx<9",
   "sphinx-autobuild",
   "sphinx-book-theme",
   "sphinx-copybutton",
diff --git a/python/tvm_ffi/utils/kwargs_wrapper.py 
b/python/tvm_ffi/utils/kwargs_wrapper.py
index 13561e6f..df1a8f90 100644
--- a/python/tvm_ffi/utils/kwargs_wrapper.py
+++ b/python/tvm_ffi/utils/kwargs_wrapper.py
@@ -22,6 +22,7 @@ keyword argument support using code generation techniques.
 
 from __future__ import annotations
 
+import dataclasses
 import functools
 import inspect
 import keyword
@@ -30,6 +31,18 @@ from typing import Any, Callable, Iterable
 # Sentinel object for missing arguments
 MISSING = object()
 
+# Internal variable names used in generated code to avoid user argument 
conflicts
+_INTERNAL_TARGET_FUNC = "__i_target_func"
+_INTERNAL_MISSING = "__i_MISSING"
+_INTERNAL_DEFAULTS_DICT = "__i_arg_defaults"
+_INTERNAL_ASTUPLE = "__i_astuple"
+_INTERNAL_NAMES = {
+    _INTERNAL_TARGET_FUNC,
+    _INTERNAL_MISSING,
+    _INTERNAL_DEFAULTS_DICT,
+    _INTERNAL_ASTUPLE,
+}
+
 
 def _validate_argument_names(names: list[str], arg_type: str) -> None:
     """Validate that argument names are valid Python identifiers and unique.
@@ -126,6 +139,77 @@ def _validate_wrapper_args(
         )
 
 
+def _build_wrapper_code(
+    arg_names: list[str],
+    arg_defaults: tuple,
+    kwonly_names: list[str],
+    kwonly_defaults: dict[str, Any],
+    dc_to_tuple_set: set[str],
+) -> tuple[str, dict[str, Any]]:
+    """Build the generated wrapper code string and runtime defaults dict.
+
+    Returns
+    -------
+        A tuple of (code_str, runtime_defaults) where code_str is the generated
+        wrapper function code and runtime_defaults maps arg names to their 
default values.
+
+    """
+    # Build positional defaults dictionary (right-aligned)
+    arg_defaults_dict = (
+        dict(zip(arg_names[-len(arg_defaults) :], arg_defaults)) if 
arg_defaults else {}
+    )
+
+    arg_parts: list[str] = []
+    call_parts: list[str] = []
+    runtime_defaults: dict[str, Any] = {}
+
+    def _wrap_astuple(name: str, expr: str) -> str:
+        if name in dc_to_tuple_set:
+            return f"{_INTERNAL_ASTUPLE}({expr})"
+        return expr
+
+    def _add_param_with_default(name: str, default_value: Any) -> None:
+        # Directly embed None and bool defaults; use MISSING sentinel for 
others.
+        if default_value is None:
+            arg_parts.append(f"{name}=None")
+            call_parts.append(_wrap_astuple(name, name))
+        elif type(default_value) is bool:
+            default_value_str = "True" if default_value else "False"
+            arg_parts.append(f"{name}={default_value_str}")
+            call_parts.append(_wrap_astuple(name, name))
+        else:
+            arg_parts.append(f"{name}={_INTERNAL_MISSING}")
+            runtime_defaults[name] = default_value
+            base_expr = (
+                f'{_INTERNAL_DEFAULTS_DICT}["{name}"] if {name} is 
{_INTERNAL_MISSING} else {name}'
+            )
+            call_parts.append(_wrap_astuple(name, base_expr))
+
+    for name in arg_names:
+        if name in arg_defaults_dict:
+            _add_param_with_default(name, arg_defaults_dict[name])
+        else:
+            arg_parts.append(name)
+            call_parts.append(_wrap_astuple(name, name))
+
+    if kwonly_names:
+        arg_parts.append("*")
+        for name in kwonly_names:
+            if name in kwonly_defaults:
+                _add_param_with_default(name, kwonly_defaults[name])
+            else:
+                arg_parts.append(name)
+                call_parts.append(_wrap_astuple(name, name))
+
+    arg_list = ", ".join(arg_parts)
+    call_list = ", ".join(call_parts)
+    code_str = f"""
+def wrapper({arg_list}):
+    return {_INTERNAL_TARGET_FUNC}({call_list})
+"""
+    return code_str, runtime_defaults
+
+
 def make_kwargs_wrapper(
     target_func: Callable,
     arg_names: list[str],
@@ -133,6 +217,7 @@ def make_kwargs_wrapper(
     kwonly_names: list[str] | None = None,
     kwonly_defaults: dict[str, Any] | None = None,
     prototype: Callable | None = None,
+    map_dataclass_to_tuple: list[str] | None = None,
 ) -> Callable:
     """Create a wrapper with kwargs support for a function that only accepts 
positional arguments.
 
@@ -166,6 +251,11 @@ def make_kwargs_wrapper(
     prototype
         Optional prototype function to copy metadata (__name__, __doc__, 
__module__,
         __qualname__, __annotations__) from. If None, no metadata is copied.
+    map_dataclass_to_tuple
+        Optional list of argument names whose values should be converted from 
dataclass
+        instances to tuples (via ``dataclasses.astuple``) before being passed 
to the
+        target function. This is useful when the target function expects 
flattened tuple
+        arguments but callers pass dataclass instances.
 
     Returns
     -------
@@ -184,98 +274,30 @@ def make_kwargs_wrapper(
         kwonly_names = []
     if kwonly_defaults is None:
         kwonly_defaults = {}
-
-    # Internal variable names used in generated code to avoid user argument 
conflicts
-    _INTERNAL_TARGET_FUNC = "__i_target_func"
-    _INTERNAL_MISSING = "__i_MISSING"
-    _INTERNAL_DEFAULTS_DICT = "__i_arg_defaults"
-    _INTERNAL_NAMES = {_INTERNAL_TARGET_FUNC, _INTERNAL_MISSING, 
_INTERNAL_DEFAULTS_DICT}
+    dc_to_tuple_set = set(map_dataclass_to_tuple) if map_dataclass_to_tuple 
else set()
 
     # Validate all input arguments
     _validate_wrapper_args(arg_names, arg_defaults, kwonly_names, 
kwonly_defaults, _INTERNAL_NAMES)
 
-    # Build positional defaults dictionary (right-aligned)
-    # Example: arg_names=["a","b","c","d"], arg_defaults=(10,20) -> {"c":10, 
"d":20}
-    arg_defaults_dict = (
-        dict(zip(arg_names[-len(arg_defaults) :], arg_defaults)) if 
arg_defaults else {}
+    # Build the generated wrapper code
+    code_str, runtime_defaults = _build_wrapper_code(
+        arg_names, arg_defaults, kwonly_names, kwonly_defaults, dc_to_tuple_set
     )
 
-    # Build wrapper signature and call arguments
-    # Note: this code must be in this function so all code generation and exec 
are self-contained
-    # We construct runtime_defaults dict for only non-safe defaults that need 
MISSING sentinel
-    arg_parts = []
-    call_parts = []
-    runtime_defaults = {}
-
-    def _add_param_with_default(name: str, default_value: Any) -> None:
-        """Add a parameter with a default value to arg_parts and call_parts."""
-        # Rationale: we directly embed default values for None and bool
-        # since they are common case and safe to be directly included in 
generated code.
-        #
-        # For other cases (including int/str), we use the MISSING sentinel to 
ensure
-        # generated code do not contain unexpected str repr and instead they 
are passed
-        # through runtime_defaults[name].
-        #
-        # we deliberately skip int/str since bring their string representation
-        # may involve __str__ / __repr__ that could be updated by subclasses.
-        # The missing check is generally fast enough and more controllable.
-        if default_value is None:
-            # Safe to use the default value None directly in the signature
-            arg_parts.append(f"{name}=None")
-            call_parts.append(name)
-        elif type(default_value) is bool:
-            # we deliberately not use isinstance to avoid subclasses of bool
-            # we also explicitly avoid repr for safety
-            default_value_str = "True" if default_value else "False"
-            arg_parts.append(f"{name}={default_value_str}")
-            call_parts.append(name)
-        else:
-            # For all other cases, we use the MISSING sentinel
-            arg_parts.append(f"{name}={_INTERNAL_MISSING}")
-            runtime_defaults[name] = default_value
-            # The conditional check runs
-            call_parts.append(
-                f'{_INTERNAL_DEFAULTS_DICT}["{name}"] if {name} is 
{_INTERNAL_MISSING} else {name}'
-            )
-
-    # Handle positional arguments
-    for name in arg_names:
-        if name in arg_defaults_dict:
-            _add_param_with_default(name, arg_defaults_dict[name])
-        else:
-            arg_parts.append(name)
-            call_parts.append(name)
-
-    # Handle keyword-only arguments
-    if kwonly_names:
-        arg_parts.append("*")  # Separator for keyword-only args
-        for name in kwonly_names:
-            if name in kwonly_defaults:
-                _add_param_with_default(name, kwonly_defaults[name])
-            else:
-                # Required keyword-only arg (no default)
-                arg_parts.append(name)
-                call_parts.append(name)
-
-    arg_list = ", ".join(arg_parts)
-    call_list = ", ".join(call_parts)
-
-    code_str = f"""
-def wrapper({arg_list}):
-    return {_INTERNAL_TARGET_FUNC}({call_list})
-"""
     # Execute the generated code
-    exec_globals = {
-        _INTERNAL_TARGET_FUNC: target_func,
-        _INTERNAL_MISSING: MISSING,
-        _INTERNAL_DEFAULTS_DICT: runtime_defaults,
-    }
-    namespace: dict[str, Any] = {}
     # Note: this is a limited use of exec that is safe.
     # We ensure generated code does not contain any untrusted input.
     # The argument names are validated and the default values are not part of 
generated code.
     # Instead default values are set to MISSING sentinel object and explicitly 
passed as exec_globals.
     # This is a practice adopted by `dataclasses` and `pydantic`
+    exec_globals: dict[str, Any] = {
+        _INTERNAL_TARGET_FUNC: target_func,
+        _INTERNAL_MISSING: MISSING,
+        _INTERNAL_DEFAULTS_DICT: runtime_defaults,
+    }
+    if dc_to_tuple_set:
+        exec_globals[_INTERNAL_ASTUPLE] = dataclasses.astuple
+    namespace: dict[str, Any] = {}
     exec(code_str, exec_globals, namespace)
     new_func = namespace["wrapper"]
 
@@ -291,6 +313,7 @@ def make_kwargs_wrapper_from_signature(
     signature: inspect.Signature,
     prototype: Callable | None = None,
     exclude_arg_names: Iterable[str] | None = None,
+    map_dataclass_to_tuple: list[str] | None = None,
 ) -> Callable:
     """Create a wrapper with kwargs support for a function that only accepts 
positional arguments.
 
@@ -311,6 +334,10 @@ def make_kwargs_wrapper_from_signature(
         Optional iterable of argument names to ignore when extracting 
parameters from the signature.
         These arguments will not be included in the generated wrapper. If a 
name in this iterable
         does not exist in the signature, it is silently ignored.
+    map_dataclass_to_tuple
+        Optional list of argument names whose values should be converted from 
dataclass
+        instances to tuples (via ``dataclasses.astuple``) before being passed 
to the
+        target function.
 
     Returns
     -------
@@ -372,4 +399,5 @@ def make_kwargs_wrapper_from_signature(
         kwonly_names,
         kwonly_defaults,
         prototype,
+        map_dataclass_to_tuple,
     )
diff --git a/tests/python/utils/test_kwargs_wrapper.py 
b/tests/python/utils/test_kwargs_wrapper.py
index 4c370deb..2b731b29 100644
--- a/tests/python/utils/test_kwargs_wrapper.py
+++ b/tests/python/utils/test_kwargs_wrapper.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import dataclasses
 import inspect
 from typing import Any
 
@@ -369,3 +370,84 @@ def test_optimized_default_types() -> None:
     assert wrapper(1, flag=False) == (1, None, False, "default")
     assert wrapper(1, name="custom") == (1, None, True, "custom")
     assert wrapper(1, b=2, flag=False, name="test") == (1, 2, False, "test")
+
+
+def test_map_dataclass_to_tuple() -> None:
+    """Test that dataclass arguments are converted to tuples via 
dataclasses.astuple."""
+
+    @dataclasses.dataclass
+    class Config:
+        x: int
+        y: int
+
+    @dataclasses.dataclass
+    class Nested:
+        value: int
+        cfg: Config
+
+    def target(*args: Any) -> tuple[Any, ...]:
+        return args
+
+    # Basic: one dataclass arg converted to tuple
+    wrapper = make_kwargs_wrapper(target, ["a", "cfg"], 
map_dataclass_to_tuple=["cfg"])
+    result = wrapper(1, Config(x=10, y=20))
+    assert result == (1, (10, 20))
+
+    # Dataclass passed as keyword argument
+    result = wrapper(a=1, cfg=Config(x=3, y=4))
+    assert result == (1, (3, 4))
+
+    # Multiple dataclass args
+    wrapper = make_kwargs_wrapper(target, ["a", "b"], 
map_dataclass_to_tuple=["a", "b"])
+    result = wrapper(Config(x=1, y=2), Config(x=3, y=4))
+    assert result == ((1, 2), (3, 4))
+
+    # Nested dataclass (astuple recurses)
+    wrapper = make_kwargs_wrapper(target, ["a", "nested"], 
map_dataclass_to_tuple=["nested"])
+    result = wrapper(1, Nested(value=5, cfg=Config(x=10, y=20)))
+    assert result == (1, (5, (10, 20)))
+
+    # Mixed: some args converted, others not
+    wrapper = make_kwargs_wrapper(target, ["a", "cfg", "b"], 
map_dataclass_to_tuple=["cfg"])
+    result = wrapper(1, Config(x=10, y=20), 3)
+    assert result == (1, (10, 20), 3)
+
+    # With defaults: dataclass arg has a default
+    default_cfg = Config(x=0, y=0)
+    wrapper = make_kwargs_wrapper(
+        target, ["a", "cfg"], arg_defaults=(default_cfg,), 
map_dataclass_to_tuple=["cfg"]
+    )
+    result = wrapper(1)
+    assert result == (1, (0, 0))
+    result = wrapper(1, Config(x=5, y=6))
+    assert result == (1, (5, 6))
+
+    # With keyword-only dataclass arg
+    wrapper = make_kwargs_wrapper(
+        target,
+        ["a"],
+        kwonly_names=["cfg"],
+        kwonly_defaults={"cfg": Config(x=0, y=0)},
+        map_dataclass_to_tuple=["cfg"],
+    )
+    result = wrapper(1)
+    assert result == (1, (0, 0))
+    result = wrapper(1, cfg=Config(x=7, y=8))
+    assert result == (1, (7, 8))
+
+    # Empty list: no conversion
+    wrapper = make_kwargs_wrapper(target, ["a", "b"], 
map_dataclass_to_tuple=[])
+    cfg = Config(x=1, y=2)
+    result = wrapper(1, cfg)
+    assert result == (1, cfg)
+    assert result[1] is cfg  # not converted
+
+    # Works with make_kwargs_wrapper_from_signature
+    def source_func(a: int, cfg: Config) -> None:
+        pass
+
+    wrapper = make_kwargs_wrapper_from_signature(
+        target, inspect.signature(source_func), map_dataclass_to_tuple=["cfg"]
+    )
+    result = wrapper(1, Config(x=10, y=20))
+    assert result == (1, (10, 20))

Reply via email to