This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new e809f98 feat(container): add structural `__eq__`/`__ne__`/`__hash__`
to Array, List, Map, Dict (#545)
e809f98 is described below
commit e809f98099790f870184aa69537ffccce78b5a86
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 13 10:56:54 2026 -0700
feat(container): add structural `__eq__`/`__ne__`/`__hash__` to Array,
List, Map, Dict (#545)
## Summary
- Add structural `__eq__`, `__ne__`, and `__hash__` methods to the four
container classes (`Array`, `List`, `Map`, `Dict`) in
`python/tvm_ffi/container.py`
- Delegates to existing `RecursiveEq` and `RecursiveHash` C++ FFI
functions — the same infrastructure used by `_install_dataclass_dunders`
in `_dunder.py` for `@c_class`/`@py_class`
- Returns `NotImplemented` for unrelated types so Python's default
comparison fallback applies
- `Shape`, `String`, and `Bytes` are unchanged (already inherit correct
behavior from `tuple`, `str`, `bytes`)
## Breaking Change
Code that relied on identity-based equality for containers will now see
structural equality instead. For example, `Array([1, 2]) == Array([1,
2])` now returns `True` (previously `False`). Two container objects with
the same contents now compare equal and produce the same hash.
## Test Plan
- [x] 17 new tests added in `tests/python/test_container.py` covering:
- Structural equality and inequality for all four container types
- Empty containers
- Nested containers (Array of Arrays)
- `NotImplemented` return for unrelated types (plain list, dict, str)
- Hash consistency (equal objects produce equal hashes)
- Usability as set members and dict keys
- [x] Full Python test suite passes: 2072 passed, 38 skipped, 3 xfailed
- [x] All pre-commit hooks pass (ruff, ty check, etc.)
---
python/tvm_ffi/container.py | 64 +++++++++++++++++++++++
tests/python/test_container.py | 116 +++++++++++++++++++++++++++++++++++++++++
2 files changed, 180 insertions(+)
diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 5bdd9a9..4d25ce9 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -201,6 +201,22 @@ class Array(core.CContainerBase, core.Object, Sequence[T]):
"""Check if the array contains a value."""
return _ffi_api.ArrayContains(self, value)
+ def __eq__(self, other: object) -> bool:
+ """Structural equality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return _ffi_api.RecursiveEq(self, other)
+
+ def __ne__(self, other: object) -> bool:
+ """Structural inequality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return not _ffi_api.RecursiveEq(self, other)
+
+ def __hash__(self) -> int:
+ """Structural hash."""
+ return _ffi_api.RecursiveHash(self)
+
def __bool__(self) -> bool:
"""Return True if the array is non-empty."""
return len(self) > 0
@@ -344,6 +360,22 @@ class List(core.CContainerBase, core.Object,
MutableSequence[T]):
"""Check if the list contains a value."""
return _ffi_api.ListContains(self, value)
+ def __eq__(self, other: object) -> bool:
+ """Structural equality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return _ffi_api.RecursiveEq(self, other)
+
+ def __ne__(self, other: object) -> bool:
+ """Structural inequality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return not _ffi_api.RecursiveEq(self, other)
+
+ def __hash__(self) -> int:
+ """Structural hash."""
+ return _ffi_api.RecursiveHash(self)
+
def __bool__(self) -> bool:
"""Return True if the list is non-empty."""
return len(self) > 0
@@ -499,6 +531,22 @@ class Map(core.CContainerBase, core.Object, Mapping[K, V]):
"""Return True if the map contains key `k`."""
return _ffi_api.MapCount(self, k) != 0
+ def __eq__(self, other: object) -> bool:
+ """Structural equality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return _ffi_api.RecursiveEq(self, other)
+
+ def __ne__(self, other: object) -> bool:
+ """Structural inequality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return not _ffi_api.RecursiveEq(self, other)
+
+ def __hash__(self) -> int:
+ """Structural hash."""
+ return _ffi_api.RecursiveHash(self)
+
def keys(self) -> KeysView[K]:
"""Return a dynamic view of the map's keys."""
return KeysView(self)
@@ -607,6 +655,22 @@ class Dict(core.CContainerBase, core.Object,
MutableMapping[K, V]):
"""Return True if the dict contains key `k`."""
return _ffi_api.DictCount(self, k) != 0
+ def __eq__(self, other: object) -> bool:
+ """Structural equality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return _ffi_api.RecursiveEq(self, other)
+
+ def __ne__(self, other: object) -> bool:
+ """Structural inequality."""
+ if not (isinstance(other, type(self)) or isinstance(self,
type(other))):
+ return NotImplemented
+ return not _ffi_api.RecursiveEq(self, other)
+
+ def __hash__(self) -> int:
+ """Structural hash."""
+ return _ffi_api.RecursiveHash(self)
+
def __len__(self) -> int:
"""Return the number of items in the dict."""
return _ffi_api.DictSize(self)
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index b4e85c8..9b74b0c 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -733,3 +733,119 @@ def test_map_cross_conv_incompatible_map_to_dict() ->
None:
m = tvm_ffi.Map({"a": "not_int", "b": "still_not_int"})
with pytest.raises(TypeError):
testing.schema_id_dict_str_int(m) # type:
ignore[invalid-argument-type]
+
+
+# ---------------------------------------------------------------------------
+# Structural __eq__ / __ne__ / __hash__ tests
+# ---------------------------------------------------------------------------
+
+
+def test_array_structural_eq() -> None:
+ a = tvm_ffi.Array([1, 2, 3])
+ b = tvm_ffi.Array([1, 2, 3])
+ c = tvm_ffi.Array([1, 2, 4])
+ assert a == b
+ assert a != c
+ assert not (a != b)
+ assert not (a == c)
+
+
+def test_array_eq_empty() -> None:
+ assert tvm_ffi.Array([]) == tvm_ffi.Array([])
+
+
+def test_array_eq_nested() -> None:
+ a = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([3])])
+ b = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([3])])
+ c = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([4])])
+ assert a == b
+ assert a != c
+
+
+def test_array_eq_not_implemented_for_unrelated() -> None:
+ a = tvm_ffi.Array([1, 2, 3])
+ assert a.__eq__([1, 2, 3]) is NotImplemented
+ assert a.__ne__([1, 2, 3]) is NotImplemented
+ assert a.__eq__("hello") is NotImplemented
+
+
+def test_array_hash() -> None:
+ a = tvm_ffi.Array([1, 2, 3])
+ b = tvm_ffi.Array([1, 2, 3])
+ assert hash(a) == hash(b)
+ # Usable in sets and as dict keys
+ s = {a, b}
+ assert len(s) == 1
+ d = {a: "value"}
+ assert d[b] == "value"
+
+
+def test_list_structural_eq() -> None:
+ a = tvm_ffi.List([1, 2, 3])
+ b = tvm_ffi.List([1, 2, 3])
+ c = tvm_ffi.List([1, 2, 4])
+ assert a == b
+ assert a != c
+
+
+def test_list_eq_empty() -> None:
+ assert tvm_ffi.List([]) == tvm_ffi.List([])
+
+
+def test_list_eq_not_implemented_for_unrelated() -> None:
+ a = tvm_ffi.List([1, 2, 3])
+ assert a.__eq__([1, 2, 3]) is NotImplemented
+
+
+def test_list_hash() -> None:
+ a = tvm_ffi.List([1, 2, 3])
+ b = tvm_ffi.List([1, 2, 3])
+ assert hash(a) == hash(b)
+
+
+def test_map_structural_eq() -> None:
+ a = tvm_ffi.Map({"x": 1, "y": 2})
+ b = tvm_ffi.Map({"x": 1, "y": 2})
+ c = tvm_ffi.Map({"x": 1, "y": 3})
+ assert a == b
+ assert a != c
+
+
+def test_map_eq_empty() -> None:
+ assert tvm_ffi.Map({}) == tvm_ffi.Map({})
+
+
+def test_map_eq_not_implemented_for_unrelated() -> None:
+ a = tvm_ffi.Map({"x": 1})
+ assert a.__eq__({"x": 1}) is NotImplemented
+
+
+def test_map_hash() -> None:
+ a = tvm_ffi.Map({"x": 1, "y": 2})
+ b = tvm_ffi.Map({"x": 1, "y": 2})
+ assert hash(a) == hash(b)
+ s = {a, b}
+ assert len(s) == 1
+
+
+def test_dict_structural_eq() -> None:
+ a = tvm_ffi.Dict({"x": 1, "y": 2})
+ b = tvm_ffi.Dict({"x": 1, "y": 2})
+ c = tvm_ffi.Dict({"x": 1, "y": 3})
+ assert a == b
+ assert a != c
+
+
+def test_dict_eq_empty() -> None:
+ assert tvm_ffi.Dict({}) == tvm_ffi.Dict({})
+
+
+def test_dict_eq_not_implemented_for_unrelated() -> None:
+ a = tvm_ffi.Dict({"x": 1})
+ assert a.__eq__({"x": 1}) is NotImplemented
+
+
+def test_dict_hash() -> None:
+ a = tvm_ffi.Dict({"x": 1, "y": 2})
+ b = tvm_ffi.Dict({"x": 1, "y": 2})
+ assert hash(a) == hash(b)