This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e08eb2fa9 [Relax][ONNX][Torch] Add roi_align support and frontend 
integration (#18936)
1e08eb2fa9 is described below

commit 1e08eb2fa9dcb59a4f34917a94c659d81ad36c0c
Author: YinHanke <[email protected]>
AuthorDate: Thu Mar 26 22:41:54 2026 +0800

    [Relax][ONNX][Torch] Add roi_align support and frontend integration (#18936)
    
    ## Summary
    
    Add Relax `roi_align` support and wire it through the ONNX and PyTorch
    frontends.
    
    ## Changes
    
    - add `relax.vision.roi_align`, including attrs, Python wrapper, struct
    info inference, and legalization
    - add TOPI `roi_align` compute and keep both legacy and aligned ROIAlign
    semantics
    - support ONNX `RoiAlign`, including `coordinate_transformation_mode`
    handling for `output_half_pixel` and `half_pixel`
    - support PyTorch `torchvision.ops.roi_align` in the exported-program
    frontend, including the `aligned` flag
    - add regression tests for Relax op inference, legalization, TVMScript
    parsing, ONNX frontend import, and PyTorch frontend import
    - add aligned ROIAlign test coverage to make sure sub-pixel RoIs no
    longer use the legacy `min=1.0` clamp
    
    ## Validation
    
    - `pytest tests/python/relax/test_op_vision.py -k roi_align`
    - `pytest tests/python/relax/test_tvmscript_parser_op_vision.py -k
    roi_align`
    - `pytest tests/python/relax/test_frontend_onnx.py -k roi_align`
    - `pytest tests/python/relax/test_frontend_from_exported_program.py -k
    roi_align`
    
    This PR completes the Relax/ONNX/Torch roi_align work tracked in #18928.
---
 include/tvm/relax/attrs/vision.h                   |  25 +++
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |  67 ++++++-
 .../frontend/torch/exported_program_translator.py  |  37 ++++
 python/tvm/relax/op/__init__.py                    |   2 +-
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/relax/op/vision/__init__.py             |   1 +
 python/tvm/relax/op/vision/roi_align.py            |  78 ++++++++
 python/tvm/relax/transform/legalize_ops/vision.py  |  15 ++
 python/tvm/topi/testing/roi_align_python.py        |  15 +-
 python/tvm/topi/vision/__init__.py                 |   1 +
 python/tvm/topi/vision/roi_align.py                | 204 +++++++++++++++++++++
 src/relax/op/vision/roi_align.cc                   | 141 ++++++++++++++
 src/relax/op/vision/roi_align.h                    |  42 +++++
 .../relax/test_frontend_from_exported_program.py   |  60 ++++++
 tests/python/relax/test_frontend_onnx.py           |  60 +++++-
 tests/python/relax/test_op_vision.py               | 168 +++++++++++++++++
 .../relax/test_tvmscript_parser_op_vision.py       |  32 ++++
 17 files changed, 944 insertions(+), 9 deletions(-)

diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 2fd98533b5..59a1dd7314 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -48,6 +48,31 @@ struct AllClassNonMaximumSuppressionAttrs
                                     AllClassNonMaximumSuppressionAttrs, 
BaseAttrsNode);
 };  // struct AllClassNonMaximumSuppressionAttrs
 
