Hi,

I am writing a simple gemm code and speeding up the code with tvm autoschedule. 
But I now have some difficulties in understanding what the lower tir is doing 
and how I should interpret what the auto-scheduler did to accelerate the code. 
I have pasted the base lower tir code and the accelerated version below. It 
would help a lot if someone could point me to the documents of lower tir 
scripts or help me interpret the scripts!

    # The code that I am trying to accelerate
    @auto_scheduler.register_workload  # Note the auto_scheduler decorator
    def matmul(M, N, K, dtype):
        A = te.placeholder((M, K), name="A", dtype=dtype)
        B = te.placeholder((K, N), name="B", dtype=dtype)

        k = te.reduce_axis((0, K), name="k")
        matmul = te.compute(
            (M, N),
            lambda i, j: te.sum(A[i, k] * B[k, j], axis=k),
            name="matmul",
            attrs={"layout_free_placeholders": [B]},  # enable automatic layout 
transform for tensor B
        )

        return [A, B, matmul]

    Base script:
        @main = primfn(A_1: handle, B_1: handle, matmul_1: handle) -> ()
          attr = {"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True}
          buffers = {matmul: Buffer(matmul_2: Pointer(uint8), uint8, [32, 128], 
[]),
                     B: Buffer(B_2: Pointer(uint8), uint8, [64, 128], []),
                     A: Buffer(A_2: Pointer(uint8), uint8, [32, 64], [])}
          buffer_map = {A_1: A, B_1: B, matmul_1: matmul} {
          for (i: int32, 0, 32) {
            for (j: int32, 0, 128) {
              matmul_2[((i*128) + j)] = 0u8
              for (k: int32, 0, 64) {
                let cse_var_1: int32 = ((i*128) + j)
                matmul_2[cse_var_1] = ((uint8*)matmul_2[cse_var_1] + 
((uint8*)A_2[((i*64) + k)]*(uint8*)B_2[((k*128) + j)]))
              }
            }
          }
        }

    Autotuned script:
        @main = primfn(A_1: handle, B_1: handle, matmul_1: handle) -> ()
          attr = {"from_legacy_te_schedule": True, "global_symbol": "main", 
"tir.noalias": True}
          buffers = {matmul: Buffer(matmul_2: Pointer(float32), float32, [2048, 
2048], []),
                     B: Buffer(B_2: Pointer(float32), float32, [2048, 2048], 
[]),
                     A: Buffer(A_2: Pointer(float32), float32, [2048, 2048], 
[])}
          buffer_map = {A_1: A, B_1: B, matmul_1: matmul} {
          allocate(auto_scheduler_layout_transform: Pointer(global float32), 
float32, [4194304]), storage_scope = global {
            for (ax0.ax1.fused.ax2.fused: int32, 0, 32) "parallel" {
              for (ax4: int32, 0, 64) {
                for (ax6: int32, 0, 32) {
                  for (ax7: int32, 0, 64) {
                    
auto_scheduler_layout_transform[((((ax0.ax1.fused.ax2.fused*131072) + 
(ax4*2048)) + (ax6*64)) + ax7)] = (float32*)B_2[((((ax4*65536) + (ax6*2048)) + 
(ax0.ax1.fused.ax2.fused*64)) + ax7)]
                  }
                }
              }
            }
            for (i.outer.outer.j.outer.outer.fused.i.outer.inner.fused: int32, 
0, 512) "parallel" {
              allocate(matmul.local: Pointer(local float32), float32, [2048]), 
storage_scope = local;
              for (j.outer.inner: int32, 0, 4) {
                for (i.c.outer.inner.init: int32, 0, 32) {
                  for (j.c.inner.init: int32, 0, 64) {
                    matmul.local[((i.c.outer.inner.init*64) + j.c.inner.init)] 
= 0f32
                  }
                }
                for (k.outer: int32, 0, 64) {
                  for (i.c.outer.inner: int32, 0, 32) {
                    for (k.inner: int32, 0, 32) {
                      for (j.c.inner: int32, 0, 64) {
                        let cse_var_1: int32 = ((i.c.outer.inner*64) + 
j.c.inner)
                        matmul.local[cse_var_1] = 
((float32*)matmul.local[cse_var_1] + 
((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused,
 256)*2097152) + 
(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 32)*65536)) + 
(i.c.outer.inner*2048)) + (k.outer*32)) + 
k.inner)]*(float32*)auto_scheduler_layout_transform[(((((floordiv(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused,
 256), 32)*524288) + (j.outer.inner*131072)) + (k.outer*2048)) + (k.inner*64)) 
+ j.c.inner)]))
                      }
                    }
                  }
                }
                for (i.inner: int32, 0, 32) {
                  for (j.inner: int32, 0, 64) {
                    
matmul_2[((((((floordiv(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 
256)*2097152) + 
(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 32)*65536)) + 
(i.inner*2048)) + 
(floordiv(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256), 
32)*256)) + (j.outer.inner*64)) + j.inner)] = 
(float32*)matmul.local[((i.inner*64) + j.inner)]
                  }
                }
              }
            }
          }
        }





---
[Visit Topic](https://discuss.tvm.apache.org/t/understanding-lower-tir/12977/1) 
to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/2b6fb096b94de45ccfac86869274cd7d9dbd522ef913b3649dea3dafea3f8dad).

Reply via email to