# Problem Statement Existing cuda "[scatter_nd](https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/scatter.py#L726)" op (which written with TIR) has 2 problems, which block I from deploying it to real-world GPU devices: 1. There is an integer overflow bug in it's TIR implementation, for which I proposed a PR to fix this problem: https://github.com/apache/tvm/pull/8415 2. It has a relatively very poor performance on GPU. In my case, it's almost **1000x** slower than my naive hand written CUDA implementation.
# Code to Reproduce: <pre> import tvm import numpy as np import tvm.relay as relay dev = tvm.cuda() target = tvm.target.Target("cuda") # input data: data_np = np.zeros((32, 128, 128, 256)).astype(np.float32) indices_np = np.random.uniform(1,5,(32, 600, 3)).astype(np.int64) updates_np = np.random.rand(32, 600, 256).astype(np.float32) # Construct relay input nodes: data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype)) indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype)) updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype)) # Compute indices: indices_dim = len(indices_np.shape) axes = list(range(indices_dim)) indices_t = relay.transpose(indices, axes[-1:] + axes[:-1]) # Construct relay scatter_nd op: out = relay.op.scatter_nd(data, indices_t, updates, "update") func = relay.Function([data, indices, updates], out) # Execute scatter_nd: intrp = relay.create_executor("debug", device=dev, target=target) op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) </pre> # Result The script above takes 2.89 s on Nvidia T4. In comparison, I wrote a very naive cuda implementation, which takes only 2.27 ms. # Root cause The algorithm of scatter_nd is consisting of 2 stages: 1. Init the output tensor, make all it's element 0; 2. Update part of the output tensor to given values; Since the above output tensor and update tensor have different shapes, they actually require different thread/block number to achieve best performance on GPU. Thus there comes to 2 approaches to implement this op: 1. Implement it with 2 cuda kernels, 1 for init and 1 for update, and let them have different thread/block configuration to achieve best performance; 2. Implement it with 1 cuda kernel, and let it do init and update simultaneously; There is a trade-off here, use approach 1 to achieve best hardware utility or approach 2 to get rid of some kernel launch time cost. Apparently, existing scatter_nd op takes the latter one, implement a single-kernel with TIR. Unfortunately, approach 2 has a relatively very poor performance in real world cases, we can check the cuda code below to see details:  Input tensor: 32 * 128 * 128 * 256 Updates tensor: 32 * 600 * 256 CUDA kernel config: grid (1, 1, 1), block(256, 1, 1), every thread has to do 32 * 128 * 128 elements init and 32 * 600 elements update. Now we can see clearly where the problem is, the **block size 256** is way to small to fully utilize GPU SM. That's exactly the reason why it is so slow when running on GPU. # My questions 1. How to address this scatter_nd performance problem? Maybe by adopting the 2 kernels implementation approach? I don't know. 2. What is the long time plan for these ops implement with TIR, including scatter_nd? I noticed there is a new feature AutoTIR is going on, will it be able to solve such kind of problems? --- [Visit Topic](https://discuss.tvm.apache.org/t/topi-cuda-scatter-nd-has-a-very-poor-performance-on-cuda-backend-1000x-slower-than-hand-written-cuda-code/10426/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/fad8961028d2d0b4d5c0284a0bb284e6878ffeee1c589278e0c1ad6b4891bc27).