This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bf6ed31302 [Relax][PyTorch] Add 3D interpolate support using resize3d 
(#18937)
bf6ed31302 is described below

commit bf6ed31302486fe189c9872642c4a5dbd5c7988f
Author: Nirdesh Devadiya <[email protected]>
AuthorDate: Thu Mar 26 20:14:27 2026 +0530

    [Relax][PyTorch] Add 3D interpolate support using resize3d (#18937)
    
    Adds support for torch.nn.functional.interpolate 3D mode in Relax
    frontend.
    
    - Handles 5D inputs (NCDHW)
    - Maps to relax.op.image.resize3d
    - Ensures correct layout handling
    - Adds tests for scale_factor and size cases
    
    All tests pass locally.
    
    part of #18928
    
    Signed-off-by: nirdesh17 <[email protected]>
---
 python/tvm/relax/frontend/torch/fx_translator.py |  40 ++--
 tests/python/relax/test_frontend_from_fx.py      | 222 +++++++++++++++++++++++
 2 files changed, 251 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relax/frontend/torch/fx_translator.py 
b/python/tvm/relax/frontend/torch/fx_translator.py
index e7fcb0c202..c81768f6d9 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -527,7 +527,7 @@ class TorchFXImporter(BaseFXGraphImporter):
             # Determine spatial dimension indices based on layout
             # NCHW: spatial dims are [2, 3, ...] (skip batch and channel)
             # NHWC: spatial dims are [1, 2, ...] (skip batch, before channel)
-            if self.default_image_layout == "NHWC":
+            if self.default_image_layout in ("NHWC", "NDHWC"):
                 spatial_start = 1
                 spatial_end = len(shape) - 1
             else:  # NCHW or other layouts
@@ -547,25 +547,43 @@ class TorchFXImporter(BaseFXGraphImporter):
 
         if method.startswith("nearest"):
             method = "nearest_neighbor"
-        elif method[0:2] == "bi":
+        elif method.startswith("bi"):
             method = method[2:]
+        elif method.startswith("tri"):
+            method = method[3:]
 
         if method == "nearest_neighbor":
             coord_trans = "asymmetric"
-        elif align_corners:
+        elif align_corners is True:
             coord_trans = "align_corners"
         else:
             coord_trans = "half_pixel"
 
-        return self.block_builder.emit(
-            relax.op.image.resize2d(
-                data,
-                size,
-                layout=self.default_image_layout,
-                method=method,
-                coordinate_transformation_mode=coord_trans,
+        if data.struct_info.ndim == 5:
+            if self.default_image_layout == "NDHWC":
+                layout_3d = "NDHWC"
+            else:
+                layout_3d = "NCDHW"
+                
+            return self.block_builder.emit(
+                relax.op.image.resize3d(
+                    data,
+                    size,
+                    layout=layout_3d,
+                    method=method,
+                    coordinate_transformation_mode=coord_trans,
+                )
+            )
+        else:
+            return self.block_builder.emit(
+                relax.op.image.resize2d(
+                    data,
+                    size,
+                    layout=self.default_image_layout,
+                    method=method,
+                    coordinate_transformation_mode=coord_trans,
+                )
             )
-        )
 
     def _linear_module(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_fx.py 
b/tests/python/relax/test_frontend_from_fx.py
index 102b8a0f42..4d9060bf72 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3671,6 +3671,149 @@ def test_interpolate():
 
     verify_model(Interpolate4(), input_info, {}, expected4)
 
+    input_info_5d = [([1, 3, 4, 10, 10], "float32")]
+    class Interpolate5(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=None,
+                scale_factor=(2.0, 2.0, 2.0),
+                mode="trilinear",
+                align_corners=False,
+            )
+    @tvm.script.ir_module
+    class expected5:
+        @R.function
+        def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> 
R.Tensor(
+            (1, 3, 8, 20, 20), dtype="float32"
+        ):
+
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = 
R.image.resize3d(
+                    input_5,
+                    (8, 20, 20),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 
0.000000],
+                    layout="NCDHW",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Interpolate5(), input_info_5d, {}, expected5)
+
+    class Interpolate6(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=None,
+                scale_factor=(2.0,4.0,4.0),
+                mode="trilinear",
+                align_corners=False,
+            )
+    @tvm.script.ir_module
+    class expected6:
+        @R.function
+        def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> 
R.Tensor(
+            (1, 3, 8, 40, 40), dtype="float32"
+        ):
+
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = 
R.image.resize3d(
+                    input_5,
+                    (8, 40, 40),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 
0.000000],
+                    layout="NCDHW",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Interpolate6(), input_info_5d, {}, expected6)
+
+    class Interpolate7(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=(8,40,40),
+                mode="trilinear",
+                align_corners=False,
+            )
+    @tvm.script.ir_module
+    class expected7:
+        @R.function
+        def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> 
R.Tensor(
+            (1, 3, 8, 40, 40), dtype="float32"
+        ):
+
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = 
R.image.resize3d(
+                    input_5,
+                    (8, 40, 40),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 
0.000000],
+                    layout="NCDHW",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Interpolate7(), input_info_5d, {}, expected7)
+
+    class Interpolate8(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=(8,40,40),
+                mode="trilinear",
+                align_corners=True,
+            )
+    @tvm.script.ir_module
+    class expected8:
+        @R.function
+        def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) -> 
R.Tensor(
+            (1, 3, 8, 40, 40), dtype="float32"
+        ):
+
+            with R.dataflow():
+                lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = 
R.image.resize3d(
+                    input_5,
+                    (8, 40, 40),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 
0.000000],
+                    layout="NCDHW",
+                    method="linear",
+                    coordinate_transformation_mode="align_corners",
+                    rounding_method="",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    verify_model(Interpolate8(), input_info_5d, {}, expected8)
+
 
 def test_interpolate_nhwc_layout():
     # First verify backward compatibility - default should still be NCHW
