Hi,
I am writing a simple gemm code and speeding up the code with tvm autoschedule. But I now have some difficulties in understanding what the lower tir is doing and how I should interpret what the auto-scheduler did to accelerate the code. I have pasted the base lower tir code and the accelerated version below. It would help a lot if someone could point me to the documents of lower tir scripts or help me interpret the scripts! # The code that I am trying to accelerate @auto_scheduler.register_workload # Note the auto_scheduler decorator def matmul(M, N, K, dtype): A = te.placeholder((M, K), name="A", dtype=dtype) B = te.placeholder((K, N), name="B", dtype=dtype) k = te.reduce_axis((0, K), name="k") matmul = te.compute( (M, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name="matmul", attrs={"layout_free_placeholders": [B]}, # enable automatic layout transform for tensor B ) return [A, B, matmul] Base script: @main = primfn(A_1: handle, B_1: handle, matmul_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {matmul: Buffer(matmul_2: Pointer(uint8), uint8, [32, 128], []), B: Buffer(B_2: Pointer(uint8), uint8, [64, 128], []), A: Buffer(A_2: Pointer(uint8), uint8, [32, 64], [])} buffer_map = {A_1: A, B_1: B, matmul_1: matmul} { for (i: int32, 0, 32) { for (j: int32, 0, 128) { matmul_2[((i*128) + j)] = 0u8 for (k: int32, 0, 64) { let cse_var_1: int32 = ((i*128) + j) matmul_2[cse_var_1] = ((uint8*)matmul_2[cse_var_1] + ((uint8*)A_2[((i*64) + k)]*(uint8*)B_2[((k*128) + j)])) } } } } Autotuned script: @main = primfn(A_1: handle, B_1: handle, matmul_1: handle) -> () attr = {"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True} buffers = {matmul: Buffer(matmul_2: Pointer(float32), float32, [2048, 2048], []), B: Buffer(B_2: Pointer(float32), float32, [2048, 2048], []), A: Buffer(A_2: Pointer(float32), float32, [2048, 2048], [])} buffer_map = {A_1: A, B_1: B, matmul_1: matmul} { allocate(auto_scheduler_layout_transform: Pointer(global float32), float32, [4194304]), storage_scope = global { for (ax0.ax1.fused.ax2.fused: int32, 0, 32) "parallel" { for (ax4: int32, 0, 64) { for (ax6: int32, 0, 32) { for (ax7: int32, 0, 64) { auto_scheduler_layout_transform[((((ax0.ax1.fused.ax2.fused*131072) + (ax4*2048)) + (ax6*64)) + ax7)] = (float32*)B_2[((((ax4*65536) + (ax6*2048)) + (ax0.ax1.fused.ax2.fused*64)) + ax7)] } } } } for (i.outer.outer.j.outer.outer.fused.i.outer.inner.fused: int32, 0, 512) "parallel" { allocate(matmul.local: Pointer(local float32), float32, [2048]), storage_scope = local; for (j.outer.inner: int32, 0, 4) { for (i.c.outer.inner.init: int32, 0, 32) { for (j.c.inner.init: int32, 0, 64) { matmul.local[((i.c.outer.inner.init*64) + j.c.inner.init)] = 0f32 } } for (k.outer: int32, 0, 64) { for (i.c.outer.inner: int32, 0, 32) { for (k.inner: int32, 0, 32) { for (j.c.inner: int32, 0, 64) { let cse_var_1: int32 = ((i.c.outer.inner*64) + j.c.inner) matmul.local[cse_var_1] = ((float32*)matmul.local[cse_var_1] + ((float32*)A_2[(((((floordiv(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256)*2097152) + (floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 32)*65536)) + (i.c.outer.inner*2048)) + (k.outer*32)) + k.inner)]*(float32*)auto_scheduler_layout_transform[(((((floordiv(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256), 32)*524288) + (j.outer.inner*131072)) + (k.outer*2048)) + (k.inner*64)) + j.c.inner)])) } } } } for (i.inner: int32, 0, 32) { for (j.inner: int32, 0, 64) { matmul_2[((((((floordiv(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256)*2097152) + (floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 32)*65536)) + (i.inner*2048)) + (floordiv(floormod(i.outer.outer.j.outer.outer.fused.i.outer.inner.fused, 256), 32)*256)) + (j.outer.inner*64)) + j.inner)] = (float32*)matmul.local[((i.inner*64) + j.inner)] } } } } } } --- [Visit Topic](https://discuss.tvm.apache.org/t/understanding-lower-tir/12977/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/2b6fb096b94de45ccfac86869274cd7d9dbd522ef913b3649dea3dafea3f8dad).