This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d36a45b9f [Relax][ONNX] Add image.resize3d op and wire 5D Resize 
(#18931)
4d36a45b9f is described below

commit 4d36a45b9f5d3a52b1337c77bd6162fd71796c87
Author: Dayuxiaoshui <[email protected]>
AuthorDate: Thu Mar 26 12:37:25 2026 +0800

    [Relax][ONNX] Add image.resize3d op and wire 5D Resize (#18931)
    
    ## Summary
    
    - Add Relax `image.resize3d` end-to-end: attrs, C++ op
    registration/inference, Python API, and legalization to
    `topi.image.resize3d`.
    - Update ONNX 5D `Resize` to emit `relax.image.resize3d` instead of
    direct `emit_te(topi.image.resize3d)`.
    - Reuse the existing `resize2d` implementation pattern, which let us
    move faster while keeping behavior consistent and risk low.
    - Add tests for op inference, TVMScript parser, legalization, ONNX
    import, and `resize3d` negative/error cases.
    
    ## Test Plan
    
    - `python3 -m pytest -q tests/python/relax/test_op_image.py -k
    'resize3d'`
    - `python3 -m pytest -q
    tests/python/relax/test_transform_legalize_ops_image.py -k 'resize3d'`
    - `python3 -m pytest -q
    tests/python/relax/test_tvmscript_parser_op_image.py -k 'resize3d'`
    - `python3 -m pytest -q tests/python/relax/test_frontend_onnx.py -k
    'resize_nd_sizes or resize_5d_emits_relax_resize3d'`
---
 include/tvm/relax/attrs/image.h                    |  49 ++++++++
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |  21 ++--
 python/tvm/relax/op/image/__init__.py              |   2 +-
 python/tvm/relax/op/image/image.py                 |  52 +++++++++
 python/tvm/relax/transform/legalize_ops/image.py   |  17 +++
 src/relax/op/image/resize.cc                       | 123 ++++++++++++++++++++-
 src/relax/op/image/resize.h                        |   6 +
 tests/python/relax/test_frontend_onnx.py           |  31 ++++++
 tests/python/relax/test_op_image.py                | 110 ++++++++++++++++++
 .../relax/test_transform_legalize_ops_image.py     |  37 +++++++
 .../python/relax/test_tvmscript_parser_op_image.py |  17 +++
 11 files changed, 448 insertions(+), 17 deletions(-)

diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h
index b367ce5843..52aac58dcd 100644
--- a/include/tvm/relax/attrs/image.h
+++ b/include/tvm/relax/attrs/image.h
@@ -78,6 +78,55 @@ struct Resize2DAttrs : public 
AttrsNodeReflAdapter<Resize2DAttrs> {
   TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize2DAttrs", 
Resize2DAttrs, BaseAttrsNode);
 };  // struct Resize2dAttrs
 
+/*! \brief Attributes used in image resize3d operator */
+struct Resize3DAttrs : public AttrsNodeReflAdapter<Resize3DAttrs> {
+  ffi::Array<FloatImm> roi;
+  ffi::String layout;
+  ffi::String method;
+  ffi::String coordinate_transformation_mode;
+  ffi::String rounding_method;
+  double cubic_alpha;
+  int cubic_exclude;
+  double extrapolation_value;
+  DataType out_dtype;
+
+  static void RegisterReflection() {
+    namespace refl = tvm::ffi::reflection;
+    refl::ObjectDef<Resize3DAttrs>()
+        .def_ro("roi", &Resize3DAttrs::roi,
+                "Region of Interest for coordinate transformation mode 
'tf_crop_and_resize'")
+        .def_ro("layout", &Resize3DAttrs::layout,
+                "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', 
etc."
+                "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, 
height, and width"
+                "dimensions respectively. Resize is applied on the 'D', 'H' 
and"
+                "'W' dimensions.")
+        .def_ro("method", &Resize3DAttrs::method,
+                "Specify the mode to use for scaling."
+                "nearest_neighbor -  Nearest Neighbor"
+                "linear - Trilinear Interpolation"
+                "cubic - Tricubic Interpolation")
+        .def_ro("coordinate_transformation_mode", 
&Resize3DAttrs::coordinate_transformation_mode,
+                "Describes how to transform the coordinate in the resized 
tensor"
+                "to the coordinate in the original tensor."
+                "Refer to the ONNX Resize operator specification for details"
+                "Available options are half_pixel, align_corners and 
asymmetric")
+        .def_ro("rounding_method", &Resize3DAttrs::rounding_method,
+                "indicates how to find the \"nearest\" pixel in 
nearest_neighbor method"
+                "Available options are round, floor, and ceil.")
+        .def_ro("cubic_alpha", &Resize3DAttrs::cubic_alpha,
+                "Spline Coefficient for Tricubic Interpolation")
+        .def_ro("cubic_exclude", &Resize3DAttrs::cubic_exclude,
+                "Flag to exclude exterior of the image during tricubic 
interpolation")
+        .def_ro("extrapolation_value", &Resize3DAttrs::extrapolation_value,
+                "Value to return when roi is outside of the image")
+        .def_ro(
+            "out_dtype", &Resize3DAttrs::out_dtype,
+            "The dtype of the output tensor. It it is not specified, the 
output will have the same "
+            "dtype as input if not specified.");
+  }
+  TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.Resize3DAttrs", 
Resize3DAttrs, BaseAttrsNode);
+};  // struct Resize3DAttrs
+
 /*! \brief Attributes used in image grid_sample operator */
 struct GridSampleAttrs : public AttrsNodeReflAdapter<GridSampleAttrs> {
   ffi::String method;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index f08505951d..d828025e0a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2397,18 +2397,17 @@ class Resize(OnnxOpConverter):
                 extrapolation_value=extrapolation_value,
             )
         else:  # ndims == 5
-            return bb.emit_te(
-                topi.image.resize3d,
+            return relax.op.image.resize3d(
                 x,
-                roi,
-                sizes,
-                "NCDHW",
-                topi_mode,
-                coord_mode,
-                rounding_method,
-                cubic_coeff_a,
-                exclude_outside,
-                extrapolation_value,
+                size=relax.ShapeExpr(sizes),
+                roi=roi,
+                layout="NCDHW",
+                method=relax_mode,
+                coordinate_transformation_mode=coord_mode,
+                rounding_method=rounding_method,
+                cubic_alpha=cubic_coeff_a,
+                cubic_exclude=exclude_outside,
+                extrapolation_value=extrapolation_value,
             )
 
 
diff --git a/python/tvm/relax/op/image/__init__.py 
b/python/tvm/relax/op/image/__init__.py
index e5a5b74b58..6b02c32199 100644
--- a/python/tvm/relax/op/image/__init__.py
+++ b/python/tvm/relax/op/image/__init__.py
@@ -17,4 +17,4 @@
 # under the License.
 """Image operators."""
 
-from .image import grid_sample, resize2d
+from .image import grid_sample, resize2d, resize3d
diff --git a/python/tvm/relax/op/image/image.py 
b/python/tvm/relax/op/image/image.py
index 5e48bb6522..b267f40709 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -130,6 +130,58 @@ def resize2d(
     )
 
 
+def resize3d(
+    data: Expr,
+    size: Expr | PrimExprLike | tuple[PrimExprLike],
+    roi: float | tuple[float] | None = None,
+    layout: str = "NCDHW",
+    method: str = "linear",
+    coordinate_transformation_mode: str = "half_pixel",
+    rounding_method: str = "",
+    cubic_alpha: float = -0.75,
+    cubic_exclude: int = 0,
+    extrapolation_value: float = 0.0,
+    out_dtype: str | DataType | None = None,
+) -> Expr:
+    """Image resize3d operator.
+
+    This operator takes data as input and does 3D scaling to the given output 
size.
+    In the default case, where data layout is `NCDHW`
+    with data of shape (n, c, d, h, w),
+    the output has shape (n, c, size[0], size[1], size[2]).
+    """
+    if roi is None:
+        roi = (0.0, 0.0, 0.0, 0.0, 0.0, 0.0)  # type: ignore
+    elif isinstance(roi, float):
+        roi = (roi, roi, roi, roi, roi, roi)  # type: ignore
+    elif isinstance(roi, tuple | list):
+        roi = tuple(val if isinstance(val, float) else float(val) for val in 
roi)
+    else:
+        raise NotImplementedError(f"Unsupported roi type {type(roi)}")
+
+    if isinstance(size, int | PrimExpr):
+        size = (size, size, size)
+    if isinstance(size, tuple | list):
+        if len(size) == 1:
+            size = ShapeExpr([size[0], size[0], size[0]])
+        else:
+            size = ShapeExpr(size)
+
+    return _ffi_api.resize3d(  # type: ignore
+        data,
+        size,
+        roi,
+        layout,
+        method,
+        coordinate_transformation_mode,
+        rounding_method,
+        cubic_alpha,
+        cubic_exclude,
+        extrapolation_value,
+        out_dtype,
+    )
+
+
 def grid_sample(
     data: Expr,
     grid: Expr,
diff --git a/python/tvm/relax/transform/legalize_ops/image.py 
b/python/tvm/relax/transform/legalize_ops/image.py
index 2ce22b424e..1e7aaebceb 100644
--- a/python/tvm/relax/transform/legalize_ops/image.py
+++ b/python/tvm/relax/transform/legalize_ops/image.py
@@ -52,3 +52,20 @@ def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr:
         padding_mode=call.attrs.padding_mode,
         align_corners=call.attrs.align_corners,
     )
+
+
+@register_legalize("relax.image.resize3d")
+def _image_resize3d(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.image.resize3d,
+        call.args[0],
+        roi=call.attrs.roi,
+        size=call.args[1],
+        layout=call.attrs.layout,
+        method=call.attrs.method,
+        
coordinate_transformation_mode=call.attrs.coordinate_transformation_mode,
+        rounding_method=call.attrs.rounding_method,
+        bicubic_alpha=call.attrs.cubic_alpha,
+        bicubic_exclude=call.attrs.cubic_exclude,
+        extrapolation_value=call.attrs.extrapolation_value,
+    )
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 15fdcf3eb9..fe9df47dd5 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -32,6 +32,7 @@ namespace tvm {
 namespace relax {
 
 TVM_FFI_STATIC_INIT_BLOCK() { Resize2DAttrs::RegisterReflection(); }
+TVM_FFI_STATIC_INIT_BLOCK() { Resize3DAttrs::RegisterReflection(); }
 
 /* relax.resize2d */
 
@@ -60,11 +61,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 }
 
 StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
-  if (call->args.size() != 1 && call->args.size() != 2) {
-    ctx->ReportFatal(
-        Diagnostic::Error(call)
-        << "Resize2D expects either one or two arguments, while the given 
number of arguments is "
-        << call->args.size());
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Resize2D expects 2 arguments, while the given number 
of arguments is "
+                     << call->args.size());
   }
 
   const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
@@ -149,6 +149,119 @@ TVM_REGISTER_OP("relax.image.resize2d")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.resize3d */
+
+Expr resize3d(Expr data, Expr size, ffi::Array<FloatImm> roi, ffi::String 
layout,
+              ffi::String method, ffi::String coordinate_transformation_mode,
+              ffi::String rounding_method, double cubic_alpha, int 
cubic_exclude,
+              double extrapolation_value, ffi::Optional<DataType> out_dtype) {
+  ObjectPtr<Resize3DAttrs> attrs = ffi::make_object<Resize3DAttrs>();
+  attrs->roi = std::move(roi);
+  attrs->layout = std::move(layout);
+  attrs->method = std::move(method);
+  attrs->coordinate_transformation_mode = 
std::move(coordinate_transformation_mode);
+  attrs->rounding_method = std::move(rounding_method);
+  attrs->cubic_alpha = cubic_alpha;
+  attrs->cubic_exclude = cubic_exclude;
+  attrs->extrapolation_value = extrapolation_value;
+  attrs->out_dtype = out_dtype.value_or(DataType::Void());
+
+  static const Op& op = Op::Get("relax.image.resize3d");
+  return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.image.resize3d", resize3d);
+}
+
+StructInfo InferStructInfoResize3D(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Resize3D expects 2 arguments, while the given number 
of arguments is "
+                     << call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+  const auto* size_value = call->args[1].as<ShapeExprNode>();
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Resize3D expects the input data to be a Tensor, while 
the given data is "
+                     << call->args[0]->GetTypeKey());
+  }
+  if (size_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "Resize3D expects the given output image size to be a Shape, while 
the given one is "
+        << call->args[1]->GetTypeKey());
+  }
+  if (size_sinfo->ndim != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "Resize3D expects the given 
output image size to "
+                                                "be a 3-dim shape, while the 
given one has ndim "
+                                             << size_sinfo->ndim);
+  }
+
+  const auto* attrs = call->attrs.as<Resize3DAttrs>();
+  auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->layout, 
 //
+                                                     /*tgt_layout=*/"NCDHW",   
  //
+                                                     /*tensor_name=*/"data");
+
+  DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : 
attrs->out_dtype;
+
+  ffi::Optional<ShapeExpr> data_shape = CheckNdimPerLayoutAndGetShape(
+      call, ctx, ffi::GetRef<TensorStructInfo>(data_sinfo), data_layout);
+  if (!data_shape.defined() || size_value == nullptr) {
+    return TensorStructInfo(out_dtype, data_layout.ndim(), 
data_sinfo->vdevice);
+  }
+
+  ffi::Array<PrimExpr> data_NCDHW_shape = 
data2NCDHW.ForwardShape(data_shape.value()->values);
+  ffi::Array<PrimExpr> out_NCDHW_shape(data_NCDHW_shape);
+  out_NCDHW_shape.Set(2, size_value->values[0]);
+  out_NCDHW_shape.Set(3, size_value->values[1]);
+  out_NCDHW_shape.Set(4, size_value->values[2]);
+
+  ffi::Array<PrimExpr> out_shape = data2NCDHW.BackwardShape(out_NCDHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype, 
data_sinfo->vdevice);
+}
+
+InferLayoutOutput InferLayoutResize3d(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  const auto& it = desired_layouts.find("relax.image.resize3d");
+  const auto* attrs = call->attrs.as<Resize3DAttrs>();
+  TVM_FFI_ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision data_layout;
+  ObjectPtr<Resize3DAttrs> new_attrs = ffi::make_object<Resize3DAttrs>(*attrs);
+
+  if (it != desired_layouts.end()) {
+    Layout desired_data_layout = (*it).second[0];
+    TVM_FFI_ICHECK_EQ(desired_data_layout.ndim(), 
desired_data_layout.ndim_primal())
+        << "Axis swap only";
+    data_layout = TransposeLike(InitialLayout(5), attrs->layout, 
desired_data_layout);
+    new_attrs->layout = (*it).second[0];
+  } else {
+    data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+    if (data_layout->layout.ndim() != data_layout->layout.ndim_primal()) {
+      data_layout = LayoutDecision(InitialLayout(5));
+    }
+    new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5), 
data_layout->layout).name();
+  }
+  return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, 
{data_layout},
+                           Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.image.resize3d")
+    .set_attrs_type<Resize3DAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("size", "Shape", "The output image shape.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize3D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutResize3d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.grid_sample */
 
 TVM_FFI_STATIC_INIT_BLOCK() { GridSampleAttrs::RegisterReflection(); }
diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h
index a208aae092..c769cf91f5 100644
--- a/src/relax/op/image/resize.h
+++ b/src/relax/op/image/resize.h
@@ -38,6 +38,12 @@ Expr resize2d(Expr data, Expr size, ffi::Array<FloatImm> 
roi, ffi::String layout
               ffi::String rounding_method, double cubic_alpha, int 
cubic_exclude,
               double extrapolation_value, ffi::Optional<DataType> out_dtype);
 
+/*! \brief Image resize3d operator. */
+Expr resize3d(Expr data, Expr size, ffi::Array<FloatImm> roi, ffi::String 
layout,
+              ffi::String method, ffi::String coordinate_transformation_mode,
+              ffi::String rounding_method, double cubic_alpha, int 
cubic_exclude,
+              double extrapolation_value, ffi::Optional<DataType> out_dtype);
+
 /*! \brief Image grid_sample operator. */
 Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
                  ffi::String padding_mode, bool align_corners);
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 8740720205..f56adfbfdb 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2839,6 +2839,37 @@ def test_resize_nd_sizes():
         check_correctness(model, opset=18)
 
 
