tlopex commented on code in PR #18491:
URL: https://github.com/apache/tvm/pull/18491#discussion_r2597401528
##########
src/relax/backend/vm/vm_shape_lower.cc:
##########
@@ -399,6 +399,23 @@ class VMShapeLowerMutator
return ffi::GetRef<Expr>(op);
}
+ // Check if all expressions are computed if not mark variables as ready
and trigger computation
+ for (const PrimExpr& expr : op->values) {
+ if (!expr->IsInstance<IntImmNode>()) {
+ auto it = slot_map_.find(expr);
+ if (it != slot_map_.end() && !it->second->value_computed) {
+ // If it's a variable, mark it as ready for computation
+ if (expr.as<tir::VarNode>()) {
+ it->second->value_computed = true;
Review Comment:
Well, `ComputePrimValue` is intended only for evaluating statically
evaluable `PrimExpr` into `IntImm` (constant folding), so I think extending
ComputePrimValue would not address the root issue.
The real problem is that VMShapeLower cannot consume composite `PrimExpr`
directly. The correct solution here should be to canonicalize ShapeExpr earlier
by introducing a Relax ShapeVar binding for any non-trivial PrimExpr.
Just like:
```
# 1. Compute the symbolic value first (Canonicalization)
s1 = R.prim_value(n + 1)
# 2. Pass the computed var to the shape (VMShapeLower is happy now)
lv = R.call_tir(cls.func, (x,), R.shape([s1]), dtype="float32")
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]