This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1e08eb2fa9 [Relax][ONNX][Torch] Add roi_align support and frontend
integration (#18936)
1e08eb2fa9 is described below
commit 1e08eb2fa9dcb59a4f34917a94c659d81ad36c0c
Author: YinHanke <[email protected]>
AuthorDate: Thu Mar 26 22:41:54 2026 +0800
[Relax][ONNX][Torch] Add roi_align support and frontend integration (#18936)
## Summary
Add Relax `roi_align` support and wire it through the ONNX and PyTorch
frontends.
## Changes
- add `relax.vision.roi_align`, including attrs, Python wrapper, struct
info inference, and legalization
- add TOPI `roi_align` compute and keep both legacy and aligned ROIAlign
semantics
- support ONNX `RoiAlign`, including `coordinate_transformation_mode`
handling for `output_half_pixel` and `half_pixel`
- support PyTorch `torchvision.ops.roi_align` in the exported-program
frontend, including the `aligned` flag
- add regression tests for Relax op inference, legalization, TVMScript
parsing, ONNX frontend import, and PyTorch frontend import
- add aligned ROIAlign test coverage to make sure sub-pixel RoIs no
longer use the legacy `min=1.0` clamp
## Validation
- `pytest tests/python/relax/test_op_vision.py -k roi_align`
- `pytest tests/python/relax/test_tvmscript_parser_op_vision.py -k
roi_align`
- `pytest tests/python/relax/test_frontend_onnx.py -k roi_align`
- `pytest tests/python/relax/test_frontend_from_exported_program.py -k
roi_align`
This PR completes the Relax/ONNX/Torch roi_align work tracked in #18928.
---
include/tvm/relax/attrs/vision.h | 25 +++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 67 ++++++-
.../frontend/torch/exported_program_translator.py | 37 ++++
python/tvm/relax/op/__init__.py | 2 +-
python/tvm/relax/op/op_attrs.py | 5 +
python/tvm/relax/op/vision/__init__.py | 1 +
python/tvm/relax/op/vision/roi_align.py | 78 ++++++++
python/tvm/relax/transform/legalize_ops/vision.py | 15 ++
python/tvm/topi/testing/roi_align_python.py | 15 +-
python/tvm/topi/vision/__init__.py | 1 +
python/tvm/topi/vision/roi_align.py | 204 +++++++++++++++++++++
src/relax/op/vision/roi_align.cc | 141 ++++++++++++++
src/relax/op/vision/roi_align.h | 42 +++++
.../relax/test_frontend_from_exported_program.py | 60 ++++++
tests/python/relax/test_frontend_onnx.py | 60 +++++-
tests/python/relax/test_op_vision.py | 168 +++++++++++++++++
.../relax/test_tvmscript_parser_op_vision.py | 32 ++++
17 files changed, 944 insertions(+), 9 deletions(-)
diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h
index 2fd98533b5..59a1dd7314 100644
--- a/include/tvm/relax/attrs/vision.h
+++ b/include/tvm/relax/attrs/vision.h
@@ -48,6 +48,31 @@ struct AllClassNonMaximumSuppressionAttrs
AllClassNonMaximumSuppressionAttrs,
BaseAttrsNode);
}; // struct AllClassNonMaximumSuppressionAttrs
+/*! \brief Attributes used in ROIAlign operator */
+struct ROIAlignAttrs : public AttrsNodeReflAdapter<ROIAlignAttrs> {
+ ffi::Array<int64_t> pooled_size;
+ double spatial_scale;
+ int sample_ratio;
+ bool aligned;
+ ffi::String layout;
+ ffi::String mode;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<ROIAlignAttrs>()
+ .def_ro("pooled_size", &ROIAlignAttrs::pooled_size, "Output size of
roi align.")
+ .def_ro("spatial_scale", &ROIAlignAttrs::spatial_scale,
+ "Ratio of input feature map height (or width) to raw image
height (or width).")
+ .def_ro("sample_ratio", &ROIAlignAttrs::sample_ratio,
+ "Optional sampling ratio of ROI align, using adaptive size by
default.")
+ .def_ro("aligned", &ROIAlignAttrs::aligned,
+ "Whether to use the aligned ROIAlign semantics without the
legacy 1-pixel clamp.")
+ .def_ro("layout", &ROIAlignAttrs::layout, "Dimension ordering of the
input data.")
+ .def_ro("mode", &ROIAlignAttrs::mode, "Mode for ROI Align. Can be
'avg' or 'max'.");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.ROIAlignAttrs",
ROIAlignAttrs, BaseAttrsNode);
+}; // struct ROIAlignAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d828025e0a..c8d4c469fc 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2420,6 +2420,71 @@ class Einsum(OnnxOpConverter):
return bb.emit_te(topi.einsum, equation, *inputs)
+class RoiAlign(OnnxOpConverter):
+ """Converts an onnx RoiAlign node into an equivalent Relax expression."""
+
+ @classmethod
+ def _impl(cls, bb, inputs, attr, params,
default_coordinate_transformation_mode):
+ if len(inputs) != 3:
+ raise ValueError("RoiAlign expects exactly 3 inputs")
+
+ data = inputs[0]
+ rois = inputs[1]
+ batch_indices = inputs[2]
+ rois_dtype = rois.struct_info.dtype
+
+ mode = attr.get("mode", b"avg")
+ if isinstance(mode, bytes):
+ mode = mode.decode("ascii")
+ if mode not in ("avg", "max"):
+ raise NotImplementedError("RoiAlign in Relax only supports avg and
max modes")
+
+ output_height = attr.get("output_height", 1)
+ output_width = attr.get("output_width", 1)
+ sampling_ratio = attr.get("sampling_ratio", 0)
+ spatial_scale = attr.get("spatial_scale", 1.0)
+ coordinate_transformation_mode = attr.get(
+ "coordinate_transformation_mode",
default_coordinate_transformation_mode
+ )
+ if isinstance(coordinate_transformation_mode, bytes):
+ coordinate_transformation_mode =
coordinate_transformation_mode.decode("ascii")
+
+ if coordinate_transformation_mode == "half_pixel":
+ offset = relax.const([-0.5, -0.5, -0.5, -0.5], rois_dtype)
+ rois = relax.op.add(rois, offset)
+ aligned = True
+ elif coordinate_transformation_mode != "output_half_pixel":
+ raise NotImplementedError(
+ "RoiAlign only supports coordinate_transformation_mode "
+ "'half_pixel' and 'output_half_pixel'"
+ )
+ else:
+ aligned = False
+
+ batch_indices = relax.op.expand_dims(batch_indices, axis=1)
+ batch_indices = relax.op.astype(batch_indices, rois_dtype)
+ rois = relax.op.concat([batch_indices, rois], axis=1)
+
+ return relax.op.vision.roi_align(
+ data,
+ rois,
+ pooled_size=(output_height, output_width),
+ spatial_scale=spatial_scale,
+ sample_ratio=sampling_ratio,
+ aligned=aligned,
+ layout="NCHW",
+ mode=mode,
+ )
+
+ @classmethod
+ def _impl_v10(cls, bb, inputs, attr, params):
+ return cls._impl(bb, inputs, attr, params, b"output_half_pixel")
+
+ @classmethod
+ def _impl_v16(cls, bb, inputs, attr, params):
+ return cls._impl(bb, inputs, attr, params, b"half_pixel")
+
+
class Range(OnnxOpConverter):
"""Converts an onnx Range node into an equivalent Relax expression."""
@@ -4082,7 +4147,7 @@ def _get_convert_map():
"NonZero": NonZero,
# "If": If,
# "MaxRoiPool": MaxRoiPool,
- # "RoiAlign": RoiAlign,
+ "RoiAlign": RoiAlign,
"NonMaxSuppression": NonMaxSuppression,
"AllClassNMS": AllClassNMS,
"GridSample": GridSample,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 67e0e45da0..47633c69b5 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1123,6 +1123,42 @@ class ExportedProgramImporter(BaseFXGraphImporter):
)
)
+ def _torchvision_roi_align(self, node: fx.Node) -> relax.Var:
+ """Convert torchvision.ops.roi_align to relax.op.vision.roi_align."""
+ args = self.retrieve_args(node)
+ data = args[0]
+ rois = args[1]
+ spatial_scale = args[2] if len(args) > 2 else 1.0
+ pooled_height = args[3] if len(args) > 3 else 1
+ pooled_width = args[4] if len(args) > 4 else pooled_height
+ sampling_ratio = args[5] if len(args) > 5 else -1
+ aligned = args[6] if len(args) > 6 else False
+
+ if aligned:
+ batch_indices = self.block_builder.emit(
+ relax.op.strided_slice(rois, axes=[1], begin=[0], end=[1])
+ )
+ boxes = self.block_builder.emit(
+ relax.op.strided_slice(rois, axes=[1], begin=[1], end=[5])
+ )
+ boxes = self.block_builder.emit(
+ relax.op.subtract(boxes, relax.const(0.5,
rois.struct_info.dtype))
+ )
+ rois = self.block_builder.emit(relax.op.concat([batch_indices,
boxes], axis=1))
+
+ return self.block_builder.emit(
+ relax.op.vision.roi_align(
+ data,
+ rois,
+ pooled_size=(pooled_height, pooled_width),
+ spatial_scale=spatial_scale,
+ sample_ratio=sampling_ratio,
+ aligned=aligned,
+ layout="NCHW",
+ mode="avg",
+ )
+ )
+
def _scalar_tensor(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
scalar_value = args[0]
@@ -1732,6 +1768,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
"zeros.default": self._zeros,
"zeros_like.default": self._zeros_like,
"grid_sampler_2d.default": self._grid_sampler_2d,
+ "roi_align.default": self._torchvision_roi_align,
# datatype
"to.dtype": self._to,
"to.dtype_layout": self._to,
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 7c3f75298b..0bc3f65784 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -157,7 +157,7 @@ from .unary import (
tanh,
trunc,
)
-from .vision import all_class_non_max_suppression
+from .vision import all_class_non_max_suppression, roi_align
def _register_op_make():
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index f07623bd38..a3b6544dcc 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -246,6 +246,11 @@ class AllClassNonMaximumSuppressionAttrs(Attrs):
"""Attributes for vision.all_class_non_max_suppression"""
+@tvm_ffi.register_object("relax.attrs.ROIAlignAttrs")
+class ROIAlignAttrs(Attrs):
+ """Attributes for vision.roi_align"""
+
+
@tvm_ffi.register_object("relax.attrs.Conv1DAttrs")
class Conv1DAttrs(Attrs):
"""Attributes for nn.conv1d"""
diff --git a/python/tvm/relax/op/vision/__init__.py
b/python/tvm/relax/op/vision/__init__.py
index ea20d2b400..76d9ea35a1 100644
--- a/python/tvm/relax/op/vision/__init__.py
+++ b/python/tvm/relax/op/vision/__init__.py
@@ -18,3 +18,4 @@
"""VISION operators."""
from .nms import *
+from .roi_align import *
diff --git a/python/tvm/relax/op/vision/roi_align.py
b/python/tvm/relax/op/vision/roi_align.py
new file mode 100644
index 0000000000..8db694c7f2
--- /dev/null
+++ b/python/tvm/relax/op/vision/roi_align.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""ROI Align operator"""
+
+from ..base import Expr
+from . import _ffi_api
+
+
+def roi_align(
+ data: Expr,
+ rois: Expr,
+ pooled_size: int | tuple[int, int] | list[int],
+ spatial_scale: float,
+ sample_ratio: int = -1,
+ aligned: bool = False,
+ layout: str = "NCHW",
+ mode: str = "avg",
+):
+ """ROI Align operator.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ 4-D input tensor.
+
+ rois : relax.Expr
+ 2-D input tensor with shape `(num_roi, 5)` in
+ `[batch_idx, x1, y1, x2, y2]` format.
+
+ pooled_size : Union[int, Tuple[int, int], List[int]]
+ Output pooled size.
+
+ spatial_scale : float
+ Ratio of input feature map height (or width) to raw image height (or
width).
+
+ sample_ratio : int, optional
+ Sampling ratio for ROI align. Non-positive values use adaptive
sampling.
+
+ aligned : bool, optional
+ Whether to use aligned ROIAlign semantics without the legacy 1-pixel
clamp.
+
+ layout : str, optional
+ Layout of the input data. Supported values are `NCHW` and `NHWC`.
+
+ mode : str, optional
+ Mode for ROI align. Supported values are `avg` and `max`.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(pooled_size, int):
+ pooled_size = (pooled_size, pooled_size)
+ return _ffi_api.roi_align(
+ data,
+ rois,
+ pooled_size,
+ spatial_scale,
+ sample_ratio,
+ aligned,
+ layout,
+ mode,
+ )
diff --git a/python/tvm/relax/transform/legalize_ops/vision.py
b/python/tvm/relax/transform/legalize_ops/vision.py
index f95dfa35b6..7a1e305f39 100644
--- a/python/tvm/relax/transform/legalize_ops/vision.py
+++ b/python/tvm/relax/transform/legalize_ops/vision.py
@@ -103,3 +103,18 @@ def _all_class_non_max_suppression(block_builder:
BlockBuilder, call: Call) -> E
# Return trimmed indices along with num_total_detections for compatibility
return relax.Tuple([trimmed_indices, num_total_detections])
+
+
+@register_legalize("relax.vision.roi_align")
+def _roi_align(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.vision.roi_align,
+ call.args[0],
+ call.args[1],
+ pooled_size=call.attrs.pooled_size,
+ spatial_scale=call.attrs.spatial_scale,
+ mode=call.attrs.mode,
+ sample_ratio=call.attrs.sample_ratio,
+ aligned=call.attrs.aligned,
+ layout=call.attrs.layout,
+ )
diff --git a/python/tvm/topi/testing/roi_align_python.py
b/python/tvm/topi/testing/roi_align_python.py
index 9fc72074e4..19725a0c5b 100644
--- a/python/tvm/topi/testing/roi_align_python.py
+++ b/python/tvm/topi/testing/roi_align_python.py
@@ -59,6 +59,7 @@ def roi_align_common(
pooled_size_w,
spatial_scale,
sample_ratio,
+ aligned,
avg_mode,
max_mode,
height,
@@ -72,8 +73,8 @@ def roi_align_common(
roi = rois_np[i]
batch_index = int(roi[0])
roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] *
spatial_scale
- roi_h = max(roi_end_h - roi_start_h, 1.0)
- roi_w = max(roi_end_w - roi_start_w, 1.0)
+ roi_h = roi_end_h - roi_start_h if aligned else max(roi_end_h -
roi_start_h, 1.0)
+ roi_w = roi_end_w - roi_start_w if aligned else max(roi_end_w -
roi_start_w, 1.0)
bin_h = roi_h / pooled_size_h
bin_w = roi_w / pooled_size_w
@@ -115,7 +116,9 @@ def roi_align_common(
return b_np
-def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale,
sample_ratio, mode=b"avg"):
+def roi_align_nchw_python(
+ a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg",
aligned=False
+):
"""Roi align NCHW in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
@@ -137,6 +140,7 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size,
spatial_scale, sample_rati
pooled_size_w,
spatial_scale,
sample_ratio,
+ aligned,
avg_mode,
max_mode,
height,
@@ -145,7 +149,9 @@ def roi_align_nchw_python(a_np, rois_np, pooled_size,
spatial_scale, sample_rati
)
-def roi_align_nhwc_python(a_np, rois_np, pooled_size, spatial_scale,
sample_ratio, mode=b"avg"):
+def roi_align_nhwc_python(
+ a_np, rois_np, pooled_size, spatial_scale, sample_ratio, mode=b"avg",
aligned=False
+):
"""Roi align NHWC in python"""
avg_mode = mode in (b"avg", "avg", 0)
max_mode = mode in (b"max", "max", 1)
@@ -169,6 +175,7 @@ def roi_align_nhwc_python(a_np, rois_np, pooled_size,
spatial_scale, sample_rati
pooled_size_w,
spatial_scale,
sample_ratio,
+ aligned,
avg_mode,
max_mode,
height,
diff --git a/python/tvm/topi/vision/__init__.py
b/python/tvm/topi/vision/__init__.py
index c637b9cab2..75725a8a4b 100644
--- a/python/tvm/topi/vision/__init__.py
+++ b/python/tvm/topi/vision/__init__.py
@@ -18,3 +18,4 @@
"""Vision operators."""
from .nms import *
+from .roi_align import *
diff --git a/python/tvm/topi/vision/roi_align.py
b/python/tvm/topi/vision/roi_align.py
new file mode 100644
index 0000000000..2c2d0faec1
--- /dev/null
+++ b/python/tvm/topi/vision/roi_align.py
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""ROI Align operator"""
+
+import tvm
+from tvm import te
+
+from ..cpp.utils import bilinear_sample_nchw, bilinear_sample_nhwc
+
+
+def _sample_common(
+ i,
+ c,
+ ph,
+ pw,
+ rois,
+ pooled_size_h,
+ pooled_size_w,
+ spatial_scale,
+ sample_ratio,
+ aligned,
+ dtype,
+ avg_mode,
+ bilinear_func,
+):
+ roi = rois[i]
+ batch_index = roi[0].astype("int32")
+ roi_start_w = roi[1] * spatial_scale
+ roi_start_h = roi[2] * spatial_scale
+ roi_end_w = roi[3] * spatial_scale
+ roi_end_h = roi[4] * spatial_scale
+
+ if aligned:
+ roi_h = roi_end_h - roi_start_h
+ roi_w = roi_end_w - roi_start_w
+ else:
+ roi_h = te.max(roi_end_h - roi_start_h, tvm.tirx.const(1.0, dtype))
+ roi_w = te.max(roi_end_w - roi_start_w, tvm.tirx.const(1.0, dtype))
+
+ pooled_size_h_const = tvm.tirx.const(pooled_size_h, dtype)
+ pooled_size_w_const = tvm.tirx.const(pooled_size_w, dtype)
+ bin_h = roi_h / pooled_size_h_const
+ bin_w = roi_w / pooled_size_w_const
+
+ if sample_ratio > 0:
+ roi_bin_grid_h = tvm.tirx.const(sample_ratio, "int32")
+ roi_bin_grid_w = tvm.tirx.const(sample_ratio, "int32")
+ else:
+ roi_bin_grid_h = te.ceil(roi_h / pooled_size_h_const).astype("int32")
+ roi_bin_grid_w = te.ceil(roi_w / pooled_size_w_const).astype("int32")
+
+ count = roi_bin_grid_h * roi_bin_grid_w
+ rh = te.reduce_axis((0, roi_bin_grid_h), name="rh")
+ rw = te.reduce_axis((0, roi_bin_grid_w), name="rw")
+ roi_start_h = roi_start_h + tvm.tirx.Cast(dtype, ph) * bin_h
+ roi_start_w = roi_start_w + tvm.tirx.Cast(dtype, pw) * bin_w
+
+ def sample_value(rh_idx, rw_idx):
+ return bilinear_func(
+ batch_index,
+ c,
+ roi_start_h
+ + (tvm.tirx.Cast(dtype, rh_idx) + tvm.tirx.const(0.5, dtype))
+ * bin_h
+ / tvm.tirx.Cast(dtype, roi_bin_grid_h),
+ roi_start_w
+ + (tvm.tirx.Cast(dtype, rw_idx) + tvm.tirx.const(0.5, dtype))
+ * bin_w
+ / tvm.tirx.Cast(dtype, roi_bin_grid_w),
+ )
+
+ if avg_mode:
+ return te.sum(
+ sample_value(rh, rw) / tvm.tirx.Cast(dtype, count),
+ axis=[rh, rw],
+ )
+ return te.max(sample_value(rh, rw), axis=[rh, rw])
+
+
+def roi_align_nchw(data, rois, pooled_size, spatial_scale, mode,
sample_ratio=-1, aligned=False):
+ """ROI align operator in NCHW layout."""
+ avg_mode = mode in (b"avg", "avg", 0)
+ max_mode = mode in (b"max", "max", 1)
+ assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a
valid mode."
+
+ _, channel, height, width = data.shape
+ num_roi, _ = rois.shape
+ dtype = rois.dtype
+
+ if isinstance(pooled_size, int):
+ pooled_size_h = pooled_size_w = pooled_size
+ else:
+ pooled_size_h, pooled_size_w = pooled_size
+
+ height_f = tvm.tirx.Cast(dtype, height)
+ width_f = tvm.tirx.Cast(dtype, width)
+ zero = tvm.tirx.const(0.0, data.dtype)
+
+ def _bilinear(n, c, y, x):
+ outside = tvm.tirx.any(y < -1.0, x < -1.0, y > height_f, x > width_f)
+ y = te.min(te.max(y, 0.0), tvm.tirx.Cast(dtype, height - 1))
+ x = te.min(te.max(x, 0.0), tvm.tirx.Cast(dtype, width - 1))
+ val = bilinear_sample_nchw(data, (n, c, y, x), height - 1, width - 1)
+ return tvm.tirx.if_then_else(outside, zero, val)
+
+ return te.compute(
+ (num_roi, channel, pooled_size_h, pooled_size_w),
+ lambda i, c, ph, pw: _sample_common(
+ i,
+ c,
+ ph,
+ pw,
+ rois,
+ pooled_size_h,
+ pooled_size_w,
+ spatial_scale,
+ sample_ratio,
+ aligned,
+ dtype,
+ avg_mode,
+ _bilinear,
+ ),
+ tag="pool,roi_align_nchw",
+ )
+
+
+def roi_align_nhwc(data, rois, pooled_size, spatial_scale, mode,
sample_ratio=-1, aligned=False):
+ """ROI align operator in NHWC layout."""
+ avg_mode = mode in (b"avg", "avg", 0)
+ max_mode = mode in (b"max", "max", 1)
+ assert avg_mode or max_mode, "Mode must be avg or max. Please pass in a
valid mode."
+
+ _, height, width, channel = data.shape
+ num_roi, _ = rois.shape
+ dtype = rois.dtype
+
+ if isinstance(pooled_size, int):
+ pooled_size_h = pooled_size_w = pooled_size
+ else:
+ pooled_size_h, pooled_size_w = pooled_size
+
+ height_f = tvm.tirx.Cast(dtype, height)
+ width_f = tvm.tirx.Cast(dtype, width)
+ zero = tvm.tirx.const(0.0, data.dtype)
+
+ def _bilinear(n, c, y, x):
+ outside = tvm.tirx.any(y < -1.0, x < -1.0, y > height_f, x > width_f)
+ y = te.min(te.max(y, 0.0), tvm.tirx.Cast(dtype, height - 1))
+ x = te.min(te.max(x, 0.0), tvm.tirx.Cast(dtype, width - 1))
+ val = bilinear_sample_nhwc(data, (n, y, x, c), height - 1, width - 1)
+ return tvm.tirx.if_then_else(outside, zero, val)
+
+ return te.compute(
+ (num_roi, pooled_size_h, pooled_size_w, channel),
+ lambda i, ph, pw, c: _sample_common(
+ i,
+ c,
+ ph,
+ pw,
+ rois,
+ pooled_size_h,
+ pooled_size_w,
+ spatial_scale,
+ sample_ratio,
+ aligned,
+ dtype,
+ avg_mode,
+ _bilinear,
+ ),
+ tag="pool,roi_align_nhwc",
+ )
+
+
+def roi_align(
+ data,
+ rois,
+ pooled_size,
+ spatial_scale,
+ mode="avg",
+ sample_ratio=-1,
+ aligned=False,
+ layout="NCHW",
+):
+ """ROI align operator."""
+ if layout == "NCHW":
+ return roi_align_nchw(data, rois, pooled_size, spatial_scale, mode,
sample_ratio, aligned)
+ if layout == "NHWC":
+ return roi_align_nhwc(data, rois, pooled_size, spatial_scale, mode,
sample_ratio, aligned)
+ raise ValueError(f"Unsupported layout for roi_align: {layout}")
diff --git a/src/relax/op/vision/roi_align.cc b/src/relax/op/vision/roi_align.cc
new file mode 100644
index 0000000000..ae5185d6d4
--- /dev/null
+++ b/src/relax/op/vision/roi_align.cc
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_align.cc
+ * \brief ROI Align operators.
+ */
+
+#include "roi_align.h"
+
+#include <tvm/ffi/reflection/registry.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+TVM_FFI_STATIC_INIT_BLOCK() { ROIAlignAttrs::RegisterReflection(); }
+
+Expr roi_align(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double
spatial_scale,
+ int sample_ratio, bool aligned, ffi::String layout, ffi::String
mode) {
+ if (pooled_size.size() == 1) {
+ pooled_size.push_back(pooled_size[0]);
+ }
+ TVM_FFI_ICHECK_EQ(pooled_size.size(), 2)
+ << "The input pooled_size length is expected to be 2. However, the given
pooled_size is "
+ << pooled_size;
+
+ auto attrs = ffi::make_object<ROIAlignAttrs>();
+ attrs->pooled_size = std::move(pooled_size);
+ attrs->spatial_scale = spatial_scale;
+ attrs->sample_ratio = sample_ratio;
+ attrs->aligned = aligned;
+ attrs->layout = layout;
+ attrs->mode = mode;
+
+ static const Op& op = Op::Get("relax.vision.roi_align");
+ return Call(op, {std::move(data), std::move(rois)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.vision.roi_align", roi_align);
+}
+
+StructInfo InferStructInfoROIAlign(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign expects two arguments, while the given
number of arguments is "
+ << call->args.size());
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* rois_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[1]);
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign expects the input data to be a Tensor, while
the given data is "
+ << call->args[0]->GetTypeKey());
+ }
+ if (rois_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign expects the rois to be a Tensor, while the
given rois is "
+ << call->args[1]->GetTypeKey());
+ }
+ if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim != 4) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign expects the input data to be 4-D, while the
given data has ndim "
+ << data_sinfo->ndim);
+ }
+ if (!rois_sinfo->IsUnknownNdim() && rois_sinfo->ndim != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign expects the rois tensor to be 2-D, while the
given rois has ndim "
+ << rois_sinfo->ndim);
+ }
+
+ const auto* attrs = call->attrs.as<ROIAlignAttrs>();
+ TVM_FFI_ICHECK(attrs != nullptr) << "Invalid ROIAlign attrs";
+ if (attrs->layout != "NCHW" && attrs->layout != "NHWC") {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign only supports NCHW and NHWC layout, but got
" << attrs->layout);
+ }
+ if (attrs->mode != "avg" && attrs->mode != "max") {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign only supports avg and max mode, but got " <<
attrs->mode);
+ }
+
+ const auto* rois_shape = rois_sinfo->shape.as<ShapeExprNode>();
+ if (rois_shape != nullptr) {
+ const auto* last_dim = rois_shape->values[1].as<IntImmNode>();
+ if (last_dim != nullptr && last_dim->value != 5) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "ROIAlign expects rois to have shape (num_roi, 5),
but got last "
+ "dimension "
+ << last_dim->value);
+ }
+ }
+
+ if (data_sinfo->shape.as<ShapeExprNode>() == nullptr || rois_shape ==
nullptr) {
+ return TensorStructInfo(data_sinfo->dtype, 4, data_sinfo->vdevice);
+ }
+
+ ffi::Array<PrimExpr> data_shape =
data_sinfo->shape.as<ShapeExprNode>()->values;
+ ffi::Array<PrimExpr> out_shape;
+ if (attrs->layout == "NCHW") {
+ out_shape = {rois_shape->values[0], data_shape[1],
Integer(attrs->pooled_size[0]),
+ Integer(attrs->pooled_size[1])};
+ } else {
+ out_shape = {rois_shape->values[0], Integer(attrs->pooled_size[0]),
+ Integer(attrs->pooled_size[1]), data_shape[3]};
+ }
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.vision.roi_align")
+ .set_attrs_type<ROIAlignAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("rois", "Tensor",
+ "The input rois with shape (num_roi, 5) in [batch_idx, x1,
y1, x2, y2] format.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoROIAlign)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/op/vision/roi_align.h b/src/relax/op/vision/roi_align.h
new file mode 100644
index 0000000000..e2b861ac64
--- /dev/null
+++ b/src/relax/op/vision/roi_align.h
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file roi_align.h
+ * \brief The functions to make Relax ROI Align operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_VISION_ROI_ALIGN_H_
+#define TVM_RELAX_OP_VISION_ROI_ALIGN_H_
+
+#include <tvm/relax/attrs/vision.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief ROI Align operator. */
+Expr roi_align(Expr data, Expr rois, ffi::Array<int64_t> pooled_size, double
spatial_scale,
+ int sample_ratio, bool aligned, ffi::String layout, ffi::String
mode);
+
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_OP_VISION_ROI_ALIGN_H_
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index e1cadb9d02..7a3548b4cf 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -20,6 +20,7 @@ import operator
import numpy as np
import pytest
import torch
+import torchvision
from torch import nn
from torch.export import export
from torch.nn import Module
@@ -8746,6 +8747,65 @@ def test_grid_sample():
verify_model(GridSample(), example_args, {}, expected)
+def test_torchvision_roi_align():
+ class ROIAlign(Module):
+ def forward(self, input, rois):
+ return torchvision.ops.roi_align(
+ input,
+ rois,
+ output_size=(3, 3),
+ spatial_scale=1.0,
+ sampling_ratio=2,
+ aligned=False,
+ )
+
+ @tvm.script.ir_module
+ class expected:
+ @R.function
+ def main(
+ input_1: R.Tensor((1, 3, 8, 8), dtype="float32"),
+ rois: R.Tensor((2, 5), dtype="float32"),
+ ) -> R.Tuple(R.Tensor((2, 3, 3, 3), dtype="float32")):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 3, 3), dtype="float32") =
R.vision.roi_align(
+ input_1,
+ rois,
+ pooled_size=(3, 3),
+ spatial_scale=1.0,
+ sample_ratio=2,
+ layout="NCHW",
+ mode="avg",
+ )
+ gv: R.Tuple(R.Tensor((2, 3, 3, 3), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (
+ torch.randn(1, 3, 8, 8, dtype=torch.float32),
+ torch.tensor([[0.0, 1.0, 1.0, 6.0, 6.0], [0.0, 0.5, 0.5, 7.0, 7.0]],
dtype=torch.float32),
+ )
+ verify_model(ROIAlign(), example_args, {}, expected)
+
+
+def test_torchvision_roi_align_aligned():
+ class ROIAlign(Module):
+ def forward(self, input, rois):
+ return torchvision.ops.roi_align(
+ input,
+ rois,
+ output_size=(1, 1),
+ spatial_scale=1.0,
+ sampling_ratio=2,
+ aligned=True,
+ )
+
+ example_args = (
+ torch.arange(16, dtype=torch.float32).reshape(1, 1, 4, 4),
+ torch.tensor([[0.0, 1.0, 1.0, 1.2, 1.2]], dtype=torch.float32),
+ )
+ verify_model_numerically(ROIAlign(), example_args, rtol=1e-5, atol=1e-5)
+
+
def test_upsample_nearest2d():
class UpsampleNearest2dScale(Module):
def forward(self, input):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index f56adfbfdb..86fa533874 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4010,6 +4010,7 @@ def test_nms_score_threshold():
tvm_selected[:min_rows], ort_selected[:min_rows], rtol=1e-5,
atol=1e-5
)
+
@pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
@pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
@pytest.mark.parametrize("align_corners", [0, 1])
@@ -4052,6 +4053,7 @@ def test_grid_sample(mode, padding_mode, align_corners):
opset=16,
)
+
def test_grid_sample_linear_mode_translation():
"""Test that ONNX mode='linear' is correctly translated to 'bilinear'.
@@ -4078,7 +4080,9 @@ def test_grid_sample_linear_mode_translation():
helper.make_tensor_value_info("grid", TensorProto.FLOAT,
grid_shape),
],
outputs=[
- helper.make_tensor_value_info("Y", TensorProto.FLOAT, [x_shape[0],
x_shape[1], grid_shape[1], grid_shape[2]]),
+ helper.make_tensor_value_info(
+ "Y", TensorProto.FLOAT, [x_shape[0], x_shape[1],
grid_shape[1], grid_shape[2]]
+ ),
],
)
@@ -4113,7 +4117,9 @@ def test_grid_sample_cubic_mode_translation():
helper.make_tensor_value_info("grid", TensorProto.FLOAT,
grid_shape),
],
outputs=[
- helper.make_tensor_value_info("Y", TensorProto.FLOAT, [x_shape[0],
x_shape[1], grid_shape[1], grid_shape[2]]),
+ helper.make_tensor_value_info(
+ "Y", TensorProto.FLOAT, [x_shape[0], x_shape[1],
grid_shape[1], grid_shape[2]]
+ ),
],
)
@@ -4122,6 +4128,54 @@ def test_grid_sample_cubic_mode_translation():
# Verify 'cubic' was translated to 'bicubic' in the Relax IR
assert 'method="bicubic"' in str(tvm_model)
+
[email protected](
+ ("coordinate_transformation_mode", "rois"),
+ [
+ (
+ "output_half_pixel",
+ np.array([[1.0, 1.0, 6.0, 6.0], [2.0, 0.5, 7.0, 7.0]],
dtype="float32"),
+ ),
+ ("half_pixel", np.array([[1.0, 1.0, 1.2, 1.2], [2.0, 0.5, 1.1, 1.1]],
dtype="float32")),
+ ],
+)
+def test_roi_align(coordinate_transformation_mode, rois):
+ x_shape = [1, 4, 8, 8]
+ rois_shape = [2, 4]
+ batch_indices_shape = [2]
+ out_shape = [2, 4, 3, 3]
+
+ node = helper.make_node(
+ "RoiAlign",
+ inputs=["X", "rois", "batch_indices"],
+ outputs=["Y"],
+ output_height=3,
+ output_width=3,
+ sampling_ratio=2,
+ spatial_scale=1.0,
+ mode="avg",
+ coordinate_transformation_mode=coordinate_transformation_mode,
+ )
+
+ graph = helper.make_graph(
+ [node],
+ "roi_align_test",
+ inputs=[
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, x_shape),
+ helper.make_tensor_value_info("rois", TensorProto.FLOAT,
rois_shape),
+ helper.make_tensor_value_info("batch_indices", TensorProto.INT64,
batch_indices_shape),
+ ],
+ outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT,
out_shape)],
+ )
+
+ model = helper.make_model(graph, producer_name="roi_align_test")
+ inputs = {
+ "X": rg.standard_normal(size=x_shape).astype("float32"),
+ "rois": rois,
+ "batch_indices": np.array([0, 0], dtype="int64"),
+ }
+ check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()
-
diff --git a/tests/python/relax/test_op_vision.py
b/tests/python/relax/test_op_vision.py
index 753ee14140..b902518b49 100644
--- a/tests/python/relax/test_op_vision.py
+++ b/tests/python/relax/test_op_vision.py
@@ -21,6 +21,7 @@ import pytest
import tvm
import tvm.testing
from tvm import TVMError, relax, tirx
+from tvm.ir import Op
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R
@@ -30,6 +31,173 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+def test_roi_align_op_correctness():
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+ assert relax.op.vision.roi_align(x, rois, (7, 7), 1.0).op ==
Op.get("relax.vision.roi_align")
+
+
+def test_roi_align_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
+ rois = relax.Var("rois", R.Tensor((5, 5), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.vision.roi_align(x0, rois, (7, 7), 0.25),
+ relax.TensorStructInfo((5, 3, 7, 7), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.vision.roi_align(x1, rois, (5, 7), 1.0, layout="NHWC"),
+ relax.TensorStructInfo((5, 5, 7, 3), "float32"),
+ )
+
+
+def test_roi_align_infer_struct_info_aligned():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((5, 5), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.vision.roi_align(x, rois, (7, 7), 1.0, aligned=True),
+ relax.TensorStructInfo((5, 3, 7, 7), "float32"),
+ )
+
+
+def test_roi_align_infer_struct_info_shape_var():
+ bb = relax.BlockBuilder()
+ n = tirx.Var("n", "int64")
+ c = tirx.Var("c", "int64")
+ h = tirx.Var("h", "int64")
+ w = tirx.Var("w", "int64")
+ num_roi = tirx.Var("num_roi", "int64")
+
+ x = relax.Var("x", R.Tensor((n, c, h, w), "float32"))
+ rois = relax.Var("rois", R.Tensor((num_roi, 5), "float32"))
+
+ _check_inference(
+ bb,
+ relax.op.vision.roi_align(x, rois, (7, 7), 0.5),
+ relax.TensorStructInfo((num_roi, c, 7, 7), "float32"),
+ )
+
+
+def test_roi_align_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois0 = relax.Var("rois", R.Tensor((4,), "float32"))
+ rois1 = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_align(x0, rois1, (7, 7), 1.0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_align(x1, rois0, (7, 7), 1.0))
+
+
+def test_roi_align_wrong_rois_last_dim():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((4, 4), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_align(x, rois, (7, 7), 1.0))
+
+
+def test_roi_align_wrong_layout():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+ rois = relax.Var("rois", R.Tensor((4, 5), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.vision.roi_align(x, rois, (7, 7), 1.0,
layout="HWCN"))
+
+
+def test_roi_align_legalize():
+ @tvm.script.ir_module
+ class ROIAlign:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 8, 8), "float32"),
+ rois: R.Tensor((2, 5), "float32"),
+ ) -> R.Tensor((2, 2, 3, 3), "float32"):
+ gv: R.Tensor((2, 2, 3, 3), "float32") = R.vision.roi_align(
+ x,
+ rois,
+ pooled_size=(3, 3),
+ spatial_scale=1.0,
+ sample_ratio=2,
+ layout="NCHW",
+ mode="avg",
+ )
+ return gv
+
+ mod = LegalizeOps()(ROIAlign)
+ assert "call_tir" in str(mod)
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TensorStructInfo((2, 2, 3, 3), "float32"),
+ )
+
+
+def test_roi_align_legalize_aligned():
+ @tvm.script.ir_module
+ class ROIAlign:
+ @R.function
+ def main(
+ x: R.Tensor((1, 1, 4, 4), "float32"),
+ rois: R.Tensor((1, 5), "float32"),
+ ) -> R.Tensor((1, 1, 1, 1), "float32"):
+ gv: R.Tensor((1, 1, 1, 1), "float32") = R.vision.roi_align(
+ x,
+ rois,
+ pooled_size=(1, 1),
+ spatial_scale=1.0,
+ sample_ratio=2,
+ aligned=True,
+ layout="NCHW",
+ mode="avg",
+ )
+ return gv
+
+ mod = LegalizeOps()(ROIAlign)
+ assert "call_tir" in str(mod)
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TensorStructInfo((1, 1, 1, 1), "float32"),
+ )
+
+
+def test_roi_align_legalize_sample_ratio_zero():
+ @tvm.script.ir_module
+ class ROIAlign:
+ @R.function
+ def main(
+ x: R.Tensor((1, 2, 8, 8), "float32"),
+ rois: R.Tensor((1, 5), "float32"),
+ ) -> R.Tensor((1, 2, 2, 2), "float32"):
+ gv: R.Tensor((1, 2, 2, 2), "float32") = R.vision.roi_align(
+ x,
+ rois,
+ pooled_size=(2, 2),
+ spatial_scale=1.0,
+ sample_ratio=0,
+ layout="NCHW",
+ mode="avg",
+ )
+ return gv
+
+ mod = LegalizeOps()(ROIAlign)
+ assert "call_tir" in str(mod)
+ tvm.ir.assert_structural_equal(
+ mod["main"].ret_struct_info,
+ relax.TensorStructInfo((1, 2, 2, 2), "float32"),
+ )
+
+
def test_all_class_non_max_suppression_infer_struct_info():
bb = relax.BlockBuilder()
batch_size, num_classes, num_boxes = 10, 8, 5
diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py
b/tests/python/relax/test_tvmscript_parser_op_vision.py
index 10817c1dc4..c4e8ff0c9d 100644
--- a/tests/python/relax/test_tvmscript_parser_op_vision.py
+++ b/tests/python/relax/test_tvmscript_parser_op_vision.py
@@ -75,5 +75,37 @@ def test_all_class_non_max_suppression():
_check(foo, bb.get()["foo"])
+def test_roi_align():
+ @R.function
+ def foo(
+ x: R.Tensor((1, 2, 8, 8), "float32"),
+ rois: R.Tensor((2, 5), "float32"),
+ ) -> R.Tensor((2, 2, 3, 3), "float32"):
+ gv: R.Tensor((2, 2, 3, 3), "float32") = R.vision.roi_align(
+ x,
+ rois,
+ pooled_size=(3, 3),
+ spatial_scale=1.0,
+ sample_ratio=2,
+ layout="NCHW",
+ mode="avg",
+ )
+ return gv
+
+ x = relax.Var("x", R.Tensor((1, 2, 8, 8), "float32"))
+ rois = relax.Var("rois", R.Tensor((2, 5), "float32"))
+
+ bb = relax.BlockBuilder()
+ with bb.function("foo", [x, rois]):
+ gv = bb.emit(
+ relax.op.vision.roi_align(
+ x, rois, (3, 3), 1.0, sample_ratio=2, layout="NCHW", mode="avg"
+ )
+ )
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
if __name__ == "__main__":
tvm.testing.main()