This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 6a61156c feat(python): add field registration, _register_py_class, and 
Field descriptor (#505)
6a61156c is described below

commit 6a61156cb6780101d633a72a5151b65602f8b7e2
Author: Junru Shao <[email protected]>
AuthorDate: Fri Mar 20 12:05:12 2026 -0700

    feat(python): add field registration, _register_py_class, and Field 
descriptor (#505)
    
    ## Summary
    
    - Implement the Python-side type registration pipeline for
    `@py_class`-decorated classes: allocates dynamic type indices, computes
    native field layouts, registers getters/setters with type-converting
    callbacks, and wires up `__ffi_new__`/`__ffi_init__` for automatic
    construction.
    - Add `Field` descriptor and `field()` factory function mirroring stdlib
    `dataclasses.field()` API.
    - Fix dunder family-skip logic so user-defined `__eq__` suppresses
    generated `__ne__` (and vice versa), same for ordering operators.
    
    ## Architecture
    
    Three-layer implementation:
    - **Cython `object.pxi`**: `_register_py_class` (dynamic type index
    allocation, ancestor chain, TypeInfo insertion) and `_rollback_py_class`
    (cleanup on phase-2 failure).
    - **Cython `type_info.pxi`**: Field registration engine — computes
    per-field native layout from `parent_type_info.total_size`, obtains C
    getter/setter function pointers, and registers via
    `TVMFFITypeRegisterField`. Installs `MakeFFINew`/`RegisterAutoInit`.
    - **Python `field.py`**: Pure-Python `Field` descriptor with `__slots__`
    matching stdlib `dataclasses.field()` signature, plus `KW_ONLY` sentinel
    (3.9-compat).
    
    ## Behavioral Changes
    
    - `__post_init__` is now called after `__ffi_init__` if defined
    (previously silently ignored).
    - Dunder family-skip: defining any member of `{__eq__, __ne__}` or
    `{__lt__, __le__, __gt__, __ge__}` suppresses the entire family's
    auto-generation.
    
    ## Test plan
    
    - [x] Python tests pass: `uv run pytest -vvs tests/python`
    - [ ] Full `@py_class` end-to-end tests (in follow-up commits)
    - [x] Pre-commit lints pass
---
 python/tvm_ffi/cython/object.pxi    |  88 +++++++++++
 python/tvm_ffi/cython/type_info.pxi | 291 ++++++++++++++++++++++++++++++++++++
 python/tvm_ffi/dataclasses/field.py | 204 +++++++++++++++++++++++++
 python/tvm_ffi/registry.py          |  44 ++++--
 4 files changed, 618 insertions(+), 9 deletions(-)

diff --git a/python/tvm_ffi/cython/object.pxi b/python/tvm_ffi/cython/object.pxi
index 489837a5..0f4fdabe 100644
--- a/python/tvm_ffi/cython/object.pxi
+++ b/python/tvm_ffi/cython/object.pxi
@@ -573,6 +573,94 @@ def _lookup_or_register_type_info_from_type_key(type_key: 
str) -> TypeInfo:
     return info
 
 
+def _register_py_class(parent_type_info, str type_key, object type_cls):
+    """Register a new Python-defined TVM-FFI type.
+
+    Allocates a dynamic type index for *type_key* as a child of
+    *parent_type_info* and registers it in the global type tables.
+
+    Parameters
+    ----------
+    parent_type_info : TypeInfo
+        The parent type's TypeInfo (e.g., Object's TypeInfo).
+    type_key : str
+        The unique type key string for the new type.
+    type_cls : type
+        The Python class to associate with this type.
+
+    Returns
+    -------
+    TypeInfo
+        The newly created TypeInfo with ``fields=None`` (pending registration).
+
+    Raises
+    ------
+    ValueError
+        If *type_key* is already registered.
+    """
+    # Reject duplicate type keys
+    if type_key in TYPE_KEY_TO_INFO:
+        raise ValueError(
+            f"Type key '{type_key}' is already registered"
+        )
+
+    cdef int32_t parent_type_index = parent_type_info.type_index
+    cdef int32_t parent_type_depth = len(parent_type_info.type_ancestors)
+    cdef int32_t type_depth = parent_type_depth + 1
+    cdef ByteArrayArg type_key_arg = ByteArrayArg(c_str(type_key))
+    cdef int32_t type_index
+
+    # Allocate a new type index
+    # static_type_index=-1 means dynamic allocation
+    # num_child_slots=0, child_slots_can_overflow=1
+    type_index = TVMFFITypeGetOrAllocIndex(
+        type_key_arg.cptr(),
+        -1,           # static_type_index (dynamic)
+        type_depth,
+        0,            # num_child_slots
+        1,            # child_slots_can_overflow
+        parent_type_index,
+    )
+
+    # Build ancestors list
+    cdef list ancestors = list(parent_type_info.type_ancestors)
+    ancestors.append(parent_type_index)
+
+    # Create TypeInfo with fields=None (pending _register_fields call)
+    cdef object info = TypeInfo(
+        type_cls=type_cls,
+        type_index=type_index,
+        type_key=type_key,
+        type_ancestors=ancestors,
+        fields=None,
+        methods=[],
+        parent_type_info=parent_type_info,
+    )
+
+    _update_registry(type_index, type_key, info, type_cls)
+    return info
+
+
+def _rollback_py_class(object type_info):
+    """Roll back a ``_register_py_class`` call from the Python-level registry.
+
+    Called by ``@py_class`` when phase-2 (field validation) fails, so
+    the type key can be reused after the user fixes the error.  The
+    C-level type index is permanently consumed (cannot be reclaimed),
+    but the Python dicts are cleaned up so that a retry does not hit
+    "already registered".
+    """
+    cdef int32_t idx = type_info.type_index
+    cdef str key = type_info.type_key
+    cdef object cls = type_info.type_cls
+    TYPE_KEY_TO_INFO.pop(key, None)
+    if cls is not None:
+        TYPE_CLS_TO_INFO.pop(cls, None)
+    if 0 <= idx < len(TYPE_INDEX_TO_INFO):
+        TYPE_INDEX_TO_INFO[idx] = None
+        TYPE_INDEX_TO_CLS[idx] = None
+
+
 def _lookup_type_attr(type_index: int32_t, attr_key: str) -> Any:
     cdef ByteArrayArg attr_key_bytes = ByteArrayArg(c_str(attr_key))
     cdef const TVMFFITypeAttrColumn* column = 
TVMFFIGetTypeAttrColumn(&attr_key_bytes.cdata)
diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 024ba268..7a0b8c5e 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -704,6 +704,297 @@ class TypeInfo:
                     end = f_end
         return (end + 7) & ~7  # align to 8 bytes
 
+    def _register_fields(self, fields):
+        """Register Field descriptors and set up __ffi_new__/__ffi_init__.
+
+        Delegates to the module-level _register_fields function,
+        stores the resulting list[TypeField] on self.fields,
+        then reads back methods registered by C++ via _register_methods.
+
+        Can only be called once (fields must be None beforehand).
+        """
+        assert self.fields is None, (
+            f"_register_fields already called for {self.type_key!r}"
+        )
+        self.fields = _register_fields(self, fields)
+        self._register_methods()
+
+    def _register_methods(self):
+        """Read methods from the C type table into self.methods.
+
+        Called after C++ registers __ffi_init__, __ffi_shallow_copy__, etc.
+        """
+        cdef const TVMFFITypeInfo* c_info = TVMFFIGetTypeInfo(self.type_index)
+        cdef const TVMFFIMethodInfo* mi
+        self.methods = []
+        for i in range(c_info.num_methods):
+            mi = &(c_info.methods[i])
+            self.methods.append(TypeMethod(
+                name=bytearray_to_str(&mi.name),
+                doc=bytearray_to_str(&mi.doc) if mi.doc.size != 0 else None,
+                func=_get_method_from_method_info(mi),
+                is_static=(mi.flags & kTVMFFIFieldFlagBitMaskIsStaticMethod) 
!= 0,
+                metadata=json.loads(bytearray_to_str(&mi.metadata)) if 
mi.metadata.size != 0 else {},
+            ))
+
+
+# ---------------------------------------------------------------------------
+# Python-defined type field registration helpers
+# ---------------------------------------------------------------------------
+
+# Native layout for each TypeSchema origin: (size, alignment, 
field_static_type_index)
+_ORIGIN_NATIVE_LAYOUT = {
+    "int": (8, 8, kTVMFFIInt),
+    "float": (8, 8, kTVMFFIFloat),
+    "bool": (1, 1, kTVMFFIBool),
+    "ctypes.c_void_p": (8, 8, kTVMFFIOpaquePtr),
+    "dtype": (4, 2, kTVMFFIDataType),
+    "Device": (8, 4, kTVMFFIDevice),
+    "Any": (16, 8, -1),  # kTVMFFIAny = -1
+    # str/bytes can be SmallStr/SmallBytes (inline, not ObjectRef),
+    # so store as Any (16 bytes) to handle both inline and heap variants.
+    "str": (16, 8, -1),
+    "bytes": (16, 8, -1),
+    # Optional/Union can hold any type including inline scalars
+    "Optional": (16, 8, -1),
+    "Union": (16, 8, -1),
+}
+
+cdef _register_one_field(
+    int32_t type_index,
+    object py_field,
+    int64_t offset,
+    int64_t size,
+    int64_t alignment,
+    int32_t field_type_index,
+    TVMFFIFieldGetter getter,
+    CObject setter_fn,
+):
+    """Build a TVMFFIFieldInfo and register it for the given type."""
+    cdef TVMFFIFieldInfo info
+    cdef int c_api_ret_code
+
+    # --- name ---
+    name_bytes = c_str(py_field.name)
+    cdef ByteArrayArg name_arg = ByteArrayArg(name_bytes)
+    info.name = name_arg.cdata
+
+    # --- doc ---
+    cdef ByteArrayArg doc_arg
+    if py_field.doc is not None:
+        doc_bytes = c_str(py_field.doc)
+        doc_arg = ByteArrayArg(doc_bytes)
+        info.doc = doc_arg.cdata
+    else:
+        info.doc.data = NULL
+        info.doc.size = 0
+
+    # --- metadata (JSON with type_schema) ---
+    metadata_str = json.dumps({"type_schema": py_field.ty.to_json()})
+    metadata_bytes = c_str(metadata_str)
+    cdef ByteArrayArg metadata_arg = ByteArrayArg(metadata_bytes)
+    info.metadata = metadata_arg.cdata
+
+    # --- flags ---
+    cdef int64_t flags = kTVMFFIFieldFlagBitMaskWritable | 
kTVMFFIFieldFlagBitSetterIsFunctionObj
+    if py_field.default is not MISSING or py_field.default_factory is not 
MISSING:
+        flags |= kTVMFFIFieldFlagBitMaskHasDefault
+    if py_field.default_factory is not MISSING:
+        flags |= kTVMFFIFieldFlagBitMaskDefaultFromFactory
+    if not py_field.init:
+        flags |= kTVMFFIFieldFlagBitMaskInitOff
+    if not py_field.repr:
+        flags |= kTVMFFIFieldFlagBitMaskReprOff
+    if not py_field.hash:
+        flags |= kTVMFFIFieldFlagBitMaskHashOff
+    if not py_field.compare:
+        flags |= kTVMFFIFieldFlagBitMaskCompareOff
+    if py_field.kw_only:
+        flags |= kTVMFFIFieldFlagBitMaskKwOnly
+    info.flags = flags
+
+    # --- native layout ---
+    info.size = size
+    info.alignment = alignment
+    info.offset = offset
+
+    # --- getter / setter ---
+    info.getter = getter
+    info.setter = <void*>setter_fn.chandle
+
+    # --- default value ---
+    cdef TVMFFIAny default_any
+    default_any.type_index = kTVMFFINone
+    default_any.v_int64 = 0
+    # Determine which Python object (if any) to store as the default.
+    # No memory leak: TVMFFIAny is a POD struct; TVMFFITypeRegisterField
+    # copies the bytes into the type table, which owns the reference.
+    cdef object default_obj = MISSING
+    if py_field.default is not MISSING:
+        default_obj = py_field.default
+    elif py_field.default_factory is not MISSING:
+        default_obj = py_field.default_factory
+    if default_obj is not MISSING:
+        TVMFFIPyPyObjectToFFIAny(
+            TVMFFIPyArgSetterFactory_,
+            <PyObject*>default_obj,
+            &default_any,
+            &c_api_ret_code
+        )
+        CHECK_CALL(c_api_ret_code)
+    info.default_value_or_factory = default_any
+
+    # --- field_static_type_index ---
+    info.field_static_type_index = field_type_index
+
+    CHECK_CALL(TVMFFITypeRegisterField(type_index, &info))
+
+
+cdef int _f_type_convert(void* type_converter, const TVMFFIAny* value, 
TVMFFIAny* result) noexcept with gil:
+    """C callback for type conversion, called from C++ MakeFieldSetter.
+
+    Parameters
+    ----------
+    type_converter : void*
+        A PyObject* pointing to a _TypeConverter instance (borrowed reference).
+    value : const TVMFFIAny*
+        The packed value to convert (borrowed from the caller).
+    result : TVMFFIAny*
+        Output: the converted value (caller takes ownership).
+
+    Returns 0 on success, -1 on error (error stored in TLS via 
set_last_ffi_error).
+    """
+    cdef TVMFFIAny temp
+    cdef _TypeConverter conv
+    cdef CAny cany
+    try:
+        # Unpack the packed AnyView to a Python object.
+        # We must IncRef if it's an object, because make_ret takes ownership.
+        temp = value[0]
+        if temp.type_index >= kTVMFFIStaticObjectBegin:
+            if temp.v_obj != NULL:
+                TVMFFIObjectIncRef(<TVMFFIObjectHandle>temp.v_obj)
+        py_value = make_ret(temp)
+        # Dispatch directly through the C-level converter
+        conv = <_TypeConverter>type_converter
+        cany = _type_convert_impl(conv, py_value)
+        # Transfer ownership from CAny to result (zero cany to prevent 
double-free)
+        result[0] = cany.cdata
+        cany.cdata.type_index = kTVMFFINone
+        cany.cdata.v_int64 = 0
+        return 0
+    except Exception as err:
+        set_last_ffi_error(err)
+        return -1
+
+
+def _register_fields(type_info, fields):
+    """Register Field descriptors for a Python-defined type and set up 
__ffi_new__/__ffi_init__.
+
+    For each Field:
+    1. Computes native layout (size, alignment, offset)
+    2. Obtains a C getter function pointer
+    3. Creates a FunctionObj setter with type conversion
+    4. Registers via TVMFFITypeRegisterField
+
+    After all fields, registers __ffi_new__ (object allocator) and
+    __ffi_init__ (auto-generated constructor).
+
+    Parameters
+    ----------
+    type_info : TypeInfo
+        The TypeInfo of the type being defined.
+    fields : list[Field]
+        The Field descriptors to register.
+
+    Returns
+    -------
+    list[TypeField]
+        The registered field descriptors.
+    """
+    cdef int32_t type_index = type_info.type_index
+    # Start field offsets AFTER all parent fields (not at fixed offset 24).
+    # This is critical for inheritance: child fields must not overlap parent 
memory.
+    cdef int64_t current_offset = type_info.parent_type_info.total_size
+    cdef int64_t size, alignment
+    cdef int32_t field_type_index
+    cdef TVMFFIFieldGetter getter
+    cdef FieldGetter fgetter
+    cdef FieldSetter fsetter
+
+    # Get global functions
+    _get_field_getter = _get_global_func("ffi.GetFieldGetter", False)
+    _make_field_setter = _get_global_func("ffi.MakeFieldSetter", False)
+    _make_ffi_new = _get_global_func("ffi.MakeFFINew", False)
+    _register_auto_init = _get_global_func("ffi.RegisterAutoInit", False)
+
+    cdef list type_fields = []
+
+    for py_field in fields:
+        # 1. Get layout
+        layout = _ORIGIN_NATIVE_LAYOUT.get(py_field.ty.origin, (8, 8, 
kTVMFFIObject))
+        size = layout[0]
+        alignment = layout[1]
+        field_type_index = layout[2]
+
+        # 2. Compute offset (align up)
+        current_offset = (current_offset + alignment - 1) & ~(alignment - 1)
+        field_offset = current_offset
+        current_offset += size
+
+        # 3. Get getter (C function pointer) and setter (FunctionObj).
+        # Pointers are transported as int64_t through the FFI boundary.
+        getter = 
<TVMFFIFieldGetter><int64_t>_get_field_getter(field_type_index)
+        setter_fn = <CObject>_make_field_setter(
+            field_type_index,
+            <int64_t><void*>py_field.ty._converter,
+            <int64_t>&_f_type_convert,
+        )
+
+        # 4. Register field in the C type table
+        _register_one_field(
+            type_index, py_field, field_offset, size, alignment,
+            field_type_index, getter, setter_fn,
+        )
+
+        # 5. Build the Python-side TypeField descriptor
+        fgetter = FieldGetter.__new__(FieldGetter)
+        fgetter.getter = getter
+        fgetter.offset = field_offset
+        fsetter = FieldSetter.__new__(FieldSetter)
+        fsetter.setter = <void*>setter_fn.chandle
+        fsetter.offset = field_offset
+        fsetter.flags = <int64_t>(kTVMFFIFieldFlagBitMaskWritable | 
kTVMFFIFieldFlagBitSetterIsFunctionObj)
+        type_fields.append(
+            TypeField(
+                name=py_field.name,
+                doc=py_field.doc,
+                size=size,
+                offset=field_offset,
+                frozen=False,
+                metadata={"type_schema": py_field.ty.to_json()},
+                getter=fgetter,
+                setter=fsetter,
+                ty=py_field.ty,
+                c_init=py_field.init,
+                c_kw_only=py_field.kw_only,
+                c_has_default=(py_field.default is not MISSING or 
py_field.default_factory is not MISSING),
+            )
+        )
+
+    # Align total size to 8 bytes
+    cdef int64_t total_size = (current_offset + 7) & ~7
+    if total_size < sizeof(TVMFFIObject):
+        total_size = sizeof(TVMFFIObject)
+
+    # 7. Register __ffi_new__ + deleter
+    _make_ffi_new(type_index, total_size)
+
+    # 8. Register __ffi_init__ (auto-generated constructor)
+    _register_auto_init(type_index)
+
+    return type_fields
+
 
 def _member_method_wrapper(method_func: Callable[..., Any]) -> Callable[..., 
Any]:
     def wrapper(self: Any, *args: Any) -> Any:
