This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new be9a827  fix(py_class): support `tvm_ffi.dtype` and `tvm_ffi.Device` 
as field type annotations (#540)
be9a827 is described below

commit be9a8274623252b48ac252961f5968255bf51114
Author: Junru Shao <[email protected]>
AuthorDate: Sun Apr 12 13:05:04 2026 -0700

    fix(py_class): support `tvm_ffi.dtype` and `tvm_ffi.Device` as field type 
annotations (#540)
    
    ## Summary
    
    - `TypeSchema.from_annotation()` in the Cython layer only recognized the
    C-level `DataType`/`Device` cdef classes for dtype/device annotations
    (`annotation is DataType`). The public Python wrapper classes —
    `tvm_ffi.dtype` (`class dtype(str)` in `_dtype.py`) and `tvm_ffi.Device`
    — are distinct types, so using them as `@py_class` field type
    annotations raised `TypeError: Cannot convert <class '...'> to
    TypeSchema`.
    - The fix extends the identity checks to also match the Python wrapper
    classes via the existing `_CLASS_DTYPE` / `_CLASS_DEVICE` module-level
    sentinels.
    - Added 6 regression tests covering dtype fields, Device fields,
    combined usage, setter mutation, and `Optional` variants.
    
    ## Test plan
    
    - [x] `uv run pytest -vvs tests/python/test_dataclass_py_class.py` — all
    342 tests pass (including 6 new `TestDtypeDeviceFields` tests)
    - [ ] CI: lint, C++ tests, full Python test suite, Rust tests
    
    🤖 Generated with [Claude Code](https://claude.com/claude-code)
---
 python/tvm_ffi/cython/type_info.pxi     |  4 +--
 tests/python/test_dataclass_py_class.py | 64 +++++++++++++++++++++++++++++++++
 2 files changed, 66 insertions(+), 2 deletions(-)

diff --git a/python/tvm_ffi/cython/type_info.pxi 
b/python/tvm_ffi/cython/type_info.pxi
index 67be0d8..118e067 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -404,9 +404,9 @@ class TypeSchema:
                 return TypeSchema("bytes")
 
         # --- Non-CObject cdef classes with known origins ---
-        if annotation is DataType:
+        if annotation is DataType or (_CLASS_DTYPE is not None and annotation 
is _CLASS_DTYPE):
             return TypeSchema("dtype")
-        if annotation is Device:
+        if annotation is Device or (_CLASS_DEVICE is not None and annotation 
is _CLASS_DEVICE):
             return TypeSchema("Device")
 
         # --- ctypes.c_void_p ---
diff --git a/tests/python/test_dataclass_py_class.py 
b/tests/python/test_dataclass_py_class.py
index 57a7c1c..a9c5358 100644
--- a/tests/python/test_dataclass_py_class.py
+++ b/tests/python/test_dataclass_py_class.py
@@ -4969,3 +4969,67 @@ class TestSuperInitPattern:
         assert obj3.x == 42
         assert obj3.y == "hello"
         assert not obj.same_as(obj3)
+
+
+class TestDtypeDeviceFields:
+    """Regression: @py_class should accept tvm_ffi.dtype and tvm_ffi.Device as 
field types."""
+
+    def test_dtype_field(self) -> None:
+        @py_class(_unique_key("DtypeField"))
+        class DtypeHolder(Object):
+            dt: tvm_ffi.dtype
+
+        obj = DtypeHolder(dt=tvm_ffi.dtype("float32"))
+        assert obj.dt == "float32"
+        assert isinstance(obj.dt, tvm_ffi.dtype)
+
+    def test_dtype_field_setter(self) -> None:
+        @py_class(_unique_key("DtypeFieldSet"))
+        class DtypeHolder2(Object):
+            dt: tvm_ffi.dtype
+
+        obj = DtypeHolder2(dt=tvm_ffi.dtype("float32"))
+        obj.dt = tvm_ffi.dtype("int8")
+        assert obj.dt == "int8"
+
+    def test_device_field(self) -> None:
+        @py_class(_unique_key("DeviceField"))
+        class DeviceHolder(Object):
+            dev: tvm_ffi.Device
+
+        dev = tvm_ffi.device("cpu", 0)
+        obj = DeviceHolder(dev=dev)
+        assert obj.dev == dev
+
+    def test_dtype_device_together(self) -> None:
+        @py_class(_unique_key("DtypeDeviceTogether"))
+        class DtypeDeviceHolder(Object):
+            dt: tvm_ffi.dtype
+            dev: tvm_ffi.Device
+            name: str
+
+        dev = tvm_ffi.device("cpu", 0)
+        obj = DtypeDeviceHolder(dt=tvm_ffi.dtype("float16"), dev=dev, 
name="test")
+        assert obj.dt == "float16"
+        assert obj.dev == dev
+        assert obj.name == "test"
+
+    def test_optional_dtype_field(self) -> None:
+        @py_class(_unique_key("OptDtype"))
+        class OptDtype(Object):
+            dt: Optional[tvm_ffi.dtype] = None
+
+        obj_none = OptDtype()
+        assert obj_none.dt is None
+        obj_val = OptDtype(dt=tvm_ffi.dtype("bfloat16"))
+        assert obj_val.dt == "bfloat16"
+
+    def test_optional_device_field(self) -> None:
+        @py_class(_unique_key("OptDevice"))
+        class OptDevice(Object):
+            dev: Optional[tvm_ffi.Device] = None
+
+        obj_none = OptDevice()
+        assert obj_none.dev is None
+        obj_val = OptDevice(dev=tvm_ffi.device("cpu", 0))
+        assert obj_val.dev == tvm_ffi.device("cpu", 0)

Reply via email to