This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f2469e637 [Relax] Add affine_grid operator with PyTorch and ONNX 
frontend support (#18933)
2f2469e637 is described below

commit 2f2469e6371dd4c9f89cba3924d877d090230861
Author: HoYi <[email protected]>
AuthorDate: Fri Mar 27 10:38:10 2026 +0800

    [Relax] Add affine_grid operator with PyTorch and ONNX frontend support 
(#18933)
    
    ## Summary
    
    Add `relax.image.affine_grid` operator for Spatial Transformer Networks,
    along with PyTorch and ONNX frontend integration.
    
    TOPI compute (`topi.image.affine_grid`) already exists. This PR
    completes the Relax-level registration and frontend support, following
    the existing `resize2d` / `grid_sample` pattern.
    
    ## Changes
    
    **Relax op registration:**
    - C++ op function, FFI registration, and struct info inference
    (`resize.h`, `resize.cc`)
    - Python wrapper with flexible size parameter handling (`image.py`)
    - Legalization to `topi.image.affine_grid` with `PrimExpr` → `int`
    conversion
    - Op-level tests (struct info inference + e2e numerical correctness) and
    legalization test
    
    **PyTorch frontend:**
    - Converter for `aten.affine_grid_generator.default`
    - Layout conversion from TVM `[N,2,H,W]` to PyTorch `[N,H,W,2]` via
    `permute_dims`
    - Single-kernel path is 5.6x faster than the decomposed path (30+ ops)
    - Structural IR test and numerical correctness test
    
    **ONNX frontend:**
    - `AffineGrid` converter with `_impl_v20` (opset 20, when the op was
    first introduced)
    - Support for constant size tensor `[N,C,H,W]`
    - Layout conversion from TVM `[N,2,H,W]` to ONNX `[N,H,W,2]`
    - End-to-end correctness test against ONNX Runtime
    
    ## Limitations
    
    - Only `align_corners=True` is supported (matches current TOPI
    implementation)
    - Only 2D affine grid is supported
    
    ## Validation
    
    ```bash
    pytest tests/python/relax/test_op_image.py -k "affine_grid" -v           # 
8 passed
    pytest tests/python/relax/test_transform_legalize_ops_image.py -k 
"affine_grid" -v  # 1 passed
    pytest tests/python/relax/test_frontend_from_exported_program.py -k 
"affine_grid" -v  # 2 passed
    pytest tests/python/relax/test_frontend_onnx.py -k "affine_grid" -v     # 1 
passed
    ```
    
    All 12 tests passed.
    
    ---------
    
    Co-authored-by: Claude Opus 4.6 <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |  34 ++++++
 .../frontend/torch/exported_program_translator.py  |  24 ++++
 python/tvm/relax/op/image/__init__.py              |   2 +-
 python/tvm/relax/op/image/image.py                 |  41 +++++++
 python/tvm/relax/transform/legalize_ops/image.py   |  18 ++-
 src/relax/op/image/resize.cc                       |  94 +++++++++++++++
 src/relax/op/image/resize.h                        |   3 +
 .../relax/test_frontend_from_exported_program.py   |  57 +++++++++
 tests/python/relax/test_frontend_onnx.py           |  26 ++++
 tests/python/relax/test_op_image.py                | 131 +++++++++++++++++++++
 .../relax/test_transform_legalize_ops_image.py     |  36 ++++++
 11 files changed, 464 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index c8d4c469fc..a117317125 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -2411,6 +2411,39 @@ class Resize(OnnxOpConverter):
             )
 
 
+class AffineGrid(OnnxOpConverter):
+    """Converts an onnx AffineGrid node into an equivalent Relax expression."""
+
+    @classmethod
+    def _impl_v20(cls, bb, inputs, attr, params):
+        theta = inputs[0]  # [N, 2, 3] for 2D
+        size = get_constant(inputs[1], params)  # [N, C, H, W] for 2D
+        align_corners = attr.get("align_corners", 0)
+
+        if align_corners != 1:
+            raise NotImplementedError(
+                "AffineGrid with align_corners=0 is not yet supported in TVM"
+            )
+
+        # Extract size values
+        if isinstance(size, relax.Constant):
+            size_vals = size.data.numpy().astype("int64").tolist()
+        elif isinstance(size, relax.expr.ShapeExpr):
+            size_vals = [int(v.value) for v in size.values]
+        else:
+            raise NotImplementedError(f"Dynamic size of type {type(size)} is 
not supported")
+
+        # Only 2D is supported: size = [N, C, H, W]
+        if len(size_vals) != 4:
+            raise ValueError("Only 2D AffineGrid (size=[N,C,H,W]) is 
supported")
+        target_h, target_w = size_vals[2], size_vals[3]
+
+        # Relax affine_grid outputs [N, 2, H, W]
+        grid = bb.emit(relax.op.image.affine_grid(theta, (target_h, target_w)))
+        # Permute to ONNX convention [N, H, W, 2]
+        return bb.emit(relax.op.permute_dims(grid, axes=[0, 2, 3, 1]))
+
+
 class Einsum(OnnxOpConverter):
     """Converts an onnx Einsum node into an equivalent Relax expression."""
 
@@ -4151,6 +4184,7 @@ def _get_convert_map():
         "NonMaxSuppression": NonMaxSuppression,
         "AllClassNMS": AllClassNMS,
         "GridSample": GridSample,
+        "AffineGrid": AffineGrid,
         "Upsample": Upsample,
         # others
         "DepthToSpace": DepthToSpace,
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py 
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 47633c69b5..cc37554bf3 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -1123,6 +1123,29 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             )
         )
 
+    def _affine_grid_generator(self, node: fx.Node) -> relax.Var:
+        """Convert torch.nn.functional.affine_grid to 
relax.op.image.affine_grid."""
+        args = self.retrieve_args(node)
+        theta = args[0]  # [N, 2, 3]
+        size = args[1]  # [N, C, H, W]
+        align_corners = args[2] if len(args) > 2 else False
+
+        if not align_corners:
+            raise NotImplementedError(
+                "affine_grid with align_corners=False is not yet supported in 
TVM"
+            )
+
+        # Extract spatial dimensions (H, W) from PyTorch's [N, C, H, W] size
+        target_h = size[2]
+        target_w = size[3]
+
+        # Relax affine_grid outputs [N, 2, H, W]
+        grid = self.block_builder.emit(
+            relax.op.image.affine_grid(theta, (target_h, target_w))
+        )
+        # Permute to PyTorch convention [N, H, W, 2]
+        return self.block_builder.emit(relax.op.permute_dims(grid, axes=[0, 2, 
3, 1]))
+
     def _torchvision_roi_align(self, node: fx.Node) -> relax.Var:
         """Convert torchvision.ops.roi_align to relax.op.vision.roi_align."""
         args = self.retrieve_args(node)
@@ -1768,6 +1791,7 @@ class ExportedProgramImporter(BaseFXGraphImporter):
             "zeros.default": self._zeros,
             "zeros_like.default": self._zeros_like,
             "grid_sampler_2d.default": self._grid_sampler_2d,
+            "affine_grid_generator.default": self._affine_grid_generator,
             "roi_align.default": self._torchvision_roi_align,
             # datatype
             "to.dtype": self._to,
diff --git a/python/tvm/relax/op/image/__init__.py 
b/python/tvm/relax/op/image/__init__.py
index 6b02c32199..dcc0d1f883 100644
--- a/python/tvm/relax/op/image/__init__.py
+++ b/python/tvm/relax/op/image/__init__.py
@@ -17,4 +17,4 @@
 # under the License.
 """Image operators."""
 
-from .image import grid_sample, resize2d, resize3d
+from .image import affine_grid, grid_sample, resize2d, resize3d
diff --git a/python/tvm/relax/op/image/image.py 
b/python/tvm/relax/op/image/image.py
index b267f40709..323bfa74b5 100644
--- a/python/tvm/relax/op/image/image.py
+++ b/python/tvm/relax/op/image/image.py
@@ -16,6 +16,8 @@
 # under the License.
 """Image operators."""
 
+from typing import cast
+
 from tvm import DataType
 from tvm.ir.expr import PrimExpr
 
@@ -23,6 +25,7 @@ from ...expr import Expr, ShapeExpr
 from . import _ffi_api
 
 PrimExprLike = int | PrimExpr
+SizeLike = PrimExprLike | tuple[PrimExprLike, ...]
 
 
 def resize2d(
@@ -229,3 +232,41 @@ def grid_sample(
         padding_mode,
         align_corners,
     )
+
+
+def affine_grid(
+    data: Expr,
+    size: Expr | SizeLike,
+) -> Expr:
+    """Generate a 2D sampling grid using an affine transformation matrix.
+
+    This operation is described in https://arxiv.org/pdf/1506.02025.pdf.
+    It generates a uniform sampling grid within the target shape, normalizes it
+    to [-1, 1], and applies the provided affine transformation.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input affine matrix tensor with shape [batch, 2, 3].
+
+    size : Union[Expr, PrimExprLike, Tuple[PrimExprLike, PrimExprLike]]
+        The target output spatial shape (H, W). If a single integer or PrimExpr
+        is provided, it is interpreted as a square output shape (size, size).
+
+    Returns
+    -------
+    result : relax.Expr
+        The output grid tensor with shape [batch, 2, H, W].
+
+    Note
+    ----
+    Only `align_corners=True` is supported by this operator, matching the
+    behavior of the underlying TOPI implementation. When using this operator
+    via PyTorch or ONNX frontends, `align_corners=False` will be rejected.
+    """
+    if isinstance(size, int | PrimExpr):
+        size = (size, size)
+    if isinstance(size, tuple | list):
+        size = ShapeExpr(size)
+
+    return cast(Expr, _ffi_api.affine_grid(data, size))
diff --git a/python/tvm/relax/transform/legalize_ops/image.py 
b/python/tvm/relax/transform/legalize_ops/image.py
index 1e7aaebceb..19431a2731 100644
--- a/python/tvm/relax/transform/legalize_ops/image.py
+++ b/python/tvm/relax/transform/legalize_ops/image.py
@@ -17,7 +17,7 @@
 # pylint: disable=invalid-name
 """Default legalization function for image operators."""
 
-from tvm import topi
+from tvm import tirx, topi
 
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr
@@ -54,6 +54,22 @@ def _image_grid_sample(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.image.affine_grid")
+def _image_affine_grid(bb: BlockBuilder, call: Call) -> Expr:
+    for v in call.args[1].values:
+        if not isinstance(v, (int, tirx.IntImm)):
+            raise ValueError(
+                "affine_grid legalization requires static target_shape, "
+                f"got symbolic value: {v}"
+            )
+    target_shape = [int(v) for v in call.args[1].values]
+    return bb.call_te(
+        topi.image.affine_grid,
+        call.args[0],
+        target_shape=target_shape,
+    )
+
+
 @register_legalize("relax.image.resize3d")
 def _image_resize3d(bb: BlockBuilder, call: Call) -> Expr:
     return bb.call_te(
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index fe9df47dd5..ba7de8115e 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -340,5 +340,99 @@ TVM_REGISTER_OP("relax.image.grid_sample")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.image.affine_grid */
+
+Expr affine_grid(Expr data, Expr size) {
+  static const Op& op = Op::Get("relax.image.affine_grid");
+  return Call(op, {std::move(data), std::move(size)}, Attrs(), {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.image.affine_grid", affine_grid);
+}
+
+StructInfo InferStructInfoAffineGrid(const Call& call, const BlockBuilder& 
ctx) {
+  if (call->args.size() != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "AffineGrid expects two arguments, while the given 
number of arguments is "
+                     << call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+  const auto* size_value = call->args[1].as<ShapeExprNode>();
+
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "AffineGrid expects the input data to be a Tensor, while the given 
data is "
+        << call->args[0]->GetTypeKey());
+  }
+  if (size_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "AffineGrid expects the target size to be a Shape, while the given 
one is "
+        << call->args[1]->GetTypeKey());
+  }
+  if (size_sinfo->ndim != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "AffineGrid expects the target size to be a 2-dim 
shape, while the given "
+                        "one has ndim "
+                     << size_sinfo->ndim);
+  }
+
+  // data should be 3-D: [batch, 2, 3]
+  if (data_sinfo->ndim != -1 && data_sinfo->ndim != 3) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "AffineGrid expects the input data to be 3-D (batch, 
2, 3), but got ndim "
+                     << data_sinfo->ndim);
+  }
+
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape != nullptr) {
+    // Check that the affine matrix has shape [batch, 2, 3]
+    if (data_shape->values.size() >= 2) {
+      auto* dim1 = data_shape->values[1].as<IntImmNode>();
+      if (dim1 != nullptr && dim1->value != 2) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "AffineGrid expects the second dimension of input 
to be 2, but got "
+                         << dim1->value);
+      }
+    }
+    if (data_shape->values.size() >= 3) {
+      auto* dim2 = data_shape->values[2].as<IntImmNode>();
+      if (dim2 != nullptr && dim2->value != 3) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << "AffineGrid expects the third dimension of input 
to be 3, but got "
+                         << dim2->value);
+      }
+    }
+  }
+
+  DataType out_dtype = data_sinfo->dtype;
+
+  if (data_shape == nullptr || size_value == nullptr) {
+    return TensorStructInfo(out_dtype, /*ndim=*/4, data_sinfo->vdevice);
+  }
+
+  // Output shape: [batch, 2, target_height, target_width]
+  ffi::Array<PrimExpr> out_shape;
+  out_shape.push_back(data_shape->values[0]);  // batch
+  out_shape.push_back(IntImm(DataType::Int(64), 2));  // 2 (spatial dimensions)
+  out_shape.push_back(size_value->values[0]);  // target_height
+  out_shape.push_back(size_value->values[1]);  // target_width
+
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.image.affine_grid")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input affine matrix tensor.")
+    .add_argument("size", "Shape", "The target output shape (H, W).")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAffineGrid)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h
index c769cf91f5..06a927d3a7 100644
--- a/src/relax/op/image/resize.h
+++ b/src/relax/op/image/resize.h
@@ -48,6 +48,9 @@ Expr resize3d(Expr data, Expr size, ffi::Array<FloatImm> roi, 
ffi::String layout
 Expr grid_sample(Expr data, Expr grid, ffi::String method, ffi::String layout,
                  ffi::String padding_mode, bool align_corners);
 
+/*! \brief Image affine_grid operator. */
+Expr affine_grid(Expr data, Expr size);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_frontend_from_exported_program.py 
b/tests/python/relax/test_frontend_from_exported_program.py
index 7a3548b4cf..6029499372 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -9095,5 +9095,62 @@ def test_cond_nested():
     )
 
 
