This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3eb86f78ed [Relax][TOPI] Add relax.vision.multibox_transform_loc for 
SSD/TFLite box decode (#18942)
3eb86f78ed is described below

commit 3eb86f78ed1bdb2111118924d16e92bf2d1b054d
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Sat Mar 28 12:20:22 2026 +0800

    [Relax][TOPI] Add relax.vision.multibox_transform_loc for SSD/TFLite box 
decode (#18942)
    
    Introduce relax.vision.multibox_transform_loc with
    MultiboxTransformLocAttrs: decode center-size offsets against ltrb
    priors, softmax on class logits, and optional clip, threshold masking,
    and background score zeroing. Register the C++ op with FInferStructInfo
    checks for shapes and dtypes (including batch and 4*N consistency).
    Legalize to topi.vision.multibox_transform_loc.
    
    Add tests for struct inference, invalid inputs, Legalize+e2e on LLVM,
    attribute branches, and TVMScript roundtrip. Add a standalone numpy
    reference under topi/testing (not exported from tvm.topi.testing to
    avoid pulling scipy).
    
    Update TFLite frontend NotImplementedError text for
    DETECTION_POSTPROCESS and NON_MAX_SUPPRESSION_V5 to note multibox is
    available and link tracking issue #18928.
---
 include/tvm/relax/attrs/vision.h                   |  24 ++
 .../tvm/relax/frontend/tflite/tflite_frontend.py   |  12 +-
 python/tvm/relax/op/__init__.py                    |   2 +-
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/relax/op/vision/__init__.py             |   1 +
 .../tvm/relax/op/vision/multibox_transform_loc.py  |  85 +++++++
 python/tvm/relax/transform/legalize_ops/vision.py  |  24 ++
 python/tvm/topi/vision/__init__.py                 |   1 +
 python/tvm/topi/vision/multibox_transform_loc.py   | 121 +++++++++
 src/relax/op/vision/multibox_transform_loc.cc      | 204 +++++++++++++++
 src/relax/op/vision/multibox_transform_loc.h       |  42 +++
 tests/python/relax/test_op_vision.py               | 283 +++++++++++++++++++++
 .../relax/test_tvmscript_parser_op_vision.py       |  42 +++
 13 files changed, 839 insertions(+), 7 deletions(-)

diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 59a1dd7314..4e3351bb90 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -73,6 +73,30 @@ struct ROIAlignAttrs : public 
AttrsNodeReflAdapter<ROIAlignAttrs> {
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", 
ROIAlignAttrs, BaseAttrsNode);
 };  // struct ROIAlignAttrs
 
+/*! \brief Attributes for multibox_transform_loc (SSD / TFLite-style box 
decode). */
+struct MultiboxTransformLocAttrs : public 
AttrsNodeReflAdapter<MultiboxTransformLocAttrs> {
+  bool clip;
+  double threshold;
+  ffi::Array<double> variances;
+  bool keep_background;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<MultiboxTransformLocAttrs>()
+        .def_ro("clip", &MultiboxTransformLocAttrs::clip,
+                "Clip decoded ymin,xmin,ymax,xmax to [0,1].")
+        .def_ro("threshold", &MultiboxTransformLocAttrs::threshold,
+                "After softmax, zero scores strictly below this value.")
+        .def_ro("variances", &MultiboxTransformLocAttrs::variances,
+                "(x,y,w,h) scales = TFLite 
1/x_scale,1/y_scale,1/w_scale,1/h_scale on "
+                "encodings. Very large w/h scales can overflow exp in decode.")
+        .def_ro("keep_background", &MultiboxTransformLocAttrs::keep_background,
+                "If false, force output scores[:,0,:] to 0 (background 
class).");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.MultiboxTransformLocAttrs",
+                                    MultiboxTransformLocAttrs, BaseAttrsNode);
+};  // struct MultiboxTransformLocAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 5c73af18ad..435180dfee 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -3205,9 +3205,10 @@ class OperatorConverter:
     def convert_detection_postprocess(self, op):
         """Convert TFLite_Detection_PostProcess"""
         raise NotImplementedError(
-            "DETECTION_POSTPROCESS requires vision ops 
(multibox_transform_loc, "
-            "non_max_suppression, get_valid_counts) not yet available in 
Relax. "
-            "See https://github.com/apache/tvm/issues/XXXX";
+            "DETECTION_POSTPROCESS is not wired in this frontend yet: it still 
needs "
+            "Relax NMS / get_valid_counts / related vision helpers (see dead 
code below). "
+            "relax.vision.multibox_transform_loc exists; tracking: "
+            "https://github.com/apache/tvm/issues/18928";
         )
         flexbuffer = op.CustomOptionsAsNumpy().tobytes()
         custom_options = FlexBufferDecoder(flexbuffer).decode()
@@ -3340,9 +3341,8 @@ class OperatorConverter:
         """Convert TFLite NonMaxSuppressionV5"""
         # 
