This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 643cf60f40 [Relax][TFLite] Add test coverage for Reduction operations 
(#18971) (#19370)
643cf60f40 is described below

commit 643cf60f40f330173299ef1484abc9f650dbf798
Author: Bana <[email protected]>
AuthorDate: Wed Apr 8 21:38:24 2026 +0300

    [Relax][TFLite] Add test coverage for Reduction operations (#18971) (#19370)
    
    Closes part of #18971
    
    ---
    
    ## Description
    This PR improves the test coverage of the TFLite frontend to Relax
    converter by adding comprehensive tests for the **"Reductions" group.**
    
    **Operations covered:**
    * `SUM` (`tf.reduce_sum`)
    * `MEAN` (`tf.reduce_mean`)
    * `REDUCE_MAX` (`tf.reduce_max`)
    * `REDUCE_MIN` (`tf.reduce_min`)
    * `REDUCE_PROD` (`tf.reduce_prod`)
    
    ## Changes made:
    * Added a parameterized testing function (`test_reduction_ops`) to cover
    combinations of the aforementioned reduction operators.
    * Covered a variety of axis configurations, including positive scalars,
    lists of axes, negative indices, and `None` (global reduction).
    * Tested with different combinations of the `keepdims` flag
    (`True`/`False`) and dtypes (`float32`/`int32`).
    * Handled the representation of global reductions (`axis=None`) in the
    expected Relax IR Module `_make_reduce_expected` utility by expanding it
    to all input axes, perfectly mirroring the frontend's output graph
    structure (`list(range(len(input_shape)))`).
    
    ## Testing
    ```bash
    pytest tests/python/relax/test_frontend_tflite.py::test_reduction_ops
    ```
    <img width="931" height="42" alt="Screenshot 2026-04-08 102859"
    
src="https://github.com/user-attachments/assets/f666199b-b839-42e7-a6ad-2753b92b45b0";
    />
---
 tests/python/relax/test_frontend_tflite.py | 48 ++++++++++++++++++++++++++++++
 1 file changed, 48 insertions(+)

diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 275e162b81..1f314e7d57 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1254,5 +1254,53 @@ def test_resize_nearest_neighbor(input_shape, 
output_size, tf_op, coordinate_tra
     verify(ResizeNearest, expected)
 
 
+def _make_reduce_expected(relax_op, input_shape, axes, keepdims, dtype):
+    if axes is None:
+        axes = list(range(len(input_shape)))
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", relax.TensorStructInfo(input_shape, dtype))
+    with bb.function("main", [x]):
+        with bb.dataflow():
+            gv = bb.emit_output(relax_op(x, axis=axes, keepdims=keepdims))
+        bb.emit_func_output(gv)
+    mod = bb.get()
+    mod["main"] = mod["main"].with_attr("num_input", 1)
+    return mod
+
+
[email protected](
+    "tf_op, relax_op",
+    [
+        (tf.reduce_sum, relax.op.sum),
+        (tf.reduce_mean, relax.op.mean),
+        (tf.reduce_max, relax.op.max),
+        (tf.reduce_min, relax.op.min),
+        (tf.reduce_prod, relax.op.prod),
+    ],
+)
[email protected](
+    "input_shape, axes",
+    [
+        ((1, 8, 8, 3), 1),
+        ((1, 8, 8, 3), [1, 2]),
+        ((1, 8, 8, 3), -1),
+        ((1, 8, 8, 3), None),
+        ((30,), 0),
+        ((2, 5, 2), [0, 2]),
+    ],
+)
[email protected]("keepdims", [True, False])
[email protected]("dtype", [tf.float32, tf.int32])
+def test_reduction_ops(tf_op, relax_op, input_shape, axes, keepdims, dtype):
+    class ReduceModule(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=input_shape, 
dtype=dtype)])
+        def func(self, x):
+            return tf_op(x, axis=axes, keepdims=keepdims)
+
+    relax_dtype = "float32" if dtype == tf.float32 else "int32"
+    expected = _make_reduce_expected(relax_op, input_shape, axes, keepdims, 
relax_dtype)
+    verify(ReduceModule, expected)
+
+
 if __name__ == "__main__":
     pytest.main(["-s", __file__])

Reply via email to