@@ -3786,6 +3929,85 @@ def test_interpolate_nhwc_layout():
         mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC")
     tvm.ir.assert_structural_equal(mod2, expected_nhwc2)
 
+    input_info_5d = [([1, 4, 10, 10, 3], "float32")]
+
+    class InterpolateNHWC3(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=None,
+                scale_factor=(2.0,4.0,4.0),
+                mode="trilinear",
+                align_corners=False,
+            )
+    @tvm.script.ir_module
+    class expected_nhwc3:
+        @R.function
+        def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> 
R.Tensor(
+            (1, 8, 40, 40, 3), dtype="float32"
+        ):
+
+            with R.dataflow():
+                lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = 
R.image.resize3d(
+                    input_5,
+                    (8, 40, 40),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 
0.000000],
+                    layout="NDHWC",
+                    method="linear",
+                    coordinate_transformation_mode="half_pixel",
+                    rounding_method="",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    graph_model3 = fx.symbolic_trace(InterpolateNHWC3())
+    with torch.no_grad():
+        mod3 = from_fx(graph_model3, input_info_5d, 
default_image_layout="NDHWC")
+    tvm.ir.assert_structural_equal(mod3, expected_nhwc3)
+
+    class InterpolateNHWC4(Module):
+        def forward(self, input):
+            return torch.nn.functional.interpolate(
+                input,
+                size=None,
+                scale_factor=(2.0,4.0,4.0),
+                mode="trilinear",
+                align_corners=True,
+            )
+    @tvm.script.ir_module
+    class expected_nhwc4:
+        @R.function
+        def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) -> 
R.Tensor(
+            (1, 8, 40, 40, 3), dtype="float32"
+        ):
+
+            with R.dataflow():
+                lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = 
R.image.resize3d(
+                    input_5,
+                    (8, 40, 40),
+                    roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 
0.000000],
+                    layout="NDHWC",
+                    method="linear",
+                    coordinate_transformation_mode="align_corners",
+                    rounding_method="",
+                    cubic_alpha=-0.75,
+                    cubic_exclude=0,
+                    extrapolation_value=0,
+                    out_dtype="",
+                )
+                gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv
+                R.output(gv)
+            return gv
+
+    graph_model4 = fx.symbolic_trace(InterpolateNHWC4())
+    with torch.no_grad():
+        mod4 = from_fx(graph_model4, input_info_5d, 
default_image_layout="NDHWC")
+    tvm.ir.assert_structural_equal(mod4, expected_nhwc4)
 
 def test_addmm():
     input_info = [

Reply via email to