I was trying to define argmax using **comm_reducer**, and I couldn't find any 
way to define a comm_reducer that **just computes the index of the maximum 
value**, and does not compute the value as well.

For example, Taking the comm_reducer for  argmax as defined in 
[`tests/python/integration/test_reduce.py`](https://github.com/apache/tvm/blob/main/tests/python/integration/test_reduce.py#L319):

    def fcombine(x, y):
        lhs = tvm.tir.Select((x[1] >= y[1]), x[0], y[0])
        rhs = tvm.tir.Select((x[1] >= y[1]), x[1], y[1])
        return lhs, rhs
    def fidentity(t0, t1):
        return tvm.tir.const(-1, t0),  tvm.te.min_value(t1)

    argmax = te.comm_reducer(fcombine,fidentity, name="argmax")
    m = te.size_var("m")
    n = te.size_var("n")
    idx = te.placeholder((m, n), name="idx", dtype="int32")
    val = te.placeholder((m, n), name="val", dtype="float32")
    k = te.reduce_axis((0, n), "k")
    T0, T1 = te.compute((m,), lambda i: argmax((idx[i, k], val[i, k]), axis=k), 
name="T")
    s = te.create_schedule(T0.op)

    irm = tvm.lower(s, [T0, T1, idx, val])

This generates the below IR

    primfn(T_2: handle, T_3: handle, idx_1: handle, val_1: handle) -> ()
      attr = {"global_symbol": "main", "tir.noalias": True}
      buffers = {val: Buffer(val_2: Pointer(float32), float32, [m: int32, n: 
int32], [stride: int32, stride_1: int32], type=auto_broadcast),
                 idx: Buffer(idx_2: Pointer(int32), int32, [m, n], [stride_2: 
int32, stride_3: int32], type=auto_broadcast),
                 T_1: Buffer(T_4: Pointer(float32), float32, [m], [stride_4: 
int32], type=auto_broadcast),
                 T: Buffer(T_5: Pointer(int32), int32, [m], [stride_5: int32], 
type=auto_broadcast)}
      buffer_map = {T_2: T, T_3: T_1, idx_1: idx, val_1: val} {
      for (i: int32, 0, m) {
        T_5[(i*stride_5)] = -1
        T_4[(i*stride_4)] = -3.40282e+38f32
        for (k: int32, 0, n) {
          T_5[(i*stride_5)] = @tir.if_then_else(((float32*)T_4[(i*stride_4)] < 
(float32*)val_2[((i*stride) + (k*stride_1))]), (int32*)idx_2[((i*stride_2) + 
(k*stride_3))], (int32*)T_5[(i*stride_5)], dtype=int32)
          T_4[(i*stride_4)] = max((float32*)T_4[(i*stride_4)], 
(float32*)val_2[((i*stride) + (k*stride_1))])
        }
      }
    }

Since the argmax function has to take both the index and the value, the output 
of comm_reducer is also expected to be 2 Tensors, which is the max_index (`T_4` 
in the above IR) and the max_value (`T_5` in the above IR)

I see that even though the documentation of [argmax in 
topi](https://tvm.apache.org/docs/api/python/topi.html?highlight=argmax) 
mentions that it only returns the indices of max values along the axis, the IR 
generated for that also computes the value and then ignores it.

What I wanted to ask was whether it is possible to define a comm_reducer for 
Argmax that does not compute the maximum value, and thus both the memory 
allocated for max value and the extra computation can be avoided.

Thanks in advance





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/is-it-possible-to-compute-only-index-when-defining-argmax/10429/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/8e9d10d2637213eda62eaf0f8049705b2933f1455686f49bdafc7b417ae440a1).

Reply via email to