In my schedule there are two ops. One is to calculate the result using gemm and 
the other is to reshape it . The function is like this: 
```
  for (i.outer.outer, 0, 98) {
    for (j.outer.outer, 0, 16) {
      for (ii, 0, 8) {
        for (jj, 0, 8) {
          gemm_C[((((i.outer.outer*1024) + (j.outer.outer*64)) + (ii*8)) + jj)] 
= gemm_C.wmma.accumulator[((((i.outer.outer*1024) + (j.outer.outer*64)) + 
(ii*8)) + jj)]
        }
      }
    }
  }

  for (n.oh.fused.ow.fused.outer.outer.outer, 0, 98) {
    for (oc.outer.outer.outer, 0, 16) {
      for (n.oh.fused.ow.fused.inner, 0, 8) {
        for (oc.inner, 0, 8) {
          output[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + 
(n.oh.fused.ow.fused.inner*128)) + (oc.outer.outer.outer*8)) + oc.inner)] = 
gemm_C[((((n.oh.fused.ow.fused.outer.outer.outer*1024) + 
(oc.outer.outer.outer*64)) + (n.oh.fused.ow.fused.inner*8)) + oc.inner)]
        }
      }
    }
  }
```

I want these two operations to be in the same kernel. The `gemm_C` result needs 
to be stored in the shared memory. I first bind the output axis to block and 
thread.
```
 for (i, 0, 98) {
    for (j, 0, 16) {
      for (ii, 0, 8) {
        for (jj, 0, 8) {
          gemm_C[((((i*1024) + (j*64)) + (ii*8)) + jj)] = 
gemm_C.wmma.accumulator[((((i*1024) + (j*64)) + (ii*8)) + jj)]
        }
      }
    }
  }

  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  for (n.oh.fused.ow.fused.inner, 0, 8) {
    for (oc.inner, 0, 8) {
      output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + 
(blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*1024) + (blockIdx.y*64)) + 
(n.oh.fused.ow.fused.inner*8)) + oc.inner)]
    }
  }
```
And then I try to set the scope for `gemm_C` by using  
`s[gemm_C].set_scope('shared')` or `compute_at()`. Both methods will give the 
result like:

```
for (i, 0, 98) {
    for (j, 0, (16 - blockIdx.y)) {
      for (ii, 0, 8) {
        for (jj, 0, 8) {
          if (likely(((j + blockIdx.y) < 16))) {
            gemm_C[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)] = 
gemm_C.wmma.accumulator[(((((i*(16 - blockIdx.y))*64) + (j*64)) + (ii*8)) + jj)]
          }
        }
      }
    }
  }

  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 16
  // attr [iter_var(threadIdx.y, , threadIdx.y)] thread_extent = 1
  // attr [iter_var(threadIdx.z, , threadIdx.z)] thread_extent = 1
  for (n.oh.fused.ow.fused.inner, 0, 8) {
    for (oc.inner, 0, 8) {
      output[((((blockIdx.x*1024) + (n.oh.fused.ow.fused.inner*128)) + 
(blockIdx.y*8)) + oc.inner)] = gemm_C[((((blockIdx.x*(16 - blockIdx.y))*64) + 
(n.oh.fused.ow.fused.inner*8)) + oc.inner)]
    }
  }
```
The `j-axis` of `gemm_C` is inferred to be `(j, 0,  (16-blockIdx.y)`. I can't 
bind this axis to `block_y` because of this weird inference. 

Am I doing the correct things to achieve my goal? What are the possible reasons 
to cause `iter_var` to be inferred like this? How should I solve this problem?





---
[Visit 
Topic](https://discuss.tvm.ai/t/gpu-thread-binding-and-iter-var-infer/6598/1) 
to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/4bbe09530afc1a2a2e33ee7042cfb895cd2c74a06b83dec420037d0c8359c995).

Reply via email to