Completely agree with these perspectives. Another observation I have is that projects developed based on TVM are often not straightforward; they typically require hacking the underlying TVM code. For example, in the Ladder project (based on Welder), we added support for MFMA and HIP code generation to TVM and introduce our own fuseops pass at cpp. In the BitBLAS project, we introduced two or three additional ugly schedules to hack certain layout operations in order to achieve better performance ([Slides](https://leiblog.wang/static/2024-09-16/BitBLAS-Slides-20240911.pdf)), for example:
https://github.com/microsoft/BitBLAS/blob/main/bitblas/gpu/matmul_mma_dequantize.py#L1376-L1390 We also realized that relying on schedules made it difficult to describe some operators, such as FlashAttention, T-MAC, and Stream-K. Therefore, we designed some syntactic sugar for TIR to use it as a Triton-like language (transforming schedules into annotations, such as Pipeline and Layout Transform), for example, triton is hard to describe dequantize gemm, but with our sugar syntax, we can dispatch the dequantize part into thread programming instead of triton-like block programming. ```python @T.prim_func def main( A: T.Buffer(A_shape, dtypeAB), B: T.Buffer(B_shape, storage_dtype), C: T.Buffer((M, N), dtypeC), ): with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): A_shared = T.alloc_shared(A_shared_shape, dtypeAB) B_shared = T.alloc_shared(B_shared_shape, storage_dtype) B_local = T.alloc_fragment([8], storage_dtype, "local") B_dequantize_local = T.alloc_fragment([16], dtypeAB, "local") B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, dtypeAB) C_local = T.alloc_fragment((block_M, block_N), accum_dtype) T.annotate_layout( { A_shared: make_swizzle_layout(A_shared), B_shared: make_swizzle_layout(B_shared), } ) # Improve L2 Cache T.use_swizzle(panel_size=10) t = T.thread_binding(0, threads, thread="threadIdx.x") T.clear(C_local) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): T.copy(A[by * block_M, k * block_K], A_shared) for i, j in T.Parallel(block_N, block_K // num_elems_per_byte): B_shared[i, j] = B[bx * block_N + i, k * block_K // num_elems_per_byte + j] for i in T.serial(block_N * block_K // num_elems_per_byte // (threads * 4)): for v in T.vectorized(0, 4): vi = (i * threads * 4 + t * 4 + v) // (block_K // num_elems_per_byte) vj = (i * threads * 4 + t * 4 + v) % (block_K // num_elems_per_byte) B_local[v] = B_shared[vi, vj] for v in T.serial(0, 8): B_dequantize_local[v] = _tir_packed_to_unsigned_convert("int", 8)( num_bits, B_local[v // 2], v % 2, dtype=dtypeAB, ) for v in T.vectorized(0, 8): vi = (i * threads * 8 + t * 8 + v) // (block_K) vj = (i * threads * 8 + t * 8 + v) % (block_K) B_dequantize_shared[vi, vj] = B_dequantize_local[v] T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) T.copy(C_local, C[by * block_M, bx * block_N]) ``` (it's awesome that we can use T.Parallel to auto map the thread binding and do vectorization that still based on the infra from tir schedule transformations, and T.Pipeline from software pipeline, annotate layout from LayoutTransformation Pass). Anyway, the issue of all these project that I involved is that these projects rely on different versions(or modifications) of TVM, and since the changes were often made as hotfixes to release quickly (some hack may be ugly and inelegant), it is difficult to merge them upstream. One idea I have is that all third-party developers should continue to maintain their own versions of TVM for development, but use a unified IR Module (TIR) and Relax as an interface. However, I encountered some problems while trying to implement this approach, such as conflicts occurring when loading DLLs across different versions of TVM, But I don't know if that's a valuable path.  --- [Visit Topic](https://discuss.tvm.apache.org/t/phasing-out-legacy-components/17703/2) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/c75a41f34afecd1ec530ff802bbc930c0b62219afa886d0a1727372c6ddf5a86).