+def test_resize_5d_emits_relax_resize3d():
+    resize_node = helper.make_node(
+        "Resize",
+        ["X", "", "", "sizes"],
+        ["Y"],
+        mode="nearest",
+        coordinate_transformation_mode="asymmetric",
+        nearest_mode="floor",
+    )
+    graph = helper.make_graph(
+        [resize_node],
+        "resize3d_ir_check",
+        inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 
3, 4, 5])],
+        initializer=[helper.make_tensor("sizes", TensorProto.INT64, [5], [1, 
1, 4, 6, 7])],
+        outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 
4, 6, 7])],
+    )
+    model = helper.make_model(graph, producer_name="resize3d_ir_check")
+    tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+
+    seen_resize3d = False
+
+    def _visit(expr):
+        nonlocal seen_resize3d
+        if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
+            if expr.op.name == "relax.image.resize3d":
+                seen_resize3d = True
+
+    relax.analysis.post_order_visit(tvm_model["main"].body, _visit)
+    assert seen_resize3d
+
+
 def test_einsum():
     eqn = "ij->i"
     einsum_node = helper.make_node("Einsum", ["x"], ["y"], equation=eqn)
diff --git a/tests/python/relax/test_op_image.py 
b/tests/python/relax/test_op_image.py
index 43ebc79298..6650fc359b 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -26,6 +26,8 @@ from tvm.script import relax as R
 def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     assert relax.op.image.resize2d(x, (28, 28)).op == 
