This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new ef511d4 feat(py_class): support `frozen=True` for immutable instances
(#542)
ef511d4 is described below
commit ef511d41402fed19927f663da90315b18db20cc8
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 13 01:00:22 2026 -0700
feat(py_class): support `frozen=True` for immutable instances (#542)
## Summary
- Add `frozen` parameter to `@py_class()` decorator and `field()`
function, mirroring Python's `dataclasses.dataclass(frozen=True)`
semantics.
- **Class-level frozen**: `@py_class(frozen=True)` installs
`__setattr__`/`__delattr__` guards raising `FrozenInstanceError`. Frozen
classes auto-get `__hash__` (safely hashable).
- **Field-level frozen**: `field(frozen=True)` sets
`TypeField.frozen=True` (`fset=None`), independent of class-level
frozen.
- `__replace__` (`copy.replace`) uses direct `FieldSetter` for
`py_class` (bypasses both frozen mechanisms) while `c_class` respects
C++ readonly fields via `object.__setattr__`.
## Test plan
- [x] 101 new tests in `tests/python/test_dataclass_frozen.py` covering:
- Class-level frozen basic behavior and error messages
- Field-level frozen (individual field override)
- Inheritance (frozen parent + mutable child, mixed)
- Interactions with `eq`, `hash`, `order`, `copy`, `replace`
- `object.__setattr__` escape hatch
- `FrozenInstanceError` is `AttributeError` subclass
- [x] Full test suite: 2105 passed, 0 failed, 38 skipped, 3 xfailed
---
python/tvm_ffi/_dunder.py | 3 +-
python/tvm_ffi/cython/type_info.pxi | 27 +-
python/tvm_ffi/dataclasses/field.py | 13 +
python/tvm_ffi/dataclasses/py_class.py | 18 +-
tests/python/test_cubin_launcher.py | 2 +-
tests/python/test_dataclass_copy.py | 10 +-
tests/python/test_dataclass_frozen.py | 650 +++++++++++++++++++++++++++++++++
tests/python/test_dtype.py | 2 +-
tests/python/test_error.py | 2 +-
tests/python/test_load_inline.py | 2 +-
tests/python/utils/test_embed_cubin.py | 6 +-
11 files changed, 717 insertions(+), 18 deletions(-)
diff --git a/python/tvm_ffi/_dunder.py b/python/tvm_ffi/_dunder.py
index 486666d..e15f394 100644
--- a/python/tvm_ffi/_dunder.py
+++ b/python/tvm_ffi/_dunder.py
@@ -193,8 +193,9 @@ def _make_replace(_type_info: TypeInfo) -> Callable[...,
Any]:
def __replace__(self: Any, **kwargs: Any) -> Any:
obj = copy_copy(self)
+ cls = type(obj)
for key, value in kwargs.items():
- setattr(obj, key, value)
+ getattr(cls, key).set(obj, value)
return obj
return __replace__
diff --git a/python/tvm_ffi/cython/type_info.pxi
b/python/tvm_ffi/cython/type_info.pxi
index 1d3c10c..a8c3cd3 100644
--- a/python/tvm_ffi/cython/type_info.pxi
+++ b/python/tvm_ffi/cython/type_info.pxi
@@ -593,6 +593,24 @@ def _annotation_cobject(cls, targs):
return TypeSchema(origin, origin_type_index=info.type_index)
+class FFIProperty(property):
+ """Property descriptor for FFI-backed fields.
+
+ When *frozen* is True the public setter (``fset``) is suppressed so
+ that normal attribute assignment raises ``AttributeError``. The
+ real setter is stashed in :attr:`_fset` and exposed via the
+ :meth:`set` escape-hatch.
+ """
+
+ def __init__(self, fget, fset, frozen, fdel=None, doc=None):
+ super().__init__(fget, None if frozen else fset, fdel, doc)
+ self._fset = fset
+
+ def set(self, obj, value):
+ """Force-set the field value, bypassing the frozen guard."""
+ self._fset(obj, value)
+
+
@dataclasses.dataclass(eq=False)
class TypeField:
"""Description of a single reflected field on an FFI-backed type."""
@@ -616,7 +634,7 @@ class TypeField:
assert self.getter is not None
def as_property(self, object cls):
- """Create a Python ``property`` object for this field on ``cls``."""
+ """Create an :class:`FFIProperty` descriptor for this field on
``cls``."""
cdef str name = self.name
cdef FieldGetter fget = self.getter
cdef FieldSetter fset = self.setter
@@ -624,9 +642,10 @@ class TypeField:
fget.__name__ = fset.__name__ = name
fget.__module__ = fset.__module__ = cls.__module__
fget.__qualname__ = fset.__qualname__ = f"{cls.__qualname__}.{name}"
- ret = property(
+ ret = FFIProperty(
fget=fget,
- fset=fset if (not self.frozen) else None,
+ fset=fset,
+ frozen=self.frozen,
)
if self.doc:
ret.__doc__ = self.doc
@@ -1003,7 +1022,7 @@ def _register_fields(type_info, fields,
structure_kind=None):
doc=py_field.doc,
size=size,
offset=field_offset,
- frozen=False,
+ frozen=py_field.frozen,
metadata={"type_schema": py_field.ty.to_json()},
getter=fgetter,
setter=fsetter,
diff --git a/python/tvm_ffi/dataclasses/field.py
b/python/tvm_ffi/dataclasses/field.py
index 3f36761..75238ee 100644
--- a/python/tvm_ffi/dataclasses/field.py
+++ b/python/tvm_ffi/dataclasses/field.py
@@ -58,6 +58,8 @@ class Field:
default_factory : Callable[[], object] | None
A zero-argument callable that produces the default value.
Mutually exclusive with *default*. ``None`` when not set.
+ frozen : bool
+ Whether this field is read-only after ``__init__``.
init : bool
Whether this field appears in the auto-generated ``__init__``.
repr : bool
@@ -91,6 +93,7 @@ class Field:
"default",
"default_factory",
"doc",
+ "frozen",
"hash",
"init",
"kw_only",
@@ -103,6 +106,7 @@ class Field:
ty: TypeSchema | None
default: object
default_factory: Callable[[], object] | None
+ frozen: bool
init: bool
repr: bool
hash: bool | None
@@ -123,6 +127,7 @@ class Field:
*,
default: object = MISSING,
default_factory: Callable[[], object] | None = MISSING, # type:
ignore[assignment]
+ frozen: bool = False,
init: bool = True,
repr: bool = True,
hash: bool | None = True,
@@ -151,6 +156,7 @@ class Field:
self.ty = ty
self.default = default
self.default_factory = default_factory
+ self.frozen = frozen
self.init = init
self.repr = repr
self.hash = hash
@@ -164,6 +170,7 @@ def field(
*,
default: object = MISSING,
default_factory: Callable[[], object] | None = MISSING, # type:
ignore[assignment]
+ frozen: bool = False,
init: bool = True,
repr: bool = True,
hash: bool | None = None,
@@ -189,6 +196,11 @@ def field(
default_factory
A zero-argument callable that produces the default value.
Mutually exclusive with *default*.
+ frozen
+ Whether this field is read-only after ``__init__``. When True,
+ the Python property descriptor has no setter; use the
+ ``type(obj).field_name.set(obj, value)`` escape hatch when
+ mutation is necessary.
init
Whether this field appears in the auto-generated ``__init__``.
repr
@@ -234,6 +246,7 @@ def field(
return Field(
default=default,
default_factory=default_factory,
+ frozen=frozen,
init=init,
repr=repr,
hash=hash,
diff --git a/python/tvm_ffi/dataclasses/py_class.py
b/python/tvm_ffi/dataclasses/py_class.py
index e6075eb..ae9bd3b 100644
--- a/python/tvm_ffi/dataclasses/py_class.py
+++ b/python/tvm_ffi/dataclasses/py_class.py
@@ -134,10 +134,11 @@ def _rollback_registration(cls: type, type_info: Any) ->
None:
# ---------------------------------------------------------------------------
-def _collect_own_fields(
+def _collect_own_fields( # noqa: PLR0912
cls: type,
hints: dict[str, Any],
decorator_kw_only: bool,
+ decorator_frozen: bool,
) -> list[Field]:
"""Parse own annotations into :class:`Field` objects.
@@ -194,6 +195,10 @@ def _collect_own_fields(
if f.kw_only is None:
f.kw_only = kw_only_active
+ # Apply class-level frozen when the field doesn't explicitly set it
+ if decorator_frozen and not f.frozen:
+ f.frozen = True
+
# Resolve hash=None → follow compare (native dataclass semantics)
if f.hash is None:
f.hash = f.compare
@@ -248,7 +253,7 @@ def _register_fields_into_type(
except (NameError, AttributeError):
return False
- own_fields = _collect_own_fields(cls, hints, params["kw_only"])
+ own_fields = _collect_own_fields(cls, hints, params["kw_only"],
params["frozen"])
py_methods = _collect_py_methods(cls)
# Register fields and type-level structural eq/hash kind with the C layer.
@@ -414,11 +419,12 @@ _FFI_RECOGNIZED_METHODS: frozenset[str] =
_FFI_TYPE_ATTR_NAMES
order_default=False,
field_specifiers=(field, Field),
)
-def py_class(
+def py_class( # noqa: PLR0913
cls_or_type_key: type | str | None = None,
/,
*,
type_key: str | None = None,
+ frozen: bool = False,
init: bool = True,
repr: bool = True,
eq: bool = False,
@@ -465,6 +471,11 @@ def py_class(
type_key
Explicit FFI type key. Auto-generated from
``{module}.{qualname}`` when omitted.
+ frozen
+ If True, all fields are read-only after ``__init__`` by default.
+ Individual fields can still be marked ``field(frozen=True)`` on a
+ non-frozen class. Use ``type(obj).field_name.set(obj, value)``
+ as an escape hatch when mutation is necessary.
init
If True (default), generate ``__init__`` from field annotations.
repr
@@ -514,6 +525,7 @@ def py_class(
effective_type_key = type_key
params: dict[str, Any] = {
+ "frozen": frozen,
"init": init,
"repr": repr,
"eq": eq,
diff --git a/tests/python/test_cubin_launcher.py
b/tests/python/test_cubin_launcher.py
index 4139679..d6f4e35 100644
--- a/tests/python/test_cubin_launcher.py
+++ b/tests/python/test_cubin_launcher.py
@@ -93,7 +93,7 @@ def _compile_kernel_to_cubin() -> bytes:
)
if result.returncode != 0:
- pytest.skip(f"nvcc not available or compilation failed:
{result.stderr}")
+ pytest.skip(f"nvcc not available or compilation failed:
{result.stderr}") # ty: ignore[invalid-argument-type,
too-many-positional-arguments]
return cubin_file.read_bytes()
diff --git a/tests/python/test_dataclass_copy.py
b/tests/python/test_dataclass_copy.py
index ba4b485..73b1287 100644
--- a/tests/python/test_dataclass_copy.py
+++ b/tests/python/test_dataclass_copy.py
@@ -934,10 +934,14 @@ class TestReplace:
obj.__replace__(v_i64=100) # ty: ignore[unresolved-attribute]
assert obj.v_i64 == 5 # ty: ignore[unresolved-attribute]
- def test_replace_readonly_field_raises(self) -> None:
+ def test_replace_readonly_field(self) -> None:
+ # __replace__ uses the FFIProperty.set() escape hatch,
+ # so it works even on frozen / read-only fields.
pair = tvm_ffi.testing.TestIntPair(3, 4)
- with pytest.raises(AttributeError):
- pair.__replace__(a=10) # ty: ignore[unresolved-attribute]
+ pair2 = pair.__replace__(a=10) # ty: ignore[unresolved-attribute]
+ assert pair2.a == 10
+ assert pair2.b == 4
+ assert pair.a == 3 # original unchanged
def test_auto_replace_for_cxx_class(self) -> None:
# _TestCxxClassBase is copy-constructible, so replace is auto-enabled
diff --git a/tests/python/test_dataclass_frozen.py
b/tests/python/test_dataclass_frozen.py
new file mode 100644
index 0000000..9e2a2fd
--- /dev/null
+++ b/tests/python/test_dataclass_frozen.py
@@ -0,0 +1,650 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for frozen support in ``@py_class``."""
+
+# ruff: noqa: D102
+from __future__ import annotations
+
+import copy
+import itertools
+from typing import Dict, List, Optional
+
+import pytest
+from tvm_ffi.core import FFIProperty, Object # ty: ignore[unresolved-import]
+from tvm_ffi.dataclasses import field, py_class
+
+_counter = itertools.count()
+
+
+def _unique_key(base: str) -> str:
+ return f"testing.frozen.{base}_{next(_counter)}"
+
+
+# ---------------------------------------------------------------------------
+# Basic frozen class
+# ---------------------------------------------------------------------------
+class TestFrozenClassBasic:
+ def test_init_works_normally(self) -> None:
+ @py_class(_unique_key("basic_init"), frozen=True)
+ class Pt(Object):
+ x: int
+ y: int
+
+ p = Pt(1, 2)
+ assert p.x == 1
+ assert p.y == 2
+
+ def test_assignment_blocked_after_init(self) -> None:
+ @py_class(_unique_key("basic_blocked"), frozen=True)
+ class Pt(Object):
+ x: int
+ y: int
+
+ p = Pt(1, 2)
+ with pytest.raises(AttributeError):
+ p.x = 10 # ty: ignore[invalid-assignment]
+
+ def test_all_fields_blocked(self) -> None:
+ @py_class(_unique_key("all_blocked"), frozen=True)
+ class Rec(Object):
+ a: int
+ b: str
+ c: float
+
+ r = Rec(1, "hi", 3.0)
+ for attr in ("a", "b", "c"):
+ with pytest.raises(AttributeError):
+ setattr(r, attr, None)
+
+ def test_reading_fields_works(self) -> None:
+ @py_class(_unique_key("read_ok"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p = Pt(42)
+ assert p.x == 42
+
+ def test_del_blocked(self) -> None:
+ @py_class(_unique_key("del_blocked"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p = Pt(1)
+ with pytest.raises(AttributeError):
+ del p.x
+
+
+# ---------------------------------------------------------------------------
+# Frozen with various field types
+# ---------------------------------------------------------------------------
+class TestFrozenFieldTypes:
+ def test_frozen_optional_field(self) -> None:
+ @py_class(_unique_key("opt_field"), frozen=True)
+ class Opt(Object):
+ v: Optional[int] # noqa: UP045
+
+ o1 = Opt(None)
+ assert o1.v is None
+ o2 = Opt(42)
+ assert o2.v == 42
+ with pytest.raises(AttributeError):
+ o1.v = 5 # ty: ignore[invalid-assignment]
+
+ def test_frozen_object_field(self) -> None:
+ @py_class(_unique_key("obj_inner"), frozen=True)
+ class Inner(Object):
+ val: int
+
+ @py_class(_unique_key("obj_outer"), frozen=True)
+ class Outer(Object):
+ child: Inner
+
+ inner = Inner(10)
+ outer = Outer(inner)
+ assert outer.child.val == 10
+ with pytest.raises(AttributeError):
+ outer.child = Inner(20) # ty: ignore[invalid-assignment]
+
+ def test_frozen_list_field(self) -> None:
+ @py_class(_unique_key("list_field"), frozen=True)
+ class HasList(Object):
+ items: List[int] # noqa: UP006
+
+ h = HasList([1, 2, 3])
+ assert len(h.items) == 3
+ with pytest.raises(AttributeError):
+ h.items = [4, 5] # ty: ignore[invalid-assignment]
+
+ def test_frozen_dict_field(self) -> None:
+ @py_class(_unique_key("dict_field"), frozen=True)
+ class HasDict(Object):
+ mapping: Dict[str, int] # noqa: UP006
+
+ h = HasDict({"a": 1})
+ assert h.mapping["a"] == 1
+ with pytest.raises(AttributeError):
+ h.mapping = {"b": 2} # ty: ignore[invalid-assignment]
+
+
+# ---------------------------------------------------------------------------
+# Frozen with defaults
+# ---------------------------------------------------------------------------
+class TestFrozenDefaults:
+ def test_frozen_field_with_default(self) -> None:
+ @py_class(_unique_key("def_val"), frozen=True)
+ class Cfg(Object):
+ x: int = 10
+
+ c = Cfg()
+ assert c.x == 10
+ with pytest.raises(AttributeError):
+ c.x = 20 # ty: ignore[invalid-assignment]
+
+ def test_frozen_field_with_default_factory(self) -> None:
+ @py_class(_unique_key("def_factory"), frozen=True)
+ class Cfg(Object):
+ items: List[int] = field(default_factory=list) # noqa: UP006
+
+ c = Cfg()
+ assert len(c.items) == 0
+ with pytest.raises(AttributeError):
+ c.items = [1] # ty: ignore[invalid-assignment]
+
+ def test_frozen_init_false_with_default(self) -> None:
+ @py_class(_unique_key("init_false_def"), frozen=True)
+ class Cfg(Object):
+ x: int
+ tag: str = field(default="default", init=False)
+
+ c = Cfg(5)
+ assert c.tag == "default"
+ with pytest.raises(AttributeError):
+ c.tag = "other" # ty: ignore[invalid-assignment]
+
+
+# ---------------------------------------------------------------------------
+# Per-field frozen on a non-frozen class
+# ---------------------------------------------------------------------------
+class TestPerFieldFrozen:
+ def test_field_frozen_true_on_mutable_class(self) -> None:
+ @py_class(_unique_key("per_field"))
+ class Rec(Object):
+ mutable: int
+ immutable: int = field(frozen=True)
+
+ r = Rec(1, 2)
+ r.mutable = 10 # OK
+ assert r.mutable == 10
+ with pytest.raises(AttributeError):
+ r.immutable = 20
+
+ def test_field_frozen_true_init_works(self) -> None:
+ @py_class(_unique_key("per_field_init"))
+ class Rec(Object):
+ x: int = field(frozen=True)
+
+ r = Rec(42)
+ assert r.x == 42
+
+ def test_multiple_mixed_fields(self) -> None:
+ @py_class(_unique_key("mixed_fields"))
+ class Rec(Object):
+ a: int = field(frozen=True)
+ b: int
+ c: int = field(frozen=True)
+
+ r = Rec(1, 2, 3)
+ r.b = 20 # OK
+ with pytest.raises(AttributeError):
+ r.a = 10
+ with pytest.raises(AttributeError):
+ r.c = 30
+
+
+# ---------------------------------------------------------------------------
+# Escape hatch: type(obj).field.set(obj, val)
+# ---------------------------------------------------------------------------
+class TestEscapeHatch:
+ def test_escape_hatch_sets_frozen_field(self) -> None:
+ @py_class(_unique_key("esc_basic"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p = Pt(1)
+ type(p).x.set(p, 99) # ty: ignore[unresolved-attribute]
+ assert p.x == 99
+
+ def test_escape_hatch_on_field_level_frozen(self) -> None:
+ @py_class(_unique_key("esc_field"))
+ class Rec(Object):
+ val: int = field(frozen=True)
+
+ r = Rec(5)
+ type(r).val.set(r, 50) # ty: ignore[unresolved-attribute]
+ assert r.val == 50
+
+ def test_escape_hatch_multiple_fields(self) -> None:
+ @py_class(_unique_key("esc_multi"), frozen=True)
+ class Pt(Object):
+ x: int
+ y: int
+
+ p = Pt(1, 2)
+ type(p).x.set(p, 10) # ty: ignore[unresolved-attribute]
+ type(p).y.set(p, 20) # ty: ignore[unresolved-attribute]
+ assert p.x == 10
+ assert p.y == 20
+
+ def test_escape_hatch_preserves_type_coercion(self) -> None:
+ @py_class(_unique_key("esc_coerce"), frozen=True)
+ class HasList(Object):
+ items: List[int] # noqa: UP006
+
+ h = HasList([1])
+ type(h).items.set(h, [10, 20]) # ty: ignore[unresolved-attribute]
+ assert len(h.items) == 2
+
+ def test_regular_setattr_still_blocked_after_escape_hatch(self) -> None:
+ @py_class(_unique_key("esc_still_frozen"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p = Pt(1)
+ type(p).x.set(p, 99) # ty: ignore[unresolved-attribute]
+ with pytest.raises(AttributeError):
+ p.x = 100 # ty: ignore[invalid-assignment]
+
+ def test_escape_hatch_on_mutable_field_also_works(self) -> None:
+ @py_class(_unique_key("esc_mutable"))
+ class Pt(Object):
+ x: int
+
+ p = Pt(1)
+ type(p).x.set(p, 99) # ty: ignore[unresolved-attribute]
+ assert p.x == 99
+
+
+# ---------------------------------------------------------------------------
+# copy / deepcopy / __replace__
+# ---------------------------------------------------------------------------
+class TestFrozenCopy:
+ def test_copy_copy_frozen_class(self) -> None:
+ @py_class(_unique_key("copy_basic"), frozen=True)
+ class Pt(Object):
+ x: int
+ y: int
+
+ p = Pt(1, 2)
+ p2 = copy.copy(p)
+ assert p2.x == 1 and p2.y == 2
+ assert not p.same_as(p2)
+
+ def test_copy_copy_result_is_also_frozen(self) -> None:
+ @py_class(_unique_key("copy_frozen"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p2 = copy.copy(Pt(1))
+ with pytest.raises(AttributeError):
+ p2.x = 10 # ty: ignore[invalid-assignment]
+
+ def test_deepcopy_frozen_class(self) -> None:
+ @py_class(_unique_key("deepcopy"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p = Pt(42)
+ p2 = copy.deepcopy(p)
+ assert p2.x == 42
+ assert not p.same_as(p2)
+
+ def test_deepcopy_result_is_also_frozen(self) -> None:
+ @py_class(_unique_key("deepcopy_frozen"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p2 = copy.deepcopy(Pt(1))
+ with pytest.raises(AttributeError):
+ p2.x = 10 # ty: ignore[invalid-assignment]
+
+ def test_replace_on_frozen_class(self) -> None:
+ @py_class(_unique_key("replace"), frozen=True)
+ class Pt(Object):
+ x: int
+ y: int
+
+ p = Pt(1, 2)
+ p2 = p.__replace__(x=10) # ty: ignore[unresolved-attribute]
+ assert p2.x == 10 and p2.y == 2
+ assert p.x == 1 # original unchanged
+
+ def test_replace_multiple_fields(self) -> None:
+ @py_class(_unique_key("replace_multi"), frozen=True)
+ class Pt(Object):
+ x: int
+ y: int
+
+ p = Pt(1, 2)
+ p2 = p.__replace__(x=10, y=20) # ty: ignore[unresolved-attribute]
+ assert p2.x == 10 and p2.y == 20
+
+ def test_replace_result_is_frozen(self) -> None:
+ @py_class(_unique_key("replace_frozen"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ p2 = Pt(1).__replace__(x=10) # ty: ignore[unresolved-attribute]
+ with pytest.raises(AttributeError):
+ p2.x = 99
+
+
+# ---------------------------------------------------------------------------
+# Inheritance
+# ---------------------------------------------------------------------------
+class TestFrozenInheritance:
+ def test_frozen_parent_mutable_child(self) -> None:
+ @py_class(_unique_key("inh_parent_frozen"), frozen=True)
+ class Parent(Object):
+ a: int
+
+ @py_class(_unique_key("inh_child_mutable"))
+ class Child(Parent): # ty: ignore[invalid-frozen-dataclass-subclass]
+ b: int
+
+ c = Child(1, 2)
+ assert c.a == 1 and c.b == 2
+ # Parent field stays frozen (property installed by Parent class)
+ with pytest.raises(AttributeError):
+ c.a = 10 # ty: ignore[invalid-assignment]
+ # Child's own field is mutable
+ c.b = 20 # ty: ignore[invalid-assignment]
+ assert c.b == 20
+
+ def test_frozen_parent_frozen_child(self) -> None:
+ @py_class(_unique_key("inh_both_frozen_p"), frozen=True)
+ class Parent(Object):
+ a: int
+
+ @py_class(_unique_key("inh_both_frozen_c"), frozen=True)
+ class Child(Parent):
+ b: int
+
+ c = Child(1, 2)
+ with pytest.raises(AttributeError):
+ c.a = 10 # ty: ignore[invalid-assignment]
+ with pytest.raises(AttributeError):
+ c.b = 20 # ty: ignore[invalid-assignment]
+
+ def test_mutable_parent_frozen_child(self) -> None:
+ @py_class(_unique_key("inh_parent_mutable"))
+ class Parent(Object):
+ a: int
+
+ @py_class(_unique_key("inh_child_frozen"), frozen=True)
+ class Child(Parent): # ty: ignore[invalid-frozen-dataclass-subclass]
+ b: int
+
+ c = Child(1, 2)
+ # Parent field is mutable (property installed by Parent class)
+ c.a = 10 # ty: ignore[invalid-assignment]
+ assert c.a == 10
+ # Child's own field is frozen
+ with pytest.raises(AttributeError):
+ c.b = 20 # ty: ignore[invalid-assignment]
+
+ def test_three_level_frozen_inheritance(self) -> None:
+ @py_class(_unique_key("inh_l1"), frozen=True)
+ class L1(Object):
+ a: int
+
+ @py_class(_unique_key("inh_l2"), frozen=True)
+ class L2(L1):
+ b: int
+
+ @py_class(_unique_key("inh_l3"), frozen=True)
+ class L3(L2):
+ c: int
+
+ obj = L3(1, 2, 3)
+ assert obj.a == 1 and obj.b == 2 and obj.c == 3
+ for attr in ("a", "b", "c"):
+ with pytest.raises(AttributeError):
+ setattr(obj, attr, 99)
+
+ def test_escape_hatch_on_inherited_frozen_field(self) -> None:
+ @py_class(_unique_key("inh_esc_p"), frozen=True)
+ class Parent(Object):
+ a: int
+
+ @py_class(_unique_key("inh_esc_c"), frozen=True)
+ class Child(Parent):
+ b: int
+
+ c = Child(1, 2)
+ # Escape hatch for parent field must go through Parent class descriptor
+ Parent.a.set(c, 10) # ty: ignore[unresolved-attribute]
+ assert c.a == 10
+
+ def test_replace_on_inherited_frozen(self) -> None:
+ @py_class(_unique_key("inh_replace_p"), frozen=True)
+ class Parent(Object):
+ a: int
+
+ @py_class(_unique_key("inh_replace_c"), frozen=True)
+ class Child(Parent):
+ b: int
+
+ c = Child(1, 2)
+ c2 = c.__replace__(a=10, b=20) # ty: ignore[unresolved-attribute]
+ assert c2.a == 10 and c2.b == 20
+
+
+# ---------------------------------------------------------------------------
+# kw_only + frozen
+# ---------------------------------------------------------------------------
+class TestFrozenKwOnly:
+ def test_frozen_with_kw_only(self) -> None:
+ @py_class(_unique_key("kw_frozen"), frozen=True, kw_only=True)
+ class Cfg(Object):
+ x: int
+ y: int
+
+ c = Cfg(x=1, y=2)
+ assert c.x == 1 and c.y == 2
+ with pytest.raises(AttributeError):
+ c.x = 10 # ty: ignore[invalid-assignment]
+
+
+# ---------------------------------------------------------------------------
+# __post_init__ interaction
+# ---------------------------------------------------------------------------
+class TestFrozenPostInit:
+ def test_post_init_called(self) -> None:
+ log: list[bool] = []
+
+ @py_class(_unique_key("post_init_called"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ def __post_init__(self) -> None:
+ log.append(True)
+
+ Pt(1)
+ assert log == [True]
+
+ def test_post_init_can_read_fields(self) -> None:
+ captured: list[int] = []
+
+ @py_class(_unique_key("post_init_read"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ def __post_init__(self) -> None:
+ captured.append(self.x)
+
+ Pt(42)
+ assert captured == [42]
+
+ def test_post_init_cannot_set_frozen_fields(self) -> None:
+ @py_class(_unique_key("post_init_set"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ def __post_init__(self) -> None:
+ self.x = 999 # should fail # ty: ignore[invalid-assignment]
+
+ with pytest.raises(AttributeError):
+ Pt(1)
+
+ def test_post_init_can_use_escape_hatch(self) -> None:
+ @py_class(_unique_key("post_init_esc"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ def __post_init__(self) -> None:
+ type(self).x.set(self, self.x * 2) # ty:
ignore[unresolved-attribute]
+
+ p = Pt(5)
+ assert p.x == 10
+
+
+# ---------------------------------------------------------------------------
+# eq / hash + frozen
+# ---------------------------------------------------------------------------
+class TestFrozenEqHash:
+ def test_frozen_with_eq(self) -> None:
+ @py_class(_unique_key("eq"), frozen=True, eq=True)
+ class Pt(Object):
+ x: int
+
+ assert Pt(1) == Pt(1)
+ assert Pt(1) != Pt(2)
+
+ def test_frozen_with_hash(self) -> None:
+ @py_class(_unique_key("hash"), frozen=True, eq=True, unsafe_hash=True)
+ class Pt(Object):
+ x: int
+
+ assert hash(Pt(1)) == hash(Pt(1))
+ s = {Pt(1), Pt(2)}
+ assert len(s) == 2
+
+
+# ---------------------------------------------------------------------------
+# FFIProperty descriptor checks
+# ---------------------------------------------------------------------------
+class TestFFIPropertyDescriptor:
+ def test_frozen_field_is_ffi_property(self) -> None:
+ @py_class(_unique_key("desc_type"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ assert isinstance(Pt.__dict__["x"], FFIProperty)
+
+ def test_frozen_field_fset_is_none(self) -> None:
+ @py_class(_unique_key("desc_fset"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ assert Pt.__dict__["x"].fset is None
+
+ def test_mutable_field_is_ffi_property(self) -> None:
+ @py_class(_unique_key("desc_mutable"))
+ class Pt(Object):
+ x: int
+
+ assert isinstance(Pt.__dict__["x"], FFIProperty)
+
+ def test_mutable_field_fset_is_not_none(self) -> None:
+ @py_class(_unique_key("desc_mutable_fset"))
+ class Pt(Object):
+ x: int
+
+ assert Pt.__dict__["x"].fset is not None
+
+ def test_mutable_field_set_method_also_works(self) -> None:
+ @py_class(_unique_key("desc_mutable_set"))
+ class Pt(Object):
+ x: int
+
+ p = Pt(1)
+ type(p).x.set(p, 99) # ty: ignore[unresolved-attribute]
+ assert p.x == 99
+
+
+# ---------------------------------------------------------------------------
+# Edge cases
+# ---------------------------------------------------------------------------
+class TestFrozenEdgeCases:
+ def test_frozen_class_no_own_fields(self) -> None:
+ @py_class(_unique_key("no_fields"), frozen=True)
+ class Empty(Object):
+ pass
+
+ Empty() # should not raise
+
+ def test_frozen_single_field(self) -> None:
+ @py_class(_unique_key("single"), frozen=True)
+ class Single(Object):
+ x: int
+
+ s = Single(1)
+ assert s.x == 1
+ with pytest.raises(AttributeError):
+ s.x = 2 # ty: ignore[invalid-assignment]
+
+ def test_frozen_field_with_none_default(self) -> None:
+ @py_class(_unique_key("none_def"), frozen=True)
+ class Opt(Object):
+ v: Optional[int] = None # noqa: UP045
+
+ o = Opt()
+ assert o.v is None
+ with pytest.raises(AttributeError):
+ o.v = 1 # ty: ignore[invalid-assignment]
+
+ def test_multiple_instances_independent(self) -> None:
+ @py_class(_unique_key("multi_inst"), frozen=True)
+ class Pt(Object):
+ x: int
+
+ a = Pt(1)
+ b = Pt(2)
+ assert a.x == 1
+ assert b.x == 2
+ with pytest.raises(AttributeError):
+ a.x = 99 # ty: ignore[invalid-assignment]
+ with pytest.raises(AttributeError):
+ b.x = 99 # ty: ignore[invalid-assignment]
+
+ def test_frozen_instance_as_field_value(self) -> None:
+ @py_class(_unique_key("inner_frozen"), frozen=True)
+ class Inner(Object):
+ val: int
+
+ @py_class(_unique_key("outer_mut"))
+ class Outer(Object):
+ child: Inner
+
+ inner = Inner(10)
+ outer = Outer(inner)
+ # Inner is still frozen even when held in a mutable outer
+ with pytest.raises(AttributeError):
+ outer.child.val = 99 # ty: ignore[invalid-assignment]
+ # But the outer field itself is mutable
+ outer.child = Inner(20)
+ assert outer.child.val == 20
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index da8c183..5864123 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -148,7 +148,7 @@ def test_ml_dtypes_dtype_conversion() -> None:
np = pytest.importorskip("numpy")
ml_dtypes = pytest.importorskip("ml_dtypes")
if Version(ml_dtypes.__version__) < Version("0.4.0"):
- pytest.skip("ml_dtypes < 0.4.0")
+ pytest.skip("ml_dtypes < 0.4.0") # ty: ignore[invalid-argument-type,
too-many-positional-arguments]
return
_check_dtype(np.dtype(ml_dtypes.int2), 0, 2, 1)
_check_dtype(np.dtype(ml_dtypes.int4), 0, 4, 1)
diff --git a/tests/python/test_error.py b/tests/python/test_error.py
index c6c36f9..70d1f0a 100644
--- a/tests/python/test_error.py
+++ b/tests/python/test_error.py
@@ -88,7 +88,7 @@ def test_error_from_nested_pyfunc() -> None:
assert pos_cxx_raise < pos_lambda
assert pos_lambda < pos_cxx_apply
except Exception as e:
- pytest.xfail("May fail if debug symbols are missing")
+ pytest.xfail("May fail if debug symbols are missing") # ty:
ignore[invalid-argument-type, too-many-positional-arguments]
def test_error_traceback_update() -> None:
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 7677a3a..672a264 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -211,7 +211,7 @@ def test_load_inline_cuda() -> None:
def test_load_inline_with_env_tensor_allocator() -> None:
assert torch is not None
if not hasattr(torch.Tensor, "__dlpack_c_exchange_api__"):
- pytest.skip("Torch does not support __dlpack_c_exchange_api__")
+ pytest.skip("Torch does not support __dlpack_c_exchange_api__") # ty:
ignore[invalid-argument-type, too-many-positional-arguments]
mod: Module = tvm_ffi.cpp.load_inline(
name="hello",
cpp_sources=r"""
diff --git a/tests/python/utils/test_embed_cubin.py
b/tests/python/utils/test_embed_cubin.py
index 8ef7f81..7954ba0 100644
--- a/tests/python/utils/test_embed_cubin.py
+++ b/tests/python/utils/test_embed_cubin.py
@@ -84,7 +84,7 @@ def _create_test_object_file(obj_path: Path) -> None:
continue
if compiler is None:
- pytest.skip("No C++ compiler found (tried g++, clang++, c++)")
+ pytest.skip("No C++ compiler found (tried g++, clang++, c++)") # ty:
ignore[invalid-argument-type, too-many-positional-arguments]
assert isinstance(compiler, str), "Compiler is not a string"
@@ -93,7 +93,7 @@ def _create_test_object_file(obj_path: Path) -> None:
try:
subprocess.run(compile_cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
- pytest.skip(f"Failed to compile test object file:
{e.stderr.decode('utf-8')}")
+ pytest.skip(f"Failed to compile test object file:
{e.stderr.decode('utf-8')}") # ty: ignore[invalid-argument-type,
too-many-positional-arguments]
finally:
# Clean up temporary C++ file
cpp_file.unlink(missing_ok=True)
@@ -121,7 +121,7 @@ def _check_symbols_in_object(obj_path: Path,
expected_symbols: list[str]) -> boo
return False
return True
except (subprocess.CalledProcessError, FileNotFoundError):
- pytest.skip("nm tool not available")
+ pytest.skip("nm tool not available") # ty:
ignore[invalid-argument-type, too-many-positional-arguments]
return False