gemini-code-assist[bot] commented on code in PR #18626:
URL: https://github.com/apache/tvm/pull/18626#discussion_r2654854704


##########
python/tvm/relax/transform/legalize_ops/statistical.py:
##########
@@ -53,6 +53,40 @@ def _te_variance(x: te.Tensor, axis: List[tir.IntImm], 
keepdims: bool) -> te.Ten
     # return _te_mean(x * x, axis, keepdims) - mean * mean
 
 
+def _te_median(
+    x: te.Tensor, axis: List[tir.IntImm], keepdims: bool
+) -> Union[te.Tensor, Tuple[te.Tensor, te.Tensor]]:
+    # currently only supports one axis or no axis ~ same pytorch
+    # todo: support multiple axis ~ same numpy
+    shape_prod = _compute_shape_prod(x, axis)
+    mid_index = (shape_prod - 1) // 2
+
+    if axis is None or len(axis) == 0:
+        x = topi.reshape(x, [shape_prod.value])
+        ax = -1
+    else:
+        ax = axis[0].value
+    index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
+    x_sorted = topi.gather(x, axis=ax, indices=index_sorted)
+
+    new_shape = list(x.shape)
+    new_shape[ax] = 1
+    indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")
+
+    median_val = topi.gather(x_sorted, axis=ax, indices=indices)
+    median_idx = topi.gather(index_sorted, axis=ax, indices=indices)
+
+    if axis is None or len(axis) == 0:
+        return median_val if keepdims else topi.squeeze(median_val, axis=axis)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   When `axis` is `None` and `keepdims` is `True`, the output tensor should 
have the same rank as the input, with all dimensions of size 1. The current 
implementation returns a tensor of shape `(1,)` because the original rank of 
`x` is lost after it's reshaped.
   
   To fix this, we should store the original rank of `x` before any 
modifications and use it to reshape `median_val` when `keepdims` is true and 
`axis` is `None`.
   
   ```python
       orig_ndim = len(x.shape)
       shape_prod = _compute_shape_prod(x, axis)
       mid_index = (shape_prod - 1) // 2
   
       if axis is None or len(axis) == 0:
           x = topi.reshape(x, [shape_prod.value])
           ax = -1
       else:
           ax = axis[0].value
       index_sorted = topi.argsort(x, axis=ax, is_ascend=True, dtype="int64")
       x_sorted = topi.gather(x, axis=ax, indices=index_sorted)
   
       new_shape = list(x.shape)
       new_shape[ax] = 1
       indices = topi.full(new_shape, fill_value=mid_index, dtype="int64")
   
       median_val = topi.gather(x_sorted, axis=ax, indices=indices)
       median_idx = topi.gather(index_sorted, axis=ax, indices=indices)
   
       if axis is None or len(axis) == 0:
           if keepdims:
               return topi.reshape(median_val, [1] * orig_ndim)
           return topi.squeeze(median_val, axis=axis)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to