Op.get("relax.image.resize2d")
+    y = relax.Var("y", R.Tensor((2, 3, 8, 16, 32), "float32"))
+    assert relax.op.image.resize3d(y, (4, 8, 12)).op == 
Op.get("relax.image.resize3d")
 
 
 def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
@@ -185,6 +187,114 @@ def test_resize2d_infer_struct_info_more_input_dtype():
     )
 
 
+def test_resize3d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 8, 16, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor((2, 4, 8, 16, 32, 8), "float32"))
+    x3 = relax.Var("x", R.Tensor("float32", ndim=5))
+    x4 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32", vdev0))
+
+    _check_inference(
+        bb,
+        relax.op.image.resize3d(x0, (4, 8, 12)),
+        relax.TensorStructInfo((2, 3, 4, 8, 12), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize3d(x4, (4, 8, 12)),
+        relax.TensorStructInfo((2, 3, 4, 8, 12), "float32", vdev0),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize3d(x0, 7),
+        relax.TensorStructInfo((2, 3, 7, 7, 7), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize3d(x1, (4, 8, 12), layout="NDHWC"),
+        relax.TensorStructInfo((2, 4, 8, 12, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize3d(x2, (4, 8, 12), layout="NCDHW8c"),
+        relax.TensorStructInfo((2, 4, 4, 8, 12, 8), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize3d(x0, (4, 8, 12), out_dtype="float16"),
+        relax.TensorStructInfo((2, 3, 4, 8, 12), "float16"),
+    )
+    _check_inference(
+        bb, relax.op.image.resize3d(x3, (4, 8, 12)), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+
+
+def test_resize3d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x, size=(4, 8, 12), 
layout="OIHW"))
+
+
+def test_resize3d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, size=(4, 8, 12), 
layout="NCDHW8c"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x1, size=(4, 8, 12), 
layout="NCDHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x2, size=(4, 8, 12)))
+
+
+def test_resize3d_wrong_size_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float16"))
+    s0 = relax.ShapeExpr((3, 3))
+    s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30, 30)))
+    s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s3 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s4 = relax.Var("s", relax.ShapeStructInfo(ndim=1))
+    s5 = relax.Var("s", relax.ShapeStructInfo(ndim=0))
+    s6 = relax.Var("s", relax.ShapeStructInfo())
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, (3, 3)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, s0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, s1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, s2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, s3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, s4))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, s5))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, s6))
+
+
+def test_resize3d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 8, 16, 32)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 8, 16, 32), 
"float32")))
+    x2 = relax.Var("x", R.Tensor((2, 3, 8, 16, 32), "float32"))
+    s0 = relax.Var("s", R.Tensor((3, 3, 3)))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x0, size=(4, 8, 12)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x1, size=(4, 8, 12)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize3d(x2, s0))
+
+
 def test_resize2d_infer_struct_info_wrong_layout_string():
     bb = relax.BlockBuilder()
     x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
