Hi all:

I am learning the TVM CUDA backend. I have a question about how CUDA kernel is 
launched.

Below is my simple test program:
```
import tvm
from tvm import te
import numpy as np

dtype = "float32"
# GEMM size
M=16;K=8;N=16
# declear algorithm 
k = te.reduce_axis((0, K), 'k') # loop over dimension K
A = te.placeholder((M, K), name='A')
B = te.placeholder((K, N), name='B')
C = te.compute(
           (M, N),
           lambda x, y: te.sum(A[x, k] * B[k, y], axis=k),
           name='C')
# defualt schedule 
s = te.create_schedule(C.op)
#print(tvm.lower(s, [A, B, C], simple_mode=True))
# optimized schedule : tiling
bn = 4 # Tiling size: 4, over M, and N
# outer -> inner
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
#print(tvm.lower(s, [A, B, C], simple_mode=True))
AS =  s.cache_read(A, 'shared',[C])
BS =  s.cache_read(B, 'shared',[C])
s[AS].compute_at(s[C], xo)
s[BS].compute_at(s[C], yo)
s[C].bind(xo, te.thread_axis("blockIdx.x"))
s[C].bind(yo, te.thread_axis("blockIdx.y"))
s[C].bind(xi, te.thread_axis("threadIdx.x"))
s[C].bind(yi, te.thread_axis("threadIdx.y"))
target = 'cuda'
ctx = tvm.context(target, 0)
a = tvm.nd.array(np.random.rand(M, K).astype(dtype), ctx)
b = tvm.nd.array(np.random.rand(K, N).astype(dtype), ctx)
# comput C through numpy lib
answer = np.dot(a.asnumpy(), b.asnumpy())

func = tvm.build(s, [A, B, C], target=target, name='mmult')
c = tvm.nd.array(np.zeros((M, N), dtype=dtype), ctx)
# a, b : input matrix, c : resul
func(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), answer, rtol=1e-5)
#print(func.get_source())
dev_module = func.imported_modules[0]
print(dev_module)
print("-----GPU code-----")
print(dev_module.get_source())
```

The generated CUDA code:
```
extern "C" __global__ void mmult_kernel0( float* __restrict__ A,  float* 
__restrict__ B,  float* __restrict__ C) {
  __shared__ float A_shared[32];
  __shared__ float B_shared[32];
  for (int ax0 = 0; ax0 < 4; ++ax0) {
    for (int ax1 = 0; ax1 < 8; ++ax1) {
      A_shared[(((ax0 * 8) + ax1))] = A[((((((int)blockIdx.x) * 32) + (ax0 * 
8)) + ax1))];
    }
  }
  for (int ax01 = 0; ax01 < 8; ++ax01) {
    for (int ax11 = 0; ax11 < 4; ++ax11) {
      B_shared[(((ax01 * 4) + ax11))] = B[((((ax01 * 16) + (((int)blockIdx.y) * 
4)) + ax11))];
    }
  }
  C[(((((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 16)) + 
(((int)blockIdx.y) * 4)) + ((int)threadIdx.y)))] = 0.000000e+00f;
  __syncthreads();
  for (int k = 0; k < 8; ++k) {
    C[(((((((int)blockIdx.x) * 64) + (((int)threadIdx.x) * 16)) + 
(((int)blockIdx.y) * 4)) + ((int)threadIdx.y)))] = (C[(((((((int)blockIdx.x) * 
64) + (((int)threadIdx.x) * 16)) + (((int)blockIdx.y) * 4)) + 
((int)threadIdx.y)))] + (A_shared[(((((int)threadIdx.x) * 8) + k))] * 
B_shared[(((k * 4) + ((int)threadIdx.y)))]));
  }
}
```

Which is straightforward. But what confused me is that, how this kernel 
**mmult_kernel0** is launched by host(CPU, LLVM backend). I did not see how 
blockdim and griddim is configured.
We know normally we launch a CUDA kernel from CPU by:
```
kernel<<<griddim,blockdim>>>(a,b,c)
``` 
How TVM manage this settings?
Could anyone give me some tips?
@tqchen @FrozenGene





---
[Visit 
Topic](https://discuss.tvm.ai/t/how-cuda-kernel-is-launched-in-tvm-stack/6167/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/7395d8ea959ea826f97e8334951457d098d2885cf9fd0071f5f3a4dc02c97fcf).

Reply via email to