gemini-code-assist[bot] commented on code in PR #18663:
URL: https://github.com/apache/tvm/pull/18663#discussion_r2694182595
##########
tests/python/relax/test_transform_static_plan_block_memory.py:
##########
@@ -1018,6 +1018,245 @@ def main(x: R.Tensor((2, "n"), dtype="float32")) ->
R.Tensor(("2 * n + 2",), dty
tvm.ir.assert_structural_equal(mod, Expected)
+def test_lower_bound_only():
+ # fmt: off
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add:
T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.handle, PadInput: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n +
2",), dtype="float32"):
+ R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure":
True})
+ n = T.int64()
+ cls = Module
+ alloc: R.Tensor((2, n), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
+ _: R.Tuple() = cls.exp(x, alloc)
+ lv: R.Tensor((2, n), dtype="float32") = alloc
+ lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
+ alloc1: R.Tensor((2 * n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32",
runtime_device_index=0)
+ _1: R.Tuple() = cls.relu(lv1, alloc1)
+ lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((2 * n,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32",
runtime_device_index=0)
+ _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
+ lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+ alloc3: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32",
runtime_device_index=0)
+ _3: R.Tuple() = cls.pad(lv3, alloc3)
+ lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+ alloc4: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
+ _4: R.Tuple() = cls.log(lv4, alloc4)
+ gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add:
T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def exp(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def log(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def pad(rxplaceholder: T.handle, PadInput: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def relu(rxplaceholder: T.handle, compute: T.handle):
+ T.evaluate(0)
+
+ @T.prim_func
+ def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
+ T.evaluate(0)
+
+ @R.function
+ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n +
2",), dtype="float32"):
+ n = T.int64()
+ R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure":
True})
+ cls = Expected
+ storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]),
R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc: R.Tensor((2, n), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]),
R.dtype("float32"), R.prim_value(0))
+ _: R.Tuple = cls.exp(x, alloc)
+ lv: R.Tensor((2, n), dtype="float32") = alloc
+ lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv,
R.shape([2 * n]))
+ storage1: R.Object = R.memory.alloc_storage(R.shape([4 * (2 *
n)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc1: R.Tensor((2 * n,), dtype="float32") =
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]),
R.dtype("float32"))
+ _1: R.Tuple = cls.relu(lv1, alloc1)
+ lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
+ alloc2: R.Tensor((2 * n,), dtype="float32") =
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]),
R.dtype("float32"))
+ _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
+ lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
+ storage2: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n +
2)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
+ alloc3: R.Tensor((2 * n + 2,), dtype="float32") =
R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([2 * n + 2]),
R.dtype("float32"), R.prim_value(0))
+ _3: R.Tuple = cls.pad(lv3, alloc3)
+ lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
+ alloc4: R.Tensor((2 * n + 2,), dtype="float32") =
R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
+ _4: R.Tuple = cls.log(lv4, alloc4)
+ gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
+ return gv
+ # fmt: on
+
+ mod = relax.transform.StaticPlanBlockMemory()(Module)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
Review Comment:

The test functions `test_lower_bound_only` and `test_upper_and_lower_bounds`
are very similar and contain a lot of duplicated code. Consider refactoring
them into a single parameterized test using `pytest.mark.parametrize` to
improve maintainability.
You could parameterize the function attributes and the expected memory
allocation logic (static vs. dynamic). Here is a conceptual example:
```python
import pytest
@pytest.mark.parametrize(
"func_attr, is_static",
[
({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True}, False),
(
{"tir_var_upper_bound": {"n": 4}, "tir_var_lower_bound": {"n":
2}, "relax.force_pure": True},
True,
),
],
)
def test_bounds(func_attr, is_static):
@tvm.script.ir_module
class Module:
# ... (common module definition)
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n
+ 2",), dtype="float32"):
R.func_attr(func_attr)
# ... (common function body)
@I.ir_module
class Expected:
# ... (common primfuncs)
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n
+ 2",), dtype="float32"):
n = T.int64()
R.func_attr(func_attr)
cls = Expected
if is_static:
storage: R.Object = R.memory.alloc_storage(R.shape([32]),
...)
# ... static allocation logic
else:
storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]),
...)
# ... dynamic allocation logic
mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)
```
This approach would make the tests more concise and easier to maintain.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]