diff --git a/tests/python/relax/test_transform_legalize_ops_image.py 
b/tests/python/relax/test_transform_legalize_ops_image.py
index 23c128eea5..48166d24c4 100644
--- a/tests/python/relax/test_transform_legalize_ops_image.py
+++ b/tests/python/relax/test_transform_legalize_ops_image.py
@@ -18,6 +18,7 @@
 
 import tvm
 import tvm.testing
+from tvm import relax
 from tvm.relax.transform import LegalizeOps
 from tvm.script import relax as R
 from tvm.script import tirx as T
@@ -101,5 +102,41 @@ def test_image_resize2d_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_image_resize3d():
+    # fmt: off
+    @tvm.script.ir_module
+    class Resize3D:
+        @R.function
+        def main(x: R.Tensor((2, 3, 8, 8, 8), "float32")) -> R.Tensor((2, 3, 
4, 6, 7), "float32"):
+            gv: R.Tensor((2, 3, 4, 6, 7), "float32") = R.image.resize3d(
+                x,
+                size=(4, 6, 7),
+                layout="NCDHW",
+                method="nearest_neighbor",
+                coordinate_transformation_mode="asymmetric",
+                rounding_method="floor",
+            )
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Resize3D)
+
+    seen_call_tir = False
+    seen_resize3d_relax_op = False
+
+    def _visit(expr):
+        nonlocal seen_call_tir, seen_resize3d_relax_op
+        if isinstance(expr, relax.Call):
+            if isinstance(expr.op, tvm.ir.Op):
+                if expr.op.name == "relax.call_tir":
+                    seen_call_tir = True
+                if expr.op.name == "relax.image.resize3d":
+                    seen_resize3d_relax_op = True
+
+    relax.analysis.post_order_visit(mod["main"].body, _visit)
+    assert seen_call_tir
+    assert not seen_resize3d_relax_op
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_image.py 
b/tests/python/relax/test_tvmscript_parser_op_image.py
index 47a7e3f196..9bee7cd1af 100644
--- a/tests/python/relax/test_tvmscript_parser_op_image.py
+++ b/tests/python/relax/test_tvmscript_parser_op_image.py
@@ -49,5 +49,22 @@ def test_resize2d():
     _check(foo, bb.get()["foo"])
 
 
+def test_resize3d():
+    @R.function
+    def foo(x: R.Tensor((2, 3, 8, 8, 8), "float32")) -> R.Tensor((2, 3, 4, 6, 
7), "float32"):
+        gv: R.Tensor((2, 3, 4, 6, 7), "float32") = R.image.resize3d(
+            x, size=(4, 6, 7), layout="NCDHW"
+        )
+        return gv
+
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 8, 8, 8), "float32"))
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.image.resize3d(x, (4, 6, 7), layout="NCDHW"))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to