## Background and Motivation
Currently, TVM uses the `tir.Simplify` pass to remove some redundant expression
like nested equivalent if-condition. For example, given a simple softmax
operation like
```
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax",
"tir.noalias": True}
buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32),
float32, [2, 10, 257, 1025], []),
placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2,
10, 257, 1025], [])}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]),
storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")]
"thread_extent" = 6;
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")]
"thread_extent" = 1024 {
if @tir.likely((floordiv(floordiv((threadIdx.x + (blockIdx.x*1024)),
257), 10) < 2), dtype=bool) {
if @tir.likely((floordiv((threadIdx.x + (blockIdx.x*1024)), 257) < 20),
dtype=bool) {
if @tir.likely(((threadIdx.x + (blockIdx.x*1024)) < 5140),
dtype=bool) {
T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] =
-3.40282e+38f32
}
}
}
// ...
```
`tir.Simplify` will simplify this to
```
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax",
"tir.noalias": True}
buffers = {T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32),
float32, [2, 10, 257, 1025], []),
placeholder: Buffer(placeholder_2: Pointer(float32), float32, [2,
10, 257, 1025], [])}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [2, 10, 257]),
storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")]
"thread_extent" = 6;
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")]
"thread_extent" = 1024 {
if @tir.likely((((blockIdx.x*1024) + threadIdx.x) < 5140), dtype=bool) {
T_softmax_maxelem[((blockIdx.x*1024) + threadIdx.x)] = -3.40282e+38f32
}
// ...
```
where three equivalent condition will be simplified to one.
However, things will be different when the given input has a dynamic shape.
Current `tir.Simplify` will fail given an input with dynamic shape, this is
because the analyzer (actually the `RewriteSimplifier` and the
`CanonicalSimplifier`) used by this pass lacks corresponding rules for this
"non-const" situation. In the next part of this post we will continuous to use
this simple softmax example to discuss this problem. We try to fix this problem
by adding more rules in both `RewriteSimplifier` and `CanonicalSimplifier`.
Currently this is still an experimental idea, if you find something wrong or
improper, feel free to correct us in this post directly.
## Proposal
We will show our proposal by solving this problem in our simple softmax example
here. As shown in [this
post](https://discuss.tvm.apache.org/t/pre-rfc-dynamic-shape-use-sizevar-instead-of-var-when-convert-any-in-the-getshape-function/10625),
we can eliminate some redundant expressions by introducing sign information
into tensor shapes. But there are still some other redundant expressions that
are not covered by this solution. These redundant expressions can be eliminated
by the `tir.Simplify` pass when the input's shape is static as shown before.
For the dynamic situation, we list the reasons that prevent `tir.Simplify` from
simplifying these redundancy as follows:
1. `RewriteSimplifier` only has rules for `IntImm` to simplify `floordiv(x, c1)
< c2` to `x < c1 * c2`.
2. After the simplification from `floordiv(x, c1) < c2` to `x < c1 * c2`, we
can directly get a new constant `c3 = c1 * c2` providing `c1` and `c2` are
`IntImm`. But if we are given variables (or even worse, expressions), we cannot
distinguish between `v1 * v2` and `v2 * v1`.
For clarity, we use `SizeVar` `d0`, `d1`, `d2`, and `d3` for the shape in our
simple softmax example. The output of current `tir.Simplify` is
```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax",
"tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0:
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32,
stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32),
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32,
stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]),
storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")]
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")]
"thread_extent" = 512 {
if @tir.likely((floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2),
d1) < d0), dtype=bool) {
if @tir.likely((floordiv(((blockIdx.x*512) + threadIdx.x), d2) <
(d0*d1)), dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)),
dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] =
-3.40282e+38f32
}
}
}
```
To simplify `floordiv(((blockIdx.x*512) + threadIdx.x), d2) < (d0*d1)` and
`floordiv(floordiv(((blockIdx.x*512) + threadIdx.x), d2), d1) < d0` to
`(blockIdx.x*512) + threadIdx.x < ((d0*d1)*d2)`, we add a new rule in `PrimExpr
RewriteSimplifier::Impl::VisitExpr_(const LTNode* op);`:
```c++
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const LTNode* op) {
// ...
PVar<PrimExpr> x, y, z, s1, s2;
// ...
TVM_TRY_REWRITE_IF(floordiv(x, s1) < s2, x < s1 * s2,
analyzer_->const_int_bound(s1.Eval())->min_value >= 0);
// ...
}
```
Here comes our first worry. The corresponding `IntImm` version of this rule is
```c++
TVM_TRY_REWRITE_IF(floordiv(x, c1) * c1 < x, 0 < floormod(x, c1),
c1.Eval()->value > 0);
```
where the if-condition is `>0` instead of `>=0`. For `PrimExpr` version we can
only get the non-negative information from the `ConstIntBoundAnalyzer` (this
bound information comes from simple facts like `SizeVar + SizeVar >= 0` or
`SizeVar * SizeVar >= 0`). Although `=0` is an invalid case, this
transformation is not equivalent and may hide some run-time errors. [This
post](https://discuss.tvm.apache.org/t/discuss-embed-more-bound-information-into-var-or-expr/4079)
shows some possible solutions for this issue but there may be some simpler
solutions (if you have any idea, please share with us).
After adding this rule for rewrite simplify, we get:
```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax",
"tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0:
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32,
stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32),
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32,
stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]),
storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")]
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")]
"thread_extent" = 512 {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d1*d0))),
dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < (d2*(d0*d1))),
dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)),
dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] =
-3.40282e+38f32
}
}
}
```
Now the question becomes how to distinguish between `(d2*(d1*d0))`,
`(d2*(d0*d1))`, and `((d0*d1)*d2)`. We add a rule in `PrimExpr
CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op);` to get a canonical
form of multiplication:
```c++
PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
// ...
// normalize
PrimExpr a = this->CanonicalMutate(op->a);
PrimExpr b = this->CanonicalMutate(op->b);
// ...
// var * expr => expr * var
if (a.as<VarNode>() && !b.as<VarNode>()) {
std::swap(a, b);
}
// if given var * var or expr * expr, use their
// structural hash value to sort
if (a.as<VarNode>() || !b.as<VarNode>()) {
auto ah = StructuralHash()(a);
auto bh = StructuralHash()(b);
if (ah > bh) {
std::swap(a, b);
}
}
// ...
}
```
I think the method that uses the structural hash value for sorting is a bit
ugly, but I have no other better idea currently. After add this rule for
canonical simplify, we get:
```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax",
"tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0:
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32,
stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32),
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32,
stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]),
storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")]
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")]
"thread_extent" = 512 {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)),
dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)),
dtype=bool) {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)),
dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] =
-3.40282e+38f32
}
}
}
```
Next we need to find a way to remove these literally equivalent expressions.
Actually in `RewriteSimplify` there is such a mechanism:
```c++
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// add condition context to if_then_else
PrimExpr ret = IRMutatorWithAnalyzer::VisitExpr_(op);
op = ret.as<CallNode>();
// ...
ExprDeepEqual expr_equal;
if (op->op.same_as(tir::builtin::likely())) {
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (expr_equal(constraint, op->args[0])) {
return make_const(op->dtype, true);
}
}
}
return ret;
}
```
However, all `constraint`s in `literal_constraints_` have been processed by the
`CanonicalSimplify` when enter this constraint:
```c++
Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
PrimExpr condition = this->VisitExpr(op->condition); // HERE
PrimExpr real_condition = condition;
static auto op_likely = Op::Get("tir.likely");
if (auto call = condition.as<CallNode>()) {
if (call->op.same_as(op_likely)) {
real_condition = call->args[0];
}
}
Stmt then_case, else_case;
{
With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
}
if (op->else_case.defined()) {
With<ConstraintContext> ctx(analyzer_,
analyzer_->rewrite_simplify(Not(real_condition)));
else_case = this->VisitStmt(op->else_case);
}
// ...
}
```
while `op->args[0]` is not since we are in the `RewriteSimplify` and the
`CanonicalSimplify` is behind this process:
```c++
PrimExpr Analyzer::Simplify(const PrimExpr& expr, int steps) {
if (tir::is_const_int(expr)) return expr;
PrimExpr res = expr;
for (int i = 0; i < steps; ++i) {
res = this->rewrite_simplify(res); // RewriteSimplify
if (tir::is_const_int(res) || ++i == steps) return res; // is ++i proper
here?
res = this->canonical_simplify(res); //
CanonicalSimplify
if (tir::is_const_int(res)) return res;
}
return res;
}
```
This will make `op->args[0]` looks something like `((threadIdx.x: int32 +
(blockIdx.x: int32*512)) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))`
while the constraint looks like `(((blockIdx.x: int32*512) + threadIdx.x:
int32) < (((d1: int32*d0: int32)*d2: int32)*d3: int32))`. To solve this
problem, we can perform a canonical simplify before the comparison
```c++
PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) {
// ...
ExprDeepEqual expr_equal;
if (op->op.same_as(tir::builtin::likely())) {
auto condition = analyzer_->canonical_simplify(op->args[0]);
for (const auto& constraint : literal_constraints_) {
// Cases such as for (i, 0, bound) {if (likely(iter_var < bound)) { .. } }
if (expr_equal(constraint, condition)) {
return make_const(op->dtype, true);
}
}
}
return ret;
}
```
After that we get:
```c++
primfn(placeholder_1: handle, T_softmax_norm_1: handle) -> ()
attr = {"from_legacy_te_schedule": True, "global_symbol": "fused_nn_softmax",
"tir.noalias": True}
buffers = {placeholder: Buffer(placeholder_2: Pointer(float32), float32, [d0:
int32, d1: int32, d2: int32, d3: int32], [stride: int32, stride_1: int32,
stride_2: int32, stride_3: int32], type="auto"),
T_softmax_norm: Buffer(T_softmax_norm_2: Pointer(float32),
float32, [d0, d1, d2, d3], [stride_4: int32, stride_5: int32, stride_6: int32,
stride_7: int32], type="auto")}
buffer_map = {placeholder_1: placeholder, T_softmax_norm_1: T_softmax_norm} {
allocate(T_softmax_maxelem: Pointer(global float32), float32, [d0, d1, d2]),
storage_scope = global {
attr [IterVar(blockIdx.x: int32, (nullptr), "ThreadIndex", "blockIdx.x")]
"thread_extent" = floordiv((((d0*d1)*d2) + 511), 512);
attr [IterVar(threadIdx.x: int32, (nullptr), "ThreadIndex", "threadIdx.x")]
"thread_extent" = 512 {
if @tir.likely((((blockIdx.x*512) + threadIdx.x) < ((d0*d1)*d2)),
dtype=bool) {
T_softmax_maxelem[((blockIdx.x*512) + threadIdx.x)] = -3.40282e+38f32
}
```
Again, this is only an experimental idea and there are still some issues to be
solved. If you have any better ideas, please feel free to suggest below.
---
[Visit
Topic](https://discuss.tvm.apache.org/t/dynamic-shape-better-simplify-support-for-dynamic-boundary-check/10812/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/fd77531879ca165913e64c1d35153ce98bb45a6b625c747b81044893f89535e6).