dutZ1855 opened a new issue, #18605: URL: https://github.com/apache/tvm/issues/18605
TVM shows **inconsistent ONNX Cast(to=BOOL) behavior for NaN depending on how NaN is produced**: - **Direct NaN constant**: `Constant(NaN) -> Cast` returns **`True`** (matches ONNX Runtime / PyTorch). - **NaN produced by computation**: `x -> (NaN-producing op) -> Cast` returns **`False`** in TVM, while ONNX Runtime / PyTorch return **`True`**. ### Expected behavior Per ONNX `Cast` operator spec for casting from floating point to bool: - `+/-0.0` → `False` - **all else** → `True` Therefore: - `Cast(NaN -> bool)` **should be `True`** (NaN is not `+0.0`/`-0.0`, so it falls under “all else”). - In this repro, `Asin(5.0)` is NaN because arcsine’s real domain is `[-1, 1]`, so the final output should be `True`. ### Actual behavior Taking this model as an example: <img width="218" height="281" alt="Image" src="https://github.com/user-attachments/assets/6d3c89bd-fa2c-4a62-a303-c49c6ca48be9" /> Repro model (computed NaN → Cast): `Constant(5.0) -> Asin -> Cast(to=BOOL)` (opset 18, input-free) - **ONNX Runtime**: `True` - **PyTorch**: `True` - **TVM (Relax, LLVM target)**: `False` And we have also tried other possible ways to generate NAN: - `Asin(x)` with `x=5.0` - `Acos(x)` with `x=2.0` - `Sqrt(x)` with `x=-1.0` - `Log(x)` with `x=-1.0` - `Div(x, x)` with `x=0.0` (0/0) The results are consistent with the above. ### Environment Operating System:Ubuntu 22.04.4 LTS TVM version:0.23.0dev pytorch version:2.9.1 ort version:1.23.2 onnx version: 1.20.0 python:3.11.14 ### Steps to reproduce [model.zip](https://github.com/user-attachments/files/24312598/model.zip) Download the model and run the following code to obtain the results. `python cast_compare.py --model model.onnx` ``` from __future__ import annotations import argparse import os import sys from pathlib import Path import numpy as np import onnx def _ensure_repo_tvm() -> None: """ Avoid mixing: - repo TVM python (newer) - site-packages TVM runtime (older) Force-import TVM from this repo's `tvm/python`, and point TVM to `tvm/build`. """ repo_root = Path(__file__).resolve().parents[3] tvm_python = repo_root / "tvm" / "python" tvm_build = repo_root / "tvm" / "build" if tvm_python.exists(): sys.path.insert(0, tvm_python.as_posix()) if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists(): os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix() for k in list(sys.modules.keys()): if k == "tvm" or k.startswith("tvm."): del sys.modules[k] def _run_torch() -> bool | None: try: import torch except Exception: return None # Directly test the Cast semantics on NaN. a = torch.tensor(float("nan"), dtype=torch.float32) y = a.to(torch.bool) return bool(y.item()) def _run_ort(model_bytes: bytes) -> bool: import onnxruntime as ort # type: ignore sess = ort.InferenceSession(model_bytes, providers=["CPUExecutionProvider"]) outs = sess.run(None, {}) if len(outs) != 1: raise RuntimeError(f"ORT returned {len(outs)} outputs, expected 1") y = np.array(outs[0]).item() return bool(y) def _run_tvm(model_path: Path) -> bool: _ensure_repo_tvm() import tvm # type: ignore from tvm import relax # type: ignore from tvm.relax.frontend import onnx as rx_onnx # type: ignore onnx_model = onnx.load(model_path.as_posix()) converted = rx_onnx.from_onnx(onnx_model, shape_dict={}) mod = converted[0] if isinstance(converted, (list, tuple)) else converted tgt = tvm.target.Target("llvm") pipeline = relax.pipeline.get_default_pipeline(tgt) with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}): ex = relax.build(mod, target=tgt, relax_pipeline=pipeline) vm = relax.VirtualMachine(ex, tvm.cpu()) vm.set_input("main") vm.invoke_stateful("main") out = vm.get_outputs("main") if isinstance(out, tuple): out = out[0] if hasattr(out, "numpy"): arr = out.numpy() else: arr = np.array(out) return bool(np.array(arr).item()) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--model", type=Path, default=Path("cast_nan_to_bool.onnx")) args = ap.parse_args() model_path = args.model.resolve() if not model_path.exists(): print("error: model not found:", model_path) return 1 model_bytes = model_path.read_bytes() y_ort = _run_ort(model_bytes) y_torch = _run_torch() y_tvm = _run_tvm(model_path) # Minimal output: just the three backend results. print("ort :", y_ort) print("torch:", "skip" if y_torch is None else y_torch) print("tvm :", y_tvm) return 0 if __name__ == "__main__": raise SystemExit(main()) ``` ### Triage * needs-triage -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
