I was trying to define argmax using **comm_reducer**, and I couldn't find any way to define a comm_reducer that **just computes the index of the maximum value**, and does not compute the value as well.
For example, Taking the comm_reducer for argmax as defined in [`tests/python/integration/test_reduce.py`](https://github.com/apache/tvm/blob/main/tests/python/integration/test_reduce.py#L319): def fcombine(x, y): lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0]) rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1]) return lhs, rhs def fidentity(t0, t1): return tvm.tir.const(-1, t0), tvm.te.min_value(t1) argmax = te.comm_reducer(fcombine,fidentity, name="argmax") m = te.size_var("m") n = te.size_var("n") idx = te.placeholder((m, n), name="idx", dtype="int32") val = te.placeholder((m, n), name="val", dtype="float32") k = te.reduce_axis((0, n), "k") T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name="T") s = te.create_schedule(T0.op) irm = tvm.lower(s, [T0, T1, idx, val]) This generates the below IR primfn(T_2: handle, T_3: handle, idx_1: handle, val_1: handle) -> () attr = {"global_symbol": "main", "tir.noalias": True} buffers = {val: Buffer(val_2: Pointer(float32), float32, [m: int32, n: int32], [stride: int32, stride_1: int32], type=auto_broadcast), idx: Buffer(idx_2: Pointer(int32), int32, [m, n], [stride_2: int32, stride_3: int32], type=auto_broadcast), T_1: Buffer(T_4: Pointer(float32), float32, [m], [stride_4: int32], type=auto_broadcast), T: Buffer(T_5: Pointer(int32), int32, [m], [stride_5: int32], type=auto_broadcast)} buffer_map = {T_2: T, T_3: T_1, idx_1: idx, val_1: val} { for (i: int32, 0, m) { T_5[(i*stride_5)] = -1 T_4[(i*stride_4)] = -3.40282e+38f32 for (k: int32, 0, n) { T_5[(i*stride_5)] = @tir.if_then_else(((float32*)T_4[(i*stride_4)] < (float32*)val_2[((i*stride) + (k*stride_1))]), (int32*)idx_2[((i*stride_2) + (k*stride_3))], (int32*)T_5[(i*stride_5)], dtype=int32) T_4[(i*stride_4)] = max((float32*)T_4[(i*stride_4)], (float32*)val_2[((i*stride) + (k*stride_1))]) } } } Since the argmax function has to take both the index and the value, the output of comm_reducer is also expected to be 2 Tensors, which is the max_index (`T_4` in the above IR) and the max_value (`T_5` in the above IR) I see that even though the documentation of [argmax in topi](https://tvm.apache.org/docs/api/python/topi.html?highlight=argmax) mentions that it only returns the indices of max values along the axis, the IR generated for that also computes the value and then ignores it. What I wanted to ask was whether it is possible to define a comm_reducer for Argmax that does not compute the maximum value, and thus both the memory allocated for max value and the extra computation can be avoided. Thanks in advance --- [Visit Topic](https://discuss.tvm.apache.org/t/is-it-possible-to-compute-only-index-when-defining-argmax/10429/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/8e9d10d2637213eda62eaf0f8049705b2933f1455686f49bdafc7b417ae440a1).