diff --git a/python/tvm_ffi/dataclasses/field.py 
b/python/tvm_ffi/dataclasses/field.py
new file mode 100644
index 00000000..9295f13a
--- /dev/null
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Field descriptor and ``field()`` helper for Python-defined TVM-FFI types."""
+
+from __future__ import annotations
+
+import sys
+from collections.abc import Callable
+from typing import Any
+
+from ..core import MISSING, TypeSchema
+
+# Re-export the stdlib KW_ONLY sentinel so type checkers recognise
+# ``_: KW_ONLY`` as a keyword-only boundary rather than a real field.
+# dataclasses.KW_ONLY was added in Python 3.10; on older runtimes we
+# define a class sentinel (a class, not an instance, so that ``_: KW_ONLY``
+# is a valid type annotation for static analysers targeting 3.9).
+if sys.version_info >= (3, 10):
+    from dataclasses import KW_ONLY
+else:
+
+    class KW_ONLY:
+        """Sentinel type: annotations after ``_: KW_ONLY`` are keyword-only."""
+
+
+class Field:
+    """Descriptor for a single field in a Python-defined TVM-FFI type.
+
+    When constructed directly (low-level API), *name* and *ty* should be
+    provided.  When returned by :func:`field` (``@py_class`` workflow),
+    *name* and *ty* are ``None`` and filled in by the decorator.
+
+    Parameters
+    ----------
+    name : str | None
+        The field name.  ``None`` when created via :func:`field`; filled
+        in by the ``@py_class`` decorator.
+    ty : TypeSchema | None
+        The type schema.  ``None`` when created via :func:`field`; filled
+        in by the ``@py_class`` decorator.
+    default : object
+        Default value for the field. Mutually exclusive with *default_factory*.
+        ``MISSING`` when not set.
+    default_factory : Callable[[], object] | None
+        A zero-argument callable that produces the default value.
+        Mutually exclusive with *default*.  ``None`` when not set.
+    init : bool
+        Whether this field appears in the auto-generated ``__init__``.
+    repr : bool
+        Whether this field appears in ``__repr__`` output.
+    hash : bool | None
+        Whether this field participates in recursive hashing.
+        ``None`` means "follow *compare*" (the native dataclass default).
+    compare : bool
+        Whether this field participates in recursive comparison.
+    kw_only : bool | None
+        Whether this field is keyword-only in ``__init__``.
+        ``None`` means "inherit from the decorator-level *kw_only* flag".
+    doc : str | None
+        Optional docstring for the field.
+
+    """
+
+    __slots__ = (
+        "compare",
+        "default",
+        "default_factory",
+        "doc",
+        "hash",
+        "init",
+        "kw_only",
+        "name",
+        "repr",
+        "ty",
+    )
+    name: str | None
+    ty: TypeSchema | None
+    default: object
+    default_factory: Callable[[], object] | None
+    init: bool
+    repr: bool
+    hash: bool | None
+    compare: bool
+    kw_only: bool | None
+    doc: str | None
+
+    def __init__(
+        self,
+        name: str | None = None,
+        ty: TypeSchema | None = None,
+        *,
+        default: object = MISSING,
+        default_factory: Callable[[], object] | None = MISSING,  # type: 
ignore[assignment]
+        init: bool = True,
+        repr: bool = True,
+        hash: bool | None = True,
+        compare: bool = False,
+        kw_only: bool | None = False,
+        doc: str | None = None,
+    ) -> None:
+        # MISSING means "parameter not provided".
+        # An explicit None from the user fails the callable() check,
+        # matching stdlib dataclasses semantics.
+        if default_factory is not MISSING:
+            if default is not MISSING:
+                raise ValueError("cannot specify both default and 
default_factory")
+            if not callable(default_factory):
+                raise TypeError(
+                    f"default_factory must be a callable, got 
{type(default_factory).__name__}"
+                )
+        self.name = name
+        self.ty = ty
+        self.default = default
+        self.default_factory = default_factory
+        self.init = init
+        self.repr = repr
+        self.hash = hash
+        self.compare = compare
+        self.kw_only = kw_only
+        self.doc = doc
+
+
+def field(
+    *,
+    default: object = MISSING,
+    default_factory: Callable[[], object] | None = MISSING,  # type: 
ignore[assignment]
+    init: bool = True,
+    repr: bool = True,
+    hash: bool | None = None,
+    compare: bool = True,
+    kw_only: bool | None = None,
+    doc: str | None = None,
+) -> Any:
+    """Customize a field in a ``@py_class``-decorated class.
+
+    Returns a :class:`Field` sentinel whose *name* and *ty* are
+    ``None``.  The ``@py_class`` decorator fills them in later
+    from the class annotations.
+
+    The return type is ``Any`` because ``dataclass_transform`` field
+    specifiers must be assignable to any annotated type (e.g.
+    ``x: int = field(default=0)``).
+
+    Parameters
+    ----------
+    default
+        Default value for the field.  Mutually exclusive with 
*default_factory*.
+    default_factory
+        A zero-argument callable that produces the default value.
+        Mutually exclusive with *default*.
+    init
+        Whether this field appears in the auto-generated ``__init__``.
+    repr
+        Whether this field appears in ``__repr__`` output.
+    hash
+        Whether this field participates in recursive hashing.
+        ``None`` (default) means "follow *compare*".
+    compare
+        Whether this field participates in recursive comparison.
+    kw_only
+        Whether this field is keyword-only in ``__init__``.
+        ``None`` means "inherit from the decorator-level ``kw_only`` flag".
+    doc
+        Optional docstring for the field.
+
+    Returns
+    -------
+    Any
+        A :class:`Field` sentinel recognised by ``@py_class``.
+
+    Examples
+    --------
+    .. code-block:: python
+
+        @py_class
+        class Point(Object):
+            x: float
+            y: float = field(default=0.0, repr=False)
+
+    """
+    return Field(
+        default=default,
+        default_factory=default_factory,
+        init=init,
+        repr=repr,
+        hash=hash,
+        compare=compare,
+        kw_only=kw_only,
+        doc=doc,
+    )
diff --git a/python/tvm_ffi/registry.py b/python/tvm_ffi/registry.py
index 82540494..8a84bff1 100644
--- a/python/tvm_ffi/registry.py
+++ b/python/tvm_ffi/registry.py
@@ -356,16 +356,30 @@ def _make_init(type_cls: type, type_info: TypeInfo) -> 
Callable[..., None]:
     """
     sig = _make_init_signature(type_info)
     kwargs_obj = core.KWARGS