https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/non-max-suppression-v5
         raise NotImplementedError(
-            "NON_MAX_SUPPRESSION_V5 requires vision ops (get_valid_counts, "
-            "non_max_suppression) not yet available in Relax. "
-            "See https://github.com/apache/tvm/issues/XXXX";
+            "NON_MAX_SUPPRESSION_V5 is not wired in this frontend yet (needs 
get_valid_counts, "
+            "non_max_suppression, etc.). Tracking: 
https://github.com/apache/tvm/issues/18928";
         )
 
         input_tensors = self.get_input_tensors(op)
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 0bc3f65784..ee1a2c2420 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -157,7 +157,7 @@ from .unary import (
     tanh,
     trunc,
 )
-from .vision import all_class_non_max_suppression, roi_align
+from .vision import all_class_non_max_suppression, multibox_transform_loc, 
roi_align
 
 
 def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index a3b6544dcc..e8c91f04b4 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -251,6 +251,11 @@ class ROIAlignAttrs(Attrs):
     """Attributes for vision.roi_align"""
 
 
+@tvm_ffi.register_object("relax.attrs.MultiboxTransformLocAttrs")
+class MultiboxTransformLocAttrs(Attrs):
+    """Attributes for vision.multibox_transform_loc"""
+
+
 @tvm_ffi.register_object("relax.attrs.Conv1DAttrs")
 class Conv1DAttrs(Attrs):
     """Attributes for nn.conv1d"""
diff --git a/python/tvm/relax/op/vision/__init__.py 
b/python/tvm/relax/op/vision/__init__.py
index 76d9ea35a1..58266c5b2a 100644
--- a/python/tvm/relax/op/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -17,5 +17,6 @@
 # under the License.
 """VISION operators."""
 
+from .multibox_transform_loc import *
 from .nms import *
 from .roi_align import *