+/*! \brief Attributes used in ROIAlign operator */
+struct ROIAlignAttrs : public AttrsNodeReflAdapter<ROIAlignAttrs> {
+  ffi::Array<int64_t> pooled_size;
+  double spatial_scale;
+  int sample_ratio;
+  bool aligned;
+  ffi::String layout;
+  ffi::String mode;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<ROIAlignAttrs>()
+        .def_ro("pooled_size", &ROIAlignAttrs::pooled_size, "Output size of 
roi align.")
+        .def_ro("spatial_scale", &ROIAlignAttrs::spatial_scale,
+                "Ratio of input feature map height (or width) to raw image 
height (or width).")
+        .def_ro("sample_ratio", &ROIAlignAttrs::sample_ratio,
+                "Optional sampling ratio of ROI align, using adaptive size by 
default.")
+        .def_ro("aligned", &ROIAlignAttrs::aligned,
+                "Whether to use the aligned ROIAlign semantics without the 
legacy 1-pixel clamp.")
+        .def_ro("layout", &ROIAlignAttrs::layout, "Dimension ordering of the 
input data.")
+        .def_ro("mode", &ROIAlignAttrs::mode, "Mode for ROI Align. Can be 
'avg' or 'max'.");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs", 
ROIAlignAttrs, BaseAttrsNode);
+};  // struct ROIAlignAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d828025e0a..c8d4c469fc 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2420,6 +2420,71 @@ class Einsum(OnnxOpConverter):
         return bb.emit_te(topi.einsum, equation, *inputs)
 
 
+class RoiAlign(OnnxOpConverter):
+    """Converts an onnx RoiAlign node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl(cls, bb, inputs, attr, params, 
default_coordinate_transformation_mode):
+        if len(inputs) != 3:
+            raise ValueError("RoiAlign expects exactly 3 inputs")
+
+        data = inputs[0]
+        rois = inputs[1]
+        batch_indices = inputs[2]
+        rois_dtype = rois.struct_info.dtype
+
+        mode = attr.get("mode", b"avg")
+        if isinstance(mode, bytes):
+            mode = mode.decode("ascii")
+        if mode not in ("avg", "max"):
+            raise NotImplementedError("RoiAlign in Relax only supports avg and 
max modes")
+
+        output_height = attr.get("output_height", 1)
+        output_width = attr.get("output_width", 1)
+        sampling_ratio = attr.get("sampling_ratio", 0)
+        spatial_scale = attr.get("spatial_scale", 1.0)
+        coordinate_transformation_mode = attr.get(
+            "coordinate_transformation_mode", 
default_coordinate_transformation_mode
+        )
+        if isinstance(coordinate_transformation_mode, bytes):
+            coordinate_transformation_mode = 
coordinate_transformation_mode.decode("ascii")
+
+        if coordinate_transformation_mode == "half_pixel":
+            offset = relax.const([-0.5, -0.5, -0.5, -0.5], rois_dtype)
+            rois = relax.op.add(rois, offset)
+            aligned = True
+        elif coordinate_transformation_mode != "output_half_pixel":
+            raise NotImplementedError(
+                "RoiAlign only supports coordinate_transformation_mode "
+                "'half_pixel' and 'output_half_pixel'"
+            )
+        else:
+            aligned = False
+
+        batch_indices = relax.op.expand_dims(batch_indices, axis=1)
+        batch_indices = relax.op.astype(batch_indices, rois_dtype)
+        rois = relax.op.concat([batch_indices, rois], axis=1)
+
+        return relax.op.vision.roi_align(
+            data,
+            rois,
+            pooled_size=(output_height, output_width),
+            spatial_scale=spatial_scale,
+            sample_ratio=sampling_ratio,
+            aligned=aligned,
+            layout="NCHW",
+            mode=mode,
+        )
+
+    @classmethod
+    def _impl_v10(cls, bb, inputs, attr, params):
+        return cls._impl(bb, inputs, attr, params, b"output_half_pixel")
+
+    @classmethod
+    def _impl_v16(cls, bb, inputs, attr, params):
+        return cls._impl(bb, inputs, attr, params, b"half_pixel")
+
+
 class Range(OnnxOpConverter):
     """Converts an onnx Range node into an equivalent Relax expression."""
 
@@ -4082,7 +4147,7 @@ def _get_convert_map():
         "NonZero": NonZero,
         # "If": If,
         # "MaxRoiPool": MaxRoiPool,
-        # "RoiAlign": RoiAlign,
+        "RoiAlign": RoiAlign,
         "NonMaxSuppression": NonMaxSuppression,
         "AllClassNMS": AllClassNMS,
         "GridSample": GridSample,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 67e0e45da0..47633c69b5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1123,6 +1123,42 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             )
         )
 
+    def _torchvision_roi_align(self, node: fx.Node) -> relax.Var:
+        """Convert torchvision.ops.roi_align to relax.op.vision.roi_align."""
+        args = self.retrieve_args(node)
+        data = args[0]
+        rois = args[1]
+        spatial_scale = args[2] if len(args) > 2 else 1.0
+        pooled_height = args[3] if len(args) > 3 else 1
+        pooled_width = args[4] if len(args) > 4 else pooled_height
+        sampling_ratio = args[5] if len(args) > 5 else -1
+        aligned = args[6] if len(args) > 6 else False
+
+        if aligned:
+            batch_indices = self.block_builder.emit(
+                relax.op.strided_slice(rois, axes=[1], begin=[0], end=[1])
+            )
+            boxes = self.block_builder.emit(
+                relax.op.strided_slice(rois, axes=[1], begin=[1], end=[5])
+            )
+            boxes = self.block_builder.emit(
+                relax.op.subtract(boxes, relax.const(0.5, 
rois.struct_info.dtype))
+            )
+            rois = self.block_builder.emit(relax.op.concat([batch_indices, 
boxes], axis=1))
+
+        return self.block_builder.emit(
+            relax.op.vision.roi_align(
+                data,
+                rois,
+                pooled_size=(pooled_height, pooled_width),
+                spatial_scale=spatial_scale,
+                sample_ratio=sampling_ratio,
+                aligned=aligned,
+                layout="NCHW",
+                mode="avg",
+            )
+        )
+
     def _scalar_tensor(self, node: fx.Node) -> relax.Var:
         args = self.retrieve_args(node)
         scalar_value = args[0]
@@ -1732,6 +1768,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "zeros.default": self._zeros,
             "zeros_like.default": self._zeros_like,
             "grid_sampler_2d.default": self._grid_sampler_2d,
+            "roi_align.default": self._torchvision_roi_align,
             # datatype
             "to.dtype": self._to,
             "to.dtype_layout": self._to,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 7c3f75298b..0bc3f65784 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -157,7 +157,7 @@ from .unary import (
     tanh,
     trunc,
 )
-from .vision import all_class_non_max_suppression
+from .vision import all_class_non_max_suppression, roi_align
 
 
 def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index f07623bd38..a3b6544dcc 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -246,6 +246,11 @@ class AllClassNonMaximumSuppressionAttrs(Attrs):
     """Attributes for vision.all_class_non_max_suppression"""
 
 
+@tvm_ffi.register_object("relax.attrs.ROIAlignAttrs")
+class ROIAlignAttrs(Attrs):
+    """Attributes for vision.roi_align"""
+
+
 @tvm_ffi.register_object("relax.attrs.Conv1DAttrs")
 class Conv1DAttrs(Attrs):
     """Attributes for nn.conv1d"""
diff --git a/python/tvm/relax/op/vision/__init__.py 
b/python/tvm/relax/op/vision/__init__.py
index ea20d2b400..76d9ea35a1 100644
--- a/python/tvm/relax/op/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -18,3 +18,4 @@
 """VISION operators."""
 
 from .nms import *
+from .roi_align import *
diff --git a/python/tvm/relax/op/vision/roi_align.py 
b/python/tvm/relax/op/vision/roi_align.py
new file mode 100644
index 0000000000..8db694c7f2
--- /dev/null
+++ b/python/tvm/relax/op/vision/roi_align.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""ROI Align operator"""
+
+from ..base import Expr
+from . import _ffi_api
+
+
+def roi_align(
+    data: Expr,
+    rois: Expr,
+    pooled_size: int | tuple[int, int] | list[int],
+    spatial_scale: float,
+    sample_ratio: int = -1,
+    aligned: bool = False,
+    layout: str = "NCHW",
+    mode: str = "avg",
+):
+    """ROI Align operator.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        4-D input tensor.
+
+    rois : relax.Expr
+        2-D input tensor with shape `(num_roi, 5)` in
+        `[batch_idx, x1, y1, x2, y2]` format.
+
+    pooled_size : Union[int, Tuple[int, int], List[int]]
+        Output pooled size.
+
+    spatial_scale : float
+        Ratio of input feature map height (or width) to raw image height (or 
width).
+
+    sample_ratio : int, optional
+        Sampling ratio for ROI align. Non-positive values use adaptive 
sampling.
+
+    aligned : bool, optional
+        Whether to use aligned ROIAlign semantics without the legacy 1-pixel 
clamp.
+
+    layout : str, optional
+        Layout of the input data. Supported values are `NCHW` and `NHWC`.
+
+    mode : str, optional
+        Mode for ROI align. Supported values are `avg` and `max`.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(pooled_size, int):
+        pooled_size = (pooled_size, pooled_size)
+    return _ffi_api.roi_align(
+        data,
+        rois,
+        pooled_size,
+        spatial_scale,
+        sample_ratio,
+        aligned,
+        layout,
+        mode,
+    )
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py 
b/python/tvm/relax/transform/legalize_ops/vision.py
index f95dfa35b6..7a1e305f39 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -103,3 +103,18 @@ def _all_class_non_max_suppression(block_builder: 
BlockBuilder, call: Call) -> E
 
     # Return trimmed indices along with num_total_detections for compatibility
     return relax.Tuple([trimmed_indices, num_total_detections])
+
+
+@register_legalize("relax.vision.roi_align")
+def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.vision.roi_align,
+        call.args[0],
+        call.args[1],
+        pooled_size=call.attrs.pooled_size,
+        spatial_scale=call.attrs.spatial_scale,
+        mode=call.attrs.mode,
+        sample_ratio=call.attrs.sample_ratio,
+        aligned=call.attrs.aligned,
+        layout=call.attrs.layout,
+    )
diff --git a/python/tvm/topi/testing/roi_align_python.py 
b/python/tvm/topi/testing/roi_align_python.py
index 9fc72074e4..19725a0c5b 100644
--- a/python/tvm/topi/testing/roi_align_python.py
+++ b/python/tvm/topi/testing/roi_align_python.py
@@ -59,6 +59,7 @@ def roi_align_common(
     pooled_size_w,
     spatial_scale,
     sample_ratio,
+    aligned,
     avg_mode,
     max_mode,
     height,
@@ -72,8 +73,8 @@ def roi_align_common(
         roi = rois_np[i]
         batch_index = int(roi[0])
         roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] * 
spatial_scale
-        roi_h = max(roi_end_h - roi_start_h, 1.0)
-        roi_w = max(roi_end_w - roi_start_w, 1.0)
+        roi_h = roi_end_h - roi_start_h if aligned else max(roi_end_h - 
roi_start_h, 1.0)
+        roi_w = roi_end_w - roi_start_w if aligned else max(roi_end_w - 
roi_start_w, 1.0)
 
         bin_h = roi_h / pooled_size_h
         bin_w = roi_w / pooled_size_w
@@ -115,7 +116,9 @@ def roi_align_common(
     return b_np
 
 
-def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, 
sample_ratio, mode=b"avg"):
+def roi_align_nchw_python(
+    a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg", 
aligned=False
+):
     """Roi align NCHW in python"""
     avg_mode = mode in (b"avg", "avg", 0)
     max_mode = mode in (b"max", "max", 1)
@@ -137,6 +140,7 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, 
spatial_scale, sample_rati
         pooled_size_w,
         spatial_scale,
         sample_ratio,
+        aligned,
         avg_mode,
         max_mode,
         height,
@@ -145,7 +149,9 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size, 
spatial_scale, sample_rati
     )
 
 
-def roi_align_nhwc_python(a_np, rois_np, pooled_size, spatial_scale, 
sample_ratio, mode=b"avg"):
+def roi_align_nhwc_python(
+    a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg", 
aligned=False
+):
     """Roi align NHWC in python"""
     avg_mode = mode in (b"avg", "avg", 0)
     max_mode = mode in (b"max", "max", 1)
@@ -169,6 +175,7 @@ def roi_align_nhwc_python(a_np, rois_np, pooled_size, 
spatial_scale, sample_rati
         pooled_size_w,
         spatial_scale,
         sample_ratio,
+        aligned,
         avg_mode,
         max_mode,
         height,
diff --git a/python/tvm/topi/vision/__init__.py 
b/python/tvm/topi/vision/__init__.py
index c637b9cab2..75725a8a4b 100644
--- a/python/tvm/topi/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -18,3 +18,4 @@
 """Vision operators."""
 
 from .nms import *
+from .roi_align import *
diff --git a/python/tvm/topi/vision/roi_align.py 
b/python/tvm/topi/vision/roi_align.py
new file mode 100644
index 0000000000..2c2d0faec1
--- /dev/null
+++ b/python/tvm/topi/vision/roi_align.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""ROI Align operator"""
+
+import tvm
+from tvm import te
+
+from ..cpp.utils import bilinear_sample_nchw, bilinear_sample_nhwc
+
+
+def _sample_common(
+    i,
+    c,
+    ph,
+    pw,
+    rois,
+    pooled_size_h,
+    pooled_size_w,
+    spatial_scale,
+    sample_ratio,
+    aligned,
+    dtype,
+    avg_mode,
+    bilinear_func,
+):
+    roi = rois[i]
+    batch_index = roi[0].astype("int32")
+    roi_start_w = roi[1] * spatial_scale
+    roi_start_h = roi[2] * spatial_scale
+    roi_end_w = roi[3] * spatial_scale
+    roi_end_h = roi[4] * spatial_scale
+
+    if aligned:
+        roi_h = roi_end_h - roi_start_h
+        roi_w = roi_end_w - roi_start_w
+    else:
+        roi_h = te.max(roi_end_h - roi_start_h, tvm.tirx.const(1.0, dtype))
+        roi_w = te.max(roi_end_w - roi_start_w, tvm.tirx.const(1.0, dtype))
+
+    pooled_size_h_const = tvm.tirx.const(pooled_size_h, dtype)
+    pooled_size_w_const = tvm.tirx.const(pooled_size_w, dtype)
+    bin_h = roi_h / pooled_size_h_const
+    bin_w = roi_w / pooled_size_w_const
+
+    if sample_ratio > 0:
+        roi_bin_grid_h = tvm.tirx.const(sample_ratio, "int32")
+        roi_bin_grid_w = tvm.tirx.const(sample_ratio, "int32")
+    else:
+        roi_bin_grid_h = te.ceil(roi_h / pooled_size_h_const).astype("int32")
+        roi_bin_grid_w = te.ceil(roi_w / pooled_size_w_const).astype("int32")
+
+    count = roi_bin_grid_h * roi_bin_grid_w
+    rh = te.reduce_axis((0, roi_bin_grid_h), name="rh")
+    rw = te.reduce_axis((0, roi_bin_grid_w), name="rw")
+    roi_start_h = roi_start_h + tvm.tirx.Cast(dtype, ph) * bin_h
+    roi_start_w = roi_start_w + tvm.tirx.Cast(dtype, pw) * bin_w
+
+    def sample_value(rh_idx, rw_idx):
+        return bilinear_func(
+            batch_index,
+            c,
+            roi_start_h
+            + (tvm.tirx.Cast(dtype, rh_idx) + tvm.tirx.const(0.5, dtype))
+            * bin_h
+            / tvm.tirx.Cast(dtype, roi_bin_grid_h),
+            roi_start_w
+            + (tvm.tirx.Cast(dtype, rw_idx) + tvm.tirx.const(0.5, dtype))
+            * bin_w
+            / tvm.tirx.Cast(dtype, roi_bin_grid_w),
+        )
+
+    if avg_mode:
+        return te.sum(
+            sample_value(rh, rw) / tvm.tirx.Cast(dtype, count),
+            axis=[rh, rw],
+        )
+    return te.max(sample_value(rh, rw), axis=[rh, rw])
+
+
+def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, 
sample_ratio=-1, aligned=False):
+    """ROI align operator in NCHW layout."""
+    avg_mode = mode in (b"avg", "avg", 0)
+    max_mode = mode in (b"max", "max", 1)
+    assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a 
valid mode."
+
+    _, channel, height, width = data.shape
+    num_roi, _ = rois.shape
+    dtype = rois.dtype
+
+    if isinstance(pooled_size, int):
+        pooled_size_h = pooled_size_w = pooled_size
+    else:
+        pooled_size_h, pooled_size_w = pooled_size
+
+    height_f = tvm.tirx.Cast(dtype, height)
+    width_f = tvm.tirx.Cast(dtype, width)
+    zero = tvm.tirx.const(0.0, data.dtype)
+
+    def _bilinear(n, c, y, x):
+        outside = tvm.tirx.any(y < -1.0, x < -1.0, y > height_f, x > width_f)
+        y = te.min(te.max(y, 0.0), tvm.tirx.Cast(dtype, height - 1))
+        x = te.min(te.max(x, 0.0), tvm.tirx.Cast(dtype, width - 1))
+        val = bilinear_sample_nchw(data, (n, c, y, x), height - 1, width - 1)
+        return tvm.tirx.if_then_else(outside, zero, val)
+
+    return te.compute(
+        (num_roi, channel, pooled_size_h, pooled_size_w),
+        lambda i, c, ph, pw: _sample_common(
+            i,
+            c,
+            ph,
+            pw,
+            rois,
+            pooled_size_h,
+            pooled_size_w,
+            spatial_scale,
+            sample_ratio,
+            aligned,
+            dtype,
+            avg_mode,
+            _bilinear,
+        ),
+        tag="pool,roi_align_nchw",
+    )
+
+
+def roi_align_nhwc(data, rois, pooled_size, spatial_scale, mode, 
sample_ratio=-1, aligned=False):
+    """ROI align operator in NHWC layout."""
+    avg_mode = mode in (b"avg", "avg", 0)
+    max_mode = mode in (b"max", "max", 1)
+    assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a 
valid mode."
+
+    _, height, width, channel = data.shape
+    num_roi, _ = rois.shape
+    dtype = rois.dtype
+
+    if isinstance(pooled_size, int):
+        pooled_size_h = pooled_size_w = pooled_size
+    else:
+        pooled_size_h, pooled_size_w = pooled_size
+
+    height_f = tvm.tirx.Cast(dtype, height)
+    width_f = tvm.tirx.Cast(dtype, width)
+    zero = tvm.tirx.const(0.0, data.dtype)
+
+    def _bilinear(n, c, y, x):
+        outside = tvm.tirx.any(y < -1.0, x < -1.0, y > height_f, x > width_f)
+        y = te.min(te.max(y, 0.0), tvm.tirx.Cast(dtype, height - 1))
+        x = te.min(te.max(x, 0.0), tvm.tirx.Cast(dtype, width - 1))
+        val = bilinear_sample_nhwc(data, (n, y, x, c), height - 1, width - 1)
+        return tvm.tirx.if_then_else(outside, zero, val)
+
+    return te.compute(
+        (num_roi, pooled_size_h, pooled_size_w, channel),
+        lambda i, ph, pw, c: _sample_common(
+            i,
+            c,
+            ph,
+            pw,
+            rois,
+            pooled_size_h,
+            pooled_size_w,
+            spatial_scale,
+            sample_ratio,
+            aligned,
+            dtype,
+            avg_mode,
+            _bilinear,
+        ),
+        tag="pool,roi_align_nhwc",
+    )
+
+
+def roi_align(
+    data,
+    rois,
+    pooled_size,
+    spatial_scale,
+    mode="avg",
+    sample_ratio=-1,
+    aligned=False,
+    layout="NCHW",
+):
+    """ROI align operator."""
+    if layout == "NCHW":
+        return roi_align_nchw(data, rois, pooled_size, spatial_scale, mode, 
sample_ratio, aligned)
+    if layout == "NHWC":
+        return roi_align_nhwc(data, rois, pooled_size, spatial_scale, mode, 
sample_ratio, aligned)
+    raise ValueError(f"Unsupported layout for roi_align: {layout}")
diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc
new file mode 100644
index 0000000000..ae5185d6d4
--- /dev/null
+++ b/src/relax/op/vision/roi_align.cc
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_align.cc
+ * \brief ROI Align operators.
+ */
+
+#include "roi_align.h"
+
+#include <tvm/ffi/reflection/registry.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() { ROIAlignAttrs::RegisterReflection(); }
+
+Expr roi_align(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double 
spatial_scale,
+               int sample_ratio, bool aligned, ffi::String layout, ffi::String 
mode) {
+  if (pooled_size.size() == 1) {
+    pooled_size.push_back(pooled_size[0]);
+  }
+  TVM_FFI_ICHECK_EQ(pooled_size.size(), 2)
+      << "The input pooled_size length is expected to be 2. However, the given 
pooled_size is "
+      << pooled_size;
+
+  auto attrs = ffi::make_object<ROIAlignAttrs>();
+  attrs->pooled_size = std::move(pooled_size);
+  attrs->spatial_scale = spatial_scale;
+  attrs->sample_ratio = sample_ratio;
+  attrs->aligned = aligned;
+  attrs->layout = layout;
+  attrs->mode = mode;
+
+  static const Op& op = Op::Get("relax.vision.roi_align");
+  return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.vision.roi_align", roi_align);
+}
+
+StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIAlign expects two arguments, while the given 
number of arguments is "
+                     << call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* rois_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIAlign expects the input data to be a Tensor, while 
the given data is "
+                     << call->args[0]->GetTypeKey());
+  }
+  if (rois_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIAlign expects the rois to be a Tensor, while the 
given rois is "
+                     << call->args[1]->GetTypeKey());
+  }
+  if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIAlign expects the input data to be 4-D, while the 
given data has ndim "
+                     << data_sinfo->ndim);
+  }
+  if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIAlign expects the rois tensor to be 2-D, while the 
given rois has ndim "
+                     << rois_sinfo->ndim);
+  }
+
+  const auto* attrs = call->attrs.as<ROIAlignAttrs>();
+  TVM_FFI_ICHECK(attrs != nullptr) << "Invalid ROIAlign attrs";
+  if (attrs->layout != "NCHW" && attrs->layout != "NHWC") {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIAlign only supports NCHW and NHWC layout, but got 
" << attrs->layout);
+  }
+  if (attrs->mode != "avg" && attrs->mode != "max") {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "ROIAlign only supports avg and max mode, but got " << 
attrs->mode);
+  }
+
+  const auto* rois_shape = rois_sinfo->shape.as<ShapeExprNode>();
+  if (rois_shape != nullptr) {
+    const auto* last_dim = rois_shape->values[1].as<IntImmNode>();
+    if (last_dim != nullptr && last_dim->value != 5) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << "ROIAlign expects rois to have shape (num_roi, 5), 
but got last "
+                          "dimension "
+                       << last_dim->value);
+    }
+  }
+
+  if (data_sinfo->shape.as<ShapeExprNode>() == nullptr || rois_shape == 
nullptr) {
+    return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice);
+  }
+
+  ffi::Array<PrimExpr> data_shape = 
data_sinfo->shape.as<ShapeExprNode>()->values;
+  ffi::Array<PrimExpr> out_shape;
+  if (attrs->layout == "NCHW") {
+    out_shape = {rois_shape->values[0], data_shape[1], 
Integer(attrs->pooled_size[0]),
+                 Integer(attrs->pooled_size[1])};
+  } else {
+    out_shape = {rois_shape->values[0], Integer(attrs->pooled_size[0]),
+                 Integer(attrs->pooled_size[1]), data_shape[3]};
+  }
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.vision.roi_align")
+    .set_attrs_type<ROIAlignAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("rois", "Tensor",
+                  "The input rois with shape (num_roi, 5) in [batch_idx, x1, 
y1, x2, y2] format.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoROIAlign)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/vision/roi_align.h b/src/relax/op/vision/roi_align.h
new file mode 100644
index 0000000000..e2b861ac64
--- /dev/null
+++ b/src/relax/op/vision/roi_align.h
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_align.h
+ * \brief The functions to make Relax ROI Align operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_ROI_ALIGN_H_
+#define TVM_RELAX_OP_VISION_ROI_ALIGN_H_
+
+#include <tvm/relax/attrs/vision.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief ROI Align operator. */
+Expr roi_align(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double 
spatial_scale,
+               int sample_ratio, bool aligned, ffi::String layout, ffi::String 
mode);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_VISION_ROI_ALIGN_H_
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index e1cadb9d02..7a3548b4cf 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -20,6 +20,7 @@ import operator
 import numpy as np
 import pytest
 import torch