+    has_post_init = hasattr(type_cls, "__post_init__")
 
-    def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
-        ffi_args: list[Any] = list(args)
-        ffi_args.append(kwargs_obj)
-        for key, val in kwargs.items():
-            ffi_args.append(key)
-            ffi_args.append(val)
-        self.__ffi_init__(*ffi_args)
-
-    __init__.__signature__ = sig  # ty: ignore[unresolved-attribute]
+    if has_post_init:
+
+        def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+            ffi_args: list[Any] = list(args)
+            ffi_args.append(kwargs_obj)
+            for key, val in kwargs.items():
+                ffi_args.append(key)
+                ffi_args.append(val)
+            self.__ffi_init__(*ffi_args)
+            self.__post_init__()
+
+    else:
+
+        def __init__(self: Any, *args: Any, **kwargs: Any) -> None:
+            ffi_args: list[Any] = list(args)
+            ffi_args.append(kwargs_obj)
+            for key, val in kwargs.items():
+                ffi_args.append(key)
+                ffi_args.append(val)
+            self.__ffi_init__(*ffi_args)
+
+    __init__.__signature__ = sig  # ty: ignore[invalid-assignment]
     __init__.__qualname__ = f"{type_cls.__qualname__}.__init__"
     __init__.__module__ = type_cls.__module__
     return __init__
@@ -666,7 +680,19 @@ def _install_dataclass_dunders(
         dunders["__gt__"] = __gt__
         dunders["__ge__"] = __ge__
 
+    # Install dunders respecting user-defined overrides.
+    # Semantic families (__eq__/__ne__, __lt__/__le__/__gt__/__ge__) are
+    # treated as a unit: if the user defines any member, the whole family
+    # is skipped so generated and user-defined methods don't disagree.
+    _eq_family = {"__eq__", "__ne__"}
+    _order_family = {"__lt__", "__le__", "__gt__", "__ge__"}
+    skip_eq = bool(_eq_family & set(cls.__dict__))
+    skip_order = bool(_order_family & set(cls.__dict__))
     for name, impl in dunders.items():
+        if name in _eq_family and skip_eq:
+            continue
+        if name in _order_family and skip_order:
+            continue
         if name not in cls.__dict__:
             setattr(cls, name, impl)
 

Reply via email to