This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 69336ac4a5 [TOPI] Fix strided_slice_with_axes to handle negative axis
values (#18917)
69336ac4a5 is described below
commit 69336ac4a5f9e244deaaad94a031417c64114b2c
Author: Masahiro Hiramori <[email protected]>
AuthorDate: Fri Mar 20 15:17:36 2026 +0900
[TOPI] Fix strided_slice_with_axes to handle negative axis values (#18917)
Negative axis values (e.g., `axes=[-1]`) in `strided_slice_with_axes`
were used directly as array indices without normalization, causing an
`IndexError` during `LegalizeOps`.
This PR normalizes negative axes to positive equivalents before passing
them to `StridedSliceCanonicalizeBegin`, `StridedSliceOutputShape`, and
the compute lambda.
---
include/tvm/topi/transform.h | 29 ++++++++++++++------
..._transform_legalize_ops_index_linear_algebra.py | 31 ++++++++++++++++++++++
2 files changed, 52 insertions(+), 8 deletions(-)
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index 93938c601d..5c3ec5986c 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -904,28 +904,41 @@ inline Tensor strided_slice_with_axes(const Tensor& x,
const ffi::Array<Integer>
std::string slice_mode = "end",
std::string name =
"T_strided_slice_with_axes",
std::string tag = kInjective) {
- const size_t src_tensor_dim = x->shape.size();
- TVM_FFI_ICHECK(axes.size() <= src_tensor_dim);
+ const int64_t src_tensor_dim = static_cast<int64_t>(x->shape.size());
+ TVM_FFI_ICHECK(static_cast<int64_t>(axes.size()) <= src_tensor_dim);
TVM_FFI_ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
axes.size() == strides.size());
+ // Normalize negative axes
+ ffi::Array<Integer> normalized_axes;
+ for (size_t i = 0; i < axes.size(); ++i) {
+ int64_t axis = axes[i].IntValue();
+ if (axis < 0) {
+ axis += src_tensor_dim;
+ }
+ TVM_FFI_ICHECK(axis >= 0 && axis < src_tensor_dim)
+ << "Axis " << axes[i].IntValue() << " is out of bounds for tensor with
" << src_tensor_dim
+ << " dimensions";
+ normalized_axes.push_back(Integer(axis));
+ }
+
std::vector<int64_t> begin_vec, end_vec, strides_vec;
std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end,
strides, slice_mode);
- auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec,
strides_vec, axes,
+ auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec,
strides_vec, normalized_axes,
begin[0]->dtype, slice_mode);
- auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec,
strides_vec, axes,
- slice_mode, begin_expr);
+ auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec,
strides_vec,
+ normalized_axes, slice_mode,
begin_expr);
return te::compute(
out_shape,
[&](const ffi::Array<tirx::Var>& indices) {
ffi::Array<PrimExpr> real_indices;
for (size_t i = 0; i < out_shape.size(); ++i)
real_indices.push_back(indices[i]);
- for (size_t i = 0; i < axes.size(); ++i) {
+ for (size_t i = 0; i < normalized_axes.size(); ++i) {
auto stride = make_const(strides[i].dtype(), strides_vec[i]);
- PrimExpr ind = indices[axes[i].IntValue()] * stride + begin_expr[i];
- real_indices.Set(axes[i].IntValue(), ind);
+ PrimExpr ind = indices[normalized_axes[i].IntValue()] * stride +
begin_expr[i];
+ real_indices.Set(normalized_axes[i].IntValue(), ind);
}
return x(real_indices);
},
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index 1be7a39781..b8dbe1934b 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -259,6 +259,37 @@ def test_strided_slice_no_strides():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_strided_slice_negative_axes():
+ # fmt: off
+ @tvm.script.ir_module
+ class StridedSlice:
+ @R.function
+ def main(x: R.Tensor((8, 9, 10), "float32")) -> R.Tensor((8, 9, 3),
"float32"):
+ gv: R.Tensor((8, 9, 3), "float32") = R.strided_slice(x, axes=[-1],
begin=[2], end=[5])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((8, 9, 10), dtype="float32")) -> R.Tensor((8, 9,
3), dtype="float32"):
+ gv = R.call_tir(Expected.strided_slice, (x,),
out_sinfo=R.Tensor((8, 9, 3), dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def strided_slice(rxplaceholder: T.Buffer((T.int64(8), T.int64(9),
T.int64(10)), "float32"), T_strided_slice_with_axes: T.Buffer((T.int64(8),
T.int64(9), T.int64(3)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ for ax0, ax1, ax2 in T.grid(T.int64(8), T.int64(9), T.int64(3)):
+ with T.sblock("T_strided_slice_with_axes"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2 + T.int64(2)])
+ T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2])
+ T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] =
rxplaceholder[v_ax0, v_ax1, v_ax2 + T.int64(2)]
+ # fmt: on
+
+ mod = LegalizeOps()(StridedSlice)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_strided_slice_symbolic_sliced_axis():
# fmt: off
@tvm.script.ir_module