Starting from this IR: ```python # from tvm.script import ir as I # from tvm.script import tir as T
@I.ir_module class Module: @T.prim_func def main(A: T.Buffer((T.int64(1), T.int64(784)), "float32"), B: T.Buffer((T.int64(16), T.int64(784)), "float32"), T_matmul: T.Buffer((T.int64(1), T.int64(16)), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): for ax0, ax1, k in T.grid(T.int64(1), T.int64(16), T.int64(784)): with T.block("T_matmul"): v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k]) T.reads(A[v_ax0, v_k], B[v_ax1, v_k]) T.writes(T_matmul[v_ax0, v_ax1]) with T.init(): T_matmul[v_ax0, v_ax1] = T.float32(0) T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_ax1, v_k] ``` I would like to apply a padding of 19 on the batch axis: ``` sch = tvm.tir.Schedule(Module) block = sch.get_block("T_matmul") sch.pad_einsum(block, [19, 0, 0]) ``` However I get this error: ``` ScheduleError: An error occurred in the schedule primitive 'pad-einsum'. The IR with diagnostic is: # from tvm.script import ir as I # from tvm.script import tir as T @I.ir_module class Module: @T.prim_func def main(A_handle: T.handle, B_handle: T.handle, T_matmul_handle: T.handle): T.func_attr({"tir.noalias": T.bool(True)}) A = T.match_buffer(A_handle, (T.int64(1), T.int64(784))) B = T.match_buffer(B_handle, (T.int64(16), T.int64(784))) T_matmul = T.match_buffer(T_matmul_handle, (T.int64(1), T.int64(16))) with T.block("root"): T.reads() T.writes() for ax0 in range(T.int64(1)): for ax1 in range(T.int64(16)): for k in range(T.int64(784)): # tir.Block#0 with T.block("T_matmul"): ^^^^^^^^^^^^^^^^^^^^^^^^^ v_ax0 = T.axis.spatial(T.int64(1), ax0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v_ax1 = T.axis.spatial(T.int64(16), ax1) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ v_k = T.axis.reduce(T.int64(784), k) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.reads(A[v_ax0, v_k], B[v_ax1, v_k]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T.writes(T_matmul[v_ax0, v_ax1]) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ with T.init(): ^^^^^^^^^^^^^^ T_matmul[v_ax0, v_ax1] = T.float32(0) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] * B[v_ax1, v_k] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Error message: The padding for the block tir.Block#0 are invalid. It should be a list of 3 positive integers. Got [19, 0, 0] ``` Changing the padding value to something like `[19, 1, 1]` does work however. Why does this fail? How can I apply padding solely on the batch axis in this case? --- [Visit Topic](https://discuss.tvm.apache.org/t/tir-schedule-pad-einsum-not-working-with-padding-only-on-the-batch-axis/15500/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/5d8c24934c19ddaa415ce53037ddcd2755e47c58bdbd3ef053ad9a182b5f4eb0).