tinywisdom opened a new issue, #18617:
URL: https://github.com/apache/tvm/issues/18617

   ## Summary
   
   `tvm.compile` segfaults when compiling a Relax module (imported from 
`torch.export`) **even with a pure CPU target** (`llvm -keys=cpu`). The crash 
occurs inside the PTX-specific pass:
   
   * `tvm::tir::transform::InjectPTXLDG32(bool)`
   * `tvm::tir::PTXRewriter::VisitStmt_(BufferStoreNode const*)`
   * `tvm::tir::BufferStore::BufferStore(...)`
   
   This is unexpected because the target is **LLVM/CPU**, yet the compilation 
pipeline still enters a PTX rewriting pass. Removing `tir.ptx_ldg32` from 
PassContext avoids the crash.
   
   This suggests either:
   
   1. `tir.ptx_ldg32` enables a PTX-only pass without checking whether the 
target is CUDA/PTX-capable, or
   2. `InjectPTXLDG32` lacks a defensive early-exit / target predicate and can 
crash on non-PTX code paths.
   
   ---
   
   ## Environment
   
   From the repro output:
   
   * TVM: `0.22.0`
   * Commit: `9dbf3f22ff6f44962472f9af310fda368ca85ef2`
   * LLVM: `17.0.6`
   * Python: `3.10.16` (from stack paths)
   * NumPy: `2.2.6`
   * PyTorch: `2.9.0+cu128`
   
   Target used in repro:
   
   ```
   llvm -keys=cpu -mtriple=x86_64-unknown-linux-gnu
   ```
   
   ---
   
   ## Reproduction Steps
   
   1. Convert a small PyTorch module to Relax via `torch.export.export` + 
`tvm.relax.frontend.torch.from_exported_program`.
   2. Call `tvm.compile` under a PassContext with `config={"tir.ptx_ldg32": 1}`.
   3. Observe segfault during compilation.
   
   ---
   
   ## Minimal Repro Script
   
   ```python
   #!/usr/bin/env python3
   # -*- coding: utf-8 -*-
   
   import numpy as np
   import torch
   import torch.nn as nn
   import torch.nn.functional as F
   import tvm
   from tvm import tir
   
   
   def print_env_info():
       print("==== Environment Info ====")
       print("TVM version:", getattr(tvm, "__version__", "unknown"))
       try:
           li = tvm.support.libinfo()
           print("TVM git commit:", li.get("GIT_COMMIT_HASH", "unknown"))
           print("TVM LLVM version:", li.get("LLVM_VERSION", "unknown"))
       except Exception:
           pass
       print("Python (numpy) version:", np.__version__)
       print("PyTorch version:", torch.__version__)
       print("==========================\n")
   
   
   class BranchNet(nn.Module):
       def __init__(self, k: int):
           super().__init__()
           self.conv1 = nn.Conv2d(1, 16, k, 1)
           self.conv2 = nn.Conv2d(16, 32, 3, 1)
           self.pool = nn.MaxPool2d(2)
   
           s1 = 28 - k + 1
           s2 = s1 - 2
           sp = s2 // 2
           self.fc = nn.Linear(32 * sp * sp, 10)
   
       def forward(self, x):
           x = F.relu(self.conv1(x))
           x = F.relu(self.conv2(x))
           x = self.pool(x)
           x = x.reshape(x.shape[0], -1)
           return self.fc(x)
   
   
   class M(nn.Module):
       def __init__(self):
           super().__init__()
           self.b1 = BranchNet(3)
           self.b2 = BranchNet(5)
           self.b3 = BranchNet(7)
           self.out = nn.Linear(30, 10)
   
       def forward(self, x):
           a = self.b1(x)
           b = self.b2(x)
           c = self.b3(x)
           y = self.out(torch.cat([a, b, c], dim=1))
           return F.log_softmax(y, dim=1)
   
   
   def export_to_relax(mod: nn.Module, x: torch.Tensor) -> tvm.IRModule:
       mod = mod.to("cpu").eval()
       x = x.to("cpu")
       ep = torch.export.export(mod, (x,))
       from tvm.relax.frontend.torch import from_exported_program
       return from_exported_program(ep)
   
   
   def main():
       print_env_info()
   
       target = tvm.target.Target("llvm -keys=cpu 
-mtriple=x86_64-unknown-linux-gnu")
       tir_pipeline = tir.get_default_tir_pipeline(target)  # explicit_default
       relax_pipeline = "default"
   
       x = torch.rand(1, 1, 28, 28, dtype=torch.float32)
       ir_mod = export_to_relax(M(), x)
   
       pc = {
           "opt_level": 0,
           "disabled_pass": ["LoopPartition"],
           "config": {
               "tir.ptx_ldg32": 1,
           },
       }
   
       print("[repro] target:", target)
       print("[repro] tir_pipeline: explicit_default")
       print("[repro] compiling with tvm.compile ...")
       with tvm.transform.PassContext(**pc):
           tvm.compile(ir_mod, target=target, relax_pipeline=relax_pipeline, 
tir_pipeline=tir_pipeline)
   
   
   if __name__ == "__main__":
       main()
   ```
   
   ---
   
   ## Actual Behavior
   
   Segfault during compilation. Stack trace shows PTX rewrite pass even though 
target is LLVM:
   
   ```text
   tvm::tir::BufferStore::BufferStore(...)
   tvm::tir::PTXRewriter::VisitStmt_(tvm::tir::BufferStoreNode const*)
   ...
   tvm::tir::transform::InjectPTXLDG32(bool)
   Segmentation fault (core dumped)
   ```
   
   ---
   
   ## Expected Behavior
   
   On a CPU/LLVM target:
   
   1. Setting `tir.ptx_ldg32=1` should either be ignored (no-op) or rejected 
with a clear error message, and
   2. PTX-specific passes such as `InjectPTXLDG32` should **not run** for 
non-CUDA targets, and
   3. TVM should never segfault; failures should be surfaced as Python 
exceptions with diagnostics.
   
   
   ### Triage
   
   Please refer to the list of label tags 
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the 
relevant tags and add them below in a bullet format (example below).
   
   * needs-triage
   * bug
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to