guan404ming commented on issue #18315:
URL: https://github.com/apache/tvm/issues/18315#issuecomment-3591822500
This worked well for current main:
```
# repro_tvm_bool_flag_crash.py
import torch
import torch.nn as nn
from torch.export import export as torch_export
from torch.utils.dlpack import to_dlpack
import tvm
from tvm import relax
from tvm.relax.frontend.torch import from_exported_program
class TinyCond(nn.Module):
# Minimal conditional: `flag` is a 0-D boolean
def forward(self, x, flag):
t = (x + 1).sum()
f = (x - 1).sum()
return torch.where(flag, t, f)
def main():
torch.manual_seed(0)
m = TinyCond().eval()
# Inputs: x is regular tensor; flag is 0-D torch.bool scalar
x = torch.randn(2, 3)
flag = torch.randint(0, 2, (), dtype=torch.bool)
# Sanity check on PyTorch side
with torch.inference_mode():
_ = m(x, flag)
# Export → Relax
ep = torch_export(m, (x, flag))
mod = from_exported_program(ep)
# Build for CPU to minimize deps
target = "llvm"
dev = tvm.cpu(0)
exec_mod = relax.build(mod, target=target)
vm = relax.VirtualMachine(exec_mod, dev)
# Critical step: feed 0-D torch.bool via DLPack
tvm_x = tvm.runtime.from_dlpack(to_dlpack(x))
tvm_flag = tvm.runtime.from_dlpack(to_dlpack(flag))
print("About to call TVM VM with (x, flag) where flag is 0-d torch.bool
via DLPack ...")
# Expect crash: InternalError/ValueError from CheckTensorInfo on dtype
mismatch
vm["main"](tvm_x, tvm_flag)
if __name__ == "__main__":
main()
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]