This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 2b59825 fix(py_class): support `super().__init__()` in `init=False`
subclasses (#532)
2b59825 is described below
commit 2b59825e87a70dac521490802306e461024cc4d1
Author: Junru Shao <[email protected]>
AuthorDate: Fri Apr 10 11:26:30 2026 -0700
fix(py_class): support `super().__init__()` in `init=False` subclasses
(#532)
## Summary
- Fix segfault when `@py_class(init=False)` subclasses use
`super().__init__()` followed by field assignment
- Add `ffi.NewEmpty` C++ function to allocate zero-initialized objects
by type index
- Detect super-init-from-subclass pattern in `_make_init()` and
pre-allocate instead of dispatching to `__ffi_init__`
## Motivation
The common Python pattern of defining a custom `__init__` that calls
`super().__init__()` then sets fields crashes when used with
`@py_class(init=False)`:
```python
@py_class(init=False)
class PointerType(Node):
element_type: Object
def __init__(self, element_type):
super().__init__() # parent's auto-init dispatches to child's
C++ ctor → crash
self.element_type = element_type
```
The parent's auto-generated `__init__` forwards to `__ffi_init__` using
the child type's C++ constructor, which expects field arguments that
were not provided.
## Design
**C++ side** (`src/ffi/extra/dataclass.cc`): Register
`ffi.NewEmpty(type_index) -> ObjectRef` that allocates a
zero-initialized object via `CreateEmptyObject`. The calloc'd state is
valid: `None` for `Any`/`ObjectRef` fields, `0` for scalars.
**Python side** (`python/tvm_ffi/registry.py`):
- `_ffi_alloc_empty(obj)`: Calls `ffi.NewEmpty` to allocate an empty FFI
object for `type(obj)` and moves the handle into `obj`. No-op if already
allocated.
- `_is_super_init_from_subclass(self)`: Compares
`type(self).__tvm_ffi_type_info__` identity against the declaring
class's `type_info`. Returns `True` only for registered `@py_class`
subclasses — undecorated subclasses inherit the parent's `type_info` so
the identity check correctly filters them out.
- Both `_make_init` code paths (with/without `__post_init__`) intercept
the no-args super-init-from-subclass call and allocate an empty object
instead of dispatching to `__ffi_init__`.
## Test plan
- [x] `uv run pytest -vvs
tests/python/test_dataclass_py_class.py::TestSuperInitPattern` — 12
tests pass
- [x] `uv run pytest -vvs tests/python/` — 1970 passed, 38 skipped, 3
xfailed, 0 failures
Tests cover: basic super-init, deep hierarchy
(Node→BaseType→PointerType), calloc defaults, intermediate custom/auto
inits, normal init unaffected, non-py_class subclass error handling,
isinstance checks, field overwrite, copy/deepcopy, and
direct-from-Object inheritance.
---
python/tvm_ffi/registry.py | 46 ++++++
src/ffi/extra/dataclass.cc | 5 +
tests/python/test_dataclass_py_class.py | 280 ++++++++++++++++++++++++++++++++
3 files changed, 331 insertions(+)
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 93ad92c..7d6359a 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -18,6 +18,7 @@
from __future__ import annotations
+import functools
import inspect
import json
import sys
@@ -346,6 +347,16 @@ def init_ffi_api(namespace: str, target_module_name: str |
None = None) -> None:
__SENTINEL = object()
[email protected]_cache(maxsize=None)
+def _new_empty() -> Any:
+ return core._get_global_func("ffi.NewEmpty", False)
+
+
+def _ffi_alloc_empty(obj: Any, type_index: int) -> None:
+ if obj.__chandle__() == 0:
+ obj.__init_handle_by_constructor__(_new_empty(), type_index)
+
+
def _make_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Build a Python ``__init__`` that delegates to the C++ auto-generated
``__ffi_init__``.
@@ -354,6 +365,20 @@ def _make_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]:
signature. The ``__init__`` body is a trivial adapter — all validation
(too many positional, duplicates, missing required, kw_only enforcement,
unknown kwargs) is handled by C++.
+
+ When this generated ``__init__`` is called via ``super().__init__()`` from
+ a subclass (detected by ``type(self) is not type_cls`` with no arguments),
+ it allocates an empty zero-initialized FFI object of the correct child
+ type instead of forwarding to ``__ffi_init__``. This supports the common
+ Python pattern::
+
+ @py_class(init=False)
+ class Child(Parent):
+ x: int
+
+ def __init__(self, x):
+ super().__init__()
+ self.x = x
"""
sig = _make_init_signature(type_info)
kwargs_obj = core.KWARGS
@@ -362,6 +387,10 @@ def _make_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]:
if has_post_init:
def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+ actual_type_info = type(self).__tvm_ffi_type_info__
+ if not args and not kwargs and type_info is not actual_type_info:
+ _ffi_alloc_empty(self, actual_type_info.type_index)
+ return
ffi_args: list[Any] = list(args)
ffi_args.append(kwargs_obj)
for key, val in kwargs.items():
@@ -373,6 +402,10 @@ def _make_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]:
else:
def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+ actual_type_info = type(self).__tvm_ffi_type_info__
+ if not args and not kwargs and type_info is not actual_type_info:
+ _ffi_alloc_empty(self, actual_type_info.type_index)
+ return
ffi_args: list[Any] = list(args)
ffi_args.append(kwargs_obj)
for key, val in kwargs.items():
@@ -532,6 +565,19 @@ def _install_init(cls: type, *, enabled: bool) -> None:
``__init__``.
"""
if "__init__" in cls.__dict__:
+ if not enabled:
+ # Wrap user's __init__ to pre-allocate the C++ object so that
+ # field setters work immediately (the calloc'd object is valid:
+ # None for Any/ObjectRef fields, 0 for scalars).
+ user_init = cls.__dict__["__init__"]
+
+ @functools.wraps(user_init)
+ def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+ actual_type_info = type(self).__tvm_ffi_type_info__
+ _ffi_alloc_empty(self, actual_type_info.type_index)
+ user_init(self, *args, **kwargs)
+
+ setattr(cls, "__init__", __init__)
return
type_info: TypeInfo | None = getattr(cls, "__tvm_ffi_type_info__", None)
if type_info is None:
diff --git a/src/ffi/extra/dataclass.cc b/src/ffi/extra/dataclass.cc
index 8b5eebb..f4f55b2 100644
--- a/src/ffi/extra/dataclass.cc
+++ b/src/ffi/extra/dataclass.cc
@@ -1979,6 +1979,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef().def("ffi.MakeFieldSetter", MakeFieldSetter);
refl::GlobalDef().def("ffi.MakeFFINew", MakeFFINew);
refl::GlobalDef().def("ffi.RegisterAutoInit", refl::RegisterAutoInit);
+ // Create an empty (zero-initialized) object by type index.
+ // Used by Python super().__init__() in @py_class(init=False) subclasses.
+ refl::GlobalDef().def("ffi.NewEmpty", [](int32_t type_index) -> ObjectRef {
+ return ObjectRef(CreateEmptyObject(TVMFFIGetTypeInfo(type_index)));
+ });
// Deep copy
refl::EnsureTypeAttrColumn(refl::type_attr::kShallowCopy);
refl::GlobalDef().def("ffi.DeepCopy", DeepCopy);
diff --git a/tests/python/test_dataclass_py_class.py
b/tests/python/test_dataclass_py_class.py
index f9dc534..e1d3dac 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -4561,3 +4561,283 @@ class TestPyMethodIntrospection:
# system methods still present
assert "__ffi_init__" in names
assert "__ffi_shallow_copy__" in names
+
+
+# ---------------------------------------------------------------------------
+# super().__init__() support for @py_class(init=False) subclasses
+# ---------------------------------------------------------------------------
+class TestSuperInitPattern:
+ """Regression tests for using ``super().__init__()`` + field assignment
+ in ``@py_class(init=False)`` custom ``__init__`` methods.
+
+ Previously, ``super().__init__()`` would dispatch to the parent's
+ auto-generated ``__init__`` which called ``self.__ffi_init__()`` with
+ the **child** type's C++ constructor (requiring field arguments that
+ weren't provided), causing a crash.
+ """
+
+ def test_basic_super_init_with_field_setters(self) -> None:
+ """The original crash scenario: init=False with super().__init__() and
field setters."""
+
+ @py_class(_unique_key("SIBase"))
+ class SIBase(Object):
+ pass
+
+ @py_class(_unique_key("SIChild"), init=False)
+ class SIChild(SIBase):
+ x: int
+ y: str
+
+ def __init__(self, x: int, y: str) -> None:
+ super().__init__()
+ self.x = x
+ self.y = y
+
+ obj = SIChild(42, "hello")
+ assert obj.x == 42
+ assert obj.y == "hello"
+
+ def test_super_init_deep_hierarchy(self) -> None:
+ """super().__init__() through multiple levels of py_class
inheritance."""
+
+ @py_class(_unique_key("SIDH_Node"))
+ class Node(Object):
+ pass
+
+ @py_class(_unique_key("SIDH_BaseType"))
+ class BaseType(Node):
+ pass
+
+ @py_class(_unique_key("SIDH_PtrType"), init=False)
+ class PtrType(BaseType):
+ base_type: Any
+ specifiers: list
+ use_bracket: bool
+
+ def __init__(
+ self,
+ base_type: Any,
+ specifiers: Optional[list] = None,
+ use_bracket: bool = False,
+ ) -> None:
+ super().__init__()
+ self.base_type = base_type
+ self.specifiers = list(specifiers) if specifiers else []
+ self.use_bracket = use_bracket
+
+ void_p = PtrType("void")
+ assert void_p.base_type == "void"
+ assert list(void_p.specifiers) == []
+ assert void_p.use_bracket is False
+
+ int_p = PtrType("int", ["const", "volatile"], True)
+ assert int_p.base_type == "int"
+ assert list(int_p.specifiers) == ["const", "volatile"]
+ assert int_p.use_bracket is True
+
+ def test_super_init_with_defaults_from_calloc(self) -> None:
+ """Fields not set after super().__init__() should be
zero-initialized."""
+
+ @py_class(_unique_key("SIDef"))
+ class SIDefBase(Object):
+ pass
+
+ @py_class(_unique_key("SIDef_Child"), init=False)
+ class SIDefChild(SIDefBase):
+ a: int
+ b: float
+ c: bool
+ d: Any
+
+ def __init__(self) -> None:
+ super().__init__()
+ # Don't set any fields — they should be zero/None from calloc
+
+ obj = SIDefChild()
+ assert obj.a == 0
+ assert obj.b == 0.0
+ assert obj.c is False
+ assert obj.d is None
+
+ def test_super_init_intermediate_custom_init(self) -> None:
+ """Chained super().__init__() through an intermediate init=False
class."""
+
+ @py_class(_unique_key("SIChain_A"))
+ class A(Object):
+ pass
+
+ @py_class(_unique_key("SIChain_B"), init=False)
+ class B(A):
+ x: int
+
+ def __init__(self, x: int) -> None:
+ super().__init__()
+ self.x = x
+
+ @py_class(_unique_key("SIChain_C"), init=False)
+ class C(B):
+ y: str
+
+ def __init__(self, x: int, y: str) -> None:
+ super().__init__(x)
+ self.y = y
+
+ obj = C(10, "world")
+ assert obj.x == 10
+ assert obj.y == "world"
+
+ def test_super_init_intermediate_auto_init(self) -> None:
+ """super().__init__() where the intermediate parent has init=True
(auto-generated)."""
+
+ @py_class(_unique_key("SIAI_Mid"))
+ class Mid(Object):
+ a: int
+
+ @py_class(_unique_key("SIAI_Leaf"), init=False)
+ class Leaf(Mid):
+ b: str
+
+ def __init__(self, a: int, b: str) -> None:
+ super().__init__() # type: ignore[missing-argument]
+ self.a = a
+ self.b = b
+
+ obj = Leaf(5, "hi")
+ assert obj.a == 5
+ assert obj.b == "hi"
+
+ def test_normal_init_unaffected(self) -> None:
+ """Normal init=True construction must still work correctly."""
+
+ @py_class(_unique_key("SINorm"))
+ class SINorm(Object):
+ x: int
+ y: str
+
+ obj = SINorm(1, "a")
+ assert obj.x == 1
+ assert obj.y == "a"
+
+ def test_non_pyclass_subclass_no_args_errors(self) -> None:
+ """A non-py_class subclass calling parent init with no args should
still error
+ for required fields (not silently create an empty object).
+ """
+
+ @py_class(_unique_key("SINonPC"))
+ class SINonPC(Object):
+ x: int
+
+ class Plain(SINonPC):
+ pass
+
+ with pytest.raises(TypeError):
+ Plain() # type: ignore[missing-argument]
+
+ def test_super_init_isinstance(self) -> None:
+ """Objects created via super().__init__() pattern have correct
isinstance."""
+
+ @py_class(_unique_key("SIInst_B"))
+ class Base(Object):
+ pass
+
+ @py_class(_unique_key("SIInst_C"), init=False)
+ class Child(Base):
+ val: int
+
+ def __init__(self, val: int) -> None:
+ super().__init__()
+ self.val = val
+
+ obj = Child(99)
+ assert isinstance(obj, Child)
+ assert isinstance(obj, Base)
+ assert isinstance(obj, Object)
+
+ def test_super_init_field_overwrite(self) -> None:
+ """Fields can be overwritten multiple times after
super().__init__()."""
+
+ @py_class(_unique_key("SIOverwrite_B"))
+ class Base(Object):
+ pass
+
+ @py_class(_unique_key("SIOverwrite_C"), init=False)
+ class Child(Base):
+ x: int
+
+ def __init__(self, x: int) -> None:
+ super().__init__()
+ self.x = 0
+ self.x = x
+
+ obj = Child(42)
+ assert obj.x == 42
+
+ def test_super_init_copy_deepcopy(self) -> None:
+ """copy/deepcopy work on objects created via the super().__init__()
pattern."""
+
+ @py_class(_unique_key("SICopyBase"))
+ class SICopyBase(Object):
+ pass
+
+ @py_class(_unique_key("SICopy"), init=False)
+ class SICopy(SICopyBase):
+ x: int
+ y: str
+
+ def __init__(self, x: int, y: str) -> None:
+ super().__init__()
+ self.x = x
+ self.y = y
+
+ obj = SICopy(42, "hello")
+ obj2 = copy.copy(obj)
+ assert obj2.x == 42
+ assert obj2.y == "hello"
+ assert not obj.same_as(obj2)
+
+ obj3 = copy.deepcopy(obj)
+ assert obj3.x == 42
+ assert obj3.y == "hello"
+ assert not obj.same_as(obj3)
+
+ def test_super_init_direct_from_object(self) -> None:
+ """super().__init__() works when inheriting directly from Object (no
intermediate)."""
+
+ @py_class(_unique_key("SIDirect"), init=False)
+ class SIDirect(Object):
+ x: int
+ y: str
+
+ def __init__(self, x: int, y: str) -> None:
+ super().__init__()
+ self.x = x
+ self.y = y
+
+ obj = SIDirect(10, "direct")
+ assert obj.x == 10
+ assert obj.y == "direct"
+ assert isinstance(obj, Object)
+
+ def test_super_init_direct_from_object_copy(self) -> None:
+ """copy/deepcopy work for init=False classes inheriting directly from
Object."""
+
+ @py_class(_unique_key("SIDirectCopy"), init=False)
+ class SIDirectCopy(Object):
+ x: int
+ y: str
+
+ def __init__(self, x: int, y: str) -> None:
+ super().__init__()
+ self.x = x
+ self.y = y
+
+ obj = SIDirectCopy(42, "hello")
+ obj2 = copy.copy(obj)
+ assert obj2.x == 42
+ assert obj2.y == "hello"
+ assert not obj.same_as(obj2)
+
+ obj3 = copy.deepcopy(obj)
+ assert obj3.x == 42
+ assert obj3.y == "hello"
+ assert not obj.same_as(obj3)