Starting from this IR:
```python
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A: T.Buffer((T.int64(1), T.int64(784)), "float32"), B:
T.Buffer((T.int64(16), T.int64(784)), "float32"),
T_matmul: T.Buffer((T.int64(1), T.int64(16)), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
for ax0, ax1, k in T.grid(T.int64(1), T.int64(16), T.int64(784)):
with T.block("T_matmul"):
v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
T.reads(A[v_ax0, v_k], B[v_ax1, v_k])
T.writes(T_matmul[v_ax0, v_ax1])
with T.init():
T_matmul[v_ax0, v_ax1] = T.float32(0)
T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k]
* B[v_ax1, v_k]
```
I would like to apply a padding of 19 on the batch axis:
```
sch = tvm.tir.Schedule(Module)
block = sch.get_block("T_matmul")
sch.pad_einsum(block, [19, 0, 0])
```
However I get this error:
```
ScheduleError: An error occurred in the schedule primitive 'pad-einsum'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
def main(A_handle: T.handle, B_handle: T.handle, T_matmul_handle: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
A = T.match_buffer(A_handle, (T.int64(1), T.int64(784)))
B = T.match_buffer(B_handle, (T.int64(16), T.int64(784)))
T_matmul = T.match_buffer(T_matmul_handle, (T.int64(1), T.int64(16)))
with T.block("root"):
T.reads()
T.writes()
for ax0 in range(T.int64(1)):
for ax1 in range(T.int64(16)):
for k in range(T.int64(784)):
# tir.Block#0
with T.block("T_matmul"):
^^^^^^^^^^^^^^^^^^^^^^^^^
v_ax0 = T.axis.spatial(T.int64(1), ax0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v_ax1 = T.axis.spatial(T.int64(16), ax1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
v_k = T.axis.reduce(T.int64(784), k)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.reads(A[v_ax0, v_k], B[v_ax1, v_k])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T.writes(T_matmul[v_ax0, v_ax1])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
with T.init():
^^^^^^^^^^^^^^
T_matmul[v_ax0, v_ax1] = T.float32(0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] +
A[v_ax0, v_k] * B[v_ax1, v_k]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The padding for the block tir.Block#0 are invalid. It should be
a list of 3 positive integers. Got [19, 0, 0]
```
Changing the padding value to something like `[19, 1, 1]` does work however.
Why does this fail?
How can I apply padding solely on the batch axis in this case?
---
[Visit
Topic](https://discuss.tvm.apache.org/t/tir-schedule-pad-einsum-not-working-with-padding-only-on-the-batch-axis/15500/1)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/5d8c24934c19ddaa415ce53037ddcd2755e47c58bdbd3ef053ad9a182b5f4eb0).