+def test_affine_grid():
+    class AffineGrid(Module):
+        def forward(self, theta):
+            return torch.nn.functional.affine_grid(
+                theta, [1, 3, 16, 16], align_corners=True
+            )
+
+    @tvm.script.ir_module
+    class expected:
+        @R.function
+        def main(
+            theta: R.Tensor((1, 2, 3), dtype="float32"),
+        ) -> R.Tuple(R.Tensor((1, 16, 16, 2), dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((1, 2, 16, 16), dtype="float32") = 
R.image.affine_grid(
+                    theta, size=(16, 16)
+                )
+                lv1: R.Tensor((1, 16, 16, 2), dtype="float32") = 
R.permute_dims(
+                    lv, axes=[0, 2, 3, 1]
+                )
+                gv: R.Tuple(R.Tensor((1, 16, 16, 2), dtype="float32")) = (lv1,)
+                R.output(gv)
+            return gv
+
+    example_args = (torch.randn(1, 2, 3, dtype=torch.float32),)
+    # Disable decomposition to keep aten.affine_grid_generator as a single op
+    verify_model(AffineGrid(), example_args, {}, expected, 
run_ep_decomposition=False)
+
+
+def test_affine_grid_numerically():
+    """Verify affine_grid numerical correctness: PyTorch vs TVM via our 
converter."""
+
+    class AffineGrid(Module):
+        def forward(self, theta):
+            return torch.nn.functional.affine_grid(
+                theta, [2, 3, 8, 12], align_corners=True
+            )
+
+    model = AffineGrid()
+    example_args = (torch.randn(2, 2, 3, dtype=torch.float32),)
+
+    with torch.no_grad():
+        pytorch_output = model(*example_args)
+
+    exported_program = export(model, args=example_args)
+    mod = from_exported_program(exported_program, run_ep_decomposition=False)
+
+    exe = tvm.compile(mod, target="llvm")
+    vm = relax.VirtualMachine(exe, tvm.cpu())
+
+    tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args]
+    tvm_output = vm["main"](*tvm_args)
+    tvm_output_np = tvm_output[0].numpy()
+
+    tvm.testing.assert_allclose(tvm_output_np, pytorch_output.numpy(), 
rtol=1e-5, atol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 86fa533874..887533f261 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4011,6 +4011,32 @@ def test_nms_score_threshold():
         )
 
 
+def test_affine_grid():
+    affine_grid_node = helper.make_node(
+        "AffineGrid",
+        ["theta", "size"],
+        ["grid"],
+        align_corners=1,
+    )
+
+    graph = helper.make_graph(
+        [affine_grid_node],
+        "affine_grid_test",
+        inputs=[
+            helper.make_tensor_value_info("theta", TensorProto.FLOAT, [2, 2, 
3]),
+        ],
+        initializer=[
+            helper.make_tensor("size", TensorProto.INT64, [4], [2, 3, 16, 16]),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("grid", TensorProto.FLOAT, [2, 16, 
16, 2]),
+        ],
+    )
+
+    model = helper.make_model(graph, producer_name="affine_grid_test")
+    check_correctness(model, opset=20)
+
+
 @pytest.mark.parametrize("mode", ["bilinear", "nearest", "bicubic"])
 @pytest.mark.parametrize("padding_mode", ["zeros", "border", "reflection"])
 @pytest.mark.parametrize("align_corners", [0, 1])
diff --git a/tests/python/relax/test_op_image.py 
b/tests/python/relax/test_op_image.py
index 6650fc359b..3009b9414a 100644
--- a/tests/python/relax/test_op_image.py
+++ b/tests/python/relax/test_op_image.py
@@ -14,10 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import numpy as np
 import pytest
 
 import tvm
 import tvm.testing
+import tvm.topi.testing
 from tvm import TVMError, relax, tirx
 from tvm.ir import Op, VDevice
 from tvm.script import relax as R
@@ -26,6 +28,8 @@ from tvm.script import relax as R
 def test_op_correctness():
     x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     assert relax.op.image.resize2d(x, (28, 28)).op == 
Op.get("relax.image.resize2d")
+    theta = relax.Var("theta", R.Tensor((2, 2, 3), "float32"))
+    assert relax.op.image.affine_grid(theta, (16, 16)).op == 
Op.get("relax.image.affine_grid")
     y = relax.Var("y", R.Tensor((2, 3, 8, 16, 32), "float32"))
     assert relax.op.image.resize3d(y, (4, 8, 12)).op == 
Op.get("relax.image.resize3d")
 
@@ -356,5 +360,132 @@ def test_resize2d_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.image.resize2d(x2, s0))
 
 
+def test_affine_grid_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 2, 3), "float32", vdev0))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor(ndim=3))
+
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x0, (16, 16)),
+        relax.TensorStructInfo((2, 2, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x1, (16, 16)),
+        relax.TensorStructInfo((2, 2, 16, 16), "float32", vdev0),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x0, size=16),
+        relax.TensorStructInfo((2, 2, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x0, size=(16, 20)),
+        relax.TensorStructInfo((2, 2, 16, 20), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x2, size=(16, 16)),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x3, size=(16, 16)),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x4, size=(16, 16)),
+        relax.TensorStructInfo(dtype="", ndim=4),
+    )
+
+
+def test_affine_grid_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tirx.Var("n", "int64")
+    oh = tirx.Var("oh", "int64")
+    ow = tirx.Var("ow", "int64")
+    x0 = relax.Var("x", R.Tensor((n, 2, 3), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.image.affine_grid(x0, size=(oh, ow)),
+        relax.TensorStructInfo((n, 2, oh, ow), "float32"),
+    )
+
+
+def test_affine_grid_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 2, 3)))
+    x1 = relax.Var("x", R.Tensor((2, 2, 3), "float32"))
+    s0 = relax.Var("s", R.Tensor((3, 3)))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.affine_grid(x0, size=(16, 16)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.affine_grid(x1, s0))
+
+
+def test_affine_grid_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.affine_grid(x0, size=(16, 16)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.affine_grid(x1, size=(16, 16)))
+
+
+def test_affine_grid_wrong_size_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 2, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.affine_grid(x0, (16, 16, 16)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.affine_grid(x0, (16,)))
+
+
[email protected](
+    "batch, target_h, target_w",
+    [
+        (1, 16, 16),
+        (2, 8, 12),
+        (4, 32, 32),
+    ],
+)
+def test_affine_grid_e2e(batch, target_h, target_w):
+    """End-to-end numerical correctness test: build, run, compare with numpy 
reference."""
+
+    @tvm.script.ir_module
+    class AffineGridModule:
+        @R.function
+        def main(theta: R.Tensor(("batch", 2, 3), "float32")) -> 
R.Tensor("float32", ndim=4):
+            gv = R.image.affine_grid(theta, size=(target_h, target_w))
+            return gv
+
+    target = "llvm"
+    dev = tvm.cpu()
+    exe = tvm.compile(AffineGridModule, target=target)
+    vm = relax.VirtualMachine(exe, dev)
+
+    theta_np = np.random.uniform(-1, 1, size=(batch, 2, 3)).astype("float32")
+    theta_nd = tvm.runtime.tensor(theta_np, dev)
+
+    out_nd = vm["main"](theta_nd)
+    out_np = out_nd.numpy()
+
+    ref_np = tvm.topi.testing.affine_grid_python(theta_np, (target_h, 
target_w))
+
+    tvm.testing.assert_allclose(out_np, ref_np, rtol=1e-5, atol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_image.py 
b/tests/python/relax/test_transform_legalize_ops_image.py
index 48166d24c4..5c80ce0375 100644
--- a/tests/python/relax/test_transform_legalize_ops_image.py
+++ b/tests/python/relax/test_transform_legalize_ops_image.py
@@ -102,6 +102,42 @@ def test_image_resize2d_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_image_affine_grid():
+    # fmt: off
+    @tvm.script.ir_module
+    class AffineGrid:
+        @R.function
+        def main(theta: R.Tensor((2, 2, 3), "float32")) -> R.Tensor((2, 2, 16, 
16), "float32"):
+            gv: R.Tensor((2, 2, 16, 16), "float32") = 
R.image.affine_grid(theta, size=(16, 16))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(theta: R.Tensor((2, 2, 3), "float32")) -> R.Tensor((2, 2, 16, 
16), "float32"):
+            gv = R.call_tir(Expected.affine_grid, (theta,), R.Tensor((2, 2, 
16, 16), dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def affine_grid(var_theta: T.handle, var_compute: T.handle):
+            T.func_attr({"tirx.noalias": True})
+            theta = T.match_buffer(var_theta, (T.int64(2), T.int64(2), 
T.int64(3)))
+            compute = T.match_buffer(var_compute, (T.int64(2), T.int64(2), 
T.int64(16), T.int64(16)))
+            with T.sblock("root"):
+                T.reads()
+                T.writes()
+                for n, dim, i, j in T.grid(T.int64(2), T.int64(2), 
T.int64(16), T.int64(16)):
+                    with T.sblock("compute"):
+                        v_n, v_dim, v_i, v_j = T.axis.remap("SSSS", [n, dim, 
i, j])
+                        T.reads(theta[v_n, v_dim, T.int64(0):T.int64(3)])
+                        T.writes(compute[v_n, v_dim, v_i, v_j])
+                        compute[v_n, v_dim, v_i, v_j] = theta[v_n, v_dim, 
T.int64(0)] * (T.float32(-1.0) + T.Cast("float32", v_j) * 
T.float32(0.13333332666666667)) + theta[v_n, v_dim, T.int64(1)] * 
(T.float32(-1.0) + T.Cast("float32", v_i) * T.float32(0.13333332666666667)) + 
theta[v_n, v_dim, T.int64(2)]
+    # fmt: on
+
+    mod = LegalizeOps()(AffineGrid)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_image_resize3d():
     # fmt: off
     @tvm.script.ir_module

Reply via email to