+import torchvision
 from torch import nn
 from torch.export import export
 from torch.nn import Module
@@ -8746,6 +8747,65 @@ def test_grid_sample():
     verify_model(GridSample(), example_args, {}, expected)
 
 
+def test_torchvision_roi_align():
+    class ROIAlign(Module):
+        def forward(self, input, rois):
+            return torchvision.ops.roi_align(
+                input,
+                rois,
+                output_size=(3, 3),
+                spatial_scale=1.0,
+                sampling_ratio=2,
+                aligned=False,
+            )
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            input_1: R.Tensor((1, 3, 8, 8), dtype="float32"),
+            rois: R.Tensor((2, 5), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((2, 3, 3, 3), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((2, 3, 3, 3), dtype="float32") = 
R.vision.roi_align(
+                    input_1,
+                    rois,
+                    pooled_size=(3, 3),
+                    spatial_scale=1.0,
+                    sample_ratio=2,
+                    layout="NCHW",
+                    mode="avg",
+                )
+                gv: R.Tuple(R.Tensor((2, 3, 3, 3), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+    example_args = (
+        torch.randn(1, 3, 8, 8, dtype=torch.float32),
+        torch.tensor([[0.0, 1.0, 1.0, 6.0, 6.0], [0.0, 0.5, 0.5, 7.0, 7.0]], 
dtype=torch.float32),
+    )
+    verify_model(ROIAlign(), example_args, {}, expected)
+
+
+def test_torchvision_roi_align_aligned():
+    class ROIAlign(Module):
+        def forward(self, input, rois):
+            return torchvision.ops.roi_align(
+                input,
+                rois,
+                output_size=(1, 1),
+                spatial_scale=1.0,
+                sampling_ratio=2,
+                aligned=True,
+            )
+
+    example_args = (
+        torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4),
+        torch.tensor([[0.0, 1.0, 1.0, 1.2, 1.2]], dtype=torch.float32),
+    )
+    verify_model_numerically(ROIAlign(), example_args, rtol=1e-5, atol=1e-5)
+
+
 def test_upsample_nearest2d():
     class UpsampleNearest2dScale(Module):
         def forward(self, input):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index f56adfbfdb..86fa533874 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4010,6 +4010,7 @@ def test_nms_score_threshold():
             tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5, 
atol=1e-5
         )
 
+
 @pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
 @pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
 @pytest.mark.parametrize("align_corners", [0, 1])
@@ -4052,6 +4053,7 @@ def test_grid_sample(mode, padding_mode, align_corners):
         opset=16,
     )
 
+
 def test_grid_sample_linear_mode_translation():
     """Test that ONNX mode='linear' is correctly translated to 'bilinear'.
 
@@ -4078,7 +4080,9 @@ def test_grid_sample_linear_mode_translation():
             helper.make_tensor_value_info("grid", TensorProto.FLOAT, 
grid_shape),
         ],
         outputs=[
-            helper.make_tensor_value_info("Y", TensorProto.FLOAT, [x_shape[0], 
x_shape[1], grid_shape[1], grid_shape[2]]),
+            helper.make_tensor_value_info(
+                "Y", TensorProto.FLOAT, [x_shape[0], x_shape[1], 
grid_shape[1], grid_shape[2]]
+            ),
         ],
     )
 
@@ -4113,7 +4117,9 @@ def test_grid_sample_cubic_mode_translation():
             helper.make_tensor_value_info("grid", TensorProto.FLOAT, 
grid_shape),
         ],
         outputs=[
-            helper.make_tensor_value_info("Y", TensorProto.FLOAT, [x_shape[0], 
x_shape[1], grid_shape[1], grid_shape[2]]),
+            helper.make_tensor_value_info(
+                "Y", TensorProto.FLOAT, [x_shape[0], x_shape[1], 
grid_shape[1], grid_shape[2]]
+            ),
         ],
     )
 
@@ -4122,6 +4128,54 @@ def test_grid_sample_cubic_mode_translation():
     # Verify 'cubic' was translated to 'bicubic' in the Relax IR
     assert 'method="bicubic"' in str(tvm_model)
 
+
[email protected](
+    ("coordinate_transformation_mode", "rois"),
+    [
+        (
+            "output_half_pixel",
+            np.array([[1.0, 1.0, 6.0, 6.0], [2.0, 0.5, 7.0, 7.0]], 
dtype="float32"),
+        ),
+        ("half_pixel", np.array([[1.0, 1.0, 1.2, 1.2], [2.0, 0.5, 1.1, 1.1]], 
dtype="float32")),
+    ],
+)
+def test_roi_align(coordinate_transformation_mode, rois):
+    x_shape = [1, 4, 8, 8]
+    rois_shape = [2, 4]
+    batch_indices_shape = [2]
+    out_shape = [2, 4, 3, 3]
+
+    node = helper.make_node(
+        "RoiAlign",
+        inputs=["X", "rois", "batch_indices"],
+        outputs=["Y"],
+        output_height=3,
+        output_width=3,
+        sampling_ratio=2,
+        spatial_scale=1.0,
+        mode="avg",
+        coordinate_transformation_mode=coordinate_transformation_mode,
+    )
+
+    graph = helper.make_graph(
+        [node],
+        "roi_align_test",
+        inputs=[
+            helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+            helper.make_tensor_value_info("rois", TensorProto.FLOAT, 
rois_shape),
+            helper.make_tensor_value_info("batch_indices", TensorProto.INT64, 
batch_indices_shape),
+        ],
+        outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, 
out_shape)],
+    )
+
+    model = helper.make_model(graph, producer_name="roi_align_test")
+    inputs = {
+        "X": rg.standard_normal(size=x_shape).astype("float32"),
+        "rois": rois,
+        "batch_indices": np.array([0, 0], dtype="int64"),
+    }
+    check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
-
diff --git a/tests/python/relax/test_op_vision.py 
b/tests/python/relax/test_op_vision.py
index 753ee14140..b902518b49 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -21,6 +21,7 @@ import pytest
 import tvm
 import tvm.testing
 from tvm import TVMError, relax, tirx
+from tvm.ir import Op
 from tvm.relax.transform import LegalizeOps
 from tvm.script import relax as R
 
@@ -30,6 +31,173 @@ def _check_inference(bb: relax.BlockBuilder, call: 
relax.Call, expected_sinfo: r
     tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
 
 
+def test_roi_align_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+    assert relax.op.vision.roi_align(x, rois, (7, 7), 1.0).op == 
Op.get("relax.vision.roi_align")
+
+
+def test_roi_align_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
+    rois = relax.Var("rois", R.Tensor((5, 5), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.vision.roi_align(x0, rois, (7, 7), 0.25),
+        relax.TensorStructInfo((5, 3, 7, 7), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.vision.roi_align(x1, rois, (5, 7), 1.0, layout="NHWC"),
+        relax.TensorStructInfo((5, 5, 7, 3), "float32"),
+    )
+
+
+def test_roi_align_infer_struct_info_aligned():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((5, 5), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.vision.roi_align(x, rois, (7, 7), 1.0, aligned=True),
+        relax.TensorStructInfo((5, 3, 7, 7), "float32"),
+    )
+
+
+def test_roi_align_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    n = tirx.Var("n", "int64")
+    c = tirx.Var("c", "int64")
+    h = tirx.Var("h", "int64")
+    w = tirx.Var("w", "int64")
+    num_roi = tirx.Var("num_roi", "int64")
+
+    x = relax.Var("x", R.Tensor((n, c, h, w), "float32"))
+    rois = relax.Var("rois", R.Tensor((num_roi, 5), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.vision.roi_align(x, rois, (7, 7), 0.5),
+        relax.TensorStructInfo((num_roi, c, 7, 7), "float32"),
+    )
+
+
+def test_roi_align_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois0 = relax.Var("rois", R.Tensor((4,), "float32"))
+    rois1 = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_align(x0, rois1, (7, 7), 1.0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_align(x1, rois0, (7, 7), 1.0))
+
+
+def test_roi_align_wrong_rois_last_dim():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((4, 4), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_align(x, rois, (7, 7), 1.0))
+
+
+def test_roi_align_wrong_layout():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.vision.roi_align(x, rois, (7, 7), 1.0, 
layout="HWCN"))
+
+
+def test_roi_align_legalize():
+    @tvm.script.ir_module
+    class ROIAlign:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 8, 8), "float32"),
+            rois: R.Tensor((2, 5), "float32"),
+        ) -> R.Tensor((2, 2, 3, 3), "float32"):
+            gv: R.Tensor((2, 2, 3, 3), "float32") = R.vision.roi_align(
+                x,
+                rois,
+                pooled_size=(3, 3),
+                spatial_scale=1.0,
+                sample_ratio=2,
+                layout="NCHW",
+                mode="avg",
+            )
+            return gv
+
+    mod = LegalizeOps()(ROIAlign)
+    assert "call_tir" in str(mod)
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TensorStructInfo((2, 2, 3, 3), "float32"),
+    )
+
+
+def test_roi_align_legalize_aligned():
+    @tvm.script.ir_module
+    class ROIAlign:
+        @R.function
+        def main(
+            x: R.Tensor((1, 1, 4, 4), "float32"),
+            rois: R.Tensor((1, 5), "float32"),
+        ) -> R.Tensor((1, 1, 1, 1), "float32"):
+            gv: R.Tensor((1, 1, 1, 1), "float32") = R.vision.roi_align(
+                x,
+                rois,
+                pooled_size=(1, 1),
+                spatial_scale=1.0,
+                sample_ratio=2,
+                aligned=True,
+                layout="NCHW",
+                mode="avg",
+            )
+            return gv
+
+    mod = LegalizeOps()(ROIAlign)
+    assert "call_tir" in str(mod)
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+    )
+
+
+def test_roi_align_legalize_sample_ratio_zero():
+    @tvm.script.ir_module
+    class ROIAlign:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2, 8, 8), "float32"),
+            rois: R.Tensor((1, 5), "float32"),
+        ) -> R.Tensor((1, 2, 2, 2), "float32"):
+            gv: R.Tensor((1, 2, 2, 2), "float32") = R.vision.roi_align(
+                x,
+                rois,
+                pooled_size=(2, 2),
+                spatial_scale=1.0,
+                sample_ratio=0,
+                layout="NCHW",
+                mode="avg",
+            )
+            return gv
+
+    mod = LegalizeOps()(ROIAlign)
+    assert "call_tir" in str(mod)
+    tvm.ir.assert_structural_equal(
+        mod["main"].ret_struct_info,
+        relax.TensorStructInfo((1, 2, 2, 2), "float32"),
+    )
+
+
 def test_all_class_non_max_suppression_infer_struct_info():
     bb = relax.BlockBuilder()
     batch_size, num_classes, num_boxes = 10, 8, 5
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py 
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index 10817c1dc4..c4e8ff0c9d 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -75,5 +75,37 @@ def test_all_class_non_max_suppression():
     _check(foo, bb.get()["foo"])
 
 
+def test_roi_align():
+    @R.function
+    def foo(
+        x: R.Tensor((1, 2, 8, 8), "float32"),
+        rois: R.Tensor((2, 5), "float32"),
+    ) -> R.Tensor((2, 2, 3, 3), "float32"):
+        gv: R.Tensor((2, 2, 3, 3), "float32") = R.vision.roi_align(
+            x,
+            rois,
+            pooled_size=(3, 3),
+            spatial_scale=1.0,
+            sample_ratio=2,
+            layout="NCHW",
+            mode="avg",
+        )
+        return gv
+
+    x = relax.Var("x", R.Tensor((1, 2, 8, 8), "float32"))
+    rois = relax.Var("rois", R.Tensor((2, 5), "float32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, rois]):
+        gv = bb.emit(
+            relax.op.vision.roi_align(
+                x, rois, (3, 3), 1.0, sample_ratio=2, layout="NCHW", mode="avg"
+            )
+        )
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to