Starting from this IR:
```python
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A: T.Buffer((T.int64(1), T.int64(784)), "float32"), B: 
T.Buffer((T.int64(16), T.int64(784)), "float32"),
             T_matmul: T.Buffer((T.int64(1), T.int64(16)), "float32")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        for ax0, ax1, k in T.grid(T.int64(1), T.int64(16), T.int64(784)):
            with T.block("T_matmul"):
                v_ax0, v_ax1, v_k = T.axis.remap("SSR", [ax0, ax1, k])
                T.reads(A[v_ax0, v_k], B[v_ax1, v_k])
                T.writes(T_matmul[v_ax0, v_ax1])
                with T.init():
                    T_matmul[v_ax0, v_ax1] = T.float32(0)
                T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + A[v_ax0, v_k] 
* B[v_ax1, v_k]
```

I would like to apply a padding of 19 on the batch axis:
```
sch = tvm.tir.Schedule(Module)
block = sch.get_block("T_matmul")
sch.pad_einsum(block, [19, 0, 0])
```
However I get this error:
```
ScheduleError: An error occurred in the schedule primitive 'pad-einsum'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func
    def main(A_handle: T.handle, B_handle: T.handle, T_matmul_handle: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        A = T.match_buffer(A_handle, (T.int64(1), T.int64(784)))
        B = T.match_buffer(B_handle, (T.int64(16), T.int64(784)))
        T_matmul = T.match_buffer(T_matmul_handle, (T.int64(1), T.int64(16)))
        with T.block("root"):
            T.reads()
            T.writes()
            for ax0 in range(T.int64(1)):
                for ax1 in range(T.int64(16)):
                    for k in range(T.int64(784)):
                        # tir.Block#0
                        with T.block("T_matmul"):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^
                            v_ax0 = T.axis.spatial(T.int64(1), ax0)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            v_ax1 = T.axis.spatial(T.int64(16), ax1)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            v_k = T.axis.reduce(T.int64(784), k)
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.reads(A[v_ax0, v_k], B[v_ax1, v_k])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T.writes(T_matmul[v_ax0, v_ax1])
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            with T.init():
                            ^^^^^^^^^^^^^^
                                T_matmul[v_ax0, v_ax1] = T.float32(0)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            T_matmul[v_ax0, v_ax1] = T_matmul[v_ax0, v_ax1] + 
A[v_ax0, v_k] * B[v_ax1, v_k]
                            
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The padding for the block tir.Block#0 are invalid. It should be 
a list of 3 positive integers. Got [19, 0, 0]
```

Changing the padding value to something like `[19, 1, 1]` does work however.
Why does this fail?
How can I apply padding solely on the batch axis in this case?





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/tir-schedule-pad-einsum-not-working-with-padding-only-on-the-batch-axis/15500/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/5d8c24934c19ddaa415ce53037ddcd2755e47c58bdbd3ef053ad9a182b5f4eb0).

Reply via email to