For the integration of a new intrinsic, I would like to do a transformation to
a TIR schedule to inline the addition of a bias into a matrix multiplication. I
have created a very simple example to reproduce my problem, let's assume to
following PrimFunc:
```python
@T.prim_func
def func(
A: T.Buffer((16, 16), "int8"),
B: T.Buffer((16, 16), "int8"),
C: T.Buffer((16, 16), "int32"),
D: T.Buffer((16, 16), "int32"),
) -> None:
temp = T.alloc_buffer((16, 16), dtype="int32")
for i, j, k in T.grid(16, 16, 16):
with T.block("multiply"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
temp[vi, vj] = T.int32(0)
temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") *
T.cast(B[vj, vk], "int32")
for i, j in T.grid(16, 16):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = temp[vi, vj] + C[vi, vj]
```
I want to transform it to achieve the following:
```python
@T.prim_func
def expected_v1(
A: T.Buffer((16, 16), "int8"),
B: T.Buffer((16, 16), "int8"),
C: T.Buffer((16, 16), "int32"),
D: T.Buffer((16, 16), "int32"),
) -> None:
temp = T.alloc_buffer((16, 16), dtype="int32")
for i, j, k in T.grid(16, 16, 16):
with T.block("multiply"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
temp[vi, vj] = T.int32(0)
temp[vi, vj] = temp[vi, vj] + T.cast(A[vi, vk], "int32") *
T.cast(B[vj, vk], "int32") + C[vi, vj]
for i, j in T.grid(16, 16):
with T.block("add"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = temp[vi, vj]
```
Or, ideally:
```python
@T.prim_func
def expected_v2(
A: T.Buffer((16, 16), "int8"),
B: T.Buffer((16, 16), "int8"),
C: T.Buffer((16, 16), "int32"),
D: T.Buffer((16, 16), "int32"),
) -> None:
temp = T.alloc_buffer((16, 16), dtype="int32")
for i, j, k in T.grid(16, 16, 16):
with T.block("multiply"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
D[vi, vj] = C[vi, vj]
D[vi, vj] = D[vi, vj] + T.cast(A[vi, vk], "int32") * T.cast(B[vj,
vk], "int32")
```
As you can see, mathematically all these computations are equivalent, so I
would expect there is some way of getting there. But everything I tried failed.
I tried to use compute_inline in the multiply block, reverse_comput_inline in
the add block, decompose_reduction and then reverse_compute_inline...
Could someone confirm this is indeed not possible? And if that is the case,
why? These seem like valid transformations, that should be possible in some
way, but I am probably missing the reason why those aren't possible.
Here is some example code to show some of what I tried (returns error `Error
message: The consumer block tir.Block#0 to be inlined is required to have only
a single producer block, and the producer block should be a complete block who
has only a single consumer`):
```python
if __name__ == "__main__":
sch = tir.Schedule(func)
mult_block = sch.get_block("multiply")
init_block = sch.decompose_reduction(mult_block,
sch.get_loops(mult_block)[-1])
update_block = sch.get_block("multiply_update")
add_block = sch.get_block("add")
sch.cache_write(add_block, 0, "local")
sch.reverse_compute_inline(add_block)
```
---
[Visit
Topic](https://discuss.tvm.apache.org/t/tir-problem-inlining-addition-into-matmul-block/18066/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/68883c0cab2d389017bff8285870da636622f8aec57184b8363cca4a8c4f3398).