This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new b7c30ea  fix(init): register __ffi_init__ as TypeMethod and use 
MISSING sentinel (#546)
b7c30ea is described below

commit b7c30ea6e8264f8189c0ab94f2ebb04e955812ad
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 13 19:31:28 2026 -0700

    fix(init): register __ffi_init__ as TypeMethod and use MISSING sentinel 
(#546)
    
    ## Summary
    
    - **C++ `def(init<>)`**: Register `__ffi_init__` as TypeMethod
    (preserving `type_schema` metadata for stub generation), then mirror to
    TypeAttrColumn for backward compatibility. Aligns
    `ObjectDef::def(init<>)` with `OverloadObjectDef::def(init<>)` which
    already did this.
    - **Python `__ffi_init__` lookup**: Prefer TypeMethod over
    TypeAttrColumn when resolving `__ffi_init__` in `_install_init` and
    `_install_dataclass_dunders`.
    - **Python `_install_ffi_init_attr`**: New helper that installs
    `__ffi_init__` as an instance method with a type-owner guard, preventing
    subclass misuse.
    - **MISSING sentinel**: Use `core.MISSING` (not a private `__SENTINEL`)
    for default parameter values in auto-generated `__init__` signatures,
    and skip MISSING kwargs instead of forwarding them to C++.
    
    ## Test plan
    
    - [x] 23 new regression tests added to `test_dataclass_init.py`
      - TypeMethod registration for `def(init<>)` classes
      - `__ffi_init__` as callable instance method
      - Type-owner guard preventing subclass misuse
      - `core.MISSING` sentinel for default parameters
      - Dual TypeMethod/TypeAttrColumn availability
    - [x] All 2095 existing Python tests pass
    - [x] All pre-commit hooks pass (ruff, ty, clang-format)
    
    🤖 Generated with [Claude Code](https://claude.com/claude-code)
    
    Co-authored-by: Claude Opus 4.6 <[email protected]>
---
 include/tvm/ffi/reflection/registry.h |  19 +++--
 python/tvm_ffi/_dunder.py             |  26 ++++--
 python/tvm_ffi/_ffi_api.py            |   6 +-
 python/tvm_ffi/registry.py            |  58 ++++++++++++-
 tests/python/test_dataclass_init.py   | 150 ++++++++++++++++++++++++++++++++++
 5 files changed, 240 insertions(+), 19 deletions(-)

diff --git a/include/tvm/ffi/reflection/registry.h 
b/include/tvm/ffi/reflection/registry.h
index 94cbc64..8f7c68c 100644
--- a/include/tvm/ffi/reflection/registry.h
+++ b/include/tvm/ffi/reflection/registry.h
@@ -850,11 +850,20 @@ class ObjectDef : public ReflectionDefBase {
   template <typename... Args, typename... Extra>
   TVM_FFI_INLINE ObjectDef& def([[maybe_unused]] init<Args...> init_func, 
Extra&&... extra) {
     has_explicit_init_ = true;
-    Function init_fn = GetMethod(std::string(type_key_) + "." + 
kInitMethodName,
-                                 &init<Args...>::template execute<Class>);
-    TVMFFIByteArray attr_name = AsByteArray(type_attr::kInit);
-    TVMFFIAny attr_value = AnyView(init_fn).CopyToTVMFFIAny();
-    TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterAttr(type_index_, &attr_name, 
&attr_value));
+    // Register as TypeMethod (preserves type_schema metadata for Python stub 
generation).
+    RegisterMethod(kInitMethodName, true, &init<Args...>::template 
execute<Class>,
+                   std::forward<Extra>(extra)...);
+    // Also mirror into __ffi_init__ TypeAttrColumn for runtime dispatch.
+    const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(type_index_);
+    constexpr TVMFFIByteArray attr_name = AsByteArray(type_attr::kInit);
+    for (int32_t i = 0; i < tinfo->num_methods; ++i) {
+      if (tinfo->methods[i].name.size == attr_name.size &&
+          std::strncmp(tinfo->methods[i].name.data, attr_name.data, 
attr_name.size) == 0) {
+        TVM_FFI_CHECK_SAFE_CALL(
+            TVMFFITypeRegisterAttr(type_index_, &attr_name, 
&tinfo->methods[i].method));
+        break;
+      }
+    }
     return *this;
   }
 
diff --git a/python/tvm_ffi/_dunder.py b/python/tvm_ffi/_dunder.py
index e15f394..af3cc98 100644
--- a/python/tvm_ffi/_dunder.py
+++ b/python/tvm_ffi/_dunder.py
@@ -27,8 +27,6 @@ from .core import TypeInfo, object_repr
 if TYPE_CHECKING:
     from .core import Function
 
-__SENTINEL = object()
-
 
 def _make_init(
     type_cls: type,
@@ -55,6 +53,7 @@ def _make_init(
     """
     sig = _make_init_signature(type_info)
     kwargs_obj = core.KWARGS
+    missing = core.MISSING
     has_post_init = hasattr(type_cls, "__post_init__")
 
     if inplace:
@@ -64,8 +63,9 @@ def _make_init(
             if kwargs:
                 ffi_args.append(kwargs_obj)
                 for key, val in kwargs.items():
-                    ffi_args.append(key)
-                    ffi_args.append(val)
+                    if val is not missing:
+                        ffi_args.append(key)
+                        ffi_args.append(val)
             ffi_init(*ffi_args)
             if has_post_init:
                 self.__post_init__()
@@ -84,8 +84,9 @@ def _make_init(
             if kwargs:
                 ffi_args.append(kwargs_obj)
                 for key, val in kwargs.items():
-                    ffi_args.append(key)
-                    ffi_args.append(val)
+                    if val is not missing:
+                        ffi_args.append(key)
+                        ffi_args.append(val)
             self.__init_handle_by_constructor__(ffi_init, *ffi_args)
             if has_post_init:
                 self.__post_init__()
@@ -138,14 +139,14 @@ def _make_init_signature(type_info: TypeInfo) -> 
inspect.Signature:
 
     for name, _has_default in pos_default:
         params.append(
-            inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, 
default=__SENTINEL)
+            inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, 
default=core.MISSING)
         )
 
     for name, _has_default in kw_required:
         params.append(inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY))
 
     for name, _has_default in kw_default:
-        params.append(inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, 
default=__SENTINEL))
+        params.append(inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, 
default=core.MISSING))
 
     return inspect.Signature(params)
 
@@ -253,7 +254,14 @@ def _install_dataclass_dunders(  # noqa: PLR0912, PLR0915
     type_index: int = type_info.type_index
     ffi_new: Function | None = core._lookup_type_attr(type_index, 
"__ffi_new__")
     ffi_init_inplace: Function | None = core._lookup_type_attr(type_index, 
"__ffi_init_inplace__")
-    ffi_init: Function | None = core._lookup_type_attr(type_index, 
"__ffi_init__")
+    # Look up __ffi_init__ from TypeMethod (preferred) or TypeAttrColumn 
(fallback).
+    ffi_init: Function | None = None
+    for method in type_info.methods:
+        if method.name == "__ffi_init__":
+            ffi_init = method.func
+            break
+    if ffi_init is None:
+        ffi_init = core._lookup_type_attr(type_index, "__ffi_init__")
     ffi_shallow_copy: Function | None = core._lookup_type_attr(type_index, 
"__ffi_shallow_copy__")
     pyobject_new = core.Object.__new__
 
diff --git a/python/tvm_ffi/_ffi_api.py b/python/tvm_ffi/_ffi_api.py
index 3a05827..a0e176d 100644
--- a/python/tvm_ffi/_ffi_api.py
+++ b/python/tvm_ffi/_ffi_api.py
@@ -57,7 +57,6 @@ if TYPE_CHECKING:
     def FunctionFromExternC(_0: c_void_p, _1: c_void_p, _2: c_void_p, /) -> 
Callable[..., Any]: ...
     def FunctionListGlobalNamesFunctor() -> Callable[..., Any]: ...
     def FunctionRemoveGlobal(_0: str, /) -> bool: ...
-    def GetFieldGetter(_0: int, /) -> int: ...
     def GetFirstStructuralMismatch(_0: Any, _1: Any, _2: bool, _3: bool, /) -> 
tuple[AccessPath, AccessPath] | None: ...
     def GetGlobalFuncMetadata(_0: str, /) -> str: ...
     def GetInvalidObject() -> Object: ...
@@ -76,6 +75,7 @@ if TYPE_CHECKING:
     def ListReverse(_0: MutableSequence[Any], /) -> None: ...
     def ListSetItem(_0: MutableSequence[Any], _1: int, _2: Any, /) -> None: ...
     def ListSize(_0: MutableSequence[Any], /) -> int: ...
+    def MakeFieldGetter(_0: int, /) -> int: ...
     def MakeFieldSetter(_0: int, _1: int, _2: int, /) -> Callable[..., Any]: 
...
     def MakeObjectFromPackedArgs(*args: Any) -> Any: ...
     def Map(*args: Any) -> Any: ...
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
     def ToJSONGraph(_0: Any, _1: Any, /) -> Any: ...
     def ToJSONGraphString(_0: Any, _1: Any, /) -> str: ...
     def _PyClassRegisterTypeAttrColumns(_0: int, _1: int, /) -> None: ...
+    def _RegisterFFIInit(_0: int, /) -> None: ...
 # fmt: on
 # tvm-ffi-stubgen(end)
 
@@ -142,7 +143,6 @@ __all__ = [
     "FunctionFromExternC",
     "FunctionListGlobalNamesFunctor",
     "FunctionRemoveGlobal",
-    "GetFieldGetter",
     "GetFirstStructuralMismatch",
     "GetGlobalFuncMetadata",
     "GetInvalidObject",
@@ -161,6 +161,7 @@ __all__ = [
     "ListReverse",
     "ListSetItem",
     "ListSize",
+    "MakeFieldGetter",
     "MakeFieldSetter",
     "MakeObjectFromPackedArgs",
     "Map",
@@ -200,5 +201,6 @@ __all__ = [
     "ToJSONGraph",
     "ToJSONGraphString",
     "_PyClassRegisterTypeAttrColumns",
+    "_RegisterFFIInit",
     # tvm-ffi-stubgen(end)
 ]
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 69a3cfe..fb31bcf 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -24,7 +24,7 @@ import warnings
 from typing import Any, Callable, Literal, Sequence, TypeVar, overload
 
 from . import core
-from .core import TypeInfo
+from .core import Function, TypeInfo
 
 # whether we simplify skip unknown objects regtistration
 _SKIP_UNKNOWN_OBJECTS = False
@@ -353,7 +353,7 @@ def init_ffi_api(namespace: str, target_module_name: str | 
None = None) -> None:
 
 
 def _install_init(cls: type, type_info: TypeInfo) -> None:
-    """Install ``__init__`` from the C++ ``__ffi_init__`` TypeAttrColumn.
+    """Install ``__init__`` from ``__ffi_init__`` TypeMethod or TypeAttrColumn.
 
     Skipped if the class body already defines ``__init__``.
     This ensures that ``register_object`` alone provides a working
@@ -366,7 +366,14 @@ def _install_init(cls: type, type_info: TypeInfo) -> None:
     """
     if "__init__" in cls.__dict__:
         return
-    ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+    # Look up __ffi_init__ from TypeMethod (preferred) or TypeAttrColumn 
(fallback).
+    ffi_init = None
+    for method in type_info.methods:
+        if method.name == "__ffi_init__":
+            ffi_init = method.func
+            break
+    if ffi_init is None:
+        ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
     if ffi_init is not None:
         from ._dunder import _make_init  # noqa: PLC0415
 
@@ -395,13 +402,58 @@ def _add_class_attrs(type_cls: type, type_info: TypeInfo) 
-> type:
         name = field.name
         if not hasattr(type_cls, name):  # skip already defined attributes
             setattr(type_cls, name, field.as_property(type_cls))
+    has_ffi_init = False
     for method in type_info.methods:
         name = method.name
+        if name == "__ffi_init__":
+            _install_ffi_init_attr(type_cls, type_info, method.func)
+            has_ffi_init = True
+            continue
         if not hasattr(type_cls, name):
             setattr(type_cls, name, method.as_callable(type_cls))
+    # Also check TypeAttrColumn for auto-generated __ffi_init__.
+    if not has_ffi_init:
+        ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+        if ffi_init is not None:
+            _install_ffi_init_attr(type_cls, type_info, ffi_init)
     return type_cls
 
 
+def _install_ffi_init_attr(cls: type, type_info: TypeInfo, ffi_init: Function) 
-> None:
+    """Install ``__ffi_init__`` as a method that delegates to 
``__init_handle_by_constructor__``.
+
+    Custom ``__init__`` methods call ``self.__ffi_init__(*args, **kwargs)`` to
+    construct the underlying C++ object. This installs a wrapper that 
translates
+    that call into ``self.__init_handle_by_constructor__(ffi_init, *ffi_args)``
+    with kwargs packed using the FFI KWARGS protocol.
+
+    The wrapper includes a type-owner guard (same as ``_make_init``) to prevent
+    subclasses from accidentally using a parent's ``__ffi_init__``.
+    """
+    kwargs_obj = core.KWARGS
+    missing = core.MISSING
+    type_name = cls.__name__
+
+    def __ffi_init__(self: Any, *args: Any, **kwargs: Any) -> None:
+        if type_info is not type(self).__tvm_ffi_type_info__:
+            raise TypeError(
+                f"Calling `{type_name}.__ffi_init__()` on a 
`{type(self).__name__}` "
+                f"instance is not supported. Define `{type(self).__name__}` 
with init=True."
+            )
+        ffi_args: list[Any] = list(args)
+        if kwargs:
+            ffi_args.append(kwargs_obj)
+            for key, val in kwargs.items():
+                if val is not missing:
+                    ffi_args.append(key)
+                    ffi_args.append(val)
+        self.__init_handle_by_constructor__(ffi_init, *ffi_args)
+
+    __ffi_init__.__qualname__ = f"{cls.__qualname__}.__ffi_init__"
+    __ffi_init__.__module__ = cls.__module__
+    cls.__ffi_init__ = __ffi_init__  # type: ignore[attr-defined]
+
+
 def _warn_missing_field_annotations(cls: type, type_info: TypeInfo, *, 
stacklevel: int) -> None:
     """Emit a warning if any C++ reflected fields lack Python annotations on 
*cls*.
 
diff --git a/tests/python/test_dataclass_init.py 
b/tests/python/test_dataclass_init.py
index e36db48..62d0b1e 100644
--- a/tests/python/test_dataclass_init.py
+++ b/tests/python/test_dataclass_init.py
@@ -37,12 +37,17 @@ from typing import Any
 import pytest
 from tvm_ffi import core
 from tvm_ffi.testing import (
+    TestCompare,
+    TestHash,
+    TestIntPair,
     _TestCxxAutoInit,
     _TestCxxAutoInitAllInitOff,
     _TestCxxAutoInitChild,
     _TestCxxAutoInitKwOnlyDefaults,
     _TestCxxAutoInitParent,
     _TestCxxAutoInitSimple,
+    _TestCxxClassDerived,
+    _TestCxxClassDerivedDerived,
     _TestCxxNoAutoInit,
 )
 
@@ -995,3 +1000,148 @@ class TestInitInplace:
         )
         with pytest.raises(TypeError, match="keyword-only"):
             _ffi_init_inplace(obj, 1)
+
+
+# ###########################################################################
+#  __ffi_init__ TypeMethod registration regression tests
+#
+#  These tests verify that def(init<>) in C++ registers __ffi_init__ as a
+#  TypeMethod (not just TypeAttrColumn), preserving type_schema metadata.
+#  They also verify the __ffi_init__ instance method wrapper and the MISSING
+#  sentinel used for default parameter values.
+#  See commit 2987899a for the original fix.
+# ###########################################################################
+
+
+class TestFfiInitAsTypeMethod:
+    """Verify that classes using ``def(init<>)`` expose __ffi_init__ as a 
TypeMethod."""
+
+    @pytest.mark.parametrize("cls", [TestIntPair, TestCompare, TestHash])
+    def test_explicit_init_has_ffi_init_type_method(self, cls: type) -> None:
+        type_info = cls.__tvm_ffi_type_info__  # ty: 
ignore[unresolved-attribute]
+        method_names = {m.name for m in type_info.methods}
+        assert "__ffi_init__" in method_names
+
+    def test_auto_init_does_not_have_ffi_init_type_method(self) -> None:
+        type_info = _TestCxxAutoInit.__tvm_ffi_type_info__  # ty: 
ignore[unresolved-attribute]
+        method_names = {m.name for m in type_info.methods}
+        assert "__ffi_init__" not in method_names
+
+    @pytest.mark.parametrize("cls", [TestIntPair, TestCompare, TestHash])
+    def test_ffi_init_type_method_has_func(self, cls: type) -> None:
+        type_info = cls.__tvm_ffi_type_info__  # ty: 
ignore[unresolved-attribute]
+        for method in type_info.methods:
+            if method.name == "__ffi_init__":
+                assert method.func is not None
+                break
+        else:
+            pytest.fail(f"__ffi_init__ not found in {cls.__name__} methods")
+
+
+class TestFfiInitAsInstanceMethod:
+    """Verify that __ffi_init__ is available as a callable method on 
registered classes."""
+
+    def test_explicit_init_class_has_ffi_init_attr(self) -> None:
+        assert hasattr(TestIntPair, "__ffi_init__")
+
+    def test_auto_init_class_has_ffi_init_attr(self) -> None:
+        assert hasattr(_TestCxxAutoInit, "__ffi_init__")
+
+    def test_ffi_init_constructs_explicit_init_object(self) -> None:
+        obj = TestIntPair.__new__(TestIntPair)
+        obj.__ffi_init__(10, 20)  # ty: ignore[unresolved-attribute]
+        assert obj.a == 10
+        assert obj.b == 20
+        assert obj.sum() == 30
+
+    def test_ffi_init_constructs_auto_init_object(self) -> None:
+        obj = _TestCxxAutoInitSimple.__new__(_TestCxxAutoInitSimple)
+        obj.__ffi_init__(7, 8)  # ty: ignore[unresolved-attribute]
+        assert obj.x == 7
+        assert obj.y == 8
+
+    def test_ffi_init_with_kwargs_protocol(self) -> None:
+        obj = _TestCxxAutoInit.__new__(_TestCxxAutoInit)
+        obj.__ffi_init__(a=1, c=3)  # ty: ignore[unresolved-attribute]
+        assert obj.a == 1
+        assert obj.b == 42
+        assert obj.c == 3
+        assert obj.d == 99
+
+
+class TestFfiInitTypeOwnerGuard:
+    """Verify the type-owner guard prevents subclass misuse of __ffi_init__."""
+
+    def test_same_type_ffi_init_succeeds(self) -> None:
+        obj = TestIntPair.__new__(TestIntPair)
+        TestIntPair.__ffi_init__(obj, 1, 2)  # ty: ignore[unresolved-attribute]
+        assert obj.a == 1
+        assert obj.b == 2
+
+    def test_subclass_ffi_init_raises_type_error(self) -> None:
+        obj = _TestCxxClassDerivedDerived.__new__(_TestCxxClassDerivedDerived)
+        with pytest.raises(TypeError, match="not supported"):
+            _TestCxxClassDerived.__ffi_init__(obj, 1, 2, 3.0)  # ty: 
ignore[unresolved-attribute]
+
+
+class TestMissingSentinelDefaults:
+    """Verify that auto-generated __init__ signatures use core.MISSING for 
defaults."""
+
+    def test_default_params_use_core_missing(self) -> None:
+        sig = inspect.signature(_TestCxxAutoInit.__init__)
+        d_param = sig.parameters["d"]
+        assert d_param.default is not inspect.Parameter.empty
+        assert d_param.default is core.MISSING
+
+    def test_kw_only_default_uses_core_missing(self) -> None:
+        sig = inspect.signature(_TestCxxAutoInitKwOnlyDefaults.__init__)
+        assert sig.parameters["k_default"].default is core.MISSING
+        assert sig.parameters["p_default"].default is core.MISSING
+
+    def test_required_params_have_no_default(self) -> None:
+        sig = inspect.signature(_TestCxxAutoInit.__init__)
+        assert sig.parameters["a"].default is inspect.Parameter.empty
+        assert sig.parameters["c"].default is inspect.Parameter.empty
+
+    def test_missing_not_sent_to_ffi(self) -> None:
+        """MISSING sentinel values should be stripped, not forwarded to C++."""
+        obj = _TestCxxAutoInitKwOnlyDefaults(p_required=1, k_required=2)
+        assert obj.p_default == 11
+        assert obj.k_default == 22
+        assert obj.hidden == 33
+
+    def test_derived_class_defaults(self) -> None:
+        obj = _TestCxxClassDerived(v_i64=1, v_i32=2, v_f64=3.0)
+        assert obj.v_f32 == 8.0
+
+
+class TestFfiInitDualRegistration:
+    """Verify __ffi_init__ is available from both TypeMethod and 
TypeAttrColumn."""
+
+    @pytest.mark.parametrize("cls", [TestIntPair, TestCompare])
+    def test_explicit_init_in_type_attr_column(self, cls: type) -> None:
+        type_info = cls.__tvm_ffi_type_info__  # ty: 
ignore[unresolved-attribute]
+        ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+        assert ffi_init is not None
+
+    def test_auto_init_in_type_attr_column(self) -> None:
+        type_info = _TestCxxAutoInit.__tvm_ffi_type_info__  # ty: 
ignore[unresolved-attribute]
+        ffi_init = core._lookup_type_attr(type_info.type_index, "__ffi_init__")
+        assert ffi_init is not None
+
+    def test_explicit_init_method_and_attr_produce_same_result(self) -> None:
+        type_info = TestIntPair.__tvm_ffi_type_info__  # ty: 
ignore[unresolved-attribute]
+        method_func = None
+        for method in type_info.methods:
+            if method.name == "__ffi_init__":
+                method_func = method.func
+                break
+        assert method_func is not None
+        attr_func = core._lookup_type_attr(type_info.type_index, 
"__ffi_init__")
+        assert attr_func is not None
+        obj1 = TestIntPair.__new__(TestIntPair)
+        obj1.__init_handle_by_constructor__(method_func, 10, 20)
+        obj2 = TestIntPair.__new__(TestIntPair)
+        obj2.__init_handle_by_constructor__(attr_func, 10, 20)
+        assert obj1.a == obj2.a == 10
+        assert obj1.b == obj2.b == 20

Reply via email to