This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new b7c30ea fix(init): register __ffi_init__ as TypeMethod and use
MISSING sentinel (#546)
b7c30ea is described below
commit b7c30ea6e8264f8189c0ab94f2ebb04e955812ad
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 13 19:31:28 2026 -0700
fix(init): register __ffi_init__ as TypeMethod and use MISSING sentinel
(#546)
## Summary
- **C++ `def(init<>)`**: Register `__ffi_init__` as TypeMethod
(preserving `type_schema` metadata for stub generation), then mirror to
TypeAttrColumn for backward compatibility. Aligns
`ObjectDef::def(init<>)` with `OverloadObjectDef::def(init<>)` which
already did this.
- **Python `__ffi_init__` lookup**: Prefer TypeMethod over
TypeAttrColumn when resolving `__ffi_init__` in `_install_init` and
`_install_dataclass_dunders`.
- **Python `_install_ffi_init_attr`**: New helper that installs
`__ffi_init__` as an instance method with a type-owner guard, preventing
subclass misuse.
- **MISSING sentinel**: Use `core.MISSING` (not a private `__SENTINEL`)
for default parameter values in auto-generated `__init__` signatures,
and skip MISSING kwargs instead of forwarding them to C++.
## Test plan
- [x] 23 new regression tests added to `test_dataclass_init.py`
- TypeMethod registration for `def(init<>)` classes
- `__ffi_init__` as callable instance method
- Type-owner guard preventing subclass misuse
- `core.MISSING` sentinel for default parameters
- Dual TypeMethod/TypeAttrColumn availability
- [x] All 2095 existing Python tests pass
- [x] All pre-commit hooks pass (ruff, ty, clang-format)
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-authored-by: Claude Opus 4.6 <[email protected]>
---
include/tvm/ffi/reflection/registry.h | 19 +++--
python/tvm_ffi/_dunder.py | 26 ++++--
python/tvm_ffi/_ffi_api.py | 6 +-
python/tvm_ffi/registry.py | 58 ++++++++++++-
tests/python/test_dataclass_init.py | 150 ++++++++++++++++++++++++++++++++++
5 files changed, 240 insertions(+), 19 deletions(-)
diff --git a/include/tvm/ffi/reflection/registry.h
b/include/tvm/ffi/reflection/registry.h
index 94cbc64..8f7c68c 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -850,11 +850,20 @@ class ObjectDef : public ReflectionDefBase {
template <typename... Args, typename... Extra>
TVM_FFI_INLINE ObjectDef& def([[maybe_unused]] init<Args...> init_func,
Extra&&... extra) {
has_explicit_init_ = true;
- Function init_fn = GetMethod(std::string(type_key_) + "." +
kInitMethodName,
- &init<Args...>::template execute<Class>);
- TVMFFIByteArray attr_name = AsByteArray(type_attr::kInit);
- TVMFFIAny attr_value = AnyView(init_fn).CopyToTVMFFIAny();
- TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &attr_name,
&attr_value));
+ // Register as TypeMethod (preserves type_schema metadata for Python stub
generation).
+ RegisterMethod(kInitMethodName, true, &init<Args...>::template
execute<Class>,
+ std::forward<Extra>(extra)...);
+ // Also mirror into __ffi_init__ TypeAttrColumn for runtime dispatch.
+ const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(type_index_);
+ constexpr TVMFFIByteArray attr_name = AsByteArray(type_attr::kInit);
+ for (int32_t i = 0; i < tinfo->num_methods; ++i) {
+ if (tinfo->methods[i].name.size == attr_name.size &&
+ std::strncmp(tinfo->methods[i].name.data, attr_name.data,
attr_name.size) == 0) {
+ TVM_FFI_CHECK_SAFE_CALL(
+ TVMFFITypeRegisterAttr(type_index_, &attr_name,
&tinfo->methods[i].method));
+ break;
+ }
+ }
return *this;
}
diff --git a/python/tvm_ffi/_dunder.py b/python/tvm_ffi/_dunder.py
index e15f394..af3cc98 100644
--- a/python/tvm_ffi/_dunder.py
+++ b/python/tvm_ffi/_dunder.py
@@ -27,8 +27,6 @@ from .core import TypeInfo, object_repr
if TYPE_CHECKING:
from .core import Function
-__SENTINEL = object()
-
def _make_init(
type_cls: type,
@@ -55,6 +53,7 @@ def _make_init(
"""
sig = _make_init_signature(type_info)
kwargs_obj = core.KWARGS
+ missing = core.MISSING
has_post_init = hasattr(type_cls, "__post_init__")
if inplace:
@@ -64,8 +63,9 @@ def _make_init(
if kwargs:
ffi_args.append(kwargs_obj)
for key, val in kwargs.items():
- ffi_args.append(key)
- ffi_args.append(val)
+ if val is not missing:
+ ffi_args.append(key)
+ ffi_args.append(val)
ffi_init(*ffi_args)
if has_post_init:
self.__post_init__()
@@ -84,8 +84,9 @@ def _make_init(
if kwargs:
ffi_args.append(kwargs_obj)
for key, val in kwargs.items():
- ffi_args.append(key)
- ffi_args.append(val)
+ if val is not missing:
+ ffi_args.append(key)
+ ffi_args.append(val)
self.__init_handle_by_constructor__(ffi_init, *ffi_args)
if has_post_init:
self.__post_init__()
@@ -138,14 +139,14 @@ def _make_init_signature(type_info: TypeInfo) ->
inspect.Signature:
for name, _has_default in pos_default:
params.append(
- inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=__SENTINEL)
+ inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=core.MISSING)
)
for name, _has_default in kw_required:
params.append(inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY))
for name, _has_default in kw_default:
- params.append(inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY,
default=__SENTINEL))
+ params.append(inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY,
default=core.MISSING))
return inspect.Signature(params)
@@ -253,7 +254,14 @@ def _install_dataclass_dunders( # noqa: PLR0912, PLR0915
type_index: int = type_info.type_index
ffi_new: Function | None = core._lookup_type_attr(type_index,
"__ffi_new__")
ffi_init_inplace: Function | None = core._lookup_type_attr(type_index,
"__ffi_init_inplace__")
- ffi_init: Function | None = core._lookup_type_attr(type_index,
"__ffi_init__")
+ # Look up __ffi_init__ from TypeMethod (preferred) or TypeAttrColumn
(fallback).
+ ffi_init: Function | None = None
+ for method in type_info.methods:
+ if method.name == "__ffi_init__":
+ ffi_init = method.func
+ break
+ if ffi_init is None:
+ ffi_init = core._lookup_type_attr(type_index, "__ffi_init__")
ffi_shallow_copy: Function | None = core._lookup_type_attr(type_index,
"__ffi_shallow_copy__")
pyobject_new = core.Object.__new__
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 3a05827..a0e176d 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -57,7 +57,6 @@ if TYPE_CHECKING:
def FunctionFromExternC(_0: c_void_p, _1: c_void_p, _2: c_void_p, /) ->
Callable[..., Any]: ...
def FunctionListGlobalNamesFunctor() -> Callable[..., Any]: ...
def FunctionRemoveGlobal(_0: str, /) -> bool: ...
- def GetFieldGetter(_0: int, /) -> int: ...
def GetFirstStructuralMismatch(_0: Any, _1: Any, _2: bool, _3: bool, /) ->
tuple[AccessPath, AccessPath] | None: ...
def GetGlobalFuncMetadata(_0: str, /) -> str: ...
def GetInvalidObject() -> Object: ...
@@ -76,6 +75,7 @@ if TYPE_CHECKING:
def ListReverse(_0: MutableSequence[Any], /) -> None: ...
def ListSetItem(_0: MutableSequence[Any], _1: int, _2: Any, /) -> None: ...
def ListSize(_0: MutableSequence[Any], /) -> int: ...
+ def MakeFieldGetter(_0: int, /) -> int: ...
def MakeFieldSetter(_0: int, _1: int, _2: int, /) -> Callable[..., Any]:
...
def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
def Map(*args: Any) -> Any: ...
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
def ToJSONGraph(_0: Any, _1: Any, /) -> Any: ...
def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
def _PyClassRegisterTypeAttrColumns(_0: int, _1: int, /) -> None: ...
+ def _RegisterFFIInit(_0: int, /) -> None: ...
# fmt: on
# tvm-ffi-stubgen(end)
@@ -142,7 +143,6 @@ __all__ = [
"FunctionFromExternC",
"FunctionListGlobalNamesFunctor",
"FunctionRemoveGlobal",
- "GetFieldGetter",
"GetFirstStructuralMismatch",
"GetGlobalFuncMetadata",
"GetInvalidObject",
@@ -161,6 +161,7 @@ __all__ = [
"ListReverse",
"ListSetItem",
"ListSize",
+ "MakeFieldGetter",
"MakeFieldSetter",
"MakeObjectFromPackedArgs",
"Map",
@@ -200,5 +201,6 @@ __all__ = [
"ToJSONGraph",
"ToJSONGraphString",
"_PyClassRegisterTypeAttrColumns",
+ "_RegisterFFIInit",
# tvm-ffi-stubgen(end)
]
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 69a3cfe..fb31bcf 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -24,7 +24,7 @@ import warnings
from typing import Any, Callable, Literal, Sequence, TypeVar, overload
from . import core
-from .core import TypeInfo
+from .core import Function, TypeInfo
# whether we simplify skip unknown objects regtistration
_SKIP_UNKNOWN_OBJECTS = False
@@ -353,7 +353,7 @@ def init_ffi_api(namespace: str, target_module_name: str |
None = None) -> None:
def _install_init(cls: type, type_info: TypeInfo) -> None:
- """Install ``__init__`` from the C++ ``__ffi_init__`` TypeAttrColumn.
+ """Install ``__init__`` from ``__ffi_init__`` TypeMethod or TypeAttrColumn.
Skipped if the class body already defines ``__init__``.
This ensures that ``register_object`` alone provides a working
@@ -366,7 +366,14 @@ def _install_init(cls: type, type_info: TypeInfo) -> None:
"""
if "__init__" in cls.__dict__:
return
- ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+ # Look up __ffi_init__ from TypeMethod (preferred) or TypeAttrColumn
(fallback).
+ ffi_init = None
+ for method in type_info.methods:
+ if method.name == "__ffi_init__":
+ ffi_init = method.func
+ break
+ if ffi_init is None:
+ ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
if ffi_init is not None:
from ._dunder import _make_init # noqa: PLC0415
@@ -395,13 +402,58 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo)
-> type:
name = field.name
if not hasattr(type_cls, name): # skip already defined attributes
setattr(type_cls, name, field.as_property(type_cls))
+ has_ffi_init = False
for method in type_info.methods:
name = method.name
+ if name == "__ffi_init__":
+ _install_ffi_init_attr(type_cls, type_info, method.func)
+ has_ffi_init = True
+ continue
if not hasattr(type_cls, name):
setattr(type_cls, name, method.as_callable(type_cls))
+ # Also check TypeAttrColumn for auto-generated __ffi_init__.
+ if not has_ffi_init:
+ ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+ if ffi_init is not None:
+ _install_ffi_init_attr(type_cls, type_info, ffi_init)
return type_cls
+def _install_ffi_init_attr(cls: type, type_info: TypeInfo, ffi_init: Function)
-> None:
+ """Install ``__ffi_init__`` as a method that delegates to
``__init_handle_by_constructor__``.
+
+ Custom ``__init__`` methods call ``self.__ffi_init__(*args, **kwargs)`` to
+ construct the underlying C++ object. This installs a wrapper that
translates
+ that call into ``self.__init_handle_by_constructor__(ffi_init, *ffi_args)``
+ with kwargs packed using the FFI KWARGS protocol.
+
+ The wrapper includes a type-owner guard (same as ``_make_init``) to prevent
+ subclasses from accidentally using a parent's ``__ffi_init__``.
+ """
+ kwargs_obj = core.KWARGS
+ missing = core.MISSING
+ type_name = cls.__name__
+
+ def __ffi_init__(self: Any, *args: Any, **kwargs: Any) -> None:
+ if type_info is not type(self).__tvm_ffi_type_info__:
+ raise TypeError(
+ f"Calling `{type_name}.__ffi_init__()` on a
`{type(self).__name__}` "
+ f"instance is not supported. Define `{type(self).__name__}`
with init=True."
+ )
+ ffi_args: list[Any] = list(args)
+ if kwargs:
+ ffi_args.append(kwargs_obj)
+ for key, val in kwargs.items():
+ if val is not missing:
+ ffi_args.append(key)
+ ffi_args.append(val)
+ self.__init_handle_by_constructor__(ffi_init, *ffi_args)
+
+ __ffi_init__.__qualname__ = f"{cls.__qualname__}.__ffi_init__"
+ __ffi_init__.__module__ = cls.__module__
+ cls.__ffi_init__ = __ffi_init__ # type: ignore[attr-defined]
+
+
def _warn_missing_field_annotations(cls: type, type_info: TypeInfo, *,
stacklevel: int) -> None:
"""Emit a warning if any C++ reflected fields lack Python annotations on
*cls*.
diff --git a/tests/python/test_dataclass_init.py
b/tests/python/test_dataclass_init.py
index e36db48..62d0b1e 100644
--- a/tests/python/test_dataclass_init.py
+++ b/tests/python/test_dataclass_init.py
@@ -37,12 +37,17 @@ from typing import Any
import pytest
from tvm_ffi import core
from tvm_ffi.testing import (
+ TestCompare,
+ TestHash,
+ TestIntPair,
_TestCxxAutoInit,
_TestCxxAutoInitAllInitOff,
_TestCxxAutoInitChild,
_TestCxxAutoInitKwOnlyDefaults,
_TestCxxAutoInitParent,
_TestCxxAutoInitSimple,
+ _TestCxxClassDerived,
+ _TestCxxClassDerivedDerived,
_TestCxxNoAutoInit,
)
@@ -995,3 +1000,148 @@ class TestInitInplace:
)
with pytest.raises(TypeError, match="keyword-only"):
_ffi_init_inplace(obj, 1)
+
+
+# ###########################################################################
+# __ffi_init__ TypeMethod registration regression tests
+#
+# These tests verify that def(init<>) in C++ registers __ffi_init__ as a
+# TypeMethod (not just TypeAttrColumn), preserving type_schema metadata.
+# They also verify the __ffi_init__ instance method wrapper and the MISSING
+# sentinel used for default parameter values.
+# See commit 2987899a for the original fix.
+# ###########################################################################
+
+
+class TestFfiInitAsTypeMethod:
+ """Verify that classes using ``def(init<>)`` expose __ffi_init__ as a
TypeMethod."""
+
+ @pytest.mark.parametrize("cls", [TestIntPair, TestCompare, TestHash])
+ def test_explicit_init_has_ffi_init_type_method(self, cls: type) -> None:
+ type_info = cls.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ method_names = {m.name for m in type_info.methods}
+ assert "__ffi_init__" in method_names
+
+ def test_auto_init_does_not_have_ffi_init_type_method(self) -> None:
+ type_info = _TestCxxAutoInit.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ method_names = {m.name for m in type_info.methods}
+ assert "__ffi_init__" not in method_names
+
+ @pytest.mark.parametrize("cls", [TestIntPair, TestCompare, TestHash])
+ def test_ffi_init_type_method_has_func(self, cls: type) -> None:
+ type_info = cls.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ for method in type_info.methods:
+ if method.name == "__ffi_init__":
+ assert method.func is not None
+ break
+ else:
+ pytest.fail(f"__ffi_init__ not found in {cls.__name__} methods")
+
+
+class TestFfiInitAsInstanceMethod:
+ """Verify that __ffi_init__ is available as a callable method on
registered classes."""
+
+ def test_explicit_init_class_has_ffi_init_attr(self) -> None:
+ assert hasattr(TestIntPair, "__ffi_init__")
+
+ def test_auto_init_class_has_ffi_init_attr(self) -> None:
+ assert hasattr(_TestCxxAutoInit, "__ffi_init__")
+
+ def test_ffi_init_constructs_explicit_init_object(self) -> None:
+ obj = TestIntPair.__new__(TestIntPair)
+ obj.__ffi_init__(10, 20) # ty: ignore[unresolved-attribute]
+ assert obj.a == 10
+ assert obj.b == 20
+ assert obj.sum() == 30
+
+ def test_ffi_init_constructs_auto_init_object(self) -> None:
+ obj = _TestCxxAutoInitSimple.__new__(_TestCxxAutoInitSimple)
+ obj.__ffi_init__(7, 8) # ty: ignore[unresolved-attribute]
+ assert obj.x == 7
+ assert obj.y == 8
+
+ def test_ffi_init_with_kwargs_protocol(self) -> None:
+ obj = _TestCxxAutoInit.__new__(_TestCxxAutoInit)
+ obj.__ffi_init__(a=1, c=3) # ty: ignore[unresolved-attribute]
+ assert obj.a == 1
+ assert obj.b == 42
+ assert obj.c == 3
+ assert obj.d == 99
+
+
+class TestFfiInitTypeOwnerGuard:
+ """Verify the type-owner guard prevents subclass misuse of __ffi_init__."""
+
+ def test_same_type_ffi_init_succeeds(self) -> None:
+ obj = TestIntPair.__new__(TestIntPair)
+ TestIntPair.__ffi_init__(obj, 1, 2) # ty: ignore[unresolved-attribute]
+ assert obj.a == 1
+ assert obj.b == 2
+
+ def test_subclass_ffi_init_raises_type_error(self) -> None:
+ obj = _TestCxxClassDerivedDerived.__new__(_TestCxxClassDerivedDerived)
+ with pytest.raises(TypeError, match="not supported"):
+ _TestCxxClassDerived.__ffi_init__(obj, 1, 2, 3.0) # ty:
ignore[unresolved-attribute]
+
+
+class TestMissingSentinelDefaults:
+ """Verify that auto-generated __init__ signatures use core.MISSING for
defaults."""
+
+ def test_default_params_use_core_missing(self) -> None:
+ sig = inspect.signature(_TestCxxAutoInit.__init__)
+ d_param = sig.parameters["d"]
+ assert d_param.default is not inspect.Parameter.empty
+ assert d_param.default is core.MISSING
+
+ def test_kw_only_default_uses_core_missing(self) -> None:
+ sig = inspect.signature(_TestCxxAutoInitKwOnlyDefaults.__init__)
+ assert sig.parameters["k_default"].default is core.MISSING
+ assert sig.parameters["p_default"].default is core.MISSING
+
+ def test_required_params_have_no_default(self) -> None:
+ sig = inspect.signature(_TestCxxAutoInit.__init__)
+ assert sig.parameters["a"].default is inspect.Parameter.empty
+ assert sig.parameters["c"].default is inspect.Parameter.empty
+
+ def test_missing_not_sent_to_ffi(self) -> None:
+ """MISSING sentinel values should be stripped, not forwarded to C++."""
+ obj = _TestCxxAutoInitKwOnlyDefaults(p_required=1, k_required=2)
+ assert obj.p_default == 11
+ assert obj.k_default == 22
+ assert obj.hidden == 33
+
+ def test_derived_class_defaults(self) -> None:
+ obj = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0)
+ assert obj.v_f32 == 8.0
+
+
+class TestFfiInitDualRegistration:
+ """Verify __ffi_init__ is available from both TypeMethod and
TypeAttrColumn."""
+
+ @pytest.mark.parametrize("cls", [TestIntPair, TestCompare])
+ def test_explicit_init_in_type_attr_column(self, cls: type) -> None:
+ type_info = cls.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+ assert ffi_init is not None
+
+ def test_auto_init_in_type_attr_column(self) -> None:
+ type_info = _TestCxxAutoInit.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+ assert ffi_init is not None
+
+ def test_explicit_init_method_and_attr_produce_same_result(self) -> None:
+ type_info = TestIntPair.__tvm_ffi_type_info__ # ty:
ignore[unresolved-attribute]
+ method_func = None
+ for method in type_info.methods:
+ if method.name == "__ffi_init__":
+ method_func = method.func
+ break
+ assert method_func is not None
+ attr_func = core._lookup_type_attr(type_info.type_index,
"__ffi_init__")
+ assert attr_func is not None
+ obj1 = TestIntPair.__new__(TestIntPair)
+ obj1.__init_handle_by_constructor__(method_func, 10, 20)
+ obj2 = TestIntPair.__new__(TestIntPair)
+ obj2.__init_handle_by_constructor__(attr_func, 10, 20)
+ assert obj1.a == obj2.a == 10
+ assert obj1.b == obj2.b == 20