In my schedule there are two ops. One is to calculate the result using gemm and the other is to reshape it . The function is like this: ``` for (i.outer.outer, 0, 98) { for (j.outer.outer, 0, 16) { for (ii, 0, 8) { for (jj, 0, 8) { gemm_C[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)] } } } }
for (n.oh.fused.ow.fused.outer.outer.outer, 0, 98) { for (oc.outer.outer.outer, 0, 16) { for (n.oh.fused.ow.fused.inner, 0, 8) { for (oc.inner, 0, 8) { output[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + (n.oh.fused.ow.fused.inner*128)) + (oc.outer.outer.outer*8)) + oc.inner)] = gemm_C[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + (oc.outer.outer.outer*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)] } } } } ``` I want these two operations to be in the same kernel. The `gemm_C` result needs to be stored in the shared memory. I first bind the output axis to block and thread. ``` for (i, 0, 98) { for (j, 0, 16) { for (ii, 0, 8) { for (jj, 0, 8) { gemm_C[((((i*1024) + (j*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[((((i*1024) + (j*64)) + (ii*8)) + jj)] } } } } // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98 // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16 // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1 // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1 for (n.oh.fused.ow.fused.inner, 0, 8) { for (oc.inner, 0, 8) { output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + (blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*1024) + (blockIdx.y*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)] } } ``` And then I try to set the scope for `gemm_C` by using `s[gemm_C].set_scope('shared')` or `compute_at()`. Both methods will give the result like: ``` for (i, 0, 98) { for (j, 0, (16 - blockIdx.y)) { for (ii, 0, 8) { for (jj, 0, 8) { if (likely(((j + blockIdx.y) < 16))) { gemm_C[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)] = gemm_C.wmma.accumulator[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)] } } } } } // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98 // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16 // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1 // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1 for (n.oh.fused.ow.fused.inner, 0, 8) { for (oc.inner, 0, 8) { output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + (blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*(16 - blockIdx.y))*64) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)] } } ``` The `j-axis` of `gemm_C` is inferred to be `(j, 0, (16-blockIdx.y)`. I can't bind this axis to `block_y` because of this weird inference. Am I doing the correct things to achieve my goal? What are the possible reasons to cause `iter_var` to be inferred like this? How should I solve this problem? --- [Visit Topic](https://discuss.tvm.ai/t/gpu-thread-binding-and-iter-var-infer/6598/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.ai/email/unsubscribe/4bbe09530afc1a2a2e33ee7042cfb895cd2c74a06b83dec420037d0c8359c995).