This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5efa4b72dc [Test][TFLite] Add unit tests for `PRELU` (#19402)
5efa4b72dc is described below
commit 5efa4b72dca7c33b5e4f6658720ba2e15dae28be
Author: Felix Hirwa Nshuti <[email protected]>
AuthorDate: Tue Apr 14 07:55:38 2026 +0200
[Test][TFLite] Add unit tests for `PRELU` (#19402)
This PR adds unit test coverage for `PRELU` activation
in the Relax TFLite frontend, as part of
https://github.com/apache/tvm/issues/18971
- Added unit test for `PRELU` and
Enabled converter to handle alpha broadcasting more cleanly across
constant and expression-backed alpha inputs.
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 24 ++------
tests/python/relax/test_frontend_tflite.py | 70 ++++++++++++++++++++--
2 files changed, 70 insertions(+), 24 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 9c99e98e01..584a65e1f4 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -2259,7 +2259,7 @@ class OperatorConverter:
# Create axes list for all dimensions being sliced
axes = list(range(input_tensor_rank))
begin = [int(v) for v in begin]
- end = [int(v) for v in end]
+ end = [int(v) for v in end]
out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end)
return out
@@ -2840,9 +2840,7 @@ class OperatorConverter:
new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in
shape_b]
max_rank = max(rank_a, rank_b)
- batch_shape = [
- max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank -
2)
- ]
+ batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in
range(max_rank - 2)]
a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
@@ -2987,21 +2985,11 @@ class OperatorConverter:
input_tensor = input_tensors[0]
alpha_tensor = input_tensors[1]
- if self.has_expr(alpha_tensor.tensor_idx):
- alpha_expr = self.get_expr(alpha_tensor.tensor_idx)
- else:
- alpha_tensor_type = alpha_tensor.tensor.Type()
- alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
- alpha_expr = self.exp_tab.new_const(
- self.get_tensor_value(alpha_tensor),
- dtype=alpha_tensor_type_str,
- source_name=alpha_tensor.tensor.Name(),
- )
- in_expr = self.get_expr(input_tensor.tensor_idx)
data_shape = to_int_list(self.get_tensor_shape(input_tensor))
-
- alpha_expr = relax.op.broadcast_to(alpha_expr, data_shape)
- alpha_expr = relax.op.reshape(alpha_expr, [-1])
+ alpha_expr = self.get_tensor_expr(alpha_tensor)
+ alpha_expr = self.bb.normalize(relax.op.broadcast_to(alpha_expr,
data_shape))
+ alpha_expr = self.bb.normalize(relax.op.reshape(alpha_expr, [-1]))
+ in_expr = self.get_tensor_expr(input_tensor)
out = relax.op.nn.prelu(_op.reshape(in_expr, [-1]), alpha_expr, axis=0)
out = relax.op.reshape(out, data_shape)
return out
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 37a6b9cd93..bf6ef8e819 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -322,6 +322,7 @@ def test_tile(input_shape, multiples, dtype):
verify(Tile)
+
def test_concat_v2():
class ConcatV2(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 30),
dtype=tf.float32)])
@@ -804,6 +805,7 @@ def test_transpose_conv():
verify(TransposeConv)
+
def test_l2_pool2d():
class L2Pool2D(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2),
dtype=tf.float32)])
@@ -815,9 +817,9 @@ def test_l2_pool2d():
@I.ir_module
class Expected:
@R.function
- def main(
- data: R.Tensor((1, 8, 8, 2), dtype="float32")
- ) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
+ def main(data: R.Tensor((1, 8, 8, 2), dtype="float32")) -> R.Tensor(
+ (1, 8, 8, 2), dtype="float32"
+ ):
R.func_attr({"num_input": 1})
with R.dataflow():
squared = R.power(data, R.const(2.0, "float32"))
@@ -883,6 +885,7 @@ def test_reverse_v2():
verify(ReverseV2, Expected)
+
def _make_conv2d_module(data_shape, kernel_shape, data_format, strides,
padding):
class Conv2DModule(tf.Module):
@tf.function(
@@ -1590,9 +1593,7 @@ _DETECTION_POSTPROCESS_SHAPE_CASES = [
"build_kwargs,expected_topk_count,expected_keep_background",
_DETECTION_POSTPROCESS_SMOKE_CASES,
)
-def test_detection_postprocess_smoke(
- build_kwargs, expected_topk_count, expected_keep_background
-):
+def test_detection_postprocess_smoke(build_kwargs, expected_topk_count,
expected_keep_background):
mod = _build_detection_postprocess_mod(**build_kwargs)
ir = mod.script()
@@ -1649,6 +1650,7 @@ def
test_detection_postprocess_shape_variations(build_kwargs):
),
)
+
def _make_resize_expected(
input_shape, output_size, method, coordinate_transformation_mode,
rounding_method
):
@@ -2109,5 +2111,61 @@ def test_relu_n1_to_1():
verify(ReLU_N1_to_1, Expected)
[email protected](
+ "shared_axes",
+ [
+ pytest.param([1, 2], id="channelwise_shared_axes"),
+ pytest.param([1, 2, 3], id="scalar_shared_axes"),
+ pytest.param(None, id="elementwise_no_shared_axes"),
+ ],
+)
+def test_prelu(shared_axes):
+ inputs = tf.keras.Input(shape=(4, 4, 3), batch_size=1, dtype=tf.float32)
+ prelu_kwargs = {
+ "alpha_initializer": tf.initializers.constant(0.25),
+ }
+ if shared_axes is not None:
+ prelu_kwargs["shared_axes"] = shared_axes
+ outputs = tf.keras.layers.PReLU(**prelu_kwargs)(inputs)
+ keras_model = tf.keras.Model(inputs, outputs)
+
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+ tflite_model_buf = converter.convert()
+ if hasattr(tflite.Model, "Model"):
+ tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+ else:
+ tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+
+ mod = from_tflite(tflite_model)
+ mod["main"] = mod["main"].without_attr("params")
+
+ if shared_axes == [1, 2]:
+ alpha_const = np.full((1, 1, 3), 0.25, dtype=np.float32)
+ elif shared_axes == [1, 2, 3]:
+ alpha_const = np.full((1, 1, 1), 0.25, dtype=np.float32)
+ else:
+ alpha_const = np.full((4, 4, 3), 0.25, dtype=np.float32)
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((1, 4, 4, 3), dtype="float32")) -> R.Tensor(
+ (1, 4, 4, 3), dtype="float32"
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.broadcast_to(
+ R.const(alpha_const), R.shape([1, 4, 4, 3])
+ )
+ lv1: R.Tensor((48,), dtype="float32") = R.reshape(x,
R.shape([48]))
+ lv2: R.Tensor((48,), dtype="float32") = R.reshape(lv,
R.shape([48]))
+ lv3: R.Tensor((48,), dtype="float32") = R.nn.prelu(lv1, lv2,
axis=0)
+ gv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.reshape(lv3,
R.shape([1, 4, 4, 3]))
+ R.output(gv)
+ return gv
+
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
pytest.main(["-s", __file__])