This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new e809f98  feat(container): add structural `__eq__`/`__ne__`/`__hash__` 
to Array, List, Map, Dict (#545)
e809f98 is described below

commit e809f98099790f870184aa69537ffccce78b5a86
Author: Junru Shao <[email protected]>
AuthorDate: Mon Apr 13 10:56:54 2026 -0700

    feat(container): add structural `__eq__`/`__ne__`/`__hash__` to Array, 
List, Map, Dict (#545)
    
    ## Summary
    
    - Add structural `__eq__`, `__ne__`, and `__hash__` methods to the four
    container classes (`Array`, `List`, `Map`, `Dict`) in
    `python/tvm_ffi/container.py`
    - Delegates to existing `RecursiveEq` and `RecursiveHash` C++ FFI
    functions — the same infrastructure used by `_install_dataclass_dunders`
    in `_dunder.py` for `@c_class`/`@py_class`
    - Returns `NotImplemented` for unrelated types so Python's default
    comparison fallback applies
    - `Shape`, `String`, and `Bytes` are unchanged (already inherit correct
    behavior from `tuple`, `str`, `bytes`)
    
    ## Breaking Change
    
    Code that relied on identity-based equality for containers will now see
    structural equality instead. For example, `Array([1, 2]) == Array([1,
    2])` now returns `True` (previously `False`). Two container objects with
    the same contents now compare equal and produce the same hash.
    
    ## Test Plan
    
    - [x] 17 new tests added in `tests/python/test_container.py` covering:
      - Structural equality and inequality for all four container types
      - Empty containers
      - Nested containers (Array of Arrays)
      - `NotImplemented` return for unrelated types (plain list, dict, str)
      - Hash consistency (equal objects produce equal hashes)
      - Usability as set members and dict keys
    - [x] Full Python test suite passes: 2072 passed, 38 skipped, 3 xfailed
    - [x] All pre-commit hooks pass (ruff, ty check, etc.)
---
 python/tvm_ffi/container.py    |  64 +++++++++++++++++++++++
 tests/python/test_container.py | 116 +++++++++++++++++++++++++++++++++++++++++
 2 files changed, 180 insertions(+)

diff --git a/python/tvm_ffi/container.py b/python/tvm_ffi/container.py
index 5bdd9a9..4d25ce9 100644
--- a/python/tvm_ffi/container.py
+++ b/python/tvm_ffi/container.py
@@ -201,6 +201,22 @@ class Array(core.CContainerBase, core.Object, Sequence[T]):
         """Check if the array contains a value."""
         return _ffi_api.ArrayContains(self, value)
 
+    def __eq__(self, other: object) -> bool:
+        """Structural equality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return _ffi_api.RecursiveEq(self, other)
+
+    def __ne__(self, other: object) -> bool:
+        """Structural inequality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return not _ffi_api.RecursiveEq(self, other)
+
+    def __hash__(self) -> int:
+        """Structural hash."""
+        return _ffi_api.RecursiveHash(self)
+
     def __bool__(self) -> bool:
         """Return True if the array is non-empty."""
         return len(self) > 0
@@ -344,6 +360,22 @@ class List(core.CContainerBase, core.Object, 
MutableSequence[T]):
         """Check if the list contains a value."""
         return _ffi_api.ListContains(self, value)
 
+    def __eq__(self, other: object) -> bool:
+        """Structural equality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return _ffi_api.RecursiveEq(self, other)
+
+    def __ne__(self, other: object) -> bool:
+        """Structural inequality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return not _ffi_api.RecursiveEq(self, other)
+
+    def __hash__(self) -> int:
+        """Structural hash."""
+        return _ffi_api.RecursiveHash(self)
+
     def __bool__(self) -> bool:
         """Return True if the list is non-empty."""
         return len(self) > 0
@@ -499,6 +531,22 @@ class Map(core.CContainerBase, core.Object, Mapping[K, V]):
         """Return True if the map contains key `k`."""
         return _ffi_api.MapCount(self, k) != 0
 
+    def __eq__(self, other: object) -> bool:
+        """Structural equality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return _ffi_api.RecursiveEq(self, other)
+
+    def __ne__(self, other: object) -> bool:
+        """Structural inequality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return not _ffi_api.RecursiveEq(self, other)
+
+    def __hash__(self) -> int:
+        """Structural hash."""
+        return _ffi_api.RecursiveHash(self)
+
     def keys(self) -> KeysView[K]:
         """Return a dynamic view of the map's keys."""
         return KeysView(self)
@@ -607,6 +655,22 @@ class Dict(core.CContainerBase, core.Object, 
MutableMapping[K, V]):
         """Return True if the dict contains key `k`."""
         return _ffi_api.DictCount(self, k) != 0
 
+    def __eq__(self, other: object) -> bool:
+        """Structural equality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return _ffi_api.RecursiveEq(self, other)
+
+    def __ne__(self, other: object) -> bool:
+        """Structural inequality."""
+        if not (isinstance(other, type(self)) or isinstance(self, 
type(other))):
+            return NotImplemented
+        return not _ffi_api.RecursiveEq(self, other)
+
+    def __hash__(self) -> int:
+        """Structural hash."""
+        return _ffi_api.RecursiveHash(self)
+
     def __len__(self) -> int:
         """Return the number of items in the dict."""
         return _ffi_api.DictSize(self)
diff --git a/tests/python/test_container.py b/tests/python/test_container.py
index b4e85c8..9b74b0c 100644
--- a/tests/python/test_container.py
+++ b/tests/python/test_container.py
@@ -733,3 +733,119 @@ def test_map_cross_conv_incompatible_map_to_dict() -> 
None:
     m = tvm_ffi.Map({"a": "not_int", "b": "still_not_int"})
     with pytest.raises(TypeError):
         testing.schema_id_dict_str_int(m)  # type: 
ignore[invalid-argument-type]
+
+
+# ---------------------------------------------------------------------------
+# Structural __eq__ / __ne__ / __hash__ tests
+# ---------------------------------------------------------------------------
+
+
+def test_array_structural_eq() -> None:
+    a = tvm_ffi.Array([1, 2, 3])
+    b = tvm_ffi.Array([1, 2, 3])
+    c = tvm_ffi.Array([1, 2, 4])
+    assert a == b
+    assert a != c
+    assert not (a != b)
+    assert not (a == c)
+
+
+def test_array_eq_empty() -> None:
+    assert tvm_ffi.Array([]) == tvm_ffi.Array([])
+
+
+def test_array_eq_nested() -> None:
+    a = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([3])])
+    b = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([3])])
+    c = tvm_ffi.Array([tvm_ffi.Array([1, 2]), tvm_ffi.Array([4])])
+    assert a == b
+    assert a != c
+
+
+def test_array_eq_not_implemented_for_unrelated() -> None:
+    a = tvm_ffi.Array([1, 2, 3])
+    assert a.__eq__([1, 2, 3]) is NotImplemented
+    assert a.__ne__([1, 2, 3]) is NotImplemented
+    assert a.__eq__("hello") is NotImplemented
+
+
+def test_array_hash() -> None:
+    a = tvm_ffi.Array([1, 2, 3])
+    b = tvm_ffi.Array([1, 2, 3])
+    assert hash(a) == hash(b)
+    # Usable in sets and as dict keys
+    s = {a, b}
+    assert len(s) == 1
+    d = {a: "value"}
+    assert d[b] == "value"
+
+
+def test_list_structural_eq() -> None:
+    a = tvm_ffi.List([1, 2, 3])
+    b = tvm_ffi.List([1, 2, 3])
+    c = tvm_ffi.List([1, 2, 4])
+    assert a == b
+    assert a != c
+
+
+def test_list_eq_empty() -> None:
+    assert tvm_ffi.List([]) == tvm_ffi.List([])
+
+
+def test_list_eq_not_implemented_for_unrelated() -> None:
+    a = tvm_ffi.List([1, 2, 3])
+    assert a.__eq__([1, 2, 3]) is NotImplemented
+
+
+def test_list_hash() -> None:
+    a = tvm_ffi.List([1, 2, 3])
+    b = tvm_ffi.List([1, 2, 3])
+    assert hash(a) == hash(b)
+
+
+def test_map_structural_eq() -> None:
+    a = tvm_ffi.Map({"x": 1, "y": 2})
+    b = tvm_ffi.Map({"x": 1, "y": 2})
+    c = tvm_ffi.Map({"x": 1, "y": 3})
+    assert a == b
+    assert a != c
+
+
+def test_map_eq_empty() -> None:
+    assert tvm_ffi.Map({}) == tvm_ffi.Map({})
+
+
+def test_map_eq_not_implemented_for_unrelated() -> None:
+    a = tvm_ffi.Map({"x": 1})
+    assert a.__eq__({"x": 1}) is NotImplemented
+
+
+def test_map_hash() -> None:
+    a = tvm_ffi.Map({"x": 1, "y": 2})
+    b = tvm_ffi.Map({"x": 1, "y": 2})
+    assert hash(a) == hash(b)
+    s = {a, b}
+    assert len(s) == 1
+
+
+def test_dict_structural_eq() -> None:
+    a = tvm_ffi.Dict({"x": 1, "y": 2})
+    b = tvm_ffi.Dict({"x": 1, "y": 2})
+    c = tvm_ffi.Dict({"x": 1, "y": 3})
+    assert a == b
+    assert a != c
+
+
+def test_dict_eq_empty() -> None:
+    assert tvm_ffi.Dict({}) == tvm_ffi.Dict({})
+
+
+def test_dict_eq_not_implemented_for_unrelated() -> None:
+    a = tvm_ffi.Dict({"x": 1})
+    assert a.__eq__({"x": 1}) is NotImplemented
+
+
+def test_dict_hash() -> None:
+    a = tvm_ffi.Dict({"x": 1, "y": 2})
+    b = tvm_ffi.Dict({"x": 1, "y": 2})
+    assert hash(a) == hash(b)

Reply via email to