tinywisdom opened a new issue, #18599:
URL: https://github.com/apache/tvm/issues/18599
## Summary
Compiling a Relax IRModule converted from a PyTorch `torch.export` program
crashes with a **segmentation fault** inside TVM’s TIR pass pipeline,
specifically in:
* `tvm::tir::transform::InjectPTXLDG32(bool)`
* `tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)`
* `tvm::tir::BufferStore::BufferStore(...)`
This occurs while invoking `tvm.compile(...)` with:
* `target = tvm.target.Target("llvm")` (CPU-only)
* `tir_pipeline = tir.get_default_tir_pipeline(target)`
* `relax_pipeline = "default"`
* `PassContext.config` includes `"tir.ptx_ldg32": 1` plus several other flags
Even though the target is LLVM CPU, the stack trace indicates a
**PTX-specific pass / rewriter** is running and then segfaulting.
This is not a Python exception; it is a hard crash (`Segmentation fault
(core dumped)`), so it likely indicates a bug in pass gating / pipeline
selection / or an unsafe assumption in the `InjectPTXLDG32` pass when used
under this pipeline.
## Environment
From the repro output:
* TVM version: `0.22.0`
* TVM git commit: `9dbf3f22ff6f44962472f9af310fda368ca85ef2`
* LLVM: `17.0.6`
* PyTorch: `2.9.0+cu128`
* Python: `3.10.16` (inferred from stack paths)
* NumPy: `2.2.6` *(printed as “Python version” in script; see note below)*
* OS: Linux x86_64
*
## Minimal Repro Script
```python
import random
import numpy as np
import torch
import torch.nn as nn
import tvm
from tvm import tir
def print_env_info():
print("==== Environment Info ====")
print("TVM version:", getattr(tvm, "__version__", "unknown"))
print("TVM git commit:", tvm.support.libinfo()["GIT_COMMIT_HASH"])
print("TVM LLVM version:", tvm.support.libinfo().get("LLVM_VERSION",
"unknown"))
print("NumPy version:", np.__version__)
print("PyTorch version:", torch.__version__)
print("==========================\n")
def set_seed(seed: int = 0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.ConvTranspose2d(100, 64, 4, 1, 0, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh(),
nn.Conv2d(3, 8, 3, 1, 1, bias=False),
nn.BatchNorm2d(8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(8, 1, 1, 1, 0, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
y = self.net(x)
return y.reshape(-1)
def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
mod = mod.to("cpu").eval()
x = x.to("cpu")
ep = torch.export.export(mod, (x,))
from tvm.relax.frontend.torch import from_exported_program
return from_exported_program(ep)
def main():
print_env_info()
set_seed(0)
target = tvm.target.Target("llvm")
tir_pipeline = tir.get_default_tir_pipeline(target)
relax_pipeline = "default"
B = 64
x = torch.rand(B, 100, 1, 1, dtype=torch.float32)
model = Model()
print("[repro] exporting torch -> relax ...")
ir_mod = export_to_relax(model, x)
pass_config = {
"relax.FuseOps.max_depth": 4,
"relax.backend.use_cuda_graph": 1,
"tir.disable_storage_rewrite": 1,
"tir.disable_vectorize": 1,
"tir.enable_debug": 1,
"tir.enable_equiv_terms_in_cse_tir": 1,
"tir.ptx_ldg32": 1,
"tir.use_async_copy": 1,
}
pc_kwargs = {
"opt_level": 1,
"disabled_pass": [
"CanonicalizeBindings",
"Simplify",
"VectorizeLoop",
"RemoveNoOp",
],
"config": pass_config,
}
print("[repro] target:", target)
print("[repro] tir_pipeline: explicit_default")
print("[repro] PassContext.config keys:", sorted(pass_config.keys()))
print("[repro] compiling with tvm.compile ...")
with tvm.transform.PassContext(**pc_kwargs):
_ = tvm.compile(
ir_mod,
target=target,
relax_pipeline=relax_pipeline,
tir_pipeline=tir_pipeline,
)
print("[repro] compile finished (no crash).")
if __name__ == "__main__":
main()
```
---
## Actual Behavior
`tvm.compile(...)` crashes with a segfault. The stack trace consistently
includes:
* `tvm::tir::BufferStore::BufferStore(...)`
* `tvm::tir::PTXRewriter::VisitStmt_(BufferStoreNode const*)`
* `tvm::tir::transform::InjectPTXLDG32(bool)`
Excerpt:
```text
!!!!!!! Segfault encountered !!!!!!!
...
tvm::tir::BufferStore::BufferStore(...)
tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
...
tvm::tir::transform::PrimFuncPassNode::operator()(...)
...
tvm::tir::transform::InjectPTXLDG32(bool)
Segmentation fault (core dumped)
```
## Triage
* needs-triage
* bug
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]