This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ba27ff271 [relax][frontend][tflite] Add tests for 
l2_normalization/slice/reverse_v2 (#19371)
9ba27ff271 is described below

commit 9ba27ff271d3b62f665d4bdac7bbb04f54607118
Author: Ahmad Jahaf <[email protected]>
AuthorDate: Sat Apr 11 20:33:58 2026 +0300

    [relax][frontend][tflite] Add tests for l2_normalization/slice/reverse_v2 
(#19371)
    
    This PR adds Relax TFLite frontend test coverage for:
    - L2_NORMALIZATION
    - SLICE
    - REVERSE_V2
    
    Part of fixing #18971.
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 20 ++++++---
 tests/python/relax/test_frontend_tflite.py         | 49 ++++++++++++++++++++++
 2 files changed, 63 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 16d5cb636b..d7b56e597b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -800,7 +800,12 @@ class OperatorConverter:
             )
 
         # TFL uses only the default epsilon value
-        out = relax.op.nn.l2_normalize(in_expr, eps=1e-12, 
axis=[input_tensor_rank - 1])
+        # Implement L2 normalization: output = input / sqrt(sum(input^2) + eps)
+        # L2 normalization is applied along the last axis
+        squared = relax.op.square(in_expr)
+        sum_squared = relax.op.sum(squared, axis=input_tensor_rank - 1, 
keepdims=True)
+        denom = relax.op.sqrt(relax.op.add(sum_squared, relax.const(1e-12, 
"float32")))
+        out = relax.op.divide(in_expr, denom)
 
         # if we have fused activation fn
         if output_tensor.qnn_params:
@@ -2251,8 +2256,11 @@ class OperatorConverter:
             else:
                 end[i] += begin[i]
 
-        out = relax.op.strided_slice(in_expr, begin, end)
-
+        # Create axes list for all dimensions being sliced
+        axes = list(range(input_tensor_rank))
+        begin = [int(v) for v in begin]
+        end   = [int(v) for v in end]
+        out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end)
         return out
 
     def convert_select(self, op):
@@ -3555,7 +3563,7 @@ class OperatorConverter:
         axis = self.get_tensor_value(input_tensors[1])
         if isinstance(axis, np.ndarray):
             assert axis.size == 1, "only one value is expected."
-            axis = int(axis)
+            axis = int(axis.flat[0])
 
         ndims = len(input_tensors[0].tensor.ShapeAsNumpy())
         assert -1 - ndims <= axis <= ndims, "axis out of range"
@@ -3628,9 +3636,9 @@ class OperatorConverter:
         axis = self.get_tensor_value(input_tensors[1])
         if isinstance(axis, np.ndarray):
             assert len(axis) == 1, "TFLite does not support multi-axis yet"
-            axis = int(axis)
+            axis = int(axis.flat[0])
 
-        out = relax.op.reverse(input_expr, axis)
+        out = relax.op.flip(input_expr, axis)
         return out
 
     def convert_matrix_set_diag(self, op):
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index c237d4db8f..c9a8470f42 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -709,6 +709,55 @@ def test_reduce(tf_op, relax_op, axis, out_shape):
     verify(TfInput, Expected)
 
 
+def test_l2_normalization():
+    class L2Normalization(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.nn.l2_normalize(x, axis=-1)
+
+    verify(L2Normalization)
+
+
+def test_slice():
+    class Slice(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(3, 4), 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.slice(x, begin=[1, 1], size=[2, 2])
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 4), dtype="float32")) -> R.Tensor((2, 2), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 2), dtype="float32") = R.strided_slice(
+                    x, axes=[0, 1], begin=[1, 1], end=[3, 3]
+                )
+                R.output(gv)
+            return gv
+
+    verify(Slice, Expected)
+
+
+def test_reverse_v2():
+    class ReverseV2(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(2, 3), 
dtype=tf.float32)])
+        def func(self, x):
+            return tf.reverse(x, axis=[1])
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((2, 3), 
dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                gv: R.Tensor((2, 3), dtype="float32") = R.flip(x, axis=1)
+                R.output(gv)
+            return gv
+
+    verify(ReverseV2, Expected)
+
+
 def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, 
padding):
     class Conv2DModule(tf.Module):
         @tf.function(

Reply via email to