This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 28aac4744d [Relax][TFLite] Add NON_MAX_SUPPRESSION_V5 support (#19349)
28aac4744d is described below
commit 28aac4744dbc1153b8f9f937616294720533026e
Author: Bana <[email protected]>
AuthorDate: Sat Apr 4 22:26:47 2026 +0300
[Relax][TFLite] Add NON_MAX_SUPPRESSION_V5 support (#19349)
Partially resolves #18928
to run the tests:
```bash
pytest tests/python/relax/test_frontend_tflite.py -v -k "nms"
```
<img width="728" height="44" alt="image"
src="https://github.com/user-attachments/assets/fbd4092a-bc7f-459d-8d51-b0ef926b241f"
/>
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 36 +++--
tests/python/relax/test_frontend_tflite.py | 162 +++++++++++++++++++++
2 files changed, 182 insertions(+), 16 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 435180dfee..1a437093a7 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3339,16 +3339,12 @@ class OperatorConverter:
def convert_nms_v5(self, op):
"""Convert TFLite NonMaxSuppressionV5"""
- #
https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5
- raise NotImplementedError(
- "NON_MAX_SUPPRESSION_V5 is not wired in this frontend yet (needs
get_valid_counts, "
- "non_max_suppression, etc.). Tracking:
https://github.com/apache/tvm/issues/18928"
- )
-
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 6, "input tensor length should be 6"
- boxes = self.get_expr(input_tensors[0].tensor_idx)
- scores = self.get_expr(input_tensors[1].tensor_idx)
+
+ boxes = self.get_tensor_expr(input_tensors[0])
+ scores = self.get_tensor_expr(input_tensors[1])
+
max_output_size = self.get_tensor_value(input_tensors[2])
iou_threshold = self.get_tensor_value(input_tensors[3])
score_threshold = self.get_tensor_value(input_tensors[4])
@@ -3374,13 +3370,16 @@ class OperatorConverter:
"It is soft_nms when soft_nms_sigma != 0, which is not
supported!"
)
- scores_expand = relax.op.expand_dims(scores, axis=-1, num_newaxis=1)
- data = relax.op.concat([scores_expand, boxes], -1)
- data = relax.op.expand_dims(data, axis=0, num_newaxis=1)
+ scores_expand = relax.op.expand_dims(scores, axis=-1)
+ data = relax.op.concat([scores_expand, boxes], axis=-1)
+ data = relax.op.expand_dims(data, axis=0)
- count, data, indices = relax.op.vision.get_valid_counts(
+ valid_counts_ret = relax.op.vision.get_valid_counts(
data, score_threshold=score_threshold, id_index=-1, score_index=0
)
+ count = valid_counts_ret[0]
+ data = valid_counts_ret[1]
+ indices = valid_counts_ret[2]
nms_ret = relax.op.vision.non_max_suppression(
data=data,
@@ -3398,10 +3397,15 @@ class OperatorConverter:
)
selected_indices = relax.op.squeeze(nms_ret[0], axis=[0])
- selected_indices = relax.op.strided_slice(selected_indices, [0],
[max_output_size])
- valide_num = relax.op.squeeze(nms_ret[1], axis=[1])
- selected_scores = relax.op.take(scores, selected_indices, axis=0)
- out = _expr.TupleWrapper(_expr.Tuple([selected_indices,
selected_scores, valide_num]), 3)
+ selected_indices = relax.op.strided_slice(selected_indices, axes=[0],
begin=[0], end=[max_output_size])
+ num_valid = relax.op.reshape(nms_ret[1], [])
+
+ # Clamp out-of-bound padded indices to prevent take() crash.
+ num_boxes = int(self.get_tensor_shape(input_tensors[0])[0])
+ safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes -
1)
+ selected_scores = relax.op.take(scores, safe_indices, axis=0)
+
+ out = relax.Tuple([selected_indices, selected_scores, num_valid])
return out
def convert_expand_dims(self, op):
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 05ca3e4270..d5149fb161 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1021,5 +1021,167 @@ def test_batch_matmul_adj():
verify(BatchMatMulAdj, Expected)
+def _verify_nms_v5(mod, tf_func, boxes_np, scores_np):
+ """E2E verify for NMS: only run on nightly, compare valid outputs only."""
+ if "CI_ENV_NIGHTLY" not in os.environ:
+ return
+
+ tf_indices, tf_scores, tf_valid = tf_func(
+ tf.constant(boxes_np), tf.constant(scores_np)
+ )
+ n_valid = int(tf_valid.numpy())
+
+ tgt = tvm.target.Target("llvm")
+ ex = tvm.compile(mod, tgt)
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+ vm.set_input("main", boxes_np, scores_np)
+ vm.invoke_stateful("main")
+ tvm_indices, tvm_scores, tvm_valid = vm.get_outputs("main")
+
+ assert int(tvm_valid.numpy()) == n_valid
+ np.testing.assert_array_equal(
+ tf_indices.numpy()[:n_valid],
+ tvm_indices.numpy()[:n_valid],
+ )
+ np.testing.assert_allclose(
+ tf_scores.numpy()[:n_valid],
+ tvm_scores.numpy()[:n_valid],
+ rtol=1e-5,
+ atol=1e-5,
+ )
+
+
+def _build_nms_v5_mod(num_boxes, max_output_size, iou_threshold,
score_threshold):
+ """Convert a NonMaxSuppressionV5 TFLite model to a Relax module.
+
+ Scalar params must be Python literals (not tf.constant) so TFLite can
+ statically infer output shapes during conversion.
+ """
+
+ class NMSv5Module(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(num_boxes, 4), dtype=tf.float32),
+ tf.TensorSpec(shape=(num_boxes,), dtype=tf.float32),
+ ]
+ )
+ def func(self, boxes, scores):
+ indices, out_scores, valid = tf.raw_ops.NonMaxSuppressionV5(
+ boxes=boxes,
+ scores=scores,
+ max_output_size=max_output_size,
+ iou_threshold=iou_threshold,
+ score_threshold=score_threshold,
+ soft_nms_sigma=0.0,
+ pad_to_max_output_size=True,
+ )
+ return indices, out_scores, valid
+
+ instance = NMSv5Module()
+ cf = instance.func.get_concrete_function()
+ mod = _get_mod_from_cfunc(cf)
+ return mod, instance.func
+
+
+def _make_valid_boxes(rng, n):
+ """Generate n random boxes with y1<=y2, x1<=x2 using the given RNG."""
+ raw = rng.random((n, 4), dtype=np.float32)
+ return np.stack(
+ [
+ np.minimum(raw[:, 0], raw[:, 2]), # y1
+ np.minimum(raw[:, 1], raw[:, 3]), # x1
+ np.maximum(raw[:, 0], raw[:, 2]), # y2
+ np.maximum(raw[:, 1], raw[:, 3]), # x2
+ ],
+ axis=1,
+ ).astype(np.float32)
+
+
+_NMS_V5_CASES = [
+ pytest.param(
+ 6, 3, 0.5, 0.0,
+ np.array([
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.0, 1.0, 1.0],
+ [0.0, 0.1, 1.0, 1.1],
+ [0.0, 0.0, 1.0, 0.9],
+ [0.5, 0.5, 1.5, 1.5],
+ [0.0, 0.0, 0.3, 0.3],
+ ], dtype=np.float32),
+ np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32),
+ id="basic",
+ ),
+ pytest.param(
+ 8, 4, 0.5, 0.4,
+ _make_valid_boxes(np.random.default_rng(42), 8),
+ np.random.default_rng(42).random(8, dtype=np.float32),
+ id="score_threshold",
+ ),
+ pytest.param(
+ 5, 3, 0.5, 0.99,
+ _make_valid_boxes(np.random.default_rng(0), 5),
+ np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32),
+ id="all_suppressed",
+ ),
+ pytest.param(
+ 6, 6, 0.1, 0.0,
+ np.array([
+ [0.0, 0.0, 0.4, 0.4],
+ [0.5, 0.5, 0.9, 0.9],
+ [0.1, 0.1, 0.5, 0.5],
+ [0.6, 0.6, 1.0, 1.0],
+ [0.0, 0.5, 0.4, 0.9],
+ [0.5, 0.0, 0.9, 0.4],
+ ], dtype=np.float32),
+ np.array([0.9, 0.85, 0.7, 0.65, 0.6, 0.55], dtype=np.float32),
+ id="iou_threshold",
+ ),
+ pytest.param(
+ 4, 10, 0.5, 0.0,
+ np.array([
+ [0.0, 0.0, 0.3, 0.3],
+ [0.5, 0.5, 0.8, 0.8],
+ [0.1, 0.1, 0.4, 0.4],
+ [0.6, 0.6, 0.9, 0.9],
+ ], dtype=np.float32),
+ np.array([0.9, 0.85, 0.7, 0.65], dtype=np.float32),
+ id="max_output_size_larger_than_boxes",
+ ),
+]
+
+
[email protected](
+ "num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores",
+ _NMS_V5_CASES,
+)
+def test_nms_v5(num_boxes, max_output_size, iou_threshold, score_threshold,
boxes, scores):
+ """NON_MAX_SUPPRESSION_V5: conversion smoke test + E2E correctness
(nightly only)."""
+ mod, tf_func = _build_nms_v5_mod(num_boxes, max_output_size,
iou_threshold, score_threshold)
+ _verify_nms_v5(mod, tf_func, boxes, scores)
+
+
+def test_nms_v5_ir():
+ """Verify the emitted Relax IR has correct structure for
NON_MAX_SUPPRESSION_V5."""
+ num_boxes = 6
+ max_output_size = 3
+ mod, _ = _build_nms_v5_mod(
+ num_boxes=num_boxes,
+ max_output_size=max_output_size,
+ iou_threshold=0.5,
+ score_threshold=0.0,
+ )
+
+ ir = mod.script()
+
+ # Validate correct sorting/id indices are passed to valid_counts
+ assert "score_index=0" in ir
+ assert "id_index=-1" in ir
+ # NMS size limit validation
+ assert f"max_output_size={max_output_size}" in ir
+ # Valid output shape must be () statically
+ assert 'R.Tensor((), dtype="int32")' in ir
+ # Bounding boxes / scores tensor bounds checks
+ assert f"R.Tensor(({max_output_size},)" in ir
+
if __name__ == "__main__":
pytest.main(["-s", __file__])