diff --git a/python/tvm/relax/op/vision/multibox_transform_loc.py 
b/python/tvm/relax/op/vision/multibox_transform_loc.py
new file mode 100644
index 0000000000..6830b1dc63
--- /dev/null
+++ b/python/tvm/relax/op/vision/multibox_transform_loc.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Multibox location transform for object detection."""
+
+from . import _ffi_api
+
+
+def multibox_transform_loc(
+    cls_pred,
+    loc_pred,
+    anchor,
+    clip=False,
+    threshold=0.0,
+    variances=(1.0, 1.0, 1.0, 1.0),
+    keep_background=True,
+):
+    """SSD / TFLite-style decode: priors + offsets → boxes; logits → softmax 
scores.
+
+    Box decode follows TFLite ``DecodeCenterSizeBoxes``; expected tensor 
layout matches
+    ``tflite_frontend.convert_detection_postprocess`` (loc reorder yxhw→xywh, 
anchor ltrb).
+
+    Parameters
+    ----------
+    cls_pred : relax.Expr
+        ``[B, C, N]`` class logits (pre-softmax).
+    loc_pred : relax.Expr
+        ``[B, 4*N]`` per-anchor encodings as ``(x,y,w,h)`` after reorder (see 
above).
+    anchor : relax.Expr
+        ``[1, N, 4]`` priors: ``(left, top, right, bottom)``.
+    clip : bool
+        If True, clip ``ymin,xmin,ymax,xmax`` to ``[0, 1]``.
+    threshold : float
+        After softmax, multiply scores by mask ``(score >= threshold)``.
+    variances : tuple of 4 floats
+        ``(x,y,w,h)`` = TFLite ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale``.
+        Use magnitudes consistent with the model: very large ``w``/``h`` 
entries scale the
+        encoded height/width terms inside ``exp(...)`` and can overflow in 
float32/float16.
+    keep_background : bool
+        If False, set output scores at class index 0 to zero.
+
+    Returns
+    -------
+    result : relax.Expr
+        Tuple ``(boxes, scores)``: ``boxes`` is ``[B, N, 4]`` as 
``(ymin,xmin,ymax,xmax)``;
+        ``scores`` is ``[B, C, N]`` softmax, post-processed like the 
implementation.
+
+    Notes
+    -----
+    **Shape/dtype (checked in ``FInferStructInfo`` when static):**
+
+    - ``cls_pred``: 3-D; ``loc_pred``: 2-D; ``anchor``: 3-D.
+    - ``cls_pred``, ``loc_pred``, ``anchor`` dtypes must match.
+    - ``N = cls_pred.shape[2]``; ``loc_pred.shape[1] == 4*N``; ``anchor.shape 
== [1,N,4]``.
+    - ``loc_pred.shape[1]`` must be divisible by 4.
+    - ``cls_pred.shape[0]`` must equal ``loc_pred.shape[0]`` (batch).
+
+    If ``cls_pred`` has **unknown** shape, inference only returns generic 
rank-3 tensor
+    struct info for the two outputs; it does **not** verify ``4*N`` vs 
``loc_pred`` or
+    ``anchor.shape[1]`` vs ``N``, because ``N`` is not available statically. 
Other checks
+    (ranks, dtypes, ``loc_pred.shape[1] % 4 == 0`` when known, batch match 
when both batch
+    axes are known, etc.) still run where applicable.
+    """
+    return _ffi_api.multibox_transform_loc(
+        cls_pred,
+        loc_pred,
+        anchor,
+        clip,
+        threshold,
+        variances,
+        keep_background,
+    )
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
index 7a1e305f39..28367a67a3 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -118,3 +118,27 @@ def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
         aligned=call.attrs.aligned,
         layout=call.attrs.layout,
     )
+
+
+@register_legalize("relax.vision.multibox_transform_loc")
+def _multibox_transform_loc(bb: BlockBuilder, call: Call) -> Expr:
+    variances = tuple(float(x) for x in call.attrs.variances)
+
+    def _te(cls_pred, loc_pred, anchor):
+        return topi.vision.multibox_transform_loc(
+            cls_pred,
+            loc_pred,
+            anchor,
+            variances,
+            clip=call.attrs.clip,
+            threshold=call.attrs.threshold,
+            keep_background=call.attrs.keep_background,
+        )
+
+    return bb.call_te(
+        _te,
+        call.args[0],
+        call.args[1],
+        call.args[2],
+        primfunc_name_hint="multibox_transform_loc",
+    )
diff --git a/python/tvm/topi/vision/__init__.py 
b/python/tvm/topi/vision/__init__.py
index 75725a8a4b..cb0467c98c 100644
--- a/python/tvm/topi/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -17,5 +17,6 @@
 # under the License.
 """Vision operators."""
 
+from .multibox_transform_loc import *
 from .nms import *
 from .roi_align import *
diff --git a/python/tvm/topi/vision/multibox_transform_loc.py 
b/python/tvm/topi/vision/multibox_transform_loc.py
new file mode 100644
index 0000000000..ab965e7981
--- /dev/null
+++ b/python/tvm/topi/vision/multibox_transform_loc.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Multibox location transform (SSD / TFLite DetectionPostProcess decode)."""
+
+import tvm
+from tvm import te, topi
+
+
+def multibox_transform_loc(
+    cls_pred,
+    loc_pred,
+    anchor,
+    variances,
+    clip=False,
+    threshold=0.0,
+    keep_background=True,
+):
+    """TFLite ``DecodeCenterSizeBoxes``-style decode + softmax score 
post-process.
+
+    Inputs must match Relax op contracts: ``cls_pred [B,C,N]``, ``loc_pred 
[B,4*N]``,
+    ``anchor [1,N,4]`` ltrb; per-anchor loc order ``(x,y,w,h)`` after 
yxhw→xywh reorder.
+
+    Parameters
+    ----------
+    cls_pred : te.Tensor
+        ``[B, C, N]`` logits.
+    loc_pred : te.Tensor
+        ``[B, 4*N]`` encodings ``(x,y,w,h)`` per anchor.
+    anchor : te.Tensor
+        ``[1, N, 4]`` ``(left, top, right, bottom)``.
+    variances : tuple of 4 float
+        ``(x,y,w,h)`` = ``1/x_scale, 1/y_scale, 1/w_scale, 1/h_scale`` 
(TFLite).
+    clip : bool
+        Clip ``ymin,xmin,ymax,xmax`` to ``[0,1]``.
+    threshold : float
+        After softmax: ``scores *= (scores >= threshold)``.
+    keep_background : bool
+        If False: ``scores[:,0,:] = 0``.
+
+    Returns
+    -------
+    boxes : te.Tensor
+        ``[B, N, 4]`` as ``(ymin,xmin,ymax,xmax)``.
+    scores : te.Tensor
+        ``[B, C, N]`` softmax, then threshold mask and optional background 
zero.
+    """
+    dtype = cls_pred.dtype
+    B = cls_pred.shape[0]
+    num_anchors = cls_pred.shape[2]
+    loc_reshaped = topi.reshape(loc_pred, [B, num_anchors, 4])
+
+    vx = tvm.tirx.const(float(variances[0]), dtype)
+    vy = tvm.tirx.const(float(variances[1]), dtype)
+    vw = tvm.tirx.const(float(variances[2]), dtype)
+    vh = tvm.tirx.const(float(variances[3]), dtype)
+    half = tvm.tirx.const(0.5, dtype)
+    zero = tvm.tirx.const(0.0, dtype)
+    one = tvm.tirx.const(1.0, dtype)
+    th = tvm.tirx.const(float(threshold), dtype)
+
+    def decode_bbox(b, a, k):
+        l = anchor[0, a, 0]
+        t = anchor[0, a, 1]
+        r = anchor[0, a, 2]
+        br = anchor[0, a, 3]
+        ay = (t + br) * half
+        ax = (l + r) * half
+        ah = br - t
+        aw = r - l
+        ex = loc_reshaped[b, a, 0]
+        ey = loc_reshaped[b, a, 1]
+        ew = loc_reshaped[b, a, 2]
+        eh = loc_reshaped[b, a, 3]
+        ycenter = ey * vy * ah + ay
+        xcenter = ex * vx * aw + ax
+        half_h = half * te.exp(eh * vh) * ah
+        half_w = half * te.exp(ew * vw) * aw
+        ymin = ycenter - half_h
+        xmin = xcenter - half_w
+        ymax = ycenter + half_h
+        xmax = xcenter + half_w
+        if clip:
+            ymin = te.max(zero, te.min(one, ymin))
+            xmin = te.max(zero, te.min(one, xmin))
+            ymax = te.max(zero, te.min(one, ymax))
+            xmax = te.max(zero, te.min(one, xmax))
+        return tvm.tirx.Select(
+            k == 0,
+            ymin,
+            tvm.tirx.Select(k == 1, xmin, tvm.tirx.Select(k == 2, ymax, xmax)),
+        )
+
+    boxes = te.compute((B, num_anchors, 4), decode_bbox, name="multibox_boxes")
+
+    scores = topi.nn.softmax(cls_pred, axis=1)
+    mask = topi.cast(topi.greater_equal(scores, th), dtype)
+    scores = scores * mask
+    if not keep_background:
+
+        def zero_bg(b, c, n):
+            s = scores[b, c, n]
+            return te.if_then_else(c == 0, zero, s)
+
+        scores = te.compute(scores.shape, zero_bg, name="multibox_scores_bg")
+
+    return [boxes, scores]
diff --git a/src/relax/op/vision/multibox_transform_loc.cc 
b/src/relax/op/vision/multibox_transform_loc.cc
new file mode 100644
index 0000000000..e01e569b78
--- /dev/null
+++ b/src/relax/op/vision/multibox_transform_loc.cc
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file multibox_transform_loc.cc
+ * \brief Multibox transform (location decode) for object detection.
+ */
+
+#include "multibox_transform_loc.h"
+
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/relax/struct_info.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() { MultiboxTransformLocAttrs::RegisterReflection(); 
}
+
+Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool 
clip, double threshold,
+                            ffi::Array<double> variances, bool 
keep_background) {
+  TVM_FFI_ICHECK_EQ(variances.size(), 4)
+      << "multibox_transform_loc: variances must be length 4 (x,y,w,h), got " 
<< variances.size();
+
+  auto attrs = ffi::make_object<MultiboxTransformLocAttrs>();
+  attrs->clip = clip;
+  attrs->threshold = threshold;
+  attrs->variances = std::move(variances);
+  attrs->keep_background = keep_background;
+
+  static const Op& op = Op::Get("relax.vision.multibox_transform_loc");
+  return Call(op, {std::move(cls_pred), std::move(loc_pred), 
std::move(anchor)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.vision.multibox_transform_loc", 
multibox_transform_loc);
+}
+
+/*!
+ * \brief Infer struct info for relax.vision.multibox_transform_loc.
+ *
+ * \note Shape cross-checks that need the anchor count N (e.g. 
loc_pred.shape[1] == 4*N,
+ * anchor.shape[1] == N with N = cls_pred.shape[2]) run only when cls_pred has 
a known
+ * static shape. If cls_pred shape is unknown, inference returns generic 
rank-3 outputs and
+ * skips those N-based relations; other checks (ndim, dtype, loc dim divisible 
by 4, etc.)
+ * still apply when their inputs are known.
+ */
+StructInfo InferStructInfoMultiboxTransformLoc(const Call& call, const 
BlockBuilder& ctx) {
+  if (call->args.size() != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "multibox_transform_loc: expected 3 inputs (cls_pred, 
loc_pred, anchor), "
+                        "got "
+                     << call->args.size());
+  }
+
+  ffi::Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, 
ctx);
+  const auto cls_sinfo = input_sinfo[0];
+  const auto loc_sinfo = input_sinfo[1];
+  const auto anchor_sinfo = input_sinfo[2];
+
+  if (!cls_sinfo->IsUnknownNdim() && cls_sinfo->ndim != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "multibox_transform_loc: cls_pred must be 3-D [B, 
num_classes, N], got "
+                        "ndim "
+                     << cls_sinfo->ndim);
+  }
+  if (!loc_sinfo->IsUnknownNdim() && loc_sinfo->ndim != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "multibox_transform_loc: loc_pred must be 2-D [B, 
4*N], got ndim "
+                     << loc_sinfo->ndim);
+  }
+  if (!anchor_sinfo->IsUnknownNdim() && anchor_sinfo->ndim != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "multibox_transform_loc: anchor must be 3-D [1, N, 4] 
ltrb, got ndim "
+                     << anchor_sinfo->ndim);
+  }
+
+  if (!cls_sinfo->IsUnknownDtype() && !loc_sinfo->IsUnknownDtype() &&
+      cls_sinfo->dtype != loc_sinfo->dtype) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "multibox_transform_loc: cls_pred and loc_pred dtype 
must match, got "
+                     << cls_sinfo->dtype << " vs " << loc_sinfo->dtype);
+  }
+  if (!cls_sinfo->IsUnknownDtype() && !anchor_sinfo->IsUnknownDtype() &&
+      cls_sinfo->dtype != anchor_sinfo->dtype) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "multibox_transform_loc: cls_pred and anchor dtype 
must match, got "
+                     << cls_sinfo->dtype << " vs " << anchor_sinfo->dtype);
+  }
+
+  auto vdev = cls_sinfo->vdevice;
+  const auto* cls_shape = cls_sinfo->shape.as<ShapeExprNode>();
+  const auto* loc_shape = loc_sinfo->shape.as<ShapeExprNode>();
+  const auto* anchor_shape = anchor_sinfo->shape.as<ShapeExprNode>();
+
+  if (loc_shape != nullptr) {
+    const auto* loc_dim1 = loc_shape->values[1].as<IntImmNode>();
+    if (loc_dim1 != nullptr && loc_dim1->value % 4 != 0) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "multibox_transform_loc: loc_pred.shape[1] must be 
divisible by 4, got "
+                       << loc_dim1->value);
+    }
+  }
+
+  if (cls_shape != nullptr && loc_shape != nullptr) {
+    const auto* cls_b = cls_shape->values[0].as<IntImmNode>();
+    const auto* loc_b = loc_shape->values[0].as<IntImmNode>();
+    if (cls_b != nullptr && loc_b != nullptr && cls_b->value != loc_b->value) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "multibox_transform_loc: cls_pred.shape[0] must 
match loc_pred.shape[0], "
+                          "got B="
+                       << cls_b->value << " vs " << loc_b->value);
+    }
+  }
+
+  if (anchor_shape != nullptr) {
+    const auto* anchor_batch = anchor_shape->values[0].as<IntImmNode>();
+    if (anchor_batch != nullptr && anchor_batch->value != 1) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "multibox_transform_loc: anchor.shape[0] must be 1, 
got "
+                       << anchor_batch->value);
+    }
+    const auto* anchor_last = anchor_shape->values[2].as<IntImmNode>();
+    if (anchor_last != nullptr && anchor_last->value != 4) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "multibox_transform_loc: anchor.shape[2] must be 4 
(ltrb), got "
+                       << anchor_last->value);
+    }
+  }
+
+  if (cls_shape == nullptr) {
+    ffi::Array<StructInfo> fields = {TensorStructInfo(cls_sinfo->dtype, 3, 
vdev),
+                                     TensorStructInfo(cls_sinfo->dtype, 3, 
vdev)};
+    return TupleStructInfo(fields);
+  }
+
+  const auto& batch = cls_shape->values[0];
+  const auto& num_classes = cls_shape->values[1];
+  const auto& num_anchors = cls_shape->values[2];
+
+  if (loc_shape != nullptr) {
+    const auto* num_anchors_imm = num_anchors.as<IntImmNode>();
+    const auto* loc_dim1 = loc_shape->values[1].as<IntImmNode>();
+    if (num_anchors_imm != nullptr && loc_dim1 != nullptr &&
+        loc_dim1->value != num_anchors_imm->value * 4) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "multibox_transform_loc: loc_pred.shape[1] must 
equal 4*N with "
+                          "N=cls_pred.shape[2], got loc_dim="
+                       << loc_dim1->value << ", N=" << num_anchors_imm->value);
+    }
+  }
+  if (anchor_shape != nullptr) {
+    const auto* num_anchors_imm = num_anchors.as<IntImmNode>();
+    const auto* anchor_num_anchors = anchor_shape->values[1].as<IntImmNode>();
+    if (num_anchors_imm != nullptr && anchor_num_anchors != nullptr &&
+        anchor_num_anchors->value != num_anchors_imm->value) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "multibox_transform_loc: anchor.shape[1] must equal 
N=cls_pred.shape[2], "
+                          "got anchor_N="
+                       << anchor_num_anchors->value << ", N=" << 
num_anchors_imm->value);
+    }
+  }
+
+  ffi::Array<PrimExpr> boxes_shape = {batch, num_anchors, Integer(4)};
+  ffi::Array<PrimExpr> scores_shape = {batch, num_classes, num_anchors};
+  ffi::Array<StructInfo> fields = {
+      TensorStructInfo(ShapeExpr(boxes_shape), cls_sinfo->dtype, vdev),
+      TensorStructInfo(ShapeExpr(scores_shape), cls_sinfo->dtype, vdev)};
+  return TupleStructInfo(fields);
+}
+
+TVM_REGISTER_OP("relax.vision.multibox_transform_loc")
+    .describe("Decode SSD/TFLite-style priors and offsets into boxes and 
softmax scores. If "
+              "cls_pred shape is unknown, N-based loc/anchor shape checks are 
skipped in "
+              "inference. Very large variances (w,h) can overflow exp in half 
box sizes.")
+    .set_attrs_type<MultiboxTransformLocAttrs>()
+    .set_num_inputs(3)
+    .add_argument("cls_pred", "Tensor", "[B,C,N] class logits (pre-softmax).")
+    .add_argument("loc_pred", "Tensor",
+                  "[B,4*N] box encodings (x,y,w,h); TFLite yxhw order remapped 
to xywh.")
+    .add_argument("anchor", "Tensor", "[1,N,4] priors as ltrb 
(left,top,right,bottom).")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoMultiboxTransformLoc)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/vision/multibox_transform_loc.h 
b/src/relax/op/vision/multibox_transform_loc.h
new file mode 100644
index 0000000000..726bc4c0e5
--- /dev/null
+++ b/src/relax/op/vision/multibox_transform_loc.h
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file multibox_transform_loc.h
+ * \brief The functions to make Relax multibox_transform_loc operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_
+#define TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_
+
+#include <tvm/relax/attrs/vision.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Decode SSD box encodings and prepare class scores 
(TFLite-compatible). */
+Expr multibox_transform_loc(Expr cls_pred, Expr loc_pred, Expr anchor, bool 
clip, double threshold,
+                            ffi::Array<double> variances, bool 
keep_background);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_VISION_MULTIBOX_TRANSFORM_LOC_H_
diff --git a/tests/python/relax/test_op_vision.py 
b/tests/python/relax/test_op_vision.py
index b902518b49..cded9f5f29 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -286,6 +286,7 @@ def 
test_all_class_non_max_suppression_legalize_dynamic_trim():
     )
 
 
[email protected]_llvm
 def test_all_class_non_max_suppression_legalize_e2e():
     @tvm.script.ir_module
     class NMSModule:
@@ -344,5 +345,287 @@ def test_all_class_non_max_suppression_legalize_e2e():
     tvm.testing.assert_allclose(selected_indices.shape, (num_total_detections, 
3))
 
 
+def test_multibox_transform_loc_op_correctness():
+    cls = relax.Var("cls", R.Tensor((1, 5, 10), "float32"))
+    loc = relax.Var("loc", R.Tensor((1, 40), "float32"))
+    anc = relax.Var("anc", R.Tensor((1, 10, 4), "float32"))
+    assert (
+        relax.op.vision.multibox_transform_loc(
+            cls, loc, anc, False, 0.0, (1.0, 1.0, 1.0, 1.0), True
+        ).op
+        == Op.get("relax.vision.multibox_transform_loc")
+    )
+
+
+def test_multibox_transform_loc_infer_struct_info():
+    bb = relax.BlockBuilder()
+    cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+    loc = relax.Var("loc", R.Tensor((2, 20), "float32"))
+    anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+    _check_inference(
+        bb,
+        relax.op.vision.multibox_transform_loc(
+            cls, loc, anc, False, 0.0, (0.1, 0.1, 0.2, 0.2), True
+        ),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 5, 4), "float32"),
+                relax.TensorStructInfo((2, 3, 5), "float32"),
+            ]
+        ),
+    )
+
+
+def test_multibox_transform_loc_wrong_cls_ndim():
+    bb = relax.BlockBuilder()
+    cls = relax.Var("cls", R.Tensor((2, 3), "float32"))
+    loc = relax.Var("loc", R.Tensor((2, 20), "float32"))
+    anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc))
+
+
+def test_multibox_transform_loc_wrong_shape_relation():
+    bb = relax.BlockBuilder()
+    cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+    anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+    loc_bad_div = relax.Var("loc_bad_div", R.Tensor((2, 19), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_div, 
anc))
+    # Divisible by 4 but loc_dim != 4*N (N=5 -> expect 20, not 24)
+    loc_bad_n = relax.Var("loc_bad_n", R.Tensor((2, 24), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc_bad_n, 
anc))
+
+
+def test_multibox_transform_loc_wrong_anchor_shape():
+    bb = relax.BlockBuilder()
+    cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+    loc = relax.Var("loc", R.Tensor((2, 20), "float32"))
+    anc_bad_batch = relax.Var("anc_bad_batch", R.Tensor((2, 5, 4), "float32"))
+    anc_bad_last = relax.Var("anc_bad_last", R.Tensor((1, 5, 5), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, 
anc_bad_batch))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, 
anc_bad_last))
+
+
+def test_multibox_transform_loc_wrong_dtype():
+    bb = relax.BlockBuilder()
+    cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+    loc = relax.Var("loc", R.Tensor((2, 20), "float16"))
+    anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc))
+
+
+def test_multibox_transform_loc_wrong_batch():
+    bb = relax.BlockBuilder()
+    cls = relax.Var("cls", R.Tensor((2, 3, 5), "float32"))
+    loc = relax.Var("loc", R.Tensor((1, 20), "float32"))
+    anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.multibox_transform_loc(cls, loc, anc))
+
+
+def _multibox_ref_numpy(
+    cls_pred, loc_pred, anchor, variances, clip=False, threshold=0.0, 
keep_background=True
+):
+    """Numpy reference aligned with ``topi.vision.multibox_transform_loc``."""
+
+    def _softmax(x, axis):
+        x_max = np.max(x, axis=axis, keepdims=True)
+        exp = np.exp(x - x_max)
+        return exp / np.sum(exp, axis=axis, keepdims=True)
+
+    B, C, N = cls_pred.shape
+    loc = loc_pred.reshape(B, N, 4)
+    scores = _softmax(cls_pred.astype("float64"), axis=1).astype(np.float32)
+    if threshold > 0.0:
+        scores = np.where(scores >= threshold, scores, 0.0).astype(np.float32)
+    if not keep_background:
+        scores = scores.copy()
+        scores[:, 0, :] = 0.0
+    vx, vy, vw, vh = variances
+    boxes = np.zeros((B, N, 4), dtype=np.float32)
+    for b in range(B):
+        for a in range(N):
+            l, t, r, br = anchor[0, a, :]
+            ay = (t + br) * 0.5
+            ax = (l + r) * 0.5
+            ah = br - t
+            aw = r - l
+            ex, ey, ew, eh = loc[b, a, :]
+            ycenter = ey * vy * ah + ay
+            xcenter = ex * vx * aw + ax
+            half_h = 0.5 * np.exp(eh * vh) * ah
+            half_w = 0.5 * np.exp(ew * vw) * aw
+            ymin = ycenter - half_h
+            xmin = xcenter - half_w
+            ymax = ycenter + half_h
+            xmax = xcenter + half_w
+            if clip:
+                ymin = np.clip(ymin, 0.0, 1.0)
+                xmin = np.clip(xmin, 0.0, 1.0)
+                ymax = np.clip(ymax, 0.0, 1.0)
+                xmax = np.clip(xmax, 0.0, 1.0)
+            boxes[b, a, :] = (ymin, xmin, ymax, xmax)
+    return boxes, scores
+
+
[email protected]_llvm
+def test_multibox_transform_loc_legalize_e2e():
+    @tvm.script.ir_module
+    class Mod:
+        @R.function
+        def main(
+            cls: R.Tensor((1, 3, 5), "float32"),
+            loc: R.Tensor((1, 20), "float32"),
+            anc: R.Tensor((1, 5, 4), "float32"),
+        ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), 
"float32")):
+            return R.vision.multibox_transform_loc(
+                cls,
+                loc,
+                anc,
+                clip=False,
+                threshold=0.0,
+                variances=(1.0, 1.0, 1.0, 1.0),
+                keep_background=True,
+            )
+
+    cls_data = np.random.randn(1, 3, 5).astype(np.float32)
+    loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05
+    anc_data = np.array(
+        [
+            [
+                [0.1, 0.1, 0.5, 0.5],
+                [0.2, 0.2, 0.6, 0.6],
+                [0.0, 0.0, 1.0, 1.0],
+                [0.3, 0.3, 0.7, 0.7],
+                [0.05, 0.05, 0.45, 0.45],
+            ]
+        ],
+        dtype=np.float32,
+    )
+
+    mod = LegalizeOps()(Mod)
+    exe = tvm.compile(mod, target="llvm")
+    vm = relax.VirtualMachine(exe, tvm.cpu())
+    ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (1.0, 
1.0, 1.0, 1.0))
+    out = vm["main"](
+        tvm.runtime.tensor(cls_data, tvm.cpu()),
+        tvm.runtime.tensor(loc_data, tvm.cpu()),
+        tvm.runtime.tensor(anc_data, tvm.cpu()),
+    )
+    tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5)
+    tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5)
+
+
[email protected]_llvm
+def test_multibox_transform_loc_legalize_e2e_nonunity_variances():
+    @tvm.script.ir_module
+    class Mod:
+        @R.function
+        def main(
+            cls: R.Tensor((1, 3, 5), "float32"),
+            loc: R.Tensor((1, 20), "float32"),
+            anc: R.Tensor((1, 5, 4), "float32"),
+        ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), 
"float32")):
+            return R.vision.multibox_transform_loc(
+                cls,
+                loc,
+                anc,
+                clip=False,
+                threshold=0.0,
+                variances=(0.1, 0.1, 0.2, 0.2),
+                keep_background=True,
+            )
+
+    cls_data = np.random.randn(1, 3, 5).astype(np.float32)
+    loc_data = np.random.randn(1, 20).astype(np.float32) * 0.05
+    anc_data = np.array(
+        [
+            [
+                [0.1, 0.1, 0.5, 0.5],
+                [0.2, 0.2, 0.6, 0.6],
+                [0.0, 0.0, 1.0, 1.0],
+                [0.3, 0.3, 0.7, 0.7],
+                [0.05, 0.05, 0.45, 0.45],
+            ]
+        ],
+        dtype=np.float32,
+    )
+
+    mod = LegalizeOps()(Mod)
+    exe = tvm.compile(mod, target="llvm")
+    vm = relax.VirtualMachine(exe, tvm.cpu())
+    ref_b, ref_s = _multibox_ref_numpy(cls_data, loc_data, anc_data, (0.1, 
0.1, 0.2, 0.2))
+    out = vm["main"](
+        tvm.runtime.tensor(cls_data, tvm.cpu()),
+        tvm.runtime.tensor(loc_data, tvm.cpu()),
+        tvm.runtime.tensor(anc_data, tvm.cpu()),
+    )
+    tvm.testing.assert_allclose(out[0].numpy(), ref_b, rtol=1e-4, atol=1e-5)
+    tvm.testing.assert_allclose(out[1].numpy(), ref_s, rtol=1e-4, atol=1e-5)
+
+
[email protected]_llvm
+def test_multibox_transform_loc_legalize_attr_branches():
+    @tvm.script.ir_module
+    class Mod:
+        @R.function
+        def main(
+            cls: R.Tensor((1, 3, 4), "float32"),
+            loc: R.Tensor((1, 16), "float32"),
+            anc: R.Tensor((1, 4, 4), "float32"),
+        ) -> R.Tuple(R.Tensor((1, 4, 4), "float32"), R.Tensor((1, 3, 4), 
"float32")):
+            return R.vision.multibox_transform_loc(
+                cls,
+                loc,
+                anc,
+                clip=True,
+                threshold=0.4,
+                variances=(1.0, 1.0, 1.0, 1.0),
+                keep_background=False,
+            )
+
+    cls_data = np.array(
+        [[[2.0, 0.1, -0.5, 0.0], [0.2, 2.2, 0.3, -1.0], [0.1, 0.4, 2.0, 0.5]]],
+        dtype=np.float32,
+    )
+    loc_data = np.array(
+        [[0.1, -0.2, 0.0, 0.0, -0.2, 0.1, 0.3, -0.1, 0.0, 0.0, 0.8, 0.8, 0.2, 
0.2, -0.6, -0.6]],
+        dtype=np.float32,
+    )
+    anc_data = np.array(
+        [[[0.1, 0.1, 0.5, 0.5], [0.2, 0.2, 0.6, 0.6], [0.0, 0.0, 1.0, 1.0], 
[0.4, 0.4, 1.2, 1.2]]],
+        dtype=np.float32,
+    )
+
+    mod = LegalizeOps()(Mod)
+    exe = tvm.compile(mod, target="llvm")
+    vm = relax.VirtualMachine(exe, tvm.cpu())
+    ref_b, ref_s = _multibox_ref_numpy(
+        cls_data,
+        loc_data,
+        anc_data,
+        (1.0, 1.0, 1.0, 1.0),
+        clip=True,
+        threshold=0.4,
+        keep_background=False,
+    )
+    out = vm["main"](
+        tvm.runtime.tensor(cls_data, tvm.cpu()),
+        tvm.runtime.tensor(loc_data, tvm.cpu()),
+        tvm.runtime.tensor(anc_data, tvm.cpu()),
+    )
+    boxes = out[0].numpy()
+    scores = out[1].numpy()
+    tvm.testing.assert_allclose(boxes, ref_b, rtol=1e-4, atol=1e-5)
+    tvm.testing.assert_allclose(scores, ref_s, rtol=1e-4, atol=1e-5)
+    assert np.all(boxes >= 0.0) and np.all(boxes <= 1.0)
+    tvm.testing.assert_allclose(scores[:, 0, :], np.zeros_like(scores[:, 0, 
:]))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py 
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index c4e8ff0c9d..f053e36744 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -75,6 +75,48 @@ def test_all_class_non_max_suppression():
     _check(foo, bb.get()["foo"])
 
 
