dutZ1855 opened a new issue, #18606: URL: https://github.com/apache/tvm/issues/18606
### Expected behavior TVM should run the model correctly. ### Actual behavior For the following model, <img width="339" height="472" alt="Image" src="https://github.com/user-attachments/assets/fb2b9759-cab2-49d6-9933-075dc7f19ec6" /> <img width="497" height="364" alt="Image" src="https://github.com/user-attachments/assets/29aac0d7-9930-41b3-aa7d-9688bd394891" /> it can be executed by onnxruntime, the results are as follows: ONNXRuntime: ``` [array([5.627121 , 3.1315434, 6.488241 , 4.1415935, 5.716218 , 4.9254465, 4.222067 , 6.5320206, 4.5028763, 4.85569 , 4.468175 , 6.8016195, 5.443865 , 5.117418 , 5.684251 , 6.9260955, 4.885092 , 5.473915 , 5.516653 , 5.740866 , 5.9058757, 4.898214 , 5.5144825, 5.5342417, 4.397482 , 4.531957 , 4.3176513, 3.5980804, 4.0122795, 3.5998032, 5.380638 , 5.7510695], dtype=float32)] ``` However, when compiling and running the model using TVM, TVM crashes: `tvm.error.InternalError: In Op(relax.nn.prelu), the input axis 1 is out of range. The input tensor has 1 dimensions, so axis should be in range [-1, 1).` ### Environment Operating System:Ubuntu 22.04.4 LTS TVM version:0.23.0dev ort version:1.23.2 onnx version: 1.20.0 python:3.11.14 ### Steps to reproduce This bug can be reproduced by the following code with the model in the attachment. [model.zip](https://github.com/user-attachments/files/24314690/model.zip) ``` from __future__ import annotations import argparse import os import pickle import sys import traceback from pathlib import Path import numpy as np def _ensure_repo_tvm() -> None: repo_root = Path(__file__).resolve().parents[3] tvm_python = repo_root / "tvm" / "python" tvm_build = repo_root / "tvm" / "build" if tvm_python.exists(): sys.path.insert(0, tvm_python.as_posix()) if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists(): os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix() def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]: obj = pickle.loads(path.read_bytes()) inp = obj.get("input", obj) if not isinstance(inp, dict): raise ValueError("oracle.pkl does not contain a dict input") return {k: np.array(v) for k, v in inp.items()} def _run_ort(model_path: Path, inputs: dict[str, np.ndarray]) -> None: import onnxruntime as ort # type: ignore np.set_printoptions(threshold=np.inf, linewidth=200) sess = ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"]) outs = sess.run(None, inputs) out_names = [o.name for o in sess.get_outputs()] outs_np = [np.array(v) for v in outs] print("ONNXRuntime:\n", outs_np) for n, a in zip(out_names, outs_np): print("[ort] output", n, "shape=", a.shape, "dtype=", a.dtype, "min/max=", (a.min(), a.max())) def _run_tvm(model_path: Path, inputs: dict[str, np.ndarray]) -> None: _ensure_repo_tvm() import onnx # type: ignore import tvm # type: ignore from tvm import relax # type: ignore from tvm.relax.frontend import onnx as rx_onnx # type: ignore onnx_model = onnx.load(model_path.as_posix()) shape_dict = {k: v.shape for k, v in inputs.items()} print("[tvm] shape_dict:", shape_dict) try: converted = rx_onnx.from_onnx(onnx_model, shape_dict=shape_dict) mod = converted[0] if isinstance(converted, (list, tuple)) else converted mod = relax.transform.DecomposeOpsForInference()(mod) mod = relax.transform.LegalizeOps()(mod) mod, params = relax.frontend.detach_params(mod) tgt = tvm.target.Target("llvm") pipeline = relax.pipeline.get_default_pipeline(tgt) with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}): _ = relax.build(mod, target=tgt, params=params, relax_pipeline=pipeline) print("[tvm] UNEXPECTED: succeeded (no axis error)") except Exception as e: print("[tvm] FAILED:", type(e).__name__) traceback.print_exc() print("\n[tvm] error message:\n" + str(e)) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--model", type=Path, default=Path("prelu_axis_1d.onnx")) ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl")) args = ap.parse_args() model_path = args.model.resolve() oracle_path = args.oracle.resolve() inputs = _load_oracle_inputs(oracle_path) _run_ort(model_path, inputs) _run_tvm(model_path, inputs) return 0 if __name__ == "__main__": raise SystemExit(main()) ``` ### Triage * needs-triage -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
