# Problem Statement
Existing cuda 
"[scatter_nd](https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/scatter.py#L726)"
 op (which written with TIR) has 2 problems, which block I from deploying it to 
real-world GPU devices:
1. There is an integer overflow bug in it's TIR implementation, for which I 
proposed a PR to fix this problem: https://github.com/apache/tvm/pull/8415
2. It has a relatively very poor performance on GPU. In my case, it's almost 
**1000x** slower than my naive hand written CUDA implementation.

# Code to Reproduce:
<pre>
import tvm
import numpy as np
import tvm.relay as relay

dev = tvm.cuda()
target = tvm.target.Target("cuda")

# input data:
data_np = np.zeros((32, 128, 128, 256)).astype(np.float32)
indices_np = np.random.uniform(1,5,(32, 600, 3)).astype(np.int64)
updates_np = np.random.rand(32, 600, 256).astype(np.float32)

# Construct relay input nodes:
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", shape=indices_np.shape, 
dtype=str(indices_np.dtype))
updates = relay.var("updates", shape=updates_np.shape, 
dtype=str(updates_np.dtype))

# Compute indices:
indices_dim = len(indices_np.shape)
axes = list(range(indices_dim))
indices_t = relay.transpose(indices, axes[-1:] + axes[:-1])

# Construct relay scatter_nd op:
out = relay.op.scatter_nd(data, indices_t, updates, "update")
func = relay.Function([data, indices, updates], out)

# Execute scatter_nd:
intrp = relay.create_executor("debug", device=dev, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)
</pre>

# Result
The script above takes 2.89 s on Nvidia T4. In comparison, I wrote a very naive 
cuda implementation, which takes only 2.27 ms.

# Root cause
The algorithm of scatter_nd is consisting of 2 stages:
1. Init the output tensor, make all it's element 0;
2. Update part of the output tensor to given values;

Since the above output tensor and update tensor have different shapes, they 
actually require different thread/block number to achieve best performance on 
GPU. Thus there comes to 2 approaches to implement this op:
1. Implement it with 2 cuda kernels, 1 for init and 1 for update, and let them 
have different thread/block configuration to achieve best performance;
2. Implement it with 1 cuda kernel, and let it do init and update 
simultaneously;

There is a trade-off here, use approach 1 to achieve best hardware utility or 
approach 2 to get rid of some kernel launch time cost. Apparently, existing 
scatter_nd op takes the latter one, implement a single-kernel with TIR.

Unfortunately, approach 2 has a relatively very poor performance in real world 
cases, we can check the cuda code below to see details:
![WeChatWorkScreenshot_3adede05-09c6-4f95-8ea2-822dad79813e|690x104](upload://dH6KQLVNpRSq9cC335p1TYixc4r.png)
 
Input tensor: 32 * 128 * 128 * 256
Updates tensor: 32 * 600 * 256
CUDA kernel config: grid (1, 1, 1), block(256, 1, 1), every thread has to do 32 
* 128 * 128 elements init  and 32 * 600 elements update.

Now we can see clearly where the problem is, the **block size 256** is way to 
small to fully utilize GPU SM. That's exactly the reason why it is so slow when 
running on GPU.

# My questions
1. How to address this scatter_nd performance problem? Maybe by adopting the 2 
kernels implementation approach? I don't know.
2. What is the long time plan for these ops implement with TIR, including 
scatter_nd? I noticed there is a new feature AutoTIR is going on, will it be 
able to solve such kind of problems?





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/topi-cuda-scatter-nd-has-a-very-poor-performance-on-cuda-backend-1000x-slower-than-hand-written-cuda-code/10426/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/fad8961028d2d0b4d5c0284a0bb284e6878ffeee1c589278e0c1ad6b4891bc27).

Reply via email to