This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b59825  fix(py_class): support `super().__init__()` in `init=False` 
subclasses (#532)
2b59825 is described below

commit 2b59825e87a70dac521490802306e461024cc4d1
Author: Junru Shao <[email protected]>
AuthorDate: Fri Apr 10 11:26:30 2026 -0700

    fix(py_class): support `super().__init__()` in `init=False` subclasses 
(#532)
    
    ## Summary
    
    - Fix segfault when `@py_class(init=False)` subclasses use
    `super().__init__()` followed by field assignment
    - Add `ffi.NewEmpty` C++ function to allocate zero-initialized objects
    by type index
    - Detect super-init-from-subclass pattern in `_make_init()` and
    pre-allocate instead of dispatching to `__ffi_init__`
    
    ## Motivation
    
    The common Python pattern of defining a custom `__init__` that calls
    `super().__init__()` then sets fields crashes when used with
    `@py_class(init=False)`:
    
    ```python
    @py_class(init=False)
    class PointerType(Node):
        element_type: Object
    
        def __init__(self, element_type):
            super().__init__()       # parent's auto-init dispatches to child's 
C++ ctor → crash
            self.element_type = element_type
    ```
    
    The parent's auto-generated `__init__` forwards to `__ffi_init__` using
    the child type's C++ constructor, which expects field arguments that
    were not provided.
    
    ## Design
    
    **C++ side** (`src/ffi/extra/dataclass.cc`): Register
    `ffi.NewEmpty(type_index) -> ObjectRef` that allocates a
    zero-initialized object via `CreateEmptyObject`. The calloc'd state is
    valid: `None` for `Any`/`ObjectRef` fields, `0` for scalars.
    
    **Python side** (`python/tvm_ffi/registry.py`):
    - `_ffi_alloc_empty(obj)`: Calls `ffi.NewEmpty` to allocate an empty FFI
    object for `type(obj)` and moves the handle into `obj`. No-op if already
    allocated.
    - `_is_super_init_from_subclass(self)`: Compares
    `type(self).__tvm_ffi_type_info__` identity against the declaring
    class's `type_info`. Returns `True` only for registered `@py_class`
    subclasses — undecorated subclasses inherit the parent's `type_info` so
    the identity check correctly filters them out.
    - Both `_make_init` code paths (with/without `__post_init__`) intercept
    the no-args super-init-from-subclass call and allocate an empty object
    instead of dispatching to `__ffi_init__`.
    
    ## Test plan
    
    - [x] `uv run pytest -vvs
    tests/python/test_dataclass_py_class.py::TestSuperInitPattern` — 12
    tests pass
    - [x] `uv run pytest -vvs tests/python/` — 1970 passed, 38 skipped, 3
    xfailed, 0 failures
    
    Tests cover: basic super-init, deep hierarchy
    (Node→BaseType→PointerType), calloc defaults, intermediate custom/auto
    inits, normal init unaffected, non-py_class subclass error handling,
    isinstance checks, field overwrite, copy/deepcopy, and
    direct-from-Object inheritance.
---
 python/tvm_ffi/registry.py              |  46 ++++++
 src/ffi/extra/dataclass.cc              |   5 +
 tests/python/test_dataclass_py_class.py | 280 ++++++++++++++++++++++++++++++++
 3 files changed, 331 insertions(+)

diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 93ad92c..7d6359a 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -18,6 +18,7 @@
 
 from __future__ import annotations
 
+import functools
 import inspect
 import json
 import sys
@@ -346,6 +347,16 @@ def init_ffi_api(namespace: str, target_module_name: str | 
None = None) -> None:
 __SENTINEL = object()
 
 
[email protected]_cache(maxsize=None)
+def _new_empty() -> Any:
+    return core._get_global_func("ffi.NewEmpty", False)
+
+
+def _ffi_alloc_empty(obj: Any, type_index: int) -> None:
+    if obj.__chandle__() == 0:
+        obj.__init_handle_by_constructor__(_new_empty(), type_index)
+
+
 def _make_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
     """Build a Python ``__init__`` that delegates to the C++ auto-generated 
``__ffi_init__``.
 
@@ -354,6 +365,20 @@ def _make_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:
     signature.  The ``__init__`` body is a trivial adapter — all validation
     (too many positional, duplicates, missing required, kw_only enforcement,
     unknown kwargs) is handled by C++.
+
+    When this generated ``__init__`` is called via ``super().__init__()`` from
+    a subclass (detected by ``type(self) is not type_cls`` with no arguments),
+    it allocates an empty zero-initialized FFI object of the correct child
+    type instead of forwarding to ``__ffi_init__``.  This supports the common
+    Python pattern::
+
+        @py_class(init=False)
+        class Child(Parent):
+            x: int
+
+            def __init__(self, x):
+                super().__init__()
+                self.x = x
     """
     sig = _make_init_signature(type_info)
     kwargs_obj = core.KWARGS
@@ -362,6 +387,10 @@ def _make_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:
     if has_post_init:
 
         def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+            actual_type_info = type(self).__tvm_ffi_type_info__
+            if not args and not kwargs and type_info is not actual_type_info:
+                _ffi_alloc_empty(self, actual_type_info.type_index)
+                return
             ffi_args: list[Any] = list(args)
             ffi_args.append(kwargs_obj)
             for key, val in kwargs.items():
@@ -373,6 +402,10 @@ def _make_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:
     else:
 
         def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+            actual_type_info = type(self).__tvm_ffi_type_info__
+            if not args and not kwargs and type_info is not actual_type_info:
+                _ffi_alloc_empty(self, actual_type_info.type_index)
+                return
             ffi_args: list[Any] = list(args)
             ffi_args.append(kwargs_obj)
             for key, val in kwargs.items():
@@ -532,6 +565,19 @@ def _install_init(cls: type, *, enabled: bool) -> None:
     ``__init__``.
     """
     if "__init__" in cls.__dict__:
+        if not enabled:
+            # Wrap user's __init__ to pre-allocate the C++ object so that
+            # field setters work immediately (the calloc'd object is valid:
+            # None for Any/ObjectRef fields, 0 for scalars).
+            user_init = cls.__dict__["__init__"]
+
+            @functools.wraps(user_init)
+            def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+                actual_type_info = type(self).__tvm_ffi_type_info__
+                _ffi_alloc_empty(self, actual_type_info.type_index)
+                user_init(self, *args, **kwargs)
+
+            setattr(cls, "__init__", __init__)
         return
     type_info: TypeInfo | None = getattr(cls, "__tvm_ffi_type_info__", None)
     if type_info is None:
diff --git a/src/ffi/extra/dataclass.cc b/src/ffi/extra/dataclass.cc
index 8b5eebb..f4f55b2 100644
--- a/src/ffi/extra/dataclass.cc
+++ b/src/ffi/extra/dataclass.cc
@@ -1979,6 +1979,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::GlobalDef().def("ffi.MakeFieldSetter", MakeFieldSetter);
   refl::GlobalDef().def("ffi.MakeFFINew", MakeFFINew);
   refl::GlobalDef().def("ffi.RegisterAutoInit", refl::RegisterAutoInit);
+  // Create an empty (zero-initialized) object by type index.
+  // Used by Python super().__init__() in @py_class(init=False) subclasses.
+  refl::GlobalDef().def("ffi.NewEmpty", [](int32_t type_index) -> ObjectRef {
+    return ObjectRef(CreateEmptyObject(TVMFFIGetTypeInfo(type_index)));
+  });
   // Deep copy
   refl::EnsureTypeAttrColumn(refl::type_attr::kShallowCopy);
   refl::GlobalDef().def("ffi.DeepCopy", DeepCopy);
diff --git a/tests/python/test_dataclass_py_class.py 
b/tests/python/test_dataclass_py_class.py
index f9dc534..e1d3dac 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -4561,3 +4561,283 @@ class TestPyMethodIntrospection:
         # system methods still present
         assert "__ffi_init__" in names
         assert "__ffi_shallow_copy__" in names
+
+
+# ---------------------------------------------------------------------------
+# super().__init__() support for @py_class(init=False) subclasses
+# ---------------------------------------------------------------------------
+class TestSuperInitPattern:
+    """Regression tests for using ``super().__init__()`` + field assignment
+    in ``@py_class(init=False)`` custom ``__init__`` methods.
+
+    Previously, ``super().__init__()`` would dispatch to the parent's
+    auto-generated ``__init__`` which called ``self.__ffi_init__()`` with
+    the **child** type's C++ constructor (requiring field arguments that
+    weren't provided), causing a crash.
+    """
+
+    def test_basic_super_init_with_field_setters(self) -> None:
+        """The original crash scenario: init=False with super().__init__() and 
field setters."""
+
+        @py_class(_unique_key("SIBase"))
+        class SIBase(Object):
+            pass
+
+        @py_class(_unique_key("SIChild"), init=False)
+        class SIChild(SIBase):
+            x: int
+            y: str
+
+            def __init__(self, x: int, y: str) -> None:
+                super().__init__()
+                self.x = x
+                self.y = y
+
+        obj = SIChild(42, "hello")
+        assert obj.x == 42
+        assert obj.y == "hello"
+
+    def test_super_init_deep_hierarchy(self) -> None:
+        """super().__init__() through multiple levels of py_class 
inheritance."""
+
+        @py_class(_unique_key("SIDH_Node"))
+        class Node(Object):
+            pass
+
+        @py_class(_unique_key("SIDH_BaseType"))
+        class BaseType(Node):
+            pass
+
+        @py_class(_unique_key("SIDH_PtrType"), init=False)
+        class PtrType(BaseType):
+            base_type: Any
+            specifiers: list
+            use_bracket: bool
+
+            def __init__(
+                self,
+                base_type: Any,
+                specifiers: Optional[list] = None,
+                use_bracket: bool = False,
+            ) -> None:
+                super().__init__()
+                self.base_type = base_type
+                self.specifiers = list(specifiers) if specifiers else []
+                self.use_bracket = use_bracket
+
+        void_p = PtrType("void")
+        assert void_p.base_type == "void"
+        assert list(void_p.specifiers) == []
+        assert void_p.use_bracket is False
+
+        int_p = PtrType("int", ["const", "volatile"], True)
+        assert int_p.base_type == "int"
+        assert list(int_p.specifiers) == ["const", "volatile"]
+        assert int_p.use_bracket is True
+
+    def test_super_init_with_defaults_from_calloc(self) -> None:
+        """Fields not set after super().__init__() should be 
zero-initialized."""
+
+        @py_class(_unique_key("SIDef"))
+        class SIDefBase(Object):
+            pass
+
+        @py_class(_unique_key("SIDef_Child"), init=False)
+        class SIDefChild(SIDefBase):
+            a: int
+            b: float
+            c: bool
+            d: Any
+
+            def __init__(self) -> None:
+                super().__init__()
+                # Don't set any fields — they should be zero/None from calloc
+
+        obj = SIDefChild()
+        assert obj.a == 0
+        assert obj.b == 0.0
+        assert obj.c is False
+        assert obj.d is None
+
+    def test_super_init_intermediate_custom_init(self) -> None:
+        """Chained super().__init__() through an intermediate init=False 
class."""
+
+        @py_class(_unique_key("SIChain_A"))
+        class A(Object):
+            pass
+
+        @py_class(_unique_key("SIChain_B"), init=False)
+        class B(A):
+            x: int
+
+            def __init__(self, x: int) -> None:
+                super().__init__()
+                self.x = x
+
+        @py_class(_unique_key("SIChain_C"), init=False)
+        class C(B):
+            y: str
+
+            def __init__(self, x: int, y: str) -> None:
+                super().__init__(x)
+                self.y = y
+
+        obj = C(10, "world")
+        assert obj.x == 10
+        assert obj.y == "world"
+
+    def test_super_init_intermediate_auto_init(self) -> None:
+        """super().__init__() where the intermediate parent has init=True 
(auto-generated)."""
+
+        @py_class(_unique_key("SIAI_Mid"))
+        class Mid(Object):
+            a: int
+
+        @py_class(_unique_key("SIAI_Leaf"), init=False)
+        class Leaf(Mid):
+            b: str
+
+            def __init__(self, a: int, b: str) -> None:
+                super().__init__()  # type: ignore[missing-argument]
+                self.a = a
+                self.b = b
+
+        obj = Leaf(5, "hi")
+        assert obj.a == 5
+        assert obj.b == "hi"
+
+    def test_normal_init_unaffected(self) -> None:
+        """Normal init=True construction must still work correctly."""
+
+        @py_class(_unique_key("SINorm"))
+        class SINorm(Object):
+            x: int
+            y: str
+
+        obj = SINorm(1, "a")
+        assert obj.x == 1
+        assert obj.y == "a"
+
+    def test_non_pyclass_subclass_no_args_errors(self) -> None:
+        """A non-py_class subclass calling parent init with no args should 
still error
+        for required fields (not silently create an empty object).
+        """
+
+        @py_class(_unique_key("SINonPC"))
+        class SINonPC(Object):
+            x: int
+
+        class Plain(SINonPC):
+            pass
+
+        with pytest.raises(TypeError):
+            Plain()  # type: ignore[missing-argument]
+
+    def test_super_init_isinstance(self) -> None:
+        """Objects created via super().__init__() pattern have correct 
isinstance."""
+
+        @py_class(_unique_key("SIInst_B"))
+        class Base(Object):
+            pass
+
+        @py_class(_unique_key("SIInst_C"), init=False)
+        class Child(Base):
+            val: int
+
+            def __init__(self, val: int) -> None:
+                super().__init__()
+                self.val = val
+
+        obj = Child(99)
+        assert isinstance(obj, Child)
+        assert isinstance(obj, Base)
+        assert isinstance(obj, Object)
+
+    def test_super_init_field_overwrite(self) -> None:
+        """Fields can be overwritten multiple times after 
super().__init__()."""
+
+        @py_class(_unique_key("SIOverwrite_B"))
+        class Base(Object):
+            pass
+
+        @py_class(_unique_key("SIOverwrite_C"), init=False)
+        class Child(Base):
+            x: int
+
+            def __init__(self, x: int) -> None:
+                super().__init__()
+                self.x = 0
+                self.x = x
+
+        obj = Child(42)
+        assert obj.x == 42
+
+    def test_super_init_copy_deepcopy(self) -> None:
+        """copy/deepcopy work on objects created via the super().__init__() 
pattern."""
+
+        @py_class(_unique_key("SICopyBase"))
+        class SICopyBase(Object):
+            pass
+
+        @py_class(_unique_key("SICopy"), init=False)
+        class SICopy(SICopyBase):
+            x: int
+            y: str
+
+            def __init__(self, x: int, y: str) -> None:
+                super().__init__()
+                self.x = x
+                self.y = y
+
+        obj = SICopy(42, "hello")
+        obj2 = copy.copy(obj)
+        assert obj2.x == 42
+        assert obj2.y == "hello"
+        assert not obj.same_as(obj2)
+
+        obj3 = copy.deepcopy(obj)
+        assert obj3.x == 42
+        assert obj3.y == "hello"
+        assert not obj.same_as(obj3)
+
+    def test_super_init_direct_from_object(self) -> None:
+        """super().__init__() works when inheriting directly from Object (no 
intermediate)."""
+
+        @py_class(_unique_key("SIDirect"), init=False)
+        class SIDirect(Object):
+            x: int
+            y: str
+
+            def __init__(self, x: int, y: str) -> None:
+                super().__init__()
+                self.x = x
+                self.y = y
+
+        obj = SIDirect(10, "direct")
+        assert obj.x == 10
+        assert obj.y == "direct"
+        assert isinstance(obj, Object)
+
+    def test_super_init_direct_from_object_copy(self) -> None:
+        """copy/deepcopy work for init=False classes inheriting directly from 
Object."""
+
+        @py_class(_unique_key("SIDirectCopy"), init=False)
+        class SIDirectCopy(Object):
+            x: int
+            y: str
+
+            def __init__(self, x: int, y: str) -> None:
+                super().__init__()
+                self.x = x
+                self.y = y
+
+        obj = SIDirectCopy(42, "hello")
+        obj2 = copy.copy(obj)
+        assert obj2.x == 42
+        assert obj2.y == "hello"
+        assert not obj.same_as(obj2)
+
+        obj3 = copy.deepcopy(obj)
+        assert obj3.x == 42
+        assert obj3.y == "hello"
+        assert not obj.same_as(obj3)

Reply via email to