This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new bf6ed31302 [Relax][PyTorch] Add 3D interpolate support using resize3d
(#18937)
bf6ed31302 is described below
commit bf6ed31302486fe189c9872642c4a5dbd5c7988f
Author: Nirdesh Devadiya <[email protected]>
AuthorDate: Thu Mar 26 20:14:27 2026 +0530
[Relax][PyTorch] Add 3D interpolate support using resize3d (#18937)
Adds support for torch.nn.functional.interpolate 3D mode in Relax
frontend.
- Handles 5D inputs (NCDHW)
- Maps to relax.op.image.resize3d
- Ensures correct layout handling
- Adds tests for scale_factor and size cases
All tests pass locally.
part of #18928
Signed-off-by: nirdesh17 <[email protected]>
---
python/tvm/relax/frontend/torch/fx_translator.py | 40 ++--
tests/python/relax/test_frontend_from_fx.py | 222 +++++++++++++++++++++++
2 files changed, 251 insertions(+), 11 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/fx_translator.py
b/python/tvm/relax/frontend/torch/fx_translator.py
index e7fcb0c202..c81768f6d9 100644
--- a/python/tvm/relax/frontend/torch/fx_translator.py
+++ b/python/tvm/relax/frontend/torch/fx_translator.py
@@ -527,7 +527,7 @@ class TorchFXImporter(BaseFXGraphImporter):
# Determine spatial dimension indices based on layout
# NCHW: spatial dims are [2, 3, ...] (skip batch and channel)
# NHWC: spatial dims are [1, 2, ...] (skip batch, before channel)
- if self.default_image_layout == "NHWC":
+ if self.default_image_layout in ("NHWC", "NDHWC"):
spatial_start = 1
spatial_end = len(shape) - 1
else: # NCHW or other layouts
@@ -547,25 +547,43 @@ class TorchFXImporter(BaseFXGraphImporter):
if method.startswith("nearest"):
method = "nearest_neighbor"
- elif method[0:2] == "bi":
+ elif method.startswith("bi"):
method = method[2:]
+ elif method.startswith("tri"):
+ method = method[3:]
if method == "nearest_neighbor":
coord_trans = "asymmetric"
- elif align_corners:
+ elif align_corners is True:
coord_trans = "align_corners"
else:
coord_trans = "half_pixel"
- return self.block_builder.emit(
- relax.op.image.resize2d(
- data,
- size,
- layout=self.default_image_layout,
- method=method,
- coordinate_transformation_mode=coord_trans,
+ if data.struct_info.ndim == 5:
+ if self.default_image_layout == "NDHWC":
+ layout_3d = "NDHWC"
+ else:
+ layout_3d = "NCDHW"
+
+ return self.block_builder.emit(
+ relax.op.image.resize3d(
+ data,
+ size,
+ layout=layout_3d,
+ method=method,
+ coordinate_transformation_mode=coord_trans,
+ )
+ )
+ else:
+ return self.block_builder.emit(
+ relax.op.image.resize2d(
+ data,
+ size,
+ layout=self.default_image_layout,
+ method=method,
+ coordinate_transformation_mode=coord_trans,
+ )
)
- )
def _linear_module(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
diff --git a/tests/python/relax/test_frontend_from_fx.py
b/tests/python/relax/test_frontend_from_fx.py
index 102b8a0f42..4d9060bf72 100644
--- a/tests/python/relax/test_frontend_from_fx.py
+++ b/tests/python/relax/test_frontend_from_fx.py
@@ -3671,6 +3671,149 @@ def test_interpolate():
verify_model(Interpolate4(), input_info, {}, expected4)
+ input_info_5d = [([1, 3, 4, 10, 10], "float32")]
+ class Interpolate5(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=None,
+ scale_factor=(2.0, 2.0, 2.0),
+ mode="trilinear",
+ align_corners=False,
+ )
+ @tvm.script.ir_module
+ class expected5:
+ @R.function
+ def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) ->
R.Tensor(
+ (1, 3, 8, 20, 20), dtype="float32"
+ ):
+
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") =
R.image.resize3d(
+ input_5,
+ (8, 20, 20),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000],
+ layout="NCDHW",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 8, 20, 20), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate5(), input_info_5d, {}, expected5)
+
+ class Interpolate6(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=None,
+ scale_factor=(2.0,4.0,4.0),
+ mode="trilinear",
+ align_corners=False,
+ )
+ @tvm.script.ir_module
+ class expected6:
+ @R.function
+ def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) ->
R.Tensor(
+ (1, 3, 8, 40, 40), dtype="float32"
+ ):
+
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") =
R.image.resize3d(
+ input_5,
+ (8, 40, 40),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000],
+ layout="NCDHW",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate6(), input_info_5d, {}, expected6)
+
+ class Interpolate7(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=(8,40,40),
+ mode="trilinear",
+ align_corners=False,
+ )
+ @tvm.script.ir_module
+ class expected7:
+ @R.function
+ def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) ->
R.Tensor(
+ (1, 3, 8, 40, 40), dtype="float32"
+ ):
+
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") =
R.image.resize3d(
+ input_5,
+ (8, 40, 40),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000],
+ layout="NCDHW",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate7(), input_info_5d, {}, expected7)
+
+ class Interpolate8(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=(8,40,40),
+ mode="trilinear",
+ align_corners=True,
+ )
+ @tvm.script.ir_module
+ class expected8:
+ @R.function
+ def main(input_5: R.Tensor((1, 3, 4, 10, 10), dtype="float32")) ->
R.Tensor(
+ (1, 3, 8, 40, 40), dtype="float32"
+ ):
+
+ with R.dataflow():
+ lv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") =
R.image.resize3d(
+ input_5,
+ (8, 40, 40),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000],
+ layout="NCDHW",
+ method="linear",
+ coordinate_transformation_mode="align_corners",
+ rounding_method="",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 3, 8, 40, 40), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ verify_model(Interpolate8(), input_info_5d, {}, expected8)
+
def test_interpolate_nhwc_layout():
# First verify backward compatibility - default should still be NCHW
@@ -3786,6 +3929,85 @@ def test_interpolate_nhwc_layout():
mod2 = from_fx(graph_model2, input_info, default_image_layout="NHWC")
tvm.ir.assert_structural_equal(mod2, expected_nhwc2)
+ input_info_5d = [([1, 4, 10, 10, 3], "float32")]
+
+ class InterpolateNHWC3(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=None,
+ scale_factor=(2.0,4.0,4.0),
+ mode="trilinear",
+ align_corners=False,
+ )
+ @tvm.script.ir_module
+ class expected_nhwc3:
+ @R.function
+ def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) ->
R.Tensor(
+ (1, 8, 40, 40, 3), dtype="float32"
+ ):
+
+ with R.dataflow():
+ lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") =
R.image.resize3d(
+ input_5,
+ (8, 40, 40),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000],
+ layout="NDHWC",
+ method="linear",
+ coordinate_transformation_mode="half_pixel",
+ rounding_method="",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ graph_model3 = fx.symbolic_trace(InterpolateNHWC3())
+ with torch.no_grad():
+ mod3 = from_fx(graph_model3, input_info_5d,
default_image_layout="NDHWC")
+ tvm.ir.assert_structural_equal(mod3, expected_nhwc3)
+
+ class InterpolateNHWC4(Module):
+ def forward(self, input):
+ return torch.nn.functional.interpolate(
+ input,
+ size=None,
+ scale_factor=(2.0,4.0,4.0),
+ mode="trilinear",
+ align_corners=True,
+ )
+ @tvm.script.ir_module
+ class expected_nhwc4:
+ @R.function
+ def main(input_5: R.Tensor((1, 4, 10, 10, 3), dtype="float32")) ->
R.Tensor(
+ (1, 8, 40, 40, 3), dtype="float32"
+ ):
+
+ with R.dataflow():
+ lv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") =
R.image.resize3d(
+ input_5,
+ (8, 40, 40),
+ roi=[0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
0.000000],
+ layout="NDHWC",
+ method="linear",
+ coordinate_transformation_mode="align_corners",
+ rounding_method="",
+ cubic_alpha=-0.75,
+ cubic_exclude=0,
+ extrapolation_value=0,
+ out_dtype="",
+ )
+ gv: R.Tensor((1, 8, 40, 40, 3), dtype="float32") = lv
+ R.output(gv)
+ return gv
+
+ graph_model4 = fx.symbolic_trace(InterpolateNHWC4())
+ with torch.no_grad():
+ mod4 = from_fx(graph_model4, input_info_5d,
default_image_layout="NDHWC")
+ tvm.ir.assert_structural_equal(mod4, expected_nhwc4)
def test_addmm():
input_info = [