This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e5d4c557e0 [relax][frontend][tflite] Add tests for 
fully_connected/depthwise_conv2d/transpose_conv/l2_pool2d (#19372)
e5d4c557e0 is described below

commit e5d4c557e0481d6cce7825aca5bad5e875d21acf
Author: Ahmad Jahaf <[email protected]>
AuthorDate: Sun Apr 12 05:05:09 2026 +0300

    [relax][frontend][tflite] Add tests for 
fully_connected/depthwise_conv2d/transpose_conv/l2_pool2d (#19372)
    
    This PR adds Relax TFLite frontend test coverage for:
    - FULLY_CONNECTED
    - DEPTHWISE_CONV_2D
    - TRANSPOSE_CONV
    - L2_POOL_2D
    
    Part of fixing #18971.
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  2 -
 tests/python/relax/test_frontend_tflite.py         | 83 +++++++++++++++++++++-
 2 files changed, 82 insertions(+), 3 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 8b2f70a0f5..9c99e98e01 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3120,8 +3120,6 @@ class OperatorConverter:
                 weight_expr_iohw,
                 strides=(stride_h, stride_w),
                 padding=padding,
-                channels=int(out_channels),
-                kernel_size=(int(kernel_h), int(kernel_w)),
                 data_layout="NHWC",
                 kernel_layout="IOHW",
                 out_dtype=output_tensor_type_str,
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 195d2f5542..58af46cbc9 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -710,6 +710,88 @@ def test_reduce(tf_op, relax_op, axis, out_shape):
     verify(TfInput, Expected)
 
 
+def test_fully_connected():
+    class FullyConnected(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8), 
dtype=tf.float32)])
+        def func(self, x):
+            weight = tf.constant(np.arange(24, dtype=np.float32).reshape((3, 
8)))
+            bias = tf.constant(np.array([0.5, 1.0, -1.0], dtype=np.float32))
+            out = tf.matmul(x, weight, transpose_b=True)
+            return tf.nn.bias_add(out, bias)
+
+    verify(FullyConnected)
+
+
+def test_depthwise_conv2d():
+    class DepthwiseConv2D(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32),
+                tf.TensorSpec(shape=(3, 3, 2, 1), dtype=tf.float32),
+            ]
+        )
+        def func(self, data, kernel):
+            return tf.nn.depthwise_conv2d(
+                input=data,
+                filter=kernel,
+                strides=[1, 1, 1, 1],
+                padding="SAME",
+            )
+
+    verify(DepthwiseConv2D)
+
+
+def test_transpose_conv():
+    class TransposeConv(tf.Module):
+        @tf.function(
+            input_signature=[
+                tf.TensorSpec(shape=(1, 8, 8, 2), dtype=tf.float32),
+                tf.TensorSpec(shape=(3, 3, 3, 2), dtype=tf.float32),
+            ]
+        )
+        def func(self, data, kernel):
+            output_shape = tf.constant([1, 8, 8, 3], dtype=tf.int32)
+            return tf.nn.conv2d_transpose(
+                input=data,
+                filters=kernel,
+                output_shape=output_shape,
+                strides=[1, 1, 1, 1],
+                padding="SAME",
+            )
+
+    verify(TransposeConv)
+
+def test_l2_pool2d():
+    class L2Pool2D(tf.Module):
+        @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), 
dtype=tf.float32)])
+        def func(self, data):
+            squared = tf.math.square(data)
+            pooled = tf.nn.avg_pool2d(squared, ksize=[2, 2], strides=[1, 1], 
padding="SAME")
+            return tf.math.sqrt(pooled)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            data: R.Tensor((1, 8, 8, 2), dtype="float32")
+        ) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                squared = R.power(data, R.const(2.0, "float32"))
+                pooled = R.nn.avg_pool2d(
+                    squared,
+                    pool_size=[2, 2],
+                    strides=[1, 1],
+                    padding=[0, 0, 1, 1],
+                    layout="NHWC",
+                )
+                gv = R.sqrt(pooled)
+                R.output(gv)
+            return gv
+
+    verify(L2Pool2D, Expected)
+
+
 def test_l2_normalization():
     class L2Normalization(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(2, 4), 
dtype=tf.float32)])
@@ -758,7 +840,6 @@ def test_reverse_v2():
 
     verify(ReverseV2, Expected)
 
-
 def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, 
padding):
     class Conv2DModule(tf.Module):
         @tf.function(

Reply via email to