Asuka0630 opened a new pull request, #18528:
URL: https://github.com/apache/tvm/pull/18528

   Dear reviewers
   
   **Why**
   When forcing the use of MMA with MultiLevelTilingTensorCore or directly 
applying tensorization via the script below, the required shared memory size is 
significantly overestimated compared to the actual usage, at the same time, the 
accumulated result of mma is also incorrect. This issue stems from two root 
causes:
   
   1. In `MmaToGlobal::Rewrite`, an extra threadIdx.x dimension is introduced 
when calling InsertCacheStage, which confuses the memory analysis and leads to 
inflated shared memory estimates.
   2. In `get_mma_sync_intrin`, the offset computation for fragment C in 
get_index_C is incorrect, resulting in erroneous accumulation results.
   
   This PR addresses both issues to ensure accurate shared memory estimation 
and correct tensor core accumulation behavior.
   ``` python
   import tvm
   import numpy as np
   from tvm.script import tir as T
   from tvm.tir.schedule import Schedule
   import tvm.tir.tensor_intrin  # pylint: disable=unused-import
   import tvm.testing
   import torch
   
   import pytest
   
   M, N, K = 4096, 4096, 4096
   np.random.seed(0)
   
   
   @tvm.script.ir_module
   class Gemm_F16F16F16:
       # fmt: off
       @T.prim_func
       def main(
           A: T.Buffer((M, K), "float16"),  # type: ignore
           B: T.Buffer((K, N), "float16"),  # type: ignore
           C: T.Buffer((M, N), "float16"),  # type: ignore
       ):
           for i, j, k in T.grid(M, N, K):
               with T.block("C"):
                   vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                   with T.init():
                       C[vi, vj] = T.float32(0)
                   C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
   
   
   @tvm.script.ir_module
   class Gemm_F16F16F32:
       # fmt: off
       @T.prim_func
       def main(
           A: T.Buffer((M, K), "float16"),  # type: ignore
           B: T.Buffer((K, N), "float16"),  # type: ignore
           C: T.Buffer((M, N), "float32"),  # type: ignore
       ):
           for i, j, k in T.grid(M, N, K):
               with T.block("C"):
                   vi, vj, vk = T.axis.remap("SSR", [i, j, k])
                   with T.init():
                       C[vi, vj] = T.float32(0)
                   C[vi, vj] = C[vi, vj] + T.cast(A[vi, vk], "float32") * 
T.cast(B[vk, vj], "float32")
   
   
   def test_run_target(mod=None, tgt_str=None, in_dtype="float16", 
out_dtype="float16"):
       if mod is None:
           return
       tgt_str = tgt_str or "cuda"
       target = tvm.target.Target(target=tgt_str)
       with tvm.transform.PassContext(opt_level=3):
           # lib: tvm.runtime.Module = tvm.build(mod, target=target)
           lib: tvm.runtime.Module = tvm.compile(mod, target=target)
   
       dev = tvm.device(tgt_str, 0)
       a_np = np.random.rand(M, K).astype(in_dtype)
       b_np = np.random.rand(K, N).astype(in_dtype)
       c_np = np.ones((M, N), dtype=out_dtype)
       a = tvm.runtime.tensor(a_np, dev)
       b = tvm.runtime.tensor(b_np, dev)
       c = tvm.runtime.tensor(c_np, dev)
   
       f = lib["main"]
       f(a, b, c)
   
       c_th = torch.matmul(
           torch.tensor(a_np).to(tgt_str), torch.tensor(b_np).to(tgt_str)
       ).to(torch.float32 if out_dtype == "float32" else torch.float16)
       c_f = torch.tensor(c.numpy()).to(tgt_str)
       print(torch.allclose(c_th, c_f, rtol=0.05, atol=0.05))
   
   
   @tvm.testing.requires_cuda
   def test_f16f16f16_mma_gemm():
       # fmt: off
       mod = Gemm_F16F16F16
       sch = Schedule(mod)
       b0 = sch.get_block(name="C", func_name="main")
       b1 = sch.get_block(name="root", func_name="main")
       sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", 
ann_val="SSSRRSRS")
       b2 = sch.reindex(block=b0, buffer=("write", 0))
       b3 = sch.reindex(block=b0, buffer=("read", 0))
       b4 = sch.reindex(block=b0, buffer=("read", 1))
       sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda vi, 
vk: (vi, vk,), pad_value=None, assume_injective_transform=True)
       sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda vj, 
vk: (vk, vj,), pad_value=None, assume_injective_transform=True)
       sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda vi, 
vj: (vi, vj,), pad_value=None, assume_injective_transform=True)
       sch.transform_block_layout(block=b2, index_map=lambda vi, vj: (vi, vj,))
       sch.transform_block_layout(block=b3, index_map=lambda vi, vk: (vi, vk,))
       sch.transform_block_layout(block=b4, index_map=lambda vj, vk: (vk, vj,))
       sch.transform_block_layout(block=b0, index_map=lambda vi, vj, vk: (vi, 
vj, vk,))
       l5, l6, l7 = sch.get_loops(block=b0)
       l8, l9 = sch.split(loop=l7, factors=[None, 8], preserve_unit_iters=True, 
disable_predication=False)
       l10, l11 = sch.split(loop=l6, factors=[None, 8], 
preserve_unit_iters=True, disable_predication=False)
       l12, l13 = sch.split(loop=l5, factors=[None, 16], 
preserve_unit_iters=True, disable_predication=False)
       l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
       sch.reorder(l16, l18, l13, l11, l9)
       b20 = sch.blockize(target=l13, preserve_unit_iters=True)
       sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_sync_m16n8k8_f16f16f16")
       sch.annotate(block_or_loop=b20, 
ann_key="meta_schedule.auto_tensorize_init", ann_val="mma_init_m16n8k8_f16")
       sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
       l21, l22, l23 = sch.get_loops(block=b20)
       v24, v25, v26, v27, v28 = sch.sample_partitioned_tile(loop=l21, n=5, 
partition_pos=3, innerpart_factor=2, decision=[2, 16, 4, 1, 2])
       l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, 
v27, v28], preserve_unit_iters=True, disable_predication=False)
       v34, v35, v36, v37, v38 = sch.sample_partitioned_tile(loop=l22, n=5, 
partition_pos=3, innerpart_factor=4, decision=[2, 16, 4, 1, 4])
       l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, 
v37, v38], preserve_unit_iters=True, disable_predication=False)
       v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, 
max_innermost_factor=4, decision=[128, 1, 4])
       l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], 
preserve_unit_iters=True, disable_predication=False)
       sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, 
l43)
       l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
       sch.bind(loop=l50, thread_axis="blockIdx.y")
       l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
       sch.bind(loop=l51, thread_axis="blockIdx.x")
       l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
       sch.bind(loop=l52, thread_axis="threadIdx.y")
       sch.annotate(block_or_loop=b20, 
ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)
       sch.annotate(block_or_loop=b20, 
ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)
       b53 = sch.write_at(loop=l52, block=b20, write_buffer_index=0, 
storage_scope="m16n8k8.matrixC")
       sch.reverse_compute_inline(block=b2)
       b54 = sch.read_at(loop=l47, block=b20, read_buffer_index=0, 
storage_scope="shared.dyn")
       sch.annotate(block_or_loop=b54, ann_key="permuted_layout", 
ann_val="g2s_A")
       b55 = sch.read_at(loop=l47, block=b20, read_buffer_index=1, 
storage_scope="shared.dyn")
       sch.annotate(block_or_loop=b55, ann_key="permuted_layout", 
ann_val="g2s_B")
       b56 = sch.cache_read(block=b20, read_buffer_index=0, 
storage_scope="m16n8k8.matrixA")
       sch.compute_at(block=b56, loop=l48, preserve_unit_loops=True, index=-1)
       l57, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b56)
       l64, l65 = sch.split(loop=l63, factors=[None, 8], 
preserve_unit_iters=True, disable_predication=False)
       l66, l67 = sch.split(loop=l62, factors=[None, 32], 
preserve_unit_iters=True, disable_predication=False)
       l68, l69, l70, l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b56)
       sch.reorder(l75, l67, l65)
       b77 = sch.blockize(target=l67, preserve_unit_iters=True)
       sch.annotate(block_or_loop=b77, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_load_m16n8k8_f16_A_shared_dyn")
       sch.annotate(block_or_loop=b77, ann_key="permuted_layout", 
ann_val="s2l_A")
       b78 = sch.cache_read(block=b20, read_buffer_index=1, 
storage_scope="m16n8k8.matrixB")
       sch.compute_at(block=b78, loop=l48, preserve_unit_loops=True, index=-1)
       l79, l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b78)
       l86, l87 = sch.split(loop=l85, factors=[None, 32], 
preserve_unit_iters=True, disable_predication=False)
       l88, l89 = sch.split(loop=l84, factors=[None, 8], 
preserve_unit_iters=True, disable_predication=False)
       l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b78)
       sch.reorder(l97, l89, l87)
       b99 = sch.blockize(target=l89, preserve_unit_iters=True)
       sch.annotate(block_or_loop=b99, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_load_m16n8k8_f16_B_shared_dyn")
       sch.annotate(block_or_loop=b99, ann_key="permuted_layout", 
ann_val="s2l_B")
       b100, = sch.get_producers(block=b54)
       sch.compute_inline(block=b100)
       sch.storage_align(block=b54, buffer_index=0, axis=-2, factor=32, 
offset=8)
       b101, = sch.get_producers(block=b55)
       sch.compute_inline(block=b101)
       sch.storage_align(block=b55, buffer_index=0, axis=-2, factor=32, 
offset=8)
       sch.annotate(block_or_loop=b54, ann_key="vector_bytes", ann_val=16)
       sch.annotate(block_or_loop=b55, ann_key="vector_bytes", ann_val=16)
       sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", 
ann_val=[0, 0, 1])
       sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", 
ann_val=[0, 1, 2])
       sch.annotate(block_or_loop=l47, 
ann_key="software_pipeline_async_stages", ann_val=[0])
       sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", 
ann_val=[0, 0, 1, 2, 2])
       sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", 
ann_val=[0, 1, 3, 2, 4])
       v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], 
probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 
0.20000000000000001, 0.20000000000000001], decision=0)
       sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", 
ann_val=v102)
       sch.enter_postproc()
       b103 = sch.get_block(name="root", func_name="main")
       sch.unannotate(block_or_loop=b103, 
ann_key="meta_schedule.unroll_explicit")
       b104, b105, b106, b107, b108, b109 = sch.get_child_blocks(b103)
       l110, l111, l112, l113 = sch.get_loops(block=b104)
       l114, l115, l116, l117 = sch.get_loops(block=b105)
       l118, l119, l120, l121, l122, l123, l124 = sch.get_loops(block=b106)
       l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b107)
       l132, l133, l134, l135, l136, l137, l138, l139, l140, l141 = 
sch.get_loops(block=b108)
       l142, l143, l144 = sch.get_loops(block=b109)
       b145 = sch.get_block(name="C_o", func_name="main")
       l146, l147, l148, l149, l150, l151, l152, l153, l154, l155 = 
sch.get_loops(block=b145)
       b156 = sch.decompose_reduction(block=b145, loop=l149)
       sch.unannotate(block_or_loop=b156, 
ann_key="meta_schedule.auto_tensorize")
       sch.annotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_init_m16n8k8_f16")
       sch.unannotate(block_or_loop=b145, 
ann_key="meta_schedule.auto_tensorize_init")
       sch.unannotate(block_or_loop=b156, 
ann_key="meta_schedule.auto_tensorize_init")
       b157 = sch.get_block(name="C_o_init", func_name="main")
       sch.unannotate(block_or_loop=b157, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b157, tensor_intrin="mma_init_m16n8k8_f16", 
preserve_unit_iters=True)
       b158 = sch.get_block(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", 
func_name="main")
       sch.unannotate(block_or_loop=b158, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b158, 
tensor_intrin="mma_load_m16n8k8_f16_A_shared_dyn", preserve_unit_iters=True)
       b159 = sch.get_block(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", 
func_name="main")
       sch.unannotate(block_or_loop=b159, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b159, 
tensor_intrin="mma_load_m16n8k8_f16_B_shared_dyn", preserve_unit_iters=True)
       b160 = sch.get_block(name="C_o_update", func_name="main")
       sch.unannotate(block_or_loop=b160, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b160, 
tensor_intrin="mma_sync_m16n8k8_f16f16f16", preserve_unit_iters=True)
       mod = sch.mod
       test_run_target(mod)
   
   
   @tvm.testing.requires_cuda
   def test_f16f16f32_mma_gemm():
       mod = Gemm_F16F16F32
       sch = Schedule(mod)
       # fmt: off
       sch = Schedule(mod)
       b0 = sch.get_block(name="C", func_name="main")
       b1 = sch.get_block(name="root", func_name="main")
       sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", 
ann_val="SSSRRSRS")
       b2 = sch.reindex(block=b0, buffer=("write", 0))
       b3 = sch.reindex(block=b0, buffer=("read", 0))
       b4 = sch.reindex(block=b0, buffer=("read", 1))
       sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda vi, 
vk: (vi, vk,), pad_value=None, assume_injective_transform=True)
       sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda vj, 
vk: (vk, vj,), pad_value=None, assume_injective_transform=True)
       sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda vi, 
vj: (vi, vj,), pad_value=None, assume_injective_transform=True)
       sch.transform_block_layout(block=b2, index_map=lambda vi, vj: (vi, vj,))
       sch.transform_block_layout(block=b3, index_map=lambda vi, vk: (vi, vk,))
       sch.transform_block_layout(block=b4, index_map=lambda vj, vk: (vk, vj,))
       sch.transform_block_layout(block=b0, index_map=lambda vi, vj, vk: (vi, 
vj, vk,))
       l5, l6, l7 = sch.get_loops(block=b0)
       l8, l9 = sch.split(loop=l7, factors=[None, 8], preserve_unit_iters=True, 
disable_predication=False)
       l10, l11 = sch.split(loop=l6, factors=[None, 8], 
preserve_unit_iters=True, disable_predication=False)
       l12, l13 = sch.split(loop=l5, factors=[None, 16], 
preserve_unit_iters=True, disable_predication=False)
       l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0)
       sch.reorder(l16, l18, l13, l11, l9)
       b20 = sch.blockize(target=l13, preserve_unit_iters=True)
       sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_sync_m16n8k8_f16f16f32")
       sch.annotate(block_or_loop=b20, 
ann_key="meta_schedule.auto_tensorize_init", ann_val="mma_init_m16n8k8_f32")
       sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1)
       l21, l22, l23 = sch.get_loops(block=b20)
       v24, v25, v26, v27, v28 = sch.sample_partitioned_tile(loop=l21, n=5, 
partition_pos=3, innerpart_factor=2, decision=[1, 16, 2, 2, 4])
       l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, 
v27, v28], preserve_unit_iters=True, disable_predication=False)
       v34, v35, v36, v37, v38 = sch.sample_partitioned_tile(loop=l22, n=5, 
partition_pos=3, innerpart_factor=4, decision=[2, 16, 2, 4, 2])
       l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, 
v37, v38], preserve_unit_iters=True, disable_predication=False)
       v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, 
max_innermost_factor=4, decision=[128, 1, 4])
       l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], 
preserve_unit_iters=True, disable_predication=False)
       sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, 
l43)
       l50 = sch.fuse(l29, l39, preserve_unit_iters=True)
       sch.bind(loop=l50, thread_axis="blockIdx.y")
       l51 = sch.fuse(l30, l40, preserve_unit_iters=True)
       sch.bind(loop=l51, thread_axis="blockIdx.x")
       l52 = sch.fuse(l31, l41, preserve_unit_iters=True)
       sch.bind(loop=l52, thread_axis="threadIdx.y")
       sch.annotate(block_or_loop=b20, 
ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)
       sch.annotate(block_or_loop=b20, 
ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)
       b53 = sch.write_at(loop=l52, block=b20, write_buffer_index=0, 
storage_scope="m16n8k8.matrixC")
       sch.reverse_compute_inline(block=b2)
       b54 = sch.read_at(loop=l47, block=b20, read_buffer_index=0, 
storage_scope="shared.dyn")
       sch.annotate(block_or_loop=b54, ann_key="permuted_layout", 
ann_val="g2s_A")
       b55 = sch.read_at(loop=l47, block=b20, read_buffer_index=1, 
storage_scope="shared.dyn")
       sch.annotate(block_or_loop=b55, ann_key="permuted_layout", 
ann_val="g2s_B")
       b56 = sch.cache_read(block=b20, read_buffer_index=0, 
storage_scope="m16n8k8.matrixA")
       sch.compute_at(block=b56, loop=l48, preserve_unit_loops=True, index=-1)
       l57, l58, l59, l60, l61, l62, l63 = sch.get_loops(block=b56)
       l64, l65 = sch.split(loop=l63, factors=[None, 8], 
preserve_unit_iters=True, disable_predication=False)
       l66, l67 = sch.split(loop=l62, factors=[None, 32], 
preserve_unit_iters=True, disable_predication=False)
       l68, l69, l70, l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b56)
       sch.reorder(l75, l67, l65)
       b77 = sch.blockize(target=l67, preserve_unit_iters=True)
       sch.annotate(block_or_loop=b77, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_load_m16n8k8_f16_A_shared_dyn")
       sch.annotate(block_or_loop=b77, ann_key="permuted_layout", 
ann_val="s2l_A")
       b78 = sch.cache_read(block=b20, read_buffer_index=1, 
storage_scope="m16n8k8.matrixB")
       sch.compute_at(block=b78, loop=l48, preserve_unit_loops=True, index=-1)
       l79, l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b78)
       l86, l87 = sch.split(loop=l85, factors=[None, 32], 
preserve_unit_iters=True, disable_predication=False)
       l88, l89 = sch.split(loop=l84, factors=[None, 8], 
preserve_unit_iters=True, disable_predication=False)
       l90, l91, l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b78)
       sch.reorder(l97, l89, l87)
       b99 = sch.blockize(target=l89, preserve_unit_iters=True)
       sch.annotate(block_or_loop=b99, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_load_m16n8k8_f16_B_shared_dyn")
       sch.annotate(block_or_loop=b99, ann_key="permuted_layout", 
ann_val="s2l_B")
       b100, = sch.get_producers(block=b54)
       sch.compute_inline(block=b100)
       sch.storage_align(block=b54, buffer_index=0, axis=-2, factor=32, 
offset=8)
       b101, = sch.get_producers(block=b55)
       sch.compute_inline(block=b101)
       sch.storage_align(block=b55, buffer_index=0, axis=-2, factor=32, 
offset=8)
       sch.annotate(block_or_loop=b54, ann_key="vector_bytes", ann_val=16)
       sch.annotate(block_or_loop=b55, ann_key="vector_bytes", ann_val=16)
       sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", 
ann_val=[0, 0, 1])
       sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", 
ann_val=[0, 1, 2])
       sch.annotate(block_or_loop=l47, 
ann_key="software_pipeline_async_stages", ann_val=[0])
       sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", 
ann_val=[0, 0, 1, 2, 2])
       sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", 
ann_val=[0, 1, 3, 2, 4])
       v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], 
probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 
0.20000000000000001, 0.20000000000000001], decision=0)
       sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", 
ann_val=v102)
       sch.enter_postproc()
       b103 = sch.get_block(name="root", func_name="main")
       sch.unannotate(block_or_loop=b103, 
ann_key="meta_schedule.unroll_explicit")
       b104, b105, b106, b107, b108, b109 = sch.get_child_blocks(b103)
       l110, l111, l112, l113 = sch.get_loops(block=b104)
       l114, l115, l116, l117 = sch.get_loops(block=b105)
       l118, l119, l120, l121, l122, l123, l124 = sch.get_loops(block=b106)
       l125, l126, l127, l128, l129, l130, l131 = sch.get_loops(block=b107)
       l132, l133, l134, l135, l136, l137, l138, l139, l140, l141 = 
sch.get_loops(block=b108)
       sch.annotate(block_or_loop=l132, ann_key="pragma_auto_unroll_max_step", 
ann_val=0)
       sch.annotate(block_or_loop=l132, ann_key="pragma_unroll_explicit", 
ann_val=1)
       l142, l143, l144 = sch.get_loops(block=b109)
       b145 = sch.get_block(name="C_o", func_name="main")
       l146, l147, l148, l149, l150, l151, l152, l153, l154, l155 = 
sch.get_loops(block=b145)
       b156 = sch.decompose_reduction(block=b145, loop=l149)
       sch.unannotate(block_or_loop=b156, 
ann_key="meta_schedule.auto_tensorize")
       sch.annotate(block_or_loop=b156, ann_key="meta_schedule.auto_tensorize", 
ann_val="mma_init_m16n8k8_f32")
       sch.unannotate(block_or_loop=b145, 
ann_key="meta_schedule.auto_tensorize_init")
       sch.unannotate(block_or_loop=b156, 
ann_key="meta_schedule.auto_tensorize_init")
       b157 = sch.get_block(name="C_o_init", func_name="main")
       sch.unannotate(block_or_loop=b157, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b157, tensor_intrin="mma_init_m16n8k8_f32", 
preserve_unit_iters=True)
       b158 = sch.get_block(name="A_reindex_shared.dyn_m16n8k8.matrixA_o", 
func_name="main")
       sch.unannotate(block_or_loop=b158, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b158, 
tensor_intrin="mma_load_m16n8k8_f16_A_shared_dyn", preserve_unit_iters=True)
       b159 = sch.get_block(name="B_reindex_shared.dyn_m16n8k8.matrixB_o", 
func_name="main")
       sch.unannotate(block_or_loop=b159, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b159, 
tensor_intrin="mma_load_m16n8k8_f16_B_shared_dyn", preserve_unit_iters=True)
       b160 = sch.get_block(name="C_o_update", func_name="main")
       sch.unannotate(block_or_loop=b160, 
ann_key="meta_schedule.auto_tensorize")
       sch.tensorize(block_or_loop=b160, 
tensor_intrin="mma_sync_m16n8k8_f16f16f32", preserve_unit_iters=True)
       mod = sch.mod
       test_run_target(mod, out_dtype="float32")
   
   
   if __name__ == """__main__""":
       test_f16f16f16_mma_gemm()
       test_f16f16f32_mma_gemm()
   ```
   
   **How**
   This PR includes the following fixes:
   
   1. Skip the threadIdx.x dimension in `InsertCacheStage` when it is not 
required, to prevent spurious shared memory overestimation and store repeatedly.
   2. Correct the offset calculation for fragment C in `get_index_C` to ensure 
accurate accumulation results during tensor core execution.
   
   **Result**
   The above script produces results that match those of PyTorch.
   
   ** Env **
   NVIDIA A100-SXM4-80GB
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to