dutZ1855 opened a new issue, #18607: URL: https://github.com/apache/tvm/issues/18607
### Expected behavior TVM should be able to import and compile an ONNX `PRelu` model when the `slope` input is a **broadcastable initializer** (e.g. shape `(1,1)`). Per the ONNX `PRelu` operator spec, `slope` is allowed to be smaller-rank than `X` as long as it can be broadcast to `X` (unidirectional broadcasting): - Spec: https://onnx.ai/onnx/operators/onnx__PRelu.html ### Actual behavior For the following model, <img width="210" height="297" alt="Image" src="https://github.com/user-attachments/assets/88164499-d512-401a-8b1f-14ea4c308d8a" /> When importing the attached model with TVM Relax (`tvm.relax.frontend.onnx.from_onnx`) and applying `relax.transform.LegalizeOps`, TVM fails with an `AssertionError` raised from TOPI `topi.nn.prelu`: - `assert len(slope.shape) == 1` ``` Traceback (most recent call last): File "DLCompilers/bug/tvm/prelu_slope_rank_broadcast/run_repro.py", line 58, in _run_tvm mod = relax.transform.LegalizeOps()(mod) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "DLCompilers/tvm/python/tvm/ir/transform.py", line 167, in __call__ return _ffi_transform_api.RunPass(self, mod) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "python/tvm_ffi/cython/function.pxi", line 904, in tvm_ffi.core.Function.__call__ File "DLCompilers/tvm/src/ir/transform.cc", line 544, in operator() [](Pass pass, ffi::RValueRef<IRModule> mod) { return pass(*std::move(mod)); }); File "DLCompilers/tvm/src/ir/transform.cc", line 290, in tvm::transform::Pass::operator()(tvm::IRModule) const return this->operator()(std::move(mod), PassContext::Current()); File "DLCompilers/tvm/src/ir/transform.cc", line 306, in tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const ret = node->operator()(std::move(mod), pass_ctx); File "DLCompilers/tvm/src/ir/transform.cc", line 414, in tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const mod = pass_func(std::move(mod), pass_ctx); File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 416, in operator() mod = LegalizeMutator(mod, cmap, skip_ops, enable_warning).Transform(); File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 84, in tvm::relax::LegalizeMutator::Transform() auto updated_func = Downcast<Function>(this->VisitExpr(func)); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&) return builder_->Normalize(ExprFunctor::VisitExpr(expr)); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&) return vtable(n, this, std::forward<Args>(args)...); File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#8}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#8}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 585, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*) Expr body = this->VisitWithNewScope(op->body, params); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 817, in tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelaxExpr const&, tvm::ffi::Optional<tvm::ffi::Array<tvm::relax::Var, void>, void>) Expr ret = this->VisitExpr(expr); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&) return builder_->Normalize(ExprFunctor::VisitExpr(expr)); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&) return vtable(n, this, std::forward<Args>(args)...); File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#10}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 628, in tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*) BindingBlock new_block = this->VisitBindingBlock(block); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 776, in tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&) ret = VisitBindingBlock_(node); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 734, in tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::DataflowBlockNode const*) this->VisitBinding(binding); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 652, in tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::ConstantNode const*) RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode); File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&) return builder_->Normalize(ExprFunctor::VisitExpr(expr)); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::VisitExpr(tvm::RelaxExpr const&) return vtable(n, this, std::forward<Args>(args)...); File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)>::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::_FUN(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)#9}::operator()(tvm::ffi::ObjectRef const&, tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const RELAX_EXPR_FUNCTOR_DISPATCH(CallNode); File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 357, in tvm::relax::LegalizeMutator::VisitExpr_(tvm::relax::CallNode const*) Expr legalized = legalization_func(builder_, visited_call); File "python/tvm_ffi/cython/function.pxi", line 1058, in tvm_ffi.core.tvm_ffi_callback File "DLCompilers/tvm/python/tvm/relax/transform/legalize_ops/nn.py", line 493, in _nn_prelu return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis) File "DLCompilers/tvm/python/tvm/relax/block_builder.py", line 361, in call_te tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, *args, **kwargs) File "DLCompilers/tvm/python/tvm/relax/utils.py", line 355, in gen_call_tir_inputs te_out = func(*te_args, **te_kwargs) File "DLCompilers/tvm/python/tvm/te/tag.py", line 57, in tagged_fdecl return fdecl(*args, **kwargs) File "DLCompilers/tvm/python/tvm/topi/nn/elemwise.py", line 130, in prelu assert len(slope.shape) == 1 AssertionError ``` ONNX Runtime can execute the same model successfully. ``` ONNXRuntime: [array([[ 1.117622 , -5.554099 , -1.7080084 , -3.217593 , 0.60142773, -0.3002757 , 0.05969319, -0.12815356, -0.7426875 , 1.2047737 , 0.77745306, -5.438607 , 0.76982784, -3.484421 , 1.0997635 , -3.8377662 , -5.1048746 , -5.466864 , -5.903262 , 0.4335317 , -1.313807 , -3.8877916 ]], dtype=float32)] [ort] output y shape= (1, 22) dtype= float32 min/max= (-5.903262, 1.2047737) ``` ### Environment Operating System:Ubuntu 22.04.4 LTS TVM version:0.23.0dev pytorch version:2.9.1 ort version:1.23.2 onnx version: 1.20.0 python:3.11.14 ### Steps to reproduce [model.zip](https://github.com/user-attachments/files/24324974/model.zip) Download the model and run the following code to obtain the results. `python run_repro.py --model model.onnx --oracle oracle.pkl` ``` from __future__ import annotations import argparse import os import pickle import sys import traceback from pathlib import Path import numpy as np def _ensure_repo_tvm() -> None: repo_root = Path(__file__).resolve().parents[3] tvm_python = repo_root / "tvm" / "python" tvm_build = repo_root / "tvm" / "build" if tvm_python.exists(): sys.path.insert(0, tvm_python.as_posix()) if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists(): os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix() def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]: obj = pickle.loads(path.read_bytes()) inp = obj.get("input", obj) if not isinstance(inp, dict): raise ValueError("oracle.pkl does not contain a dict input") return {k: np.array(v) for k, v in inp.items()} def _run_ort(model_path: Path, inputs: dict[str, np.ndarray]) -> None: import onnxruntime as ort # type: ignore np.set_printoptions(threshold=np.inf, linewidth=200) sess = ort.InferenceSession(model_path.as_posix(), providers=["CPUExecutionProvider"]) outs = sess.run(None, inputs) outs_np = [np.array(v) for v in outs] print("ONNXRuntime:\n", outs_np) for o, a in zip(sess.get_outputs(), outs_np): print("[ort] output", o.name, "shape=", a.shape, "dtype=", a.dtype, "min/max=", (a.min(), a.max())) def _run_tvm(model_path: Path, inputs: dict[str, np.ndarray]) -> None: _ensure_repo_tvm() import onnx # type: ignore import tvm # type: ignore from tvm import relax # type: ignore from tvm.relax.frontend import onnx as rx_onnx # type: ignore onnx_model = onnx.load(model_path.as_posix()) shape_dict = {k: v.shape for k, v in inputs.items()} print("[tvm] shape_dict:", shape_dict) try: converted = rx_onnx.from_onnx(onnx_model, shape_dict=shape_dict) mod = converted[0] if isinstance(converted, (list, tuple)) else converted mod = relax.transform.DecomposeOpsForInference()(mod) # Expected to FAIL here due to topi.nn.prelu requiring 1-D slope mod = relax.transform.LegalizeOps()(mod) mod, params = relax.frontend.detach_params(mod) tgt = tvm.target.Target("llvm") pipeline = relax.pipeline.get_default_pipeline(tgt) with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": False}): _ = relax.build(mod, target=tgt, params=params, relax_pipeline=pipeline) print("[tvm] UNEXPECTED: succeeded") except Exception as e: print("[tvm] FAILED:", type(e).__name__) tb = traceback.format_exc() print(tb, end="" if tb.endswith("\n") else "\n") print("\n[tvm] error repr:\n" + repr(e)) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--model", type=Path, default=Path("model.onnx")) ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl")) args = ap.parse_args() model_path = args.model.resolve() oracle_path = args.oracle.resolve() inputs = _load_oracle_inputs(oracle_path) _run_ort(model_path, inputs) _run_tvm(model_path, inputs) return 0 if __name__ == "__main__": raise SystemExit(main()) ``` ### Triage * needs-triage -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