+def test_multibox_transform_loc():
+    @R.function
+    def foo(
+        cls: R.Tensor((1, 3, 5), "float32"),
+        loc: R.Tensor((1, 20), "float32"),
+        anc: R.Tensor((1, 5, 4), "float32"),
+    ) -> R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), 
"float32")):
+        gv: R.Tuple(R.Tensor((1, 5, 4), "float32"), R.Tensor((1, 3, 5), 
"float32")) = (
+            R.vision.multibox_transform_loc(
+                cls,
+                loc,
+                anc,
+                clip=False,
+                threshold=0.0,
+                variances=(1.0, 1.0, 1.0, 1.0),
+                keep_background=True,
+            )
+        )
+        return gv
+
+    cls = relax.Var("cls", R.Tensor((1, 3, 5), "float32"))
+    loc = relax.Var("loc", R.Tensor((1, 20), "float32"))
+    anc = relax.Var("anc", R.Tensor((1, 5, 4), "float32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [cls, loc, anc]):
+        gv = bb.emit(
+            relax.op.vision.multibox_transform_loc(
+                cls,
+                loc,
+                anc,
+                clip=False,
+                threshold=0.0,
+                variances=(1.0, 1.0, 1.0, 1.0),
+                keep_background=True,
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 def test_roi_align():
     @R.function
     def foo(


Reply via email to