This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 6a61156c feat(python): add field registration, _register_py_class, and
Field descriptor (#505)
6a61156c is described below
commit 6a61156cb6780101d633a72a5151b65602f8b7e2
Author: Junru Shao <[email protected]>
AuthorDate: Fri Mar 20 12:05:12 2026 -0700
feat(python): add field registration, _register_py_class, and Field
descriptor (#505)
## Summary
- Implement the Python-side type registration pipeline for
`@py_class`-decorated classes: allocates dynamic type indices, computes
native field layouts, registers getters/setters with type-converting
callbacks, and wires up `__ffi_new__`/`__ffi_init__` for automatic
construction.
- Add `Field` descriptor and `field()` factory function mirroring stdlib
`dataclasses.field()` API.
- Fix dunder family-skip logic so user-defined `__eq__` suppresses
generated `__ne__` (and vice versa), same for ordering operators.
## Architecture
Three-layer implementation:
- **Cython `object.pxi`**: `_register_py_class` (dynamic type index
allocation, ancestor chain, TypeInfo insertion) and `_rollback_py_class`
(cleanup on phase-2 failure).
- **Cython `type_info.pxi`**: Field registration engine — computes
per-field native layout from `parent_type_info.total_size`, obtains C
getter/setter function pointers, and registers via
`TVMFFITypeRegisterField`. Installs `MakeFFINew`/`RegisterAutoInit`.
- **Python `field.py`**: Pure-Python `Field` descriptor with `__slots__`
matching stdlib `dataclasses.field()` signature, plus `KW_ONLY` sentinel
(3.9-compat).
## Behavioral Changes
- `__post_init__` is now called after `__ffi_init__` if defined
(previously silently ignored).
- Dunder family-skip: defining any member of `{__eq__, __ne__}` or
`{__lt__, __le__, __gt__, __ge__}` suppresses the entire family's
auto-generation.
## Test plan
- [x] Python tests pass: `uv run pytest -vvs tests/python`
- [ ] Full `@py_class` end-to-end tests (in follow-up commits)
- [x] Pre-commit lints pass
---
python/tvm_ffi/cython/object.pxi | 88 +++++++++++
python/tvm_ffi/cython/type_info.pxi | 291 ++++++++++++++++++++++++++++++++++++
python/tvm_ffi/dataclasses/field.py | 204 +++++++++++++++++++++++++
python/tvm_ffi/registry.py | 44 ++++--
4 files changed, 618 insertions(+), 9 deletions(-)
diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 489837a5..0f4fdabe 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -573,6 +573,94 @@ def _lookup_or_register_type_info_from_type_key(type_key:
str) -> TypeInfo:
return info
+def _register_py_class(parent_type_info, str type_key, object type_cls):
+ """Register a new Python-defined TVM-FFI type.
+
+ Allocates a dynamic type index for *type_key* as a child of
+ *parent_type_info* and registers it in the global type tables.
+
+ Parameters
+ ----------
+ parent_type_info : TypeInfo
+ The parent type's TypeInfo (e.g., Object's TypeInfo).
+ type_key : str
+ The unique type key string for the new type.
+ type_cls : type
+ The Python class to associate with this type.
+
+ Returns
+ -------
+ TypeInfo
+ The newly created TypeInfo with ``fields=None`` (pending registration).
+
+ Raises
+ ------
+ ValueError
+ If *type_key* is already registered.
+ """
+ # Reject duplicate type keys
+ if type_key in TYPE_KEY_TO_INFO:
+ raise ValueError(
+ f"Type key '{type_key}' is already registered"
+ )
+
+ cdef int32_t parent_type_index = parent_type_info.type_index
+ cdef int32_t parent_type_depth = len(parent_type_info.type_ancestors)
+ cdef int32_t type_depth = parent_type_depth + 1
+ cdef ByteArrayArg type_key_arg = ByteArrayArg(c_str(type_key))
+ cdef int32_t type_index
+
+ # Allocate a new type index
+ # static_type_index=-1 means dynamic allocation
+ # num_child_slots=0, child_slots_can_overflow=1
+ type_index = TVMFFITypeGetOrAllocIndex(
+ type_key_arg.cptr(),
+ -1, # static_type_index (dynamic)
+ type_depth,
+ 0, # num_child_slots
+ 1, # child_slots_can_overflow
+ parent_type_index,
+ )
+
+ # Build ancestors list
+ cdef list ancestors = list(parent_type_info.type_ancestors)
+ ancestors.append(parent_type_index)
+
+ # Create TypeInfo with fields=None (pending _register_fields call)
+ cdef object info = TypeInfo(
+ type_cls=type_cls,
+ type_index=type_index,
+ type_key=type_key,
+ type_ancestors=ancestors,
+ fields=None,
+ methods=[],
+ parent_type_info=parent_type_info,
+ )
+
+ _update_registry(type_index, type_key, info, type_cls)
+ return info
+
+
+def _rollback_py_class(object type_info):
+ """Roll back a ``_register_py_class`` call from the Python-level registry.
+
+ Called by ``@py_class`` when phase-2 (field validation) fails, so
+ the type key can be reused after the user fixes the error. The
+ C-level type index is permanently consumed (cannot be reclaimed),
+ but the Python dicts are cleaned up so that a retry does not hit
+ "already registered".
+ """
+ cdef int32_t idx = type_info.type_index
+ cdef str key = type_info.type_key
+ cdef object cls = type_info.type_cls
+ TYPE_KEY_TO_INFO.pop(key, None)
+ if cls is not None:
+ TYPE_CLS_TO_INFO.pop(cls, None)
+ if 0 <= idx < len(TYPE_INDEX_TO_INFO):
+ TYPE_INDEX_TO_INFO[idx] = None
+ TYPE_INDEX_TO_CLS[idx] = None
+
+
def _lookup_type_attr(type_index: int32_t, attr_key: str) -> Any:
cdef ByteArrayArg attr_key_bytes = ByteArrayArg(c_str(attr_key))
cdef const TVMFFITypeAttrColumn* column =
TVMFFIGetTypeAttrColumn(&attr_key_bytes.cdata)
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 024ba268..7a0b8c5e 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -704,6 +704,297 @@ class TypeInfo:
end = f_end
return (end + 7) & ~7 # align to 8 bytes
+ def _register_fields(self, fields):
+ """Register Field descriptors and set up __ffi_new__/__ffi_init__.
+
+ Delegates to the module-level _register_fields function,
+ stores the resulting list[TypeField] on self.fields,
+ then reads back methods registered by C++ via _register_methods.
+
+ Can only be called once (fields must be None beforehand).
+ """
+ assert self.fields is None, (
+ f"_register_fields already called for {self.type_key!r}"
+ )
+ self.fields = _register_fields(self, fields)
+ self._register_methods()
+
+ def _register_methods(self):
+ """Read methods from the C type table into self.methods.
+
+ Called after C++ registers __ffi_init__, __ffi_shallow_copy__, etc.
+ """
+ cdef const TVMFFITypeInfo* c_info = TVMFFIGetTypeInfo(self.type_index)
+ cdef const TVMFFIMethodInfo* mi
+ self.methods = []
+ for i in range(c_info.num_methods):
+ mi = &(c_info.methods[i])
+ self.methods.append(TypeMethod(
+ name=bytearray_to_str(&mi.name),
+ doc=bytearray_to_str(&mi.doc) if mi.doc.size != 0 else None,
+ func=_get_method_from_method_info(mi),
+ is_static=(mi.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod)
!= 0,
+ metadata=json.loads(bytearray_to_str(&mi.metadata)) if
mi.metadata.size != 0 else {},
+ ))
+
+
+# ---------------------------------------------------------------------------
+# Python-defined type field registration helpers
+# ---------------------------------------------------------------------------
+
+# Native layout for each TypeSchema origin: (size, alignment,
field_static_type_index)
+_ORIGIN_NATIVE_LAYOUT = {
+ "int": (8, 8, kTVMFFIInt),
+ "float": (8, 8, kTVMFFIFloat),
+ "bool": (1, 1, kTVMFFIBool),
+ "ctypes.c_void_p": (8, 8, kTVMFFIOpaquePtr),
+ "dtype": (4, 2, kTVMFFIDataType),
+ "Device": (8, 4, kTVMFFIDevice),
+ "Any": (16, 8, -1), # kTVMFFIAny = -1
+ # str/bytes can be SmallStr/SmallBytes (inline, not ObjectRef),
+ # so store as Any (16 bytes) to handle both inline and heap variants.
+ "str": (16, 8, -1),
+ "bytes": (16, 8, -1),
+ # Optional/Union can hold any type including inline scalars
+ "Optional": (16, 8, -1),
+ "Union": (16, 8, -1),
+}
+
+cdef _register_one_field(
+ int32_t type_index,
+ object py_field,
+ int64_t offset,
+ int64_t size,
+ int64_t alignment,
+ int32_t field_type_index,
+ TVMFFIFieldGetter getter,
+ CObject setter_fn,
+):
+ """Build a TVMFFIFieldInfo and register it for the given type."""
+ cdef TVMFFIFieldInfo info
+ cdef int c_api_ret_code
+
+ # --- name ---
+ name_bytes = c_str(py_field.name)
+ cdef ByteArrayArg name_arg = ByteArrayArg(name_bytes)
+ info.name = name_arg.cdata
+
+ # --- doc ---
+ cdef ByteArrayArg doc_arg
+ if py_field.doc is not None:
+ doc_bytes = c_str(py_field.doc)
+ doc_arg = ByteArrayArg(doc_bytes)
+ info.doc = doc_arg.cdata
+ else:
+ info.doc.data = NULL
+ info.doc.size = 0
+
+ # --- metadata (JSON with type_schema) ---
+ metadata_str = json.dumps({"type_schema": py_field.ty.to_json()})
+ metadata_bytes = c_str(metadata_str)
+ cdef ByteArrayArg metadata_arg = ByteArrayArg(metadata_bytes)
+ info.metadata = metadata_arg.cdata
+
+ # --- flags ---
+ cdef int64_t flags = kTVMFFIFieldFlagBitMaskWritable |
kTVMFFIFieldFlagBitSetterIsFunctionObj
+ if py_field.default is not MISSING or py_field.default_factory is not
MISSING:
+ flags |= kTVMFFIFieldFlagBitMaskHasDefault
+ if py_field.default_factory is not MISSING:
+ flags |= kTVMFFIFieldFlagBitMaskDefaultFromFactory
+ if not py_field.init:
+ flags |= kTVMFFIFieldFlagBitMaskInitOff
+ if not py_field.repr:
+ flags |= kTVMFFIFieldFlagBitMaskReprOff
+ if not py_field.hash:
+ flags |= kTVMFFIFieldFlagBitMaskHashOff
+ if not py_field.compare:
+ flags |= kTVMFFIFieldFlagBitMaskCompareOff
+ if py_field.kw_only:
+ flags |= kTVMFFIFieldFlagBitMaskKwOnly
+ info.flags = flags
+
+ # --- native layout ---
+ info.size = size
+ info.alignment = alignment
+ info.offset = offset
+
+ # --- getter / setter ---
+ info.getter = getter
+ info.setter = <void*>setter_fn.chandle
+
+ # --- default value ---
+ cdef TVMFFIAny default_any
+ default_any.type_index = kTVMFFINone
+ default_any.v_int64 = 0
+ # Determine which Python object (if any) to store as the default.
+ # No memory leak: TVMFFIAny is a POD struct; TVMFFITypeRegisterField
+ # copies the bytes into the type table, which owns the reference.
+ cdef object default_obj = MISSING
+ if py_field.default is not MISSING:
+ default_obj = py_field.default
+ elif py_field.default_factory is not MISSING:
+ default_obj = py_field.default_factory
+ if default_obj is not MISSING:
+ TVMFFIPyPyObjectToFFIAny(
+ TVMFFIPyArgSetterFactory_,
+ <PyObject*>default_obj,
+ &default_any,
+ &c_api_ret_code
+ )
+ CHECK_CALL(c_api_ret_code)
+ info.default_value_or_factory = default_any
+
+ # --- field_static_type_index ---
+ info.field_static_type_index = field_type_index
+
+ CHECK_CALL(TVMFFITypeRegisterField(type_index, &info))
+
+
+cdef int _f_type_convert(void* type_converter, const TVMFFIAny* value,
TVMFFIAny* result) noexcept with gil:
+ """C callback for type conversion, called from C++ MakeFieldSetter.
+
+ Parameters
+ ----------
+ type_converter : void*
+ A PyObject* pointing to a _TypeConverter instance (borrowed reference).
+ value : const TVMFFIAny*
+ The packed value to convert (borrowed from the caller).
+ result : TVMFFIAny*
+ Output: the converted value (caller takes ownership).
+
+ Returns 0 on success, -1 on error (error stored in TLS via
set_last_ffi_error).
+ """
+ cdef TVMFFIAny temp
+ cdef _TypeConverter conv
+ cdef CAny cany
+ try:
+ # Unpack the packed AnyView to a Python object.
+ # We must IncRef if it's an object, because make_ret takes ownership.
+ temp = value[0]
+ if temp.type_index >= kTVMFFIStaticObjectBegin:
+ if temp.v_obj != NULL:
+ TVMFFIObjectIncRef(<TVMFFIObjectHandle>temp.v_obj)
+ py_value = make_ret(temp)
+ # Dispatch directly through the C-level converter
+ conv = <_TypeConverter>type_converter
+ cany = _type_convert_impl(conv, py_value)
+ # Transfer ownership from CAny to result (zero cany to prevent
double-free)
+ result[0] = cany.cdata
+ cany.cdata.type_index = kTVMFFINone
+ cany.cdata.v_int64 = 0
+ return 0
+ except Exception as err:
+ set_last_ffi_error(err)
+ return -1
+
+
+def _register_fields(type_info, fields):
+ """Register Field descriptors for a Python-defined type and set up
__ffi_new__/__ffi_init__.
+
+ For each Field:
+ 1. Computes native layout (size, alignment, offset)
+ 2. Obtains a C getter function pointer
+ 3. Creates a FunctionObj setter with type conversion
+ 4. Registers via TVMFFITypeRegisterField
+
+ After all fields, registers __ffi_new__ (object allocator) and
+ __ffi_init__ (auto-generated constructor).
+
+ Parameters
+ ----------
+ type_info : TypeInfo
+ The TypeInfo of the type being defined.
+ fields : list[Field]
+ The Field descriptors to register.
+
+ Returns
+ -------
+ list[TypeField]
+ The registered field descriptors.
+ """
+ cdef int32_t type_index = type_info.type_index
+ # Start field offsets AFTER all parent fields (not at fixed offset 24).
+ # This is critical for inheritance: child fields must not overlap parent
memory.
+ cdef int64_t current_offset = type_info.parent_type_info.total_size
+ cdef int64_t size, alignment
+ cdef int32_t field_type_index
+ cdef TVMFFIFieldGetter getter
+ cdef FieldGetter fgetter
+ cdef FieldSetter fsetter
+
+ # Get global functions
+ _get_field_getter = _get_global_func("ffi.GetFieldGetter", False)
+ _make_field_setter = _get_global_func("ffi.MakeFieldSetter", False)
+ _make_ffi_new = _get_global_func("ffi.MakeFFINew", False)
+ _register_auto_init = _get_global_func("ffi.RegisterAutoInit", False)
+
+ cdef list type_fields = []
+
+ for py_field in fields:
+ # 1. Get layout
+ layout = _ORIGIN_NATIVE_LAYOUT.get(py_field.ty.origin, (8, 8,
kTVMFFIObject))
+ size = layout[0]
+ alignment = layout[1]
+ field_type_index = layout[2]
+
+ # 2. Compute offset (align up)
+ current_offset = (current_offset + alignment - 1) & ~(alignment - 1)
+ field_offset = current_offset
+ current_offset += size
+
+ # 3. Get getter (C function pointer) and setter (FunctionObj).
+ # Pointers are transported as int64_t through the FFI boundary.
+ getter =
<TVMFFIFieldGetter><int64_t>_get_field_getter(field_type_index)
+ setter_fn = <CObject>_make_field_setter(
+ field_type_index,
+ <int64_t><void*>py_field.ty._converter,
+ <int64_t>&_f_type_convert,
+ )
+
+ # 4. Register field in the C type table
+ _register_one_field(
+ type_index, py_field, field_offset, size, alignment,
+ field_type_index, getter, setter_fn,
+ )
+
+ # 5. Build the Python-side TypeField descriptor
+ fgetter = FieldGetter.__new__(FieldGetter)
+ fgetter.getter = getter
+ fgetter.offset = field_offset
+ fsetter = FieldSetter.__new__(FieldSetter)
+ fsetter.setter = <void*>setter_fn.chandle
+ fsetter.offset = field_offset
+ fsetter.flags = <int64_t>(kTVMFFIFieldFlagBitMaskWritable |
kTVMFFIFieldFlagBitSetterIsFunctionObj)
+ type_fields.append(
+ TypeField(
+ name=py_field.name,
+ doc=py_field.doc,
+ size=size,
+ offset=field_offset,
+ frozen=False,
+ metadata={"type_schema": py_field.ty.to_json()},
+ getter=fgetter,
+ setter=fsetter,
+ ty=py_field.ty,
+ c_init=py_field.init,
+ c_kw_only=py_field.kw_only,
+ c_has_default=(py_field.default is not MISSING or
py_field.default_factory is not MISSING),
+ )
+ )
+
+ # Align total size to 8 bytes
+ cdef int64_t total_size = (current_offset + 7) & ~7
+ if total_size < sizeof(TVMFFIObject):
+ total_size = sizeof(TVMFFIObject)
+
+ # 7. Register __ffi_new__ + deleter
+ _make_ffi_new(type_index, total_size)
+
+ # 8. Register __ffi_init__ (auto-generated constructor)
+ _register_auto_init(type_index)
+
+ return type_fields
+
def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[...,
Any]:
def wrapper(self: Any, *args: Any) -> Any:
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
new file mode 100644
index 00000000..9295f13a
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Field descriptor and ``field()`` helper for Python-defined TVM-FFI types."""
+
+from __future__ import annotations
+
+import sys
+from collections.abc import Callable
+from typing import Any
+
+from ..core import MISSING, TypeSchema
+
+# Re-export the stdlib KW_ONLY sentinel so type checkers recognise
+# ``_: KW_ONLY`` as a keyword-only boundary rather than a real field.
+# dataclasses.KW_ONLY was added in Python 3.10; on older runtimes we
+# define a class sentinel (a class, not an instance, so that ``_: KW_ONLY``
+# is a valid type annotation for static analysers targeting 3.9).
+if sys.version_info >= (3, 10):
+ from dataclasses import KW_ONLY
+else:
+
+ class KW_ONLY:
+ """Sentinel type: annotations after ``_: KW_ONLY`` are keyword-only."""
+
+
+class Field:
+ """Descriptor for a single field in a Python-defined TVM-FFI type.
+
+ When constructed directly (low-level API), *name* and *ty* should be
+ provided. When returned by :func:`field` (``@py_class`` workflow),
+ *name* and *ty* are ``None`` and filled in by the decorator.
+
+ Parameters
+ ----------
+ name : str | None
+ The field name. ``None`` when created via :func:`field`; filled
+ in by the ``@py_class`` decorator.
+ ty : TypeSchema | None
+ The type schema. ``None`` when created via :func:`field`; filled
+ in by the ``@py_class`` decorator.
+ default : object
+ Default value for the field. Mutually exclusive with *default_factory*.
+ ``MISSING`` when not set.
+ default_factory : Callable[[], object] | None
+ A zero-argument callable that produces the default value.
+ Mutually exclusive with *default*. ``None`` when not set.
+ init : bool
+ Whether this field appears in the auto-generated ``__init__``.
+ repr : bool
+ Whether this field appears in ``__repr__`` output.
+ hash : bool | None
+ Whether this field participates in recursive hashing.
+ ``None`` means "follow *compare*" (the native dataclass default).
+ compare : bool
+ Whether this field participates in recursive comparison.
+ kw_only : bool | None
+ Whether this field is keyword-only in ``__init__``.
+ ``None`` means "inherit from the decorator-level *kw_only* flag".
+ doc : str | None
+ Optional docstring for the field.
+
+ """
+
+ __slots__ = (
+ "compare",
+ "default",
+ "default_factory",
+ "doc",
+ "hash",
+ "init",
+ "kw_only",
+ "name",
+ "repr",
+ "ty",
+ )
+ name: str | None
+ ty: TypeSchema | None
+ default: object
+ default_factory: Callable[[], object] | None
+ init: bool
+ repr: bool
+ hash: bool | None
+ compare: bool
+ kw_only: bool | None
+ doc: str | None
+
+ def __init__(
+ self,
+ name: str | None = None,
+ ty: TypeSchema | None = None,
+ *,
+ default: object = MISSING,
+ default_factory: Callable[[], object] | None = MISSING, # type:
ignore[assignment]
+ init: bool = True,
+ repr: bool = True,
+ hash: bool | None = True,
+ compare: bool = False,
+ kw_only: bool | None = False,
+ doc: str | None = None,
+ ) -> None:
+ # MISSING means "parameter not provided".
+ # An explicit None from the user fails the callable() check,
+ # matching stdlib dataclasses semantics.
+ if default_factory is not MISSING:
+ if default is not MISSING:
+ raise ValueError("cannot specify both default and
default_factory")
+ if not callable(default_factory):
+ raise TypeError(
+ f"default_factory must be a callable, got
{type(default_factory).__name__}"
+ )
+ self.name = name
+ self.ty = ty
+ self.default = default
+ self.default_factory = default_factory
+ self.init = init
+ self.repr = repr
+ self.hash = hash
+ self.compare = compare
+ self.kw_only = kw_only
+ self.doc = doc
+
+
+def field(
+ *,
+ default: object = MISSING,
+ default_factory: Callable[[], object] | None = MISSING, # type:
ignore[assignment]
+ init: bool = True,
+ repr: bool = True,
+ hash: bool | None = None,
+ compare: bool = True,
+ kw_only: bool | None = None,
+ doc: str | None = None,
+) -> Any:
+ """Customize a field in a ``@py_class``-decorated class.
+
+ Returns a :class:`Field` sentinel whose *name* and *ty* are
+ ``None``. The ``@py_class`` decorator fills them in later
+ from the class annotations.
+
+ The return type is ``Any`` because ``dataclass_transform`` field
+ specifiers must be assignable to any annotated type (e.g.
+ ``x: int = field(default=0)``).
+
+ Parameters
+ ----------
+ default
+ Default value for the field. Mutually exclusive with
*default_factory*.
+ default_factory
+ A zero-argument callable that produces the default value.
+ Mutually exclusive with *default*.
+ init
+ Whether this field appears in the auto-generated ``__init__``.
+ repr
+ Whether this field appears in ``__repr__`` output.
+ hash
+ Whether this field participates in recursive hashing.
+ ``None`` (default) means "follow *compare*".
+ compare
+ Whether this field participates in recursive comparison.
+ kw_only
+ Whether this field is keyword-only in ``__init__``.
+ ``None`` means "inherit from the decorator-level ``kw_only`` flag".
+ doc
+ Optional docstring for the field.
+
+ Returns
+ -------
+ Any
+ A :class:`Field` sentinel recognised by ``@py_class``.
+
+ Examples
+ --------
+ .. code-block:: python
+
+ @py_class
+ class Point(Object):
+ x: float
+ y: float = field(default=0.0, repr=False)
+
+ """
+ return Field(
+ default=default,
+ default_factory=default_factory,
+ init=init,
+ repr=repr,
+ hash=hash,
+ compare=compare,
+ kw_only=kw_only,
+ doc=doc,
+ )
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 82540494..8a84bff1 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -356,16 +356,30 @@ def _make_init(type_cls: type, type_info: TypeInfo) ->
Callable[..., None]:
"""
sig = _make_init_signature(type_info)
kwargs_obj = core.KWARGS
+ has_post_init = hasattr(type_cls, "__post_init__")
- def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
- ffi_args: list[Any] = list(args)
- ffi_args.append(kwargs_obj)
- for key, val in kwargs.items():
- ffi_args.append(key)
- ffi_args.append(val)
- self.__ffi_init__(*ffi_args)
-
- __init__.__signature__ = sig # ty: ignore[unresolved-attribute]
+ if has_post_init:
+
+ def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+ ffi_args: list[Any] = list(args)
+ ffi_args.append(kwargs_obj)
+ for key, val in kwargs.items():
+ ffi_args.append(key)
+ ffi_args.append(val)
+ self.__ffi_init__(*ffi_args)
+ self.__post_init__()
+
+ else:
+
+ def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+ ffi_args: list[Any] = list(args)
+ ffi_args.append(kwargs_obj)
+ for key, val in kwargs.items():
+ ffi_args.append(key)
+ ffi_args.append(val)
+ self.__ffi_init__(*ffi_args)
+
+ __init__.__signature__ = sig # ty: ignore[invalid-assignment]
__init__.__qualname__ = f"{type_cls.__qualname__}.__init__"
__init__.__module__ = type_cls.__module__
return __init__
@@ -666,7 +680,19 @@ def _install_dataclass_dunders(
dunders["__gt__"] = __gt__
dunders["__ge__"] = __ge__
+ # Install dunders respecting user-defined overrides.
+ # Semantic families (__eq__/__ne__, __lt__/__le__/__gt__/__ge__) are
+ # treated as a unit: if the user defines any member, the whole family
+ # is skipped so generated and user-defined methods don't disagree.
+ _eq_family = {"__eq__", "__ne__"}
+ _order_family = {"__lt__", "__le__", "__gt__", "__ge__"}
+ skip_eq = bool(_eq_family & set(cls.__dict__))
+ skip_order = bool(_order_family & set(cls.__dict__))
for name, impl in dunders.items():
+ if name in _eq_family and skip_eq:
+ continue
+ if name in _order_family and skip_order:
+ continue
if name not in cls.__dict__:
setattr(cls, name, impl)