This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 00eb226f8d [Refactor] Update type references from tir to tirx in
PyTorch ExportedProgram frontend (#18920)
00eb226f8d is described below
commit 00eb226f8d6f569f93a5770510af23ab2159a850
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Mar 20 21:04:46 2026 +0900
[Refactor] Update type references from tir to tirx in PyTorch
ExportedProgram frontend (#18920)
Follow up for #18913 and #18917
---
python/tvm/relax/frontend/torch/exported_program_translator.py | 8 ++++----
.../relax/test_transform_legalize_ops_index_linear_algebra.py | 2 +-
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py
b/python/tvm/relax/frontend/torch/exported_program_translator.py
index fd03f67332..67e0e45da0 100644
--- a/python/tvm/relax/frontend/torch/exported_program_translator.py
+++ b/python/tvm/relax/frontend/torch/exported_program_translator.py
@@ -29,7 +29,7 @@ import torch
from torch import fx
import tvm
-from tvm import relax, tir
+from tvm import relax
from .base_fx_graph_translator import BaseFXGraphImporter
@@ -968,11 +968,11 @@ class ExportedProgramImporter(BaseFXGraphImporter):
# tensor's own dimension size (common with dynamic shapes).
if isinstance(start, int) and start == 0 and isinstance(step, int) and
step == 1:
in_shape = self.shape_of(x)
- if in_shape is not None and isinstance(end_val, tir.PrimExpr):
+ if in_shape is not None and isinstance(end_val, tvm.tirx.PrimExpr):
actual_dim = dim if dim >= 0 else len(in_shape) + dim
dim_expr = in_shape[actual_dim]
- if isinstance(dim_expr, tir.PrimExpr):
- if tir.analysis.expr_deep_equal(end_val, dim_expr):
+ if isinstance(dim_expr, tvm.tirx.PrimExpr):
+ if tvm.tirx.analysis.expr_deep_equal(end_val, dim_expr):
return x
axes = [dim]
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index b8dbe1934b..9f45c7031f 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -277,7 +277,7 @@ def test_strided_slice_negative_axes():
@T.prim_func(private=True)
def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9),
T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(8),
T.int64(9), T.int64(3)), "float32")):
- T.func_attr({"tir.noalias": True})
+ T.func_attr({"tirx.noalias": True})
for ax0, ax1, ax2 in T.grid(T.int64(8), T.int64(9), T.int64(3)):
with T.sblock("T_strided_slice_with_axes"):
v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])