This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c581abeee4 [Relax][PyTorch] Fix _slice and _expand for dynamic shapes
in PyTorch ExportedProgram frontend (#18918)
c581abeee4 is described below
commit c581abeee410cc5426a667f64b3ac27b59ace2af
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Mar 20 15:18:52 2026 +0900
[Relax][PyTorch] Fix _slice and _expand for dynamic shapes in PyTorch
ExportedProgram frontend (#18918)
Fixes two issues when translating PyTorch models with dynamic shapes:
1. **_slice**: Resolve `fx.Node` references in start/end/step arguments
and detect identity slices where the symbolic end equals the tensor
dimension (avoids redundant `strided_slice` ops).
2. **_expand**: Fall back to FX node metadata when `shape_of()` returns
`None` for tensors with unknown shapes.
---
.../frontend/torch/base_fx_graph_translator.py | 17 +++++-
.../frontend/torch/exported_program_translator.py | 21 ++++++-
.../relax/test_frontend_from_exported_program.py | 67 ++++++++++++++++++++++
3 files changed, 101 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
index 803b4b7e11..c146cf6c00 100644
--- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
+++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py
@@ -1754,13 +1754,24 @@ class BaseFXGraphImporter(metaclass=abc.ABCMeta):
def _expand(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
sizes = args[1:] if len(args) > 2 else args[1]
- broadcast_shape, in_shape = [], self.shape_of(args[0])
+ x = args[0]
+ broadcast_shape = []
+ in_shape = self.shape_of(x)
for idx, i in enumerate(sizes):
if isinstance(i, int) and i == -1:
- broadcast_shape.append(in_shape[idx])
+ if in_shape is not None:
+ broadcast_shape.append(in_shape[idx])
+ elif hasattr(node.args[0], "meta") and "val" in
node.args[0].meta:
+ # Fallback: get shape from FX node metadata (FakeTensor)
+ fake_shape = node.args[0].meta["val"].shape
+ broadcast_shape.append(fake_shape[idx])
+ else:
+ raise ValueError(
+ f"Cannot use -1 in expand for dim {idx} when input
shape is unknown"
+ )
else:
broadcast_shape.append(i)
- return self.block_builder.emit(relax.op.broadcast_to(args[0],
broadcast_shape))
+ return self.block_builder.emit(relax.op.broadcast_to(x,
broadcast_shape))
def _expand_as(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index 2487b904c6..fd03f67332 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -29,7 +29,7 @@ import torch
from torch import fx
import tvm
-from tvm import relax
+from tvm import relax, tir
from .base_fx_graph_translator import BaseFXGraphImporter
@@ -937,6 +937,14 @@ class ExportedProgramImporter(BaseFXGraphImporter):
end_val = node.args[3] if len(node.args) > 3 else None
step = node.args[4] if len(node.args) > 4 else 1
+ # Resolve fx.Node references (e.g. symbolic sizes from dynamic shapes)
+ if isinstance(start, fx.Node):
+ start = self.env[start]
+ if isinstance(end_val, fx.Node):
+ end_val = self.env[end_val]
+ if isinstance(step, fx.Node):
+ step = self.env[step]
+
if start is None:
start = 0
if end_val is None:
@@ -956,6 +964,17 @@ class ExportedProgramImporter(BaseFXGraphImporter):
):
return x
+ # Skip identity slice where end_val is a symbolic expression equal to
the
+ # tensor's own dimension size (common with dynamic shapes).
+ if isinstance(start, int) and start == 0 and isinstance(step, int) and
step == 1:
+ in_shape = self.shape_of(x)
+ if in_shape is not None and isinstance(end_val, tir.PrimExpr):
+ actual_dim = dim if dim >= 0 else len(in_shape) + dim
+ dim_expr = in_shape[actual_dim]
+ if isinstance(dim_expr, tir.PrimExpr):
+ if tir.analysis.expr_deep_equal(end_val, dim_expr):
+ return x
+
axes = [dim]
begin = [start]
end = [end_val]
diff --git a/tests/python/relax/test_frontend_from_exported_program.py
b/tests/python/relax/test_frontend_from_exported_program.py
index a9cea19fdc..e1cadb9d02 100644
--- a/tests/python/relax/test_frontend_from_exported_program.py
+++ b/tests/python/relax/test_frontend_from_exported_program.py
@@ -5634,6 +5634,73 @@ def test_slice_scatter():
verify_model(SliceScatterNegative(), example_args, {},
expected_slice_scatter)
+def test_slice_with_symbolic_end():
+ """_slice correctly handles symbolic end values from dynamic shapes."""
+
+ class SliceIdentityModel(torch.nn.Module):
+ def forward(self, x):
+ # x[:, :x.size(1)] is an identity slice that torch.export emits
+ # as slice(x, 1, 0, sym_size_int(x, 1), 1) with dynamic shapes.
+ seq_len = x.size(1)
+ return x[:, :seq_len] + 0.0 # +0.0 to ensure output is a new
tensor
+
+ # The identity slice is elided; only x + 0.0 remains.
+ @I.ir_module
+ class ExpectedIdentity:
+ @R.function
+ def main(x: R.Tensor(("s0", "s1", 4), dtype="float32")) -> R.Tuple(
+ R.Tensor(("s0", "s1", 4), dtype="float32")
+ ):
+ s0 = T.int64(is_size_var=True)
+ s1 = T.int64(is_size_var=True)
+ R.func_attr({"tir_var_lower_bound": {"s27": 2, "s77": 2}})
+ with R.dataflow():
+ lv: R.Tensor((s0, s1, 4), dtype="float32") = R.add(x,
R.const(0.0, "float32"))
+ gv: R.Tuple(R.Tensor((s0, s1, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args = (torch.randn(2, 8, 4, dtype=torch.float32),)
+ batch = torch.export.Dim("batch", min=2)
+ seq = torch.export.Dim("seq", min=2)
+ dynamic_shapes = {"x": {0: batch, 1: seq}}
+
+ verify_model(
+ SliceIdentityModel(),
+ example_args,
+ {},
+ ExpectedIdentity,
+ dynamic_shapes=dynamic_shapes,
+ map_free_vars=True,
+ )
+
+ class SliceStaticModel(torch.nn.Module):
+ def forward(self, x):
+ # A non-identity static slice
+ return x[:, :3]
+
+ @tvm.script.ir_module
+ class ExpectedStatic:
+ @R.function
+ def main(x: R.Tensor((2, 8, 4), dtype="float32")) -> R.Tuple(
+ R.Tensor((2, 3, 4), dtype="float32")
+ ):
+ with R.dataflow():
+ lv: R.Tensor((2, 3, 4), dtype="float32") = R.strided_slice(
+ x,
+ axes=[1],
+ begin=[0],
+ end=[3],
+ strides=[1],
+ )
+ gv: R.Tuple(R.Tensor((2, 3, 4), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ example_args_static = (torch.randn(2, 8, 4, dtype=torch.float32),)
+ verify_model(SliceStaticModel(), example_args_static, {}, ExpectedStatic)
+
+
def test_split():
class Chunk(Module):
def forward(self, input):