dutZ1855 opened a new issue, #18607:
URL: https://github.com/apache/tvm/issues/18607

   ### Expected behavior
   
   TVM should be able to import and compile an ONNX `PRelu` model when the 
`slope` input is a **broadcastable initializer** (e.g. shape `(1,1)`).
   
   Per the ONNX `PRelu` operator spec, `slope` is allowed to be smaller-rank 
than `X` as long as it can be broadcast to `X` (unidirectional broadcasting):
   
   - Spec: https://onnx.ai/onnx/operators/onnx__PRelu.html
   ### Actual behavior
   For the following model,
   <img width="210" height="297" alt="Image" 
src="https://github.com/user-attachments/assets/88164499-d512-401a-8b1f-14ea4c308d8a";
 />
   
   When importing the attached model with TVM Relax 
(`tvm.relax.frontend.onnx.from_onnx`) and applying 
`relax.transform.LegalizeOps`, TVM fails with an `AssertionError` raised from 
TOPI `topi.nn.prelu`:
   
   - `assert len(slope.shape) == 1`
   
   ```
   Traceback (most recent call last):
     File "DLCompilers/bug/tvm/prelu_slope_rank_broadcast/run_repro.py", line 
58, in _run_tvm
       mod = relax.transform.LegalizeOps()(mod)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "DLCompilers/tvm/python/tvm/ir/transform.py", line 167, in __call__
       return _ffi_transform_api.RunPass(self, mod)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
     File "python/tvm_ffi/cython/function.pxi", line 904, in 
tvm_ffi.core.Function.__call__
     File "DLCompilers/tvm/src/ir/transform.cc", line 544, in operator()
       [](Pass pass, ffi::RValueRef<IRModule> mod) { return 
pass(*std::move(mod)); });
       
     File "DLCompilers/tvm/src/ir/transform.cc", line 290, in 
tvm::transform::Pass::operator()(tvm::IRModule) const
       return this->operator()(std::move(mod), PassContext::Current());
       
     File "DLCompilers/tvm/src/ir/transform.cc", line 306, in 
tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext 
const&) const
       ret = node->operator()(std::move(mod), pass_ctx);
       
     File "DLCompilers/tvm/src/ir/transform.cc", line 414, in 
tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
       mod = pass_func(std::move(mod), pass_ctx);
       
     File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 416, in 
operator()
       mod = LegalizeMutator(mod, cmap, skip_ops, enable_warning).Transform();
       
     File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 84, in 
tvm::relax::LegalizeMutator::Transform()
       auto updated_func = Downcast<Function>(this->VisitExpr(func));
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in 
tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
       return builder_->Normalize(ExprFunctor::VisitExpr(expr));
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::VisitExpr(tvm::RelaxExpr const&)
       return vtable(n, this, std::forward<Args>(args)...);
       
     File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in 
tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)>::operator()(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
       return (*func_[n->type_index() - begin_type_index_])(n, 
std::forward<Args>(args)...);
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)#8}::_FUN(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
       RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode);
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 170, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)#8}::operator()(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
       RELAX_EXPR_FUNCTOR_DISPATCH(FunctionNode);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 585, in 
tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
       Expr body = this->VisitWithNewScope(op->body, params);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 817, in 
tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelaxExpr const&, 
tvm::ffi::Optional<tvm::ffi::Array<tvm::relax::Var, void>, void>)
       Expr ret = this->VisitExpr(expr);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in 
tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
       return builder_->Normalize(ExprFunctor::VisitExpr(expr));
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::VisitExpr(tvm::RelaxExpr const&)
       return vtable(n, this, std::forward<Args>(args)...);
       
     File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in 
tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)>::operator()(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
       return (*func_[n->type_index() - begin_type_index_])(n, 
std::forward<Args>(args)...);
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)#10}::_FUN(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
       RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode);
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 172, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)#10}::operator()(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
       RELAX_EXPR_FUNCTOR_DISPATCH(SeqExprNode);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 628, in 
tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
       BindingBlock new_block = this->VisitBindingBlock(block);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 776, in 
tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
       ret = VisitBindingBlock_(node);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 734, in 
tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::DataflowBlockNode 
const*)
       this->VisitBinding(binding);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 652, in 
tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, 
tvm::relax::ConstantNode const*)
       RELAX_EXPR_MUTATOR_VISIT_BINDING_IMPL(ConstantNode);
       
     File "DLCompilers/tvm/src/relax/ir/expr_functor.cc", line 554, in 
tvm::relax::ExprMutator::VisitExpr(tvm::RelaxExpr const&)
       return builder_->Normalize(ExprFunctor::VisitExpr(expr));
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 132, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::VisitExpr(tvm::RelaxExpr const&)
       return vtable(n, this, std::forward<Args>(args)...);
       
     File "DLCompilers/tvm/include/tvm/node/functor.h", line 102, in 
tvm::NodeFunctor<tvm::RelaxExpr (tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)>::operator()(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
       return (*func_[n->type_index() - begin_type_index_])(n, 
std::forward<Args>(args)...);
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)#9}::_FUN(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*)
       RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);
       
     File "DLCompilers/tvm/include/tvm/relax/expr_functor.h", line 171, in 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>::InitVTable()::{lambda(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr 
const&)>*)#9}::operator()(tvm::ffi::ObjectRef const&, 
tvm::relax::ExprFunctor<tvm::RelaxExpr (tvm::RelaxExpr const&)>*) const
       RELAX_EXPR_FUNCTOR_DISPATCH(CallNode);
       
     File "DLCompilers/tvm/src/relax/transform/legalize_ops.cc", line 357, in 
tvm::relax::LegalizeMutator::VisitExpr_(tvm::relax::CallNode const*)
       Expr legalized = legalization_func(builder_, visited_call);
       
     File "python/tvm_ffi/cython/function.pxi", line 1058, in 
tvm_ffi.core.tvm_ffi_callback
     File "DLCompilers/tvm/python/tvm/relax/transform/legalize_ops/nn.py", line 
493, in _nn_prelu
       return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], 
call.attrs.axis)
       
     File "DLCompilers/tvm/python/tvm/relax/block_builder.py", line 361, in 
call_te
       tir_func, call_args, output_sinfo, tir_vars = gen_call_tir_inputs(func, 
*args, **kwargs)
       
     File "DLCompilers/tvm/python/tvm/relax/utils.py", line 355, in 
gen_call_tir_inputs
       te_out = func(*te_args, **te_kwargs)
       
     File "DLCompilers/tvm/python/tvm/te/tag.py", line 57, in tagged_fdecl
       return fdecl(*args, **kwargs)
       
     File "DLCompilers/tvm/python/tvm/topi/nn/elemwise.py", line 130, in prelu
       assert len(slope.shape) == 1
       
   AssertionError
   ```
   
   ONNX Runtime can execute the same model successfully.
   ```
   ONNXRuntime:
    [array([[ 1.117622  , -5.554099  , -1.7080084 , -3.217593  ,  0.60142773, 
-0.3002757 ,  0.05969319, -0.12815356, -0.7426875 ,  1.2047737 ,  0.77745306, 
-5.438607  ,  0.76982784, -3.484421  ,
            1.0997635 , -3.8377662 , -5.1048746 , -5.466864  , -5.903262  ,  
0.4335317 , -1.313807  , -3.8877916 ]], dtype=float32)]
   [ort] output y shape= (1, 22) dtype= float32 min/max= (-5.903262, 1.2047737)
   ```
   
   ### Environment
   
   Operating System:Ubuntu 22.04.4 LTS
   TVM version:0.23.0dev
   pytorch version:2.9.1
   ort version:1.23.2
   onnx version: 1.20.0
   python:3.11.14
   
   ### Steps to reproduce
   
   [model.zip](https://github.com/user-attachments/files/24324974/model.zip)
   
   Download the model and run the following code to obtain the results.
   `python run_repro.py --model model.onnx --oracle oracle.pkl`
   ```
   from __future__ import annotations
   
   import argparse
   import os
   import pickle
   import sys
   import traceback
   from pathlib import Path
   
   import numpy as np
   
   
   def _ensure_repo_tvm() -> None:
       repo_root = Path(__file__).resolve().parents[3]
       tvm_python = repo_root / "tvm" / "python"
       tvm_build = repo_root / "tvm" / "build"
       if tvm_python.exists():
           sys.path.insert(0, tvm_python.as_posix())
       if "TVM_LIBRARY_PATH" not in os.environ and tvm_build.exists():
           os.environ["TVM_LIBRARY_PATH"] = tvm_build.as_posix()
   
   
   def _load_oracle_inputs(path: Path) -> dict[str, np.ndarray]:
       obj = pickle.loads(path.read_bytes())
       inp = obj.get("input", obj)
       if not isinstance(inp, dict):
           raise ValueError("oracle.pkl does not contain a dict input")
       return {k: np.array(v) for k, v in inp.items()}
   
   
   def _run_ort(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
       import onnxruntime as ort  # type: ignore
   
       np.set_printoptions(threshold=np.inf, linewidth=200)
       sess = ort.InferenceSession(model_path.as_posix(), 
providers=["CPUExecutionProvider"])
       outs = sess.run(None, inputs)
       outs_np = [np.array(v) for v in outs]
       print("ONNXRuntime:\n", outs_np)
       for o, a in zip(sess.get_outputs(), outs_np):
           print("[ort] output", o.name, "shape=", a.shape, "dtype=", a.dtype, 
"min/max=", (a.min(), a.max()))
   
   
   def _run_tvm(model_path: Path, inputs: dict[str, np.ndarray]) -> None:
       _ensure_repo_tvm()
       import onnx  # type: ignore
       import tvm  # type: ignore
       from tvm import relax  # type: ignore
       from tvm.relax.frontend import onnx as rx_onnx  # type: ignore
   
       onnx_model = onnx.load(model_path.as_posix())
       shape_dict = {k: v.shape for k, v in inputs.items()}
       print("[tvm] shape_dict:", shape_dict)
       try:
           converted = rx_onnx.from_onnx(onnx_model, shape_dict=shape_dict)
           mod = converted[0] if isinstance(converted, (list, tuple)) else 
converted
           mod = relax.transform.DecomposeOpsForInference()(mod)
           # Expected to FAIL here due to topi.nn.prelu requiring 1-D slope
           mod = relax.transform.LegalizeOps()(mod)
           mod, params = relax.frontend.detach_params(mod)
           tgt = tvm.target.Target("llvm")
           pipeline = relax.pipeline.get_default_pipeline(tgt)
           with tvm.transform.PassContext(opt_level=3, 
config={"tir.enable_debug": False}):
               _ = relax.build(mod, target=tgt, params=params, 
relax_pipeline=pipeline)
           print("[tvm] UNEXPECTED: succeeded")
       except Exception as e:
           print("[tvm] FAILED:", type(e).__name__)
           tb = traceback.format_exc()
           print(tb, end="" if tb.endswith("\n") else "\n")
           print("\n[tvm] error repr:\n" + repr(e))
   
   
   def main() -> int:
       ap = argparse.ArgumentParser()
       ap.add_argument("--model", type=Path, default=Path("model.onnx"))
       ap.add_argument("--oracle", type=Path, default=Path("oracle.pkl"))
       args = ap.parse_args()
   
       model_path = args.model.resolve()
       oracle_path = args.oracle.resolve()
       inputs = _load_oracle_inputs(oracle_path)
   
       _run_ort(model_path, inputs)
       _run_tvm(model_path, inputs)
       return 0
   
   
   if __name__ == "__main__":
       raise SystemExit(main())
   ```
   
   
   
   
   ### Triage
   
   * needs-triage
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to