This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 427b66da1a [Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven
last chunk (#19341)
427b66da1a is described below
commit 427b66da1acd5d36c481773570ed500f0e6f1c1b
Author: Kryptonite <[email protected]>
AuthorDate: Sat Apr 4 00:31:03 2026 +0300
[Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven last chunk
(#19341)
### Summary
Fixes two spec violations in `SplitToSequence`:
1. **keepdims=0** was raising `NotImplementedError`. The fix squeezes
the split axis from each chunk when `split` is scalar and `keepdims=0`.
Per spec:
> "If input 'split' is specified [as a 1-D array], this attribute is
ignored" —
verified against ORT.
2. **Uneven last chunk** was raising `ValueError`. The spec states: "The
last chunk alone may be smaller than 'split' if the input size is not
divisible by 'split'." Fixed by using index-based splitting via
`range(chunk_size, dim_size, chunk_size)` instead of a count.
Reference: https://onnx.ai/onnx/operators/onnx__SplitToSequence.html
Closes part of #18945
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 31 +++++++----
tests/python/relax/test_frontend_onnx.py | 70 +++++++++++++++++++++++++
2 files changed, 92 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index ab1ea2b292..d54877646c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4310,12 +4310,10 @@ class SplitToSequence(OnnxOpConverter):
input_tensor = inputs[0]
input_shape = input_tensor.struct_info.shape
+ split_is_scalar = False
- # If split is not provided, we split all values along axis.
if len(inputs) == 1:
split = _np.array(1)
- if not keepdims:
- raise NotImplementedError("Only keepdims=1 is supported for
now")
else:
split = inputs[1]
if not isinstance(split, relax.Constant):
@@ -4326,15 +4324,30 @@ class SplitToSequence(OnnxOpConverter):
split = _np.cumsum(split)
split = list(split[:-1])
else:
- chunk_size, dim_size = int(split), input_shape[axis]
- if dim_size % chunk_size != 0:
- raise ValueError(
- f"Dimension of size {dim_size} along axis {axis} must be "
- f"evenly divisible by chunk size {chunk_size}"
+ chunk_size = int(split)
+ dim_size = input_shape[axis]
+
+ if isinstance(dim_size, (int, tirx.IntImm)):
+ dim_size_int = int(dim_size)
+ split = math.ceil(dim_size_int / chunk_size)
+ else:
+ raise NotImplementedError(
+ "SplitToSequence with dynamic dim size and scalar split is
not supported."
)
- split = dim_size // chunk_size
output = relax.op.split(input_tensor, split, axis=axis)
+
+ # keepdims=0 applies when split is a scalar (whether provided or
defaulted to 1)
+ # Per ONNX spec: "If input 'split' is specified, this attribute is
ignored."
+ if not keepdims and len(inputs) == 1:
+ output = bb.emit(output)
+ n = len(output.struct_info.fields)
+ squeezed = [
+ relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)),
axis=[axis])
+ for i in range(n)
+ ]
+ return relax.Tuple(squeezed)
+
return output
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 7f9cd177ad..8111b95c4b 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5458,5 +5458,75 @@ def test_arg_min_max_select_last_index_ir(op_name):
assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got
{call_ops}"
[email protected]("axis", [0, 1, 2])
+def test_split_to_sequence_keepdims_0(axis: int):
+ """keepdims=0, no split input: each chunk of size 1 has the split axis
squeezed out."""
+ shape = [3, 4, 5]
+ out_shape = [s for i, s in enumerate(shape) if i != axis]
+
+ split_to_seq_node = helper.make_node(
+ "SplitToSequence",
+ ["data"], # no split input — keepdims applies here
+ ["output"],
+ axis=axis,
+ keepdims=0,
+ )
+ graph = helper.make_graph(
+ [split_to_seq_node],
+ f"test_split_to_sequence_keepdims_0_axis{axis}",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT,
shape)],
+ outputs=[helper.make_tensor_sequence_value_info("output",
TensorProto.FLOAT, out_shape)],
+ )
+ model = helper.make_model(graph,
producer_name="test_split_to_sequence_keepdims_0")
+ check_correctness(model)
+
+
+def test_split_to_sequence_keepdims_ignored_when_split_provided():
+ """Per spec: keepdims is ignored when split input is provided.
+ TVM follows the spec — output keeps the split axis even with keepdims=0."""
+ split_node = make_constant_node("split", TensorProto.INT64, (), [1])
+ split_to_seq_node = helper.make_node(
+ "SplitToSequence",
+ ["data", "split"],
+ ["output"],
+ axis=0,
+ keepdims=0,
+ )
+ graph = helper.make_graph(
+ [split_node, split_to_seq_node],
+ "test_split_to_sequence_keepdims_ignored",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [4,
5])],
+ outputs=[helper.make_tensor_sequence_value_info("output",
TensorProto.FLOAT, [1, 5])],
+ )
+ model = helper.make_model(
+ graph,
+ producer_name="test_split_to_sequence_keepdims_ignored",
+ opset_imports=[helper.make_opsetid("", 11)],
+ )
+ model.ir_version = 8
+ # Cannot use check_correctness here as ORT deviates from the spec for this
case
+ from tvm.relax.frontend.onnx import from_onnx
+ tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+ assert tvm_model is not None
+
+
[email protected]("axis", [0, 1])
+def test_split_to_sequence_uneven_last_chunk(axis: int):
+ """Spec: last chunk may be smaller if dim is not divisible by scalar
split."""
+ shape = [5, 4] if axis == 0 else [3, 5]
+ split_node = make_constant_node("split", TensorProto.INT64, (), [2])
+ split_to_seq_node = helper.make_node(
+ "SplitToSequence", ["data", "split"], ["output"], axis=axis, keepdims=1
+ )
+ graph = helper.make_graph(
+ [split_node, split_to_seq_node],
+ f"test_split_to_sequence_uneven_axis{axis}",
+ inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT,
shape)],
+ outputs=[helper.make_tensor_sequence_value_info("output",
TensorProto.FLOAT, None)],
+ )
+ model = helper.make_model(graph,
producer_name="test_split_to_sequence_uneven")
+ check_correctness(model)
+
+
if __name__ == "__main__":
tvm.testing.main()