This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d36a45b9f [Relax][ONNX] Add image.resize3d op and wire 5D Resize
(#18931)
4d36a45b9f is described below
commit 4d36a45b9f5d3a52b1337c77bd6162fd71796c87
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Thu Mar 26 12:37:25 2026 +0800
[Relax][ONNX] Add image.resize3d op and wire 5D Resize (#18931)
## Summary
- Add Relax `image.resize3d` end-to-end: attrs, C++ op
registration/inference, Python API, and legalization to
`topi.image.resize3d`.
- Update ONNX 5D `Resize` to emit `relax.image.resize3d` instead of
direct `emit_te(topi.image.resize3d)`.
- Reuse the existing `resize2d` implementation pattern, which let us
move faster while keeping behavior consistent and risk low.
- Add tests for op inference, TVMScript parser, legalization, ONNX
import, and `resize3d` negative/error cases.
## Test Plan
- `python3 -m pytest -q tests/python/relax/test_op_image.py -k
'resize3d'`
- `python3 -m pytest -q
tests/python/relax/test_transform_legalize_ops_image.py -k 'resize3d'`
- `python3 -m pytest -q
tests/python/relax/test_tvmscript_parser_op_image.py -k 'resize3d'`
- `python3 -m pytest -q tests/python/relax/test_frontend_onnx.py -k
'resize_nd_sizes or resize_5d_emits_relax_resize3d'`
---
include/tvm/relax/attrs/image.h | 49 ++++++++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 21 ++--
python/tvm/relax/op/image/__init__.py | 2 +-
python/tvm/relax/op/image/image.py | 52 +++++++++
python/tvm/relax/transform/legalize_ops/image.py | 17 +++
src/relax/op/image/resize.cc | 123 ++++++++++++++++++++-
src/relax/op/image/resize.h | 6 +
tests/python/relax/test_frontend_onnx.py | 31 ++++++
tests/python/relax/test_op_image.py | 110 ++++++++++++++++++
.../relax/test_transform_legalize_ops_image.py | 37 +++++++
.../python/relax/test_tvmscript_parser_op_image.py | 17 +++
11 files changed, 448 insertions(+), 17 deletions(-)
diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h
index b367ce5843..52aac58dcd 100644
--- a/include/tvm/relax/attrs/image.h
+++ b/include/tvm/relax/attrs/image.h
@@ -78,6 +78,55 @@ struct Resize2DAttrs : public
AttrsNodeReflAdapter<Resize2DAttrs> {
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs",
Resize2DAttrs, BaseAttrsNode);
}; // struct Resize2dAttrs
+/*! \brief Attributes used in image resize3d operator */
+struct Resize3DAttrs : public AttrsNodeReflAdapter<Resize3DAttrs> {
+ ffi::Array<FloatImm> roi;
+ ffi::String layout;
+ ffi::String method;
+ ffi::String coordinate_transformation_mode;
+ ffi::String rounding_method;
+ double cubic_alpha;
+ int cubic_exclude;
+ double extrapolation_value;
+ DataType out_dtype;
+
+ static void RegisterReflection() {
+ namespace refl = tvm::ffi::reflection;
+ refl::ObjectDef<Resize3DAttrs>()
+ .def_ro("roi", &Resize3DAttrs::roi,
+ "Region of Interest for coordinate transformation mode
'tf_crop_and_resize'")
+ .def_ro("layout", &Resize3DAttrs::layout,
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC',
etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth,
height, and width"
+ "dimensions respectively. Resize is applied on the 'D', 'H'
and"
+ "'W' dimensions.")
+ .def_ro("method", &Resize3DAttrs::method,
+ "Specify the mode to use for scaling."
+ "nearest_neighbor - Nearest Neighbor"
+ "linear - Trilinear Interpolation"
+ "cubic - Tricubic Interpolation")
+ .def_ro("coordinate_transformation_mode",
&Resize3DAttrs::coordinate_transformation_mode,
+ "Describes how to transform the coordinate in the resized
tensor"
+ "to the coordinate in the original tensor."
+ "Refer to the ONNX Resize operator specification for details"
+ "Available options are half_pixel, align_corners and
asymmetric")
+ .def_ro("rounding_method", &Resize3DAttrs::rounding_method,
+ "indicates how to find the \"nearest\" pixel in
nearest_neighbor method"
+ "Available options are round, floor, and ceil.")
+ .def_ro("cubic_alpha", &Resize3DAttrs::cubic_alpha,
+ "Spline Coefficient for Tricubic Interpolation")
+ .def_ro("cubic_exclude", &Resize3DAttrs::cubic_exclude,
+ "Flag to exclude exterior of the image during tricubic
interpolation")
+ .def_ro("extrapolation_value", &Resize3DAttrs::extrapolation_value,
+ "Value to return when roi is outside of the image")
+ .def_ro(
+ "out_dtype", &Resize3DAttrs::out_dtype,
+ "The dtype of the output tensor. It it is not specified, the
output will have the same "
+ "dtype as input if not specified.");
+ }
+ TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize3DAttrs",
Resize3DAttrs, BaseAttrsNode);
+}; // struct Resize3DAttrs
+
/*! \brief Attributes used in image grid_sample operator */
struct GridSampleAttrs : public AttrsNodeReflAdapter<GridSampleAttrs> {
ffi::String method;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index f08505951d..d828025e0a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2397,18 +2397,17 @@ class Resize(OnnxOpConverter):
extrapolation_value=extrapolation_value,
)
else: # ndims == 5
- return bb.emit_te(
- topi.image.resize3d,
+ return relax.op.image.resize3d(
x,
- roi,
- sizes,
- "NCDHW",
- topi_mode,
- coord_mode,
- rounding_method,
- cubic_coeff_a,
- exclude_outside,
- extrapolation_value,
+ size=relax.ShapeExpr(sizes),
+ roi=roi,
+ layout="NCDHW",
+ method=relax_mode,
+ coordinate_transformation_mode=coord_mode,
+ rounding_method=rounding_method,
+ cubic_alpha=cubic_coeff_a,
+ cubic_exclude=exclude_outside,
+ extrapolation_value=extrapolation_value,
)
diff --git a/python/tvm/relax/op/image/__init__.py
b/python/tvm/relax/op/image/__init__.py
index e5a5b74b58..6b02c32199 100644
--- a/python/tvm/relax/op/image/__init__.py
+++ b/python/tvm/relax/op/image/__init__.py
@@ -17,4 +17,4 @@
# under the License.
"""Image operators."""
-from .image import grid_sample, resize2d
+from .image import grid_sample, resize2d, resize3d
diff --git a/python/tvm/relax/op/image/image.py
b/python/tvm/relax/op/image/image.py
index 5e48bb6522..b267f40709 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -130,6 +130,58 @@ def resize2d(
)
+def resize3d(
+ data: Expr,
+ size: Expr | PrimExprLike | tuple[PrimExprLike],
+ roi: float | tuple[float] | None = None,
+ layout: str = "NCDHW",
+ method: str = "linear",
+ coordinate_transformation_mode: str = "half_pixel",
+ rounding_method: str = "",
+ cubic_alpha: float = -0.75,
+ cubic_exclude: int = 0,
+ extrapolation_value: float = 0.0,
+ out_dtype: str | DataType | None = None,
+) -> Expr:
+ """Image resize3d operator.
+
+ This operator takes data as input and does 3D scaling to the given output
size.
+ In the default case, where data layout is `NCDHW`
+ with data of shape (n, c, d, h, w),
+ the output has shape (n, c, size[0], size[1], size[2]).
+ """
+ if roi is None:
+ roi = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0) # type: ignore
+ elif isinstance(roi, float):
+ roi = (roi, roi, roi, roi, roi, roi) # type: ignore
+ elif isinstance(roi, tuple | list):
+ roi = tuple(val if isinstance(val, float) else float(val) for val in
roi)
+ else:
+ raise NotImplementedError(f"Unsupported roi type {type(roi)}")
+
+ if isinstance(size, int | PrimExpr):
+ size = (size, size, size)
+ if isinstance(size, tuple | list):
+ if len(size) == 1:
+ size = ShapeExpr([size[0], size[0], size[0]])
+ else:
+ size = ShapeExpr(size)
+
+ return _ffi_api.resize3d( # type: ignore
+ data,
+ size,
+ roi,
+ layout,
+ method,
+ coordinate_transformation_mode,
+ rounding_method,
+ cubic_alpha,
+ cubic_exclude,
+ extrapolation_value,
+ out_dtype,
+ )
+
+
def grid_sample(
data: Expr,
grid: Expr,
diff --git a/python/tvm/relax/transform/legalize_ops/image.py
b/python/tvm/relax/transform/legalize_ops/image.py
index 2ce22b424e..1e7aaebceb 100644
--- a/python/tvm/relax/transform/legalize_ops/image.py
+++ b/python/tvm/relax/transform/legalize_ops/image.py
@@ -52,3 +52,20 @@ def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr:
padding_mode=call.attrs.padding_mode,
align_corners=call.attrs.align_corners,
)
+
+
+@register_legalize("relax.image.resize3d")
+def _image_resize3d(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.image.resize3d,
+ call.args[0],
+ roi=call.attrs.roi,
+ size=call.args[1],
+ layout=call.attrs.layout,
+ method=call.attrs.method,
+
coordinate_transformation_mode=call.attrs.coordinate_transformation_mode,
+ rounding_method=call.attrs.rounding_method,
+ bicubic_alpha=call.attrs.cubic_alpha,
+ bicubic_exclude=call.attrs.cubic_exclude,
+ extrapolation_value=call.attrs.extrapolation_value,
+ )
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 15fdcf3eb9..fe9df47dd5 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -32,6 +32,7 @@ namespace tvm {
namespace relax {
TVM_FFI_STATIC_INIT_BLOCK() { Resize2DAttrs::RegisterReflection(); }
+TVM_FFI_STATIC_INIT_BLOCK() { Resize3DAttrs::RegisterReflection(); }
/* relax.resize2d */
@@ -60,11 +61,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
}
StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
- if (call->args.size() != 1 && call->args.size() != 2) {
- ctx->ReportFatal(
- Diagnostic::Error(call)
- << "Resize2D expects either one or two arguments, while the given
number of arguments is "
- << call->args.size());
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Resize2D expects 2 arguments, while the given number
of arguments is "
+ << call->args.size());
}
const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
@@ -149,6 +149,119 @@ TVM_REGISTER_OP("relax.image.resize2d")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.resize3d */
+
+Expr resize3d(Expr data, Expr size, ffi::Array<FloatImm> roi, ffi::String
layout,
+ ffi::String method, ffi::String coordinate_transformation_mode,
+ ffi::String rounding_method, double cubic_alpha, int
cubic_exclude,
+ double extrapolation_value, ffi::Optional<DataType> out_dtype) {
+ ObjectPtr<Resize3DAttrs> attrs = ffi::make_object<Resize3DAttrs>();
+ attrs->roi = std::move(roi);
+ attrs->layout = std::move(layout);
+ attrs->method = std::move(method);
+ attrs->coordinate_transformation_mode =
std::move(coordinate_transformation_mode);
+ attrs->rounding_method = std::move(rounding_method);
+ attrs->cubic_alpha = cubic_alpha;
+ attrs->cubic_exclude = cubic_exclude;
+ attrs->extrapolation_value = extrapolation_value;
+ attrs->out_dtype = out_dtype.value_or(DataType::Void());
+
+ static const Op& op = Op::Get("relax.image.resize3d");
+ return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.image.resize3d", resize3d);
+}
+
+StructInfo InferStructInfoResize3D(const Call& call, const BlockBuilder& ctx) {
+ if (call->args.size() != 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Resize3D expects 2 arguments, while the given number
of arguments is "
+ << call->args.size());
+ }
+
+ const auto* data_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+ const auto* size_value = call->args[1].as<ShapeExprNode>();
+ if (data_sinfo == nullptr) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "Resize3D expects the input data to be a Tensor, while
the given data is "
+ << call->args[0]->GetTypeKey());
+ }
+ if (size_sinfo == nullptr) {
+ ctx->ReportFatal(
+ Diagnostic::Error(call)
+ << "Resize3D expects the given output image size to be a Shape, while
the given one is "
+ << call->args[1]->GetTypeKey());
+ }
+ if (size_sinfo->ndim != 3) {
+ ctx->ReportFatal(Diagnostic::Error(call) << "Resize3D expects the given
output image size to "
+ "be a 3-dim shape, while the
given one has ndim "
+ << size_sinfo->ndim);
+ }
+
+ const auto* attrs = call->attrs.as<Resize3DAttrs>();
+ auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->layout,
//
+ /*tgt_layout=*/"NCDHW",
//
+ /*tensor_name=*/"data");
+
+ DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype :
attrs->out_dtype;
+
+ ffi::Optional<ShapeExpr> data_shape = CheckNdimPerLayoutAndGetShape(
+ call, ctx, ffi::GetRef<TensorStructInfo>(data_sinfo), data_layout);
+ if (!data_shape.defined() || size_value == nullptr) {
+ return TensorStructInfo(out_dtype, data_layout.ndim(),
data_sinfo->vdevice);
+ }
+
+ ffi::Array<PrimExpr> data_NCDHW_shape =
data2NCDHW.ForwardShape(data_shape.value()->values);
+ ffi::Array<PrimExpr> out_NCDHW_shape(data_NCDHW_shape);
+ out_NCDHW_shape.Set(2, size_value->values[0]);
+ out_NCDHW_shape.Set(3, size_value->values[1]);
+ out_NCDHW_shape.Set(4, size_value->values[2]);
+
+ ffi::Array<PrimExpr> out_shape = data2NCDHW.BackwardShape(out_NCDHW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), out_dtype,
data_sinfo->vdevice);
+}
+
+InferLayoutOutput InferLayoutResize3d(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ const auto& it = desired_layouts.find("relax.image.resize3d");
+ const auto* attrs = call->attrs.as<Resize3DAttrs>();
+ TVM_FFI_ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision data_layout;
+ ObjectPtr<Resize3DAttrs> new_attrs = ffi::make_object<Resize3DAttrs>(*attrs);
+
+ if (it != desired_layouts.end()) {
+ Layout desired_data_layout = (*it).second[0];
+ TVM_FFI_ICHECK_EQ(desired_data_layout.ndim(),
desired_data_layout.ndim_primal())
+ << "Axis swap only";
+ data_layout = TransposeLike(InitialLayout(5), attrs->layout,
desired_data_layout);
+ new_attrs->layout = (*it).second[0];
+ } else {
+ data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ if (data_layout->layout.ndim() != data_layout->layout.ndim_primal()) {
+ data_layout = LayoutDecision(InitialLayout(5));
+ }
+ new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5),
data_layout->layout).name();
+ }
+ return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])},
{data_layout},
+ Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.image.resize3d")
+ .set_attrs_type<Resize3DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("size", "Shape", "The output image shape.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize3D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize3d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.grid_sample */
TVM_FFI_STATIC_INIT_BLOCK() { GridSampleAttrs::RegisterReflection(); }
diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h
index a208aae092..c769cf91f5 100644
--- a/src/relax/op/image/resize.h
+++ b/src/relax/op/image/resize.h
@@ -38,6 +38,12 @@ Expr resize2d(Expr data, Expr size, ffi::Array<FloatImm>
roi, ffi::String layout
ffi::String rounding_method, double cubic_alpha, int
cubic_exclude,
double extrapolation_value, ffi::Optional<DataType> out_dtype);
+/*! \brief Image resize3d operator. */
+Expr resize3d(Expr data, Expr size, ffi::Array<FloatImm> roi, ffi::String
layout,
+ ffi::String method, ffi::String coordinate_transformation_mode,
+ ffi::String rounding_method, double cubic_alpha, int
cubic_exclude,
+ double extrapolation_value, ffi::Optional<DataType> out_dtype);
+
/*! \brief Image grid_sample operator. */
Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
ffi::String padding_mode, bool align_corners);
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 8740720205..f56adfbfdb 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2839,6 +2839,37 @@ def test_resize_nd_sizes():
check_correctness(model, opset=18)
+def test_resize_5d_emits_relax_resize3d():
+ resize_node = helper.make_node(
+ "Resize",
+ ["X", "", "", "sizes"],
+ ["Y"],
+ mode="nearest",
+ coordinate_transformation_mode="asymmetric",
+ nearest_mode="floor",
+ )
+ graph = helper.make_graph(
+ [resize_node],
+ "resize3d_ir_check",
+ inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1,
3, 4, 5])],
+ initializer=[helper.make_tensor("sizes", TensorProto.INT64, [5], [1,
1, 4, 6, 7])],
+ outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1,
4, 6, 7])],
+ )
+ model = helper.make_model(graph, producer_name="resize3d_ir_check")
+ tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+
+ seen_resize3d = False
+
+ def _visit(expr):
+ nonlocal seen_resize3d
+ if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
+ if expr.op.name == "relax.image.resize3d":
+ seen_resize3d = True
+
+ relax.analysis.post_order_visit(tvm_model["main"].body, _visit)
+ assert seen_resize3d
+
+
def test_einsum():
eqn = "ij->i"
einsum_node = helper.make_node("Einsum", ["x"], ["y"], equation=eqn)
diff --git a/tests/python/relax/test_op_image.py
b/tests/python/relax/test_op_image.py
index 43ebc79298..6650fc359b 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -26,6 +26,8 @@ from tvm.script import relax as R
def test_op_correctness():
x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
assert relax.op.image.resize2d(x, (28, 28)).op ==
Op.get("relax.image.resize2d")
+ y = relax.Var("y", R.Tensor((2, 3, 8, 16, 32), "float32"))
+ assert relax.op.image.resize3d(y, (4, 8, 12)).op ==
Op.get("relax.image.resize3d")
def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo:
relax.StructInfo):
@@ -185,6 +187,114 @@ def test_resize2d_infer_struct_info_more_input_dtype():
)
+def test_resize3d_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 8, 16, 32, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor((2, 4, 8, 16, 32, 8), "float32"))
+ x3 = relax.Var("x", R.Tensor("float32", ndim=5))
+ x4 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32", vdev0))
+
+ _check_inference(
+ bb,
+ relax.op.image.resize3d(x0, (4, 8, 12)),
+ relax.TensorStructInfo((2, 3, 4, 8, 12), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.resize3d(x4, (4, 8, 12)),
+ relax.TensorStructInfo((2, 3, 4, 8, 12), "float32", vdev0),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.resize3d(x0, 7),
+ relax.TensorStructInfo((2, 3, 7, 7, 7), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.resize3d(x1, (4, 8, 12), layout="NDHWC"),
+ relax.TensorStructInfo((2, 4, 8, 12, 3), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.resize3d(x2, (4, 8, 12), layout="NCDHW8c"),
+ relax.TensorStructInfo((2, 4, 4, 8, 12, 8), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.image.resize3d(x0, (4, 8, 12), out_dtype="float16"),
+ relax.TensorStructInfo((2, 3, 4, 8, 12), "float16"),
+ )
+ _check_inference(
+ bb, relax.op.image.resize3d(x3, (4, 8, 12)),
relax.TensorStructInfo(dtype="float32", ndim=5)
+ )
+
+
+def test_resize3d_infer_struct_info_wrong_layout_string():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x, size=(4, 8, 12),
layout="OIHW"))
+
+
+def test_resize3d_wrong_input_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32, 3), "float32"))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, size=(4, 8, 12),
layout="NCDHW8c"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x1, size=(4, 8, 12),
layout="NCDHW"))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x2, size=(4, 8, 12)))
+
+
+def test_resize3d_wrong_size_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float16"))
+ s0 = relax.ShapeExpr((3, 3))
+ s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30, 30)))
+ s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+ s3 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+ s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1))
+ s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0))
+ s6 = relax.Var("s", relax.ShapeStructInfo())
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, (3, 3)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, s0))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, s1))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, s2))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, s3))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, s4))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, s5))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, s6))
+
+
+def test_resize3d_infer_struct_info_wrong_input_type():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 8, 16, 32)))
+ x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 8, 16, 32),
"float32")))
+ x2 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+ s0 = relax.Var("s", R.Tensor((3, 3, 3)))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x0, size=(4, 8, 12)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x1, size=(4, 8, 12)))
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.image.resize3d(x2, s0))
+
+
def test_resize2d_infer_struct_info_wrong_layout_string():
bb = relax.BlockBuilder()
x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
diff --git a/tests/python/relax/test_transform_legalize_ops_image.py
b/tests/python/relax/test_transform_legalize_ops_image.py
index 23c128eea5..48166d24c4 100644
--- a/tests/python/relax/test_transform_legalize_ops_image.py
+++ b/tests/python/relax/test_transform_legalize_ops_image.py
@@ -18,6 +18,7 @@
import tvm
import tvm.testing
+from tvm import relax
from tvm.relax.transform import LegalizeOps
from tvm.script import relax as R
from tvm.script import tirx as T
@@ -101,5 +102,41 @@ def test_image_resize2d_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_image_resize3d():
+ # fmt: off
+ @tvm.script.ir_module
+ class Resize3D:
+ @R.function
+ def main(x: R.Tensor((2, 3, 8, 8, 8), "float32")) -> R.Tensor((2, 3,
4, 6, 7), "float32"):
+ gv: R.Tensor((2, 3, 4, 6, 7), "float32") = R.image.resize3d(
+ x,
+ size=(4, 6, 7),
+ layout="NCDHW",
+ method="nearest_neighbor",
+ coordinate_transformation_mode="asymmetric",
+ rounding_method="floor",
+ )
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Resize3D)
+
+ seen_call_tir = False
+ seen_resize3d_relax_op = False
+
+ def _visit(expr):
+ nonlocal seen_call_tir, seen_resize3d_relax_op
+ if isinstance(expr, relax.Call):
+ if isinstance(expr.op, tvm.ir.Op):
+ if expr.op.name == "relax.call_tir":
+ seen_call_tir = True
+ if expr.op.name == "relax.image.resize3d":
+ seen_resize3d_relax_op = True
+
+ relax.analysis.post_order_visit(mod["main"].body, _visit)
+ assert seen_call_tir
+ assert not seen_resize3d_relax_op
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_image.py
b/tests/python/relax/test_tvmscript_parser_op_image.py
index 47a7e3f196..9bee7cd1af 100644
--- a/tests/python/relax/test_tvmscript_parser_op_image.py
+++ b/tests/python/relax/test_tvmscript_parser_op_image.py
@@ -49,5 +49,22 @@ def test_resize2d():
_check(foo, bb.get()["foo"])
+def test_resize3d():
+ @R.function
+ def foo(x: R.Tensor((2, 3, 8, 8, 8), "float32")) -> R.Tensor((2, 3, 4, 6,
7), "float32"):
+ gv: R.Tensor((2, 3, 4, 6, 7), "float32") = R.image.resize3d(
+ x, size=(4, 6, 7), layout="NCDHW"
+ )
+ return gv
+
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3, 8, 8, 8), "float32"))
+ with bb.function("foo", [x]):
+ gv = bb.emit(relax.op.image.resize3d(x, (4, 6, 7), layout="NCDHW"))
+ bb.emit_func_output(gv)
+
+ _check(foo, bb.get()["foo"])
+
+
if __name__ == "__main__":
tvm.testing.main()