This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b14b023080 [Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS
tflite operator (#19345)
b14b023080 is described below
commit b14b02308049f9569e2843bd937194dc90fe3171
Author: HoYi <[email protected]>
AuthorDate: Sat Apr 11 13:42:09 2026 +0800
[Relax][Frontend][TFLite] Implement DETECTION_POSTPROCESS tflite operator
(#19345)
## Summary
- Implemented the TFLite `DETECTION_POSTPROCESS` operator conversion to
Relax IR.
- Wires up the previously unimplemented operator to support object
detection post-processing workflows in Relax.
- Relates to #18928
## Changes
- **Operator Registration**: Implemented `convert_detection_postprocess`
in `python/tvm/relax/frontend/tflite/tflite_frontend.py`.
- **Core Logic**:
- Integrated `multibox_transform_loc` for coordinate decoding and
variance scaling.
- Supported `use_regular_nms` attribute to switch between all-class NMS
and class-agnostic NMS paths.
- Leveraged `all_class_non_max_suppression` for efficient box filtering.
- **Output Alignment**: Used `topk`, `gather_nd`, and `where` operators
to ensure the output tensors (boxes, classes, scores, num_detections)
match the TFLite specification in terms of shape and layout.
- **Attribute Validation**: Added strict validation for required custom
options such as `num_classes`, `max_detections`, and scaling factors.
## Validation
Verified with linting and pre-commit hooks:
```bash
# Lint check
python -m ruff check python/tvm/relax/frontend/tflite/tflite_frontend.py
# Pre-commit checks
python -m pre_commit run --files
python/tvm/relax/frontend/tflite/tflite_frontend.py
```
Result:
- **Passed**: All static checks and style guidelines are met.
---
.../tvm/relax/frontend/tflite/tflite_flexbuffer.py | 27 ++-
.../tvm/relax/frontend/tflite/tflite_frontend.py | 223 ++++++++++++-----
python/tvm/relax/transform/legalize_ops/vision.py | 19 +-
tests/python/relax/test_frontend_tflite.py | 268 +++++++++++++++++++++
4 files changed, 460 insertions(+), 77 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
index dc8ce1df21..5152b6996e 100644
--- a/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
+++ b/python/tvm/relax/frontend/tflite/tflite_flexbuffer.py
@@ -78,12 +78,7 @@ class FlexBufferDecoder:
def indirect_jump(self, offset, byte_width):
"""Helper function to read the offset value and jump"""
- unpack_str = ""
- if byte_width == 1:
- unpack_str = "<B"
- elif byte_width == 4:
- unpack_str = "<i"
- assert unpack_str != ""
+ unpack_str = {1: "<B", 2: "<H", 4: "<I", 8: "<Q"}[byte_width]
back_jump = struct.unpack(unpack_str, self.buffer[offset : offset +
byte_width])[0]
return offset - back_jump
@@ -107,19 +102,26 @@ class FlexBufferDecoder:
# Each entry in the vector can have different datatype. Each entry is
of fixed length. The
# format is a sequence of all values followed by a sequence of
datatype of all values. For
# example - (4)(3.56)(int)(float) The end here points to the start of
the values.
+ # Each type byte contains: (type << 2) | bit_width, where bit_width
determines actual size.
values = list()
for i in range(0, size):
value_type_pos = end + size * byte_width + i
- value_type = FlexBufferType(self.buffer[value_type_pos] >> 2)
- value_bytes = self.buffer[end + i * byte_width : end + (i + 1) *
byte_width]
+ value_type_packed = self.buffer[value_type_pos]
+ value_type = FlexBufferType(value_type_packed >> 2)
+ value_bit_width = BitWidth(value_type_packed & 3)
+ value_byte_width = 1 << value_bit_width
+ value_bytes = self.buffer[end + i * byte_width : end + i *
byte_width + value_byte_width]
if value_type == FlexBufferType.FBT_BOOL:
value = bool(value_bytes[0])
elif value_type == FlexBufferType.FBT_INT:
- value = struct.unpack("<i", value_bytes)[0]
+ fmt = {1: "<b", 2: "<h", 4: "<i", 8: "<q"}[value_byte_width]
+ value = struct.unpack(fmt, value_bytes)[0]
elif value_type == FlexBufferType.FBT_UINT:
- value = struct.unpack("<I", value_bytes)[0]
+ fmt = {1: "<B", 2: "<H", 4: "<I", 8: "<Q"}[value_byte_width]
+ value = struct.unpack(fmt, value_bytes)[0]
elif value_type == FlexBufferType.FBT_FLOAT:
- value = struct.unpack("<f", value_bytes)[0]
+ fmt = {4: "<f", 8: "<d"}[value_byte_width]
+ value = struct.unpack(fmt, value_bytes)[0]
else:
raise Exception
values.append(value)
@@ -128,7 +130,8 @@ class FlexBufferDecoder:
def decode_map(self, end, byte_width, parent_byte_width):
"""Decodes the flexbuffer map and returns a dict"""
mid_loc = self.indirect_jump(end, parent_byte_width)
- map_size = struct.unpack("<i", self.buffer[mid_loc - byte_width :
mid_loc])[0]
+ size_fmt = {1: "<b", 2: "<h", 4: "<i", 8: "<q"}[byte_width]
+ map_size = struct.unpack(size_fmt, self.buffer[mid_loc - byte_width :
mid_loc])[0]
# Find keys
keys_offset = mid_loc - byte_width * 3
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index b344d9361a..16d5cb636b 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -2832,7 +2832,9 @@ class OperatorConverter:
new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in
shape_b]
max_rank = max(rank_a, rank_b)
- batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in
range(max_rank - 2)]
+ batch_shape = [
+ max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank -
2)
+ ]
a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
@@ -3225,16 +3227,49 @@ class OperatorConverter:
def convert_detection_postprocess(self, op):
"""Convert TFLite_Detection_PostProcess"""
- raise NotImplementedError(
- "DETECTION_POSTPROCESS is not wired in this frontend yet: it still
needs "
- "Relax NMS / get_valid_counts / related vision helpers (see dead
code below). "
- "relax.vision.multibox_transform_loc exists; tracking: "
- "https://github.com/apache/tvm/issues/18928"
- )
flexbuffer = op.CustomOptionsAsNumpy().tobytes()
custom_options = FlexBufferDecoder(flexbuffer).decode()
- use_regular_nms = "use_regular_nms" in custom_options and
custom_options["use_regular_nms"]
+ use_regular_nms = bool(custom_options.get("use_regular_nms", False))
+
+ required_attrs = [
+ "num_classes",
+ "max_detections",
+ "detections_per_class",
+ "nms_iou_threshold",
+ "nms_score_threshold",
+ "x_scale",
+ "y_scale",
+ "w_scale",
+ "h_scale",
+ ]
+ missing_attrs = [key for key in required_attrs if key not in
custom_options]
+ if missing_attrs:
+ raise ValueError(
+ "DETECTION_POSTPROCESS custom options miss required
attributes: "
+ + ", ".join(missing_attrs)
+ )
+
+ num_classes = int(custom_options["num_classes"])
+ max_detections = int(custom_options["max_detections"])
+ detections_per_class = int(custom_options["detections_per_class"])
+ iou_threshold = float(custom_options["nms_iou_threshold"])
+ score_threshold = float(custom_options["nms_score_threshold"])
+ x_scale = float(custom_options["x_scale"])
+ y_scale = float(custom_options["y_scale"])
+ w_scale = float(custom_options["w_scale"])
+ h_scale = float(custom_options["h_scale"])
+
+ if num_classes <= 0:
+ raise ValueError("DETECTION_POSTPROCESS requires num_classes > 0.")
+ if max_detections <= 0:
+ raise ValueError("DETECTION_POSTPROCESS requires max_detections >
0.")
+ if detections_per_class <= 0:
+ raise ValueError("DETECTION_POSTPROCESS requires
detections_per_class > 0.")
+ if not 0.0 <= iou_threshold <= 1.0:
+ raise ValueError("DETECTION_POSTPROCESS requires nms_iou_threshold
in [0, 1].")
+ if x_scale <= 0.0 or y_scale <= 0.0 or w_scale <= 0.0 or h_scale <=
0.0:
+ raise ValueError("DETECTION_POSTPROCESS requires x/y/w/h_scale to
be > 0.")
inputs = self.get_input_tensors(op)
assert len(inputs) == 3, "inputs length should be 3"
@@ -3296,67 +3331,139 @@ class OperatorConverter:
# attributes for multibox_transform_loc
multibox_transform_loc_attrs = {}
multibox_transform_loc_attrs["clip"] = False
- multibox_transform_loc_attrs["threshold"] = (
- 0.0 if use_regular_nms else custom_options["nms_score_threshold"]
- )
+ multibox_transform_loc_attrs["threshold"] = 0.0 if use_regular_nms
else score_threshold
multibox_transform_loc_attrs["variances"] = (
- 1 / custom_options["x_scale"],
- 1 / custom_options["y_scale"],
- 1 / custom_options["w_scale"],
- 1 / custom_options["h_scale"],
+ 1 / x_scale,
+ 1 / y_scale,
+ 1 / w_scale,
+ 1 / h_scale,
)
multibox_transform_loc_attrs["keep_background"] = use_regular_nms
- ret = relax.op.vision.multibox_transform_loc(
- # reshape cls_pred so it can be consumed by
- # multibox_transform_loc
- relax.op.permute_dims(cls_pred, [0, 2, 1]),
- loc_prob,
- anchor_expr,
- **multibox_transform_loc_attrs,
+ multibox_res = self.bb.emit(
+ relax.op.vision.multibox_transform_loc(
+ # reshape cls_pred so it can be consumed by
+ # multibox_transform_loc
+ relax.op.permute_dims(cls_pred, [0, 2, 1]),
+ loc_prob,
+ anchor_expr,
+ **multibox_transform_loc_attrs,
+ )
+ )
+ transformed_boxes = self.bb.emit(relax.TupleGetItem(multibox_res, 0))
+ transformed_scores = self.bb.emit(relax.TupleGetItem(multibox_res, 1))
+
+ if use_regular_nms:
+ nms_out = self.bb.emit(
+ relax.op.vision.all_class_non_max_suppression(
+ transformed_boxes,
+ transformed_scores,
+ relax.const(detections_per_class, "int64"),
+ relax.const(iou_threshold, "float32"),
+ relax.const(score_threshold, "float32"),
+ output_format="tensorflow",
+ )
+ )
+ selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0))
+ selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1))
+ num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2))
+ class_id_from_score = None
+ else:
+ topk_res = self.bb.emit(
+ relax.op.topk(transformed_scores, k=1, axis=1,
ret_type="both", largest=True)
+ )
+ max_scores = self.bb.emit(relax.TupleGetItem(topk_res, 0))
+ class_id_from_score = self.bb.emit(relax.TupleGetItem(topk_res, 1))
+ nms_out = self.bb.emit(
+ relax.op.vision.all_class_non_max_suppression(
+ transformed_boxes,
+ max_scores,
+ relax.const(max_detections, "int64"),
+ relax.const(iou_threshold, "float32"),
+ relax.const(score_threshold, "float32"),
+ output_format="tensorflow",
+ )
+ )
+ selected_indices = self.bb.emit(relax.TupleGetItem(nms_out, 0))
+ selected_scores = self.bb.emit(relax.TupleGetItem(nms_out, 1))
+ num_detections = self.bb.emit(relax.TupleGetItem(nms_out, 2))
+ class_id_from_score = relax.op.squeeze(class_id_from_score,
axis=[1])
+
+ selected_score_slots = selected_scores.struct_info.shape.values[1]
+ selected_detection_positions = relax.op.expand_dims(
+ relax.op.arange(selected_score_slots, dtype="int64"), axis=0
+ )
+ selected_valid_detection_mask = relax.op.less(
+ selected_detection_positions, relax.op.expand_dims(num_detections,
axis=1)
+ )
+ masked_selected_scores = relax.op.where(
+ selected_valid_detection_mask,
+ selected_scores,
+ relax.const(-1.0, "float32"),
+ )
+ topk_scores_res = self.bb.emit(
+ relax.op.topk(
+ masked_selected_scores, k=max_detections, axis=1,
ret_type="both", largest=True
+ )
+ )
+ detection_scores = self.bb.emit(relax.TupleGetItem(topk_scores_res, 0))
+ top_positions = self.bb.emit(relax.TupleGetItem(topk_scores_res, 1))
+ num_detections = relax.op.minimum(
+ num_detections, relax.const([max_detections], dtype="int64")
+ )
+ detection_positions = relax.op.expand_dims(
+ relax.op.arange(max_detections, dtype="int64"), axis=0
+ )
+ valid_detection_mask = relax.op.less(
+ detection_positions, relax.op.expand_dims(num_detections, axis=1)
+ )
+ top_positions_expanded = relax.op.expand_dims(top_positions, axis=2)
+ top_positions_for_pairs = relax.op.repeat(top_positions_expanded, 2,
axis=2)
+ top_index_pairs = relax.op.gather_elements(
+ selected_indices, top_positions_for_pairs, axis=1
+ )
+ top_box_ids = relax.op.squeeze(
+ relax.op.strided_slice(top_index_pairs, axes=[2], begin=[1],
end=[2]),
+ axis=[2],
+ )
+ top_box_ids_for_gather =
relax.op.expand_dims(relax.op.astype(top_box_ids, "int64"), axis=2)
+ detection_boxes = relax.op.gather_nd(
+ transformed_boxes, top_box_ids_for_gather, batch_dims=1
)
if use_regular_nms:
- # box coordinates need to be converted from ltrb to (ymin, xmin,
ymax, xmax)
- _, transformed_boxes = relax.op.split(ret[0], (2,), axis=2)
- box_l, box_t, box_r, box_b = relax.op.split(transformed_boxes, 4,
axis=2)
- transformed_boxes = relax.op.concat([box_t, box_l, box_b, box_r],
axis=2)
-
- return relax.op.vision.regular_non_max_suppression(
- boxes=transformed_boxes,
- scores=cls_pred,
-
max_detections_per_class=custom_options["detections_per_class"],
- max_detections=custom_options["max_detections"],
- num_classes=custom_options["num_classes"],
- iou_threshold=custom_options["nms_iou_threshold"],
- score_threshold=custom_options["nms_score_threshold"],
+ detection_classes = relax.op.squeeze(
+ relax.op.strided_slice(top_index_pairs, axes=[2], begin=[0],
end=[1]),
+ axis=[2],
+ )
+ detection_classes = relax.op.astype(detection_classes, "int32")
+ else:
+ top_box_ids_for_class = relax.op.expand_dims(
+ relax.op.astype(top_box_ids, "int64"), axis=2
+ )
+ detection_classes = relax.op.gather_nd(
+ class_id_from_score, top_box_ids_for_class, batch_dims=1
)
- # attributes for non_max_suppression
- non_max_suppression_attrs = {}
- non_max_suppression_attrs["return_indices"] = False
- non_max_suppression_attrs["iou_threshold"] =
custom_options["nms_iou_threshold"]
- non_max_suppression_attrs["force_suppress"] = True
- non_max_suppression_attrs["top_k"] = anchor_boxes
- non_max_suppression_attrs["max_output_size"] =
custom_options["max_detections"]
- non_max_suppression_attrs["invalid_to_bottom"] = False
-
- ret = relax.op.vision.non_max_suppression(
- ret[0], ret[1], ret[1], **non_max_suppression_attrs
+ detection_mask = relax.op.expand_dims(valid_detection_mask, axis=2)
+ detection_boxes = relax.op.where(
+ detection_mask,
+ detection_boxes,
+ relax.op.zeros((batch_size, max_detections, 4), dtype="float32"),
+ )
+ detection_classes = relax.op.where(
+ valid_detection_mask,
+ detection_classes,
+ relax.op.zeros((batch_size, max_detections), dtype="int32"),
)
- ret = relax.op.vision.get_valid_counts(ret, 0)
- valid_count = ret[0]
- # keep only the top 'max_detections' rows
- ret = relax.op.strided_slice(
- ret[1], [0, 0, 0], [batch_size, custom_options["max_detections"],
6]
+ detection_scores = relax.op.where(
+ valid_detection_mask,
+ detection_scores,
+ relax.op.zeros((batch_size, max_detections), dtype="float32"),
)
- # the output needs some reshaping to match tflite
- ret = relax.op.split(ret, 6, axis=2)
- cls_ids = relax.op.reshape(ret[0], [batch_size, -1])
- scores = relax.op.reshape(ret[1], [batch_size, -1])
- boxes = relax.op.concat([ret[3], ret[2], ret[5], ret[4]], axis=2)
- ret = relax.Tuple(relax.Tuple([boxes, cls_ids, scores, valid_count]),
size=4)
- return ret
+ detection_classes = relax.op.astype(detection_classes, "float32")
+ num_detections = relax.op.astype(num_detections, "float32")
+ return relax.Tuple([detection_boxes, detection_classes,
detection_scores, num_detections])
def convert_nms_v5(self, op):
"""Convert TFLite NonMaxSuppressionV5"""
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
index 7d8586ab52..c515fc8fe8 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -32,11 +32,15 @@ def _all_class_non_max_suppression(block_builder:
BlockBuilder, call: Call) -> E
Returns
-------
- result : Tuple[Tensor, Tensor]
- A tuple of (trimmed_indices, num_total_detections) where:
- - trimmed_indices: Tensor of shape (num_total_detections, 3)
containing only
- valid detection indices (batch_id, class_id, box_id)
- - num_total_detections: Tensor of shape (1,) with the count of valid
detections
+ result : Expr
+ The legalized NMS result.
+
+ - For ONNX output format, returns a tuple of
+ `(trimmed_indices, num_total_detections)`, where `trimmed_indices`
+ contains only valid detection indices.
+ - For TensorFlow output format, returns the TOPI result directly to
+ preserve the `(selected_indices, selected_scores, num_detections)`
+ layout expected by the Relax op.
"""
boxes = call.args[0]
scores = call.args[1]
@@ -69,8 +73,9 @@ def _all_class_non_max_suppression(block_builder:
BlockBuilder, call: Call) -> E
output_format,
)
- # Dynamic output trimming using dynamic_strided_slice
- # Extract selected_indices and num_total_detections from the NMS result
+ if output_format == "tensorflow":
+ return nms_result
+
selected_indices = block_builder.emit(TupleGetItem(nms_result, 0))
num_total_detections = block_builder.emit(TupleGetItem(nms_result, 1))
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index 02282f3d41..c237d4db8f 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -27,6 +27,7 @@ import tflite.Model
from tensorflow.keras import applications as keras_app
import tvm
+import tvm.relax.frontend.tflite.tflite_frontend as tflite_frontend
from tvm import relax
from tvm.relax.frontend.tflite import from_tflite
from tvm.script.parser import ir as I
@@ -1082,6 +1083,142 @@ def _build_nms_v5_mod(num_boxes, max_output_size,
iou_threshold, score_threshold
return mod, instance.func
+class _StubDetectionPostprocessTensor:
+ def __init__(self, shape, name):
+ self._shape = list(shape)
+ self._name = name
+
+ def Shape(self, index):
+ return self._shape[index]
+
+ def Name(self):
+ return self._name
+
+ def Type(self):
+ return 0
+
+
+class _StubDetectionPostprocessOp:
+ def __init__(self, custom_options):
+ self._custom_options =
_encode_detection_postprocess_custom_options(custom_options)
+
+ def CustomOptionsAsNumpy(self):
+ return np.frombuffer(self._custom_options, dtype="uint8")
+
+
+_DETECTION_POSTPROCESS_ANCHORS = np.array(
+ [
+ [0.5, 0.5, 1.0, 1.0],
+ [0.5, 0.2, 1.0, 1.0],
+ [0.1, 0.1, 0.5, 0.5],
+ [0.8, 0.8, 0.2, 0.2],
+ ],
+ dtype="float32",
+)
+
+
+def _encode_detection_postprocess_custom_options(custom_options):
+ from flatbuffers import flexbuffers
+
+ builder = flexbuffers.Builder()
+ with builder.Map():
+ for key, value in custom_options.items():
+ if isinstance(value, bool):
+ builder.Bool(key, value)
+ elif isinstance(value, int):
+ builder.Int(key, value)
+ else:
+ builder.Float(key, float(value))
+ return bytes(builder.Finish())
+
+
+def _make_detection_postprocess_tensor_wrapper(tensor_idx, shape, name):
+ return tflite_frontend.TensorWrapper(
+ tensor_idx,
+ _StubDetectionPostprocessTensor(shape, name),
+ None,
+ )
+
+
+def _build_detection_postprocess_mod(
+ *,
+ num_classes=1,
+ max_detections=4,
+ detections_per_class=4,
+ use_regular_nms=False,
+ nms_iou_threshold=0.5,
+ nms_score_threshold=0.3,
+ x_scale=10.0,
+ y_scale=10.0,
+ w_scale=5.0,
+ h_scale=5.0,
+ batch_size=2,
+ num_anchors=4,
+ input_num_classes=None,
+):
+ custom_options = {
+ "num_classes": num_classes,
+ "max_detections": max_detections,
+ "detections_per_class": detections_per_class,
+ "nms_iou_threshold": nms_iou_threshold,
+ "nms_score_threshold": nms_score_threshold,
+ "x_scale": x_scale,
+ "y_scale": y_scale,
+ "w_scale": w_scale,
+ "h_scale": h_scale,
+ "use_regular_nms": use_regular_nms,
+ }
+ return _convert_detection_postprocess_with_options(
+ custom_options,
+ batch_size=batch_size,
+ num_anchors=num_anchors,
+ num_classes=num_classes,
+ input_num_classes=input_num_classes,
+ )
+
+
+def _convert_detection_postprocess_with_options(
+ custom_options,
+ *,
+ batch_size=2,
+ num_anchors=4,
+ num_classes=1,
+ input_num_classes=None,
+ build_module=True,
+):
+ input_num_classes = num_classes if input_num_classes is None else
input_num_classes
+ loc = relax.Var("loc", relax.TensorStructInfo((batch_size, num_anchors,
4), "float32"))
+ cls = relax.Var(
+ "cls", relax.TensorStructInfo((batch_size, num_anchors,
input_num_classes), "float32")
+ )
+ inputs = [
+ _make_detection_postprocess_tensor_wrapper(0, (batch_size,
num_anchors, 4), "loc"),
+ _make_detection_postprocess_tensor_wrapper(
+ 1, (batch_size, num_anchors, input_num_classes), "cls"
+ ),
+ _make_detection_postprocess_tensor_wrapper(2, (num_anchors, 4),
"anchors"),
+ ]
+ converter =
tflite_frontend.OperatorConverter.__new__(tflite_frontend.OperatorConverter)
+ converter.bb = relax.BlockBuilder()
+ converter.exp_tab = tflite_frontend.ExprTable()
+ converter.get_input_tensors = lambda op: inputs
+ converter.get_expr = lambda tensor_idx: {0: loc, 1: cls}[tensor_idx]
+ converter.get_tensor_value = (
+ lambda tensor: _DETECTION_POSTPROCESS_ANCHORS if tensor.tensor_idx ==
2 else None
+ )
+ converter.get_tensor_type_str = lambda tensor_type: "float32"
+ op = _StubDetectionPostprocessOp(custom_options)
+ if not build_module:
+ return converter.convert_detection_postprocess(op)
+ bb = converter.bb
+ with bb.function("main", [loc, cls]):
+ with bb.dataflow():
+ output = converter.convert_detection_postprocess(op)
+ gv = bb.emit_output(output)
+ bb.emit_func_output(gv)
+ return bb.get()
+
+
def _make_valid_boxes(rng, n):
"""Generate n random boxes with y1<=y2, x1<=x2 using the given RNG."""
raw = rng.random((n, 4), dtype=np.float32)
@@ -1207,6 +1344,137 @@ def test_nms_v5_ir():
assert f"R.Tensor(({max_output_size},)" in ir
+_DETECTION_POSTPROCESS_SMOKE_CASES = [
+ pytest.param(
+ {
+ "num_classes": 2,
+ "input_num_classes": 3,
+ "max_detections": 2,
+ "detections_per_class": 2,
+ "use_regular_nms": False,
+ "nms_iou_threshold": 0.5,
+ "nms_score_threshold": 0.5,
+ "batch_size": 1,
+ "num_anchors": 4,
+ },
+ 2,
+ False,
+ id="basic_fast_nms",
+ ),
+ pytest.param(
+ {
+ "num_classes": 2,
+ "input_num_classes": 3,
+ "max_detections": 3,
+ "detections_per_class": 2,
+ "use_regular_nms": True,
+ "nms_iou_threshold": 0.45,
+ "nms_score_threshold": 0.25,
+ "batch_size": 2,
+ "num_anchors": 4,
+ },
+ 1,
+ True,
+ id="regular_nms_multi_batch",
+ ),
+]
+
+
+_DETECTION_POSTPROCESS_SHAPE_CASES = [
+ pytest.param(
+ {
+ "num_classes": 2,
+ "input_num_classes": 5,
+ "max_detections": 2,
+ "detections_per_class": 2,
+ "use_regular_nms": False,
+ "nms_iou_threshold": 0.5,
+ "nms_score_threshold": 0.5,
+ "batch_size": 1,
+ "num_anchors": 4,
+ },
+ id="wider_input_classes",
+ ),
+ pytest.param(
+ {
+ "num_classes": 2,
+ "input_num_classes": 3,
+ "max_detections": 4,
+ "detections_per_class": 4,
+ "use_regular_nms": False,
+ "nms_iou_threshold": 0.5,
+ "nms_score_threshold": 0.5,
+ "batch_size": 1,
+ "num_anchors": 4,
+ },
+ id="larger_max_detections",
+ ),
+]
+
+
[email protected](
+ "build_kwargs,expected_topk_count,expected_keep_background",
+ _DETECTION_POSTPROCESS_SMOKE_CASES,
+)
+def test_detection_postprocess_smoke(
+ build_kwargs, expected_topk_count, expected_keep_background
+):
+ mod = _build_detection_postprocess_mod(**build_kwargs)
+ ir = mod.script()
+
+ assert "R.vision.multibox_transform_loc" in ir
+ assert "R.vision.all_class_non_max_suppression" in ir
+ assert 'output_format="tensorflow"' in ir
+ assert "R.where" in ir
+ assert "R.gather_elements" in ir
+ assert "R.gather_nd" in ir
+ assert ir.count("R.topk(") == expected_topk_count
+ assert f"keep_background={expected_keep_background}" in ir
+ expected_batch = build_kwargs["batch_size"]
+ expected_max_detections = build_kwargs["max_detections"]
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((expected_batch,
expected_max_detections, 4), "float32"),
+ relax.TensorStructInfo((expected_batch,
expected_max_detections), "float32"),
+ relax.TensorStructInfo((expected_batch,
expected_max_detections), "float32"),
+ relax.TensorStructInfo((expected_batch,), "float32"),
+ ]
+ ),
+ )
+
+ legalized = relax.transform.LegalizeOps()(mod)
+ legalized_ir = legalized.script()
+ assert "R.vision.all_class_non_max_suppression(" not in legalized_ir
+ assert "R.call_tir(" in legalized_ir
+ tvm.ir.assert_structural_equal(legalized["main"].ret_struct_info,
mod["main"].ret_struct_info)
+
+
[email protected]("build_kwargs", _DETECTION_POSTPROCESS_SHAPE_CASES)
+def test_detection_postprocess_shape_variations(build_kwargs):
+ mod = _build_detection_postprocess_mod(**build_kwargs)
+ batch_size = build_kwargs["batch_size"]
+ num_anchors = build_kwargs["num_anchors"]
+ input_num_classes = build_kwargs["input_num_classes"]
+ max_detections = build_kwargs["max_detections"]
+
+ tvm.ir.assert_structural_equal(
+ mod["main"].params[1].struct_info,
+ relax.TensorStructInfo((batch_size, num_anchors, input_num_classes),
"float32"),
+ )
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TupleStructInfo(
+ [
+ relax.TensorStructInfo((batch_size, max_detections, 4),
"float32"),
+ relax.TensorStructInfo((batch_size, max_detections),
"float32"),
+ relax.TensorStructInfo((batch_size, max_detections),
"float32"),
+ relax.TensorStructInfo((batch_size,), "float32"),
+ ]
+ ),
+ )
+
def _make_resize_expected(
input_shape, output_size, method, coordinate_transformation_mode,
rounding_method
):