Hi all,
Currently, we are working on the tensorization for some abstracted intrinsics. Once tensorized, another pass can do the HW intrinsics selection. So, in the tensorize step, we do not care about any tail loop iterations (uneven split cases). Here is a test case to do that. ``` import tvm from tvm import te def intrin_vadd(xo, m, n): x = te.placeholder((n,), name="vx") y = te.placeholder((n,), name="vy") if m % n == 0: body = lambda i: x[i] + y[i] else: body = lambda i: tvm.tir.Select(xo * n + i < m, x[i] + y[i], tvm.tir.const(0, dtype=x.dtype)) z = te.compute(x.shape, body, name="z") def intrin_func(ins, outs): xx, yy = ins zz = outs[0] return tvm.tir.call_packed("vadd", xx, yy, zz) buffer_params = {"offset_factor": 16} return te.decl_tensor_intrin(z.op, intrin_func, default_buffer_params=buffer_params) def add(m): x = te.placeholder((m,), name="x") y = te.placeholder((m,), name="y") z = te.compute(x.shape, lambda i: x[i] + y[i], name="z") return x, y, z def check_cache_write(m, factor): x, y, z = add(m) s = te.create_schedule(z.op) _, _ = s[z].split(z.op.axis[0], factor=factor) z_global = s.cache_write(z, "global") xo, xi = z_global.op.axis cond = xo * factor + xi < m vadd = intrin_vadd(xo, m, factor) s[z_global].tensorize(xi, vadd) tvm.lower(s, [x, y, z]) check_cache_write(129, 16) ``` After splitting the axis of z, there will be a condition like `xo * factor + xi < m` to protect the computation. I follow the same IR to describe the intrinsic but got an unmatch error. ``` File "~/apache/tvm/src/te/operation/tensorize.cc", line 336 TVMError: --------------------------------------------------------------- An internal invariant was violated during the execution of TVM. Please read TVM's error reporting guidelines. More details can be found here: https://discuss.tvm.ai/t/error-reporting/7793. --------------------------------------------------------------- Check failed: expr_equal(lhs, rhs) == false: Failed to match the compute with TensorIntrin tensor_intrin's declaration provided= select((((i.outer.c*16) + i) < 129), (vx[i] + vy[i]), 0f), intrin= select((((i.outer.c*16) + i) < 129), (vx[i] + vy[i]), 0f) ``` The compute bodies are exactly the same, but the address of **i.outer.c** are not equal. I have done some investigation and found that after the schedule.normalize() pass, the outer loop var `xo` will be rebased and turned into a new IterVarNode with a new address. > ref: [tvm/schedule_dataflow_rewrite.cc at main · apache/tvm > (github.com)](https://github.com/apache/tvm/blob/main/src/te/schedule/schedule_dataflow_rewrite.cc#L482) Then, It is failed to check because the address of i.outer.c is not matched with the one in compute body. As the opposite, compute op ran the substitution. > ref: [tvm/compute_op.cc at main · apache/tvm > (github.com)](https://github.com/apache/tvm/blob/main/src/te/operation/compute_op.cc#L372) > and [tvm/compute_op.cc at main · apache/tvm > (github.com)](https://github.com/apache/tvm/blob/main/src/te/operation/compute_op.cc#L382) Finally, I try to fix this issue by change this line [tvm/tensorize.cc at main · apache/tvm (github.com)](https://github.com/apache/tvm/blob/main/src/te/operation/tensorize.cc#L330) to ``` PrimExpr rhs = ana.Simplify(Substitute(intrin_compute->body[i], value_map)); ``` and it works. There is a similar issue as below. * [how to add split condition inside tensorize scope - Questions - Apache TVM Discuss](https://discuss.tvm.apache.org/t/how-to-add-split-condition-inside-tensorize-scope/790) All in all, this issue raised some questions. 1. Does Tensorize has this constraint deliberately? 2. Can we use Tensorize to cover the tail loop cases? We will process the tail part on the HW side or in another pass. --- [Visit Topic](https://discuss.tvm.apache.org/t/tensorize-support-conditions-inside-tensorize-scope/9148/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/0e8f2952c481525758a345cd03371c2079446c724dc232b8a1f7ac4f0c65afb5).