This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9acbf4ae6b [BugFix][Relax] Add structural_equal verification to 
subroutine cache lookup (#18962)
9acbf4ae6b is described below

commit 9acbf4ae6b075d441b1ed7f9bd2f37a64521aec2
Author: 3em0 <[email protected]>
AuthorDate: Wed Apr 8 11:18:05 2026 +0800

    [BugFix][Relax] Add structural_equal verification to subroutine cache 
lookup (#18962)
    
    ## Summary
    
    - `SubroutineMixin._get_subroutine()` used `structural_hash` as the sole
    cache key without `structural_equal` verification. If two different
    `arg_sinfo` values produced the same 64-bit hash (collision), the cache
    would return a previously compiled function with mismatched parameter
    shapes, leading to silently incorrect compiled output.
    - Changed the cache to store a list of `(arg_sinfo, result)` pairs per
    hash bucket and verify with `structural_equal` on lookup, consistent
    with the pattern in `block_builder.cc`.
    - Added a security advisory document and regression test.
    
    ## Root Cause
    
    The subroutine cache (`cls._gvar`) was keyed by
    `(structural_hash(arg_sinfo), is_dataflow)`. A hash match was treated as
    proof of structural equality, skipping the necessary `structural_equal`
    check. This is a hash-only lookup anti-pattern — hash determines the
    bucket, but equality must confirm the match.
    
    For comparison, `block_builder.cc` correctly uses `StructuralHash` +
    `StructuralEqual` together as the hash and equality functions for
    `std::unordered_map`.
    
    ## Test plan
    
    - [ ] Existing test `test_linear` passes (no regression)
    - [ ] New test `test_different_shapes_produce_distinct_subroutines`
    passes — verifies that the same Module class with different input shapes
    generates distinct subroutines
    
    🤖 Generated with [Claude Code](https://claude.com/claude-code)
    
    ---------
    
    Co-authored-by: [email protected] <[email protected]>
    Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
    Co-authored-by: gemini-code-assist[bot] 
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
 docs/reference/security.rst                        | 29 ++++++++++++
 python/tvm/relax/frontend/nn/subroutine.py         | 14 ++++--
 tests/python/relax/test_frontend_nn_subroutines.py | 53 ++++++++++++++++++++++
 3 files changed, 91 insertions(+), 5 deletions(-)

diff --git a/docs/reference/security.rst b/docs/reference/security.rst
index 8380b6c15a..de3ebf464d 100644
--- a/docs/reference/security.rst
+++ b/docs/reference/security.rst
@@ -49,3 +49,32 @@ we expect users to put in only trusted URLs.
 
 RPC data exchange between the tracker, server and client are in plain-text.
 It is recommended to use them under trusted networking environment or 
encrypted channels.
+
+
+Security Advisories
+-------------------
+
+Subroutine Cache Hash Collision
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+``SubroutineMixin._get_subroutine()`` in 
``python/tvm/relax/frontend/nn/subroutine.py``
+used ``ir.structural_hash`` as the sole cache lookup key without a subsequent
+``structural_equal`` verification. If two different ``arg_sinfo`` values 
produced the
+same 64-bit hash, the cache would return a previously compiled function with
+mismatched parameter shapes, leading to silently incorrect compiled output.
+
+**Severity**: Low. The ``structural_hash`` function returns a 64-bit integer.
+A natural hash collision requires approximately 2^32 distinct inputs (birthday 
bound),
+making accidental collision extremely unlikely in normal compilation workflows.
+The issue is primarily a correctness defect rather than a practically 
exploitable
+security vulnerability.
+
+**Root Cause**: The subroutine cache (``cls._gvar``) was keyed by
+``(structural_hash(arg_sinfo, map_free_vars=True), is_dataflow)``.
+A hash match was treated as proof of structural equality, skipping the 
necessary
+``structural_equal`` check.
+
+**Fix**: The cache now stores a list of ``(arg_sinfo, result)`` pairs per hash 
bucket.
+On lookup, each candidate is verified with ``structural_equal`` before 
returning.
+This follows the standard hash-table pattern: hash for bucket selection, 
equality
+for final verification.
diff --git a/python/tvm/relax/frontend/nn/subroutine.py 
b/python/tvm/relax/frontend/nn/subroutine.py
index c62491be9f..0716f19a92 100644
--- a/python/tvm/relax/frontend/nn/subroutine.py
+++ b/python/tvm/relax/frontend/nn/subroutine.py
@@ -25,6 +25,7 @@ import re
 import typing
 
 from tvm import ir, relax
+from tvm.ir import structural_equal
 from tvm.relax.frontend import nn
 
 
@@ -141,10 +142,11 @@ class SubroutineMixin:
 
         arg_sinfo = _get_struct_info([*func_args.values(), *model_params])
         is_dataflow = block_builder.current_block_is_dataflow()
-        lookup_key = (ir.structural_hash(arg_sinfo, map_free_vars=True), 
is_dataflow)
+        lookup_key = (old_forward, ir.structural_hash(arg_sinfo, 
map_free_vars=True), is_dataflow)
 
-        if lookup_key in cls._gvar:
-            return cls._gvar[lookup_key]
+        for cached_sinfo, cached_result in cls._gvar.get(lookup_key, []):
+            if structural_equal(cached_sinfo, arg_sinfo, map_free_vars=True):
+                return cached_result
 
         func_name = _camel_to_snake(cls.__name__)
         func_params = [relax.Var(name, sinfo) for name, sinfo in 
zip(func_args, arg_sinfo.fields)]
@@ -175,5 +177,7 @@ class SubroutineMixin:
         mod = block_builder.get()
         mod.update_func(gvar, relax.utils.copy_with_new_vars(mod[gvar]))
 
-        cls._gvar[lookup_key] = (gvar, is_nn_tensor_output)
-        return cls._gvar[lookup_key]
+        result = (gvar, is_nn_tensor_output)
+        bucket = cls._gvar.setdefault(lookup_key, [])
+        bucket.append((arg_sinfo, result))
+        return result
diff --git a/tests/python/relax/test_frontend_nn_subroutines.py 
b/tests/python/relax/test_frontend_nn_subroutines.py
index 9ea44781b8..a06fa05c77 100644
--- a/tests/python/relax/test_frontend_nn_subroutines.py
+++ b/tests/python/relax/test_frontend_nn_subroutines.py
@@ -97,5 +97,58 @@ def test_linear():
     assert_structural_equal(Expected, tvm_mod, True)
 
 
+def test_different_shapes_produce_distinct_subroutines():
+    """Regression test: same Module class with different input shapes
+    must generate distinct subroutines, not reuse a cached one."""
+
+    class Linear(nn.Module):
+        define_subroutine = True
+
+        def __init__(self, in_features, out_features):
+            self.weights = nn.Parameter((in_features, out_features), 
dtype="float32")
+
+        def forward(self, input: relax.Expr) -> relax.Var:
+            return nn.op.matmul(input, self.weights)
+
+    class Model(nn.Module):
+        def __init__(self):
+            self.linear_a = Linear(32, 16)
+            self.linear_b = Linear(64, 16)
+
+        def forward(self, x: relax.Expr, y: relax.Expr) -> relax.Var:
+            a = self.linear_a(x)
+            b = self.linear_b(y)
+            return nn.op.add(a, b)
+
+    mod = Model()
+    batch_size = tvm.tirx.Var("batch_size", "int64")
+    tvm_mod, _ = mod.export_tvm(
+        spec={
+            "forward": {
+                "x": nn.spec.Tensor((batch_size, 32), "float32"),
+                "y": nn.spec.Tensor((batch_size, 64), "float32"),
+            }
+        },
+        debug=True,
+    )
+
+    # Collect all private functions (subroutines) in the module
+    subroutine_funcs = [
+        func
+        for gvar, func in tvm_mod.functions.items()
+        if isinstance(func, relax.Function)
+        and gvar.name_hint not in (
+            "forward",
+            "_initialize_effect",
+        )
+    ]
+
+    # There must be two distinct Linear subroutines (one for in_features=32,
+    # one for in_features=64), not a single cached one reused for both.
+    assert len(subroutine_funcs) == 2, (
+        f"Expected 2 distinct subroutines for different input shapes, got 
{len(subroutine_funcs)}"
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to