This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 427b66da1a [Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven 
last chunk (#19341)
427b66da1a is described below

commit 427b66da1acd5d36c481773570ed500f0e6f1c1b
Author: Kryptonite <[email protected]>
AuthorDate: Sat Apr 4 00:31:03 2026 +0300

    [Frontend][ONNX] Fix SplitToSequence keepdims=0 and uneven last chunk 
(#19341)
    
    ### Summary
    
    Fixes two spec violations in `SplitToSequence`:
    
    1. **keepdims=0** was raising `NotImplementedError`. The fix squeezes
    the split axis from each chunk when `split` is scalar and `keepdims=0`.
    Per spec:
    > "If input 'split' is specified [as a 1-D array], this attribute is
    ignored" —
       verified against ORT.
    
    2. **Uneven last chunk** was raising `ValueError`. The spec states: "The
    last chunk alone may be smaller than 'split' if the input size is not
    divisible by 'split'." Fixed by using index-based splitting via
    `range(chunk_size, dim_size, chunk_size)` instead of a count.
    
    Reference: https://onnx.ai/onnx/operators/onnx__SplitToSequence.html
    Closes part of #18945
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 31 +++++++----
 tests/python/relax/test_frontend_onnx.py        | 70 +++++++++++++++++++++++++
 2 files changed, 92 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index ab1ea2b292..d54877646c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4310,12 +4310,10 @@ class SplitToSequence(OnnxOpConverter):
 
         input_tensor = inputs[0]
         input_shape = input_tensor.struct_info.shape
+        split_is_scalar = False
 
-        # If split is not provided, we split all values along axis.
         if len(inputs) == 1:
             split = _np.array(1)
-            if not keepdims:
-                raise NotImplementedError("Only keepdims=1 is supported for 
now")
         else:
             split = inputs[1]
             if not isinstance(split, relax.Constant):
@@ -4326,15 +4324,30 @@ class SplitToSequence(OnnxOpConverter):
             split = _np.cumsum(split)
             split = list(split[:-1])
         else:
-            chunk_size, dim_size = int(split), input_shape[axis]
-            if dim_size % chunk_size != 0:
-                raise ValueError(
-                    f"Dimension of size {dim_size} along axis {axis} must be "
-                    f"evenly divisible by chunk size {chunk_size}"
+            chunk_size = int(split)
+            dim_size = input_shape[axis]
+
+            if isinstance(dim_size, (int, tirx.IntImm)):
+                dim_size_int = int(dim_size)
+                split = math.ceil(dim_size_int / chunk_size)
+            else:
+                raise NotImplementedError(
+                    "SplitToSequence with dynamic dim size and scalar split is 
not supported."
                 )
-            split = dim_size // chunk_size
 
         output = relax.op.split(input_tensor, split, axis=axis)
+
+        # keepdims=0 applies when split is a scalar (whether provided or 
defaulted to 1)
+        # Per ONNX spec: "If input 'split' is specified, this attribute is 
ignored."
+        if not keepdims and len(inputs) == 1:
+            output = bb.emit(output)
+            n = len(output.struct_info.fields)
+            squeezed = [
+                relax.op.squeeze(bb.emit(relax.TupleGetItem(output, i)), 
axis=[axis])
+                for i in range(n)
+            ]
+            return relax.Tuple(squeezed)
+        
         return output
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 7f9cd177ad..8111b95c4b 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5458,5 +5458,75 @@ def test_arg_min_max_select_last_index_ir(op_name):
     assert "relax.subtract" in call_ops, f"Expected relax.subtract in IR, got 
{call_ops}"
 
 
[email protected]("axis", [0, 1, 2])
+def test_split_to_sequence_keepdims_0(axis: int):
+    """keepdims=0, no split input: each chunk of size 1 has the split axis 
squeezed out."""
+    shape = [3, 4, 5]
+    out_shape = [s for i, s in enumerate(shape) if i != axis]
+
+    split_to_seq_node = helper.make_node(
+        "SplitToSequence",
+        ["data"],          # no split input — keepdims applies here
+        ["output"],
+        axis=axis,
+        keepdims=0,
+    )
+    graph = helper.make_graph(
+        [split_to_seq_node],
+        f"test_split_to_sequence_keepdims_0_axis{axis}",
+        inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, 
shape)],
+        outputs=[helper.make_tensor_sequence_value_info("output", 
TensorProto.FLOAT, out_shape)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_split_to_sequence_keepdims_0")
+    check_correctness(model)
+
+
+def test_split_to_sequence_keepdims_ignored_when_split_provided():
+    """Per spec: keepdims is ignored when split input is provided.
+    TVM follows the spec — output keeps the split axis even with keepdims=0."""
+    split_node = make_constant_node("split", TensorProto.INT64, (), [1])
+    split_to_seq_node = helper.make_node(
+        "SplitToSequence",
+        ["data", "split"],
+        ["output"],
+        axis=0,
+        keepdims=0,
+    )
+    graph = helper.make_graph(
+        [split_node, split_to_seq_node],
+        "test_split_to_sequence_keepdims_ignored",
+        inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, [4, 
5])],
+        outputs=[helper.make_tensor_sequence_value_info("output", 
TensorProto.FLOAT, [1, 5])],
+    )
+    model = helper.make_model(
+        graph,
+        producer_name="test_split_to_sequence_keepdims_ignored",
+        opset_imports=[helper.make_opsetid("", 11)],
+    )
+    model.ir_version = 8
+    # Cannot use check_correctness here as ORT deviates from the spec for this 
case
+    from tvm.relax.frontend.onnx import from_onnx
+    tvm_model = from_onnx(model, opset=11, keep_params_in_input=True)
+    assert tvm_model is not None
+
+
[email protected]("axis", [0, 1])
+def test_split_to_sequence_uneven_last_chunk(axis: int):
+    """Spec: last chunk may be smaller if dim is not divisible by scalar 
split."""
+    shape = [5, 4] if axis == 0 else [3, 5]
+    split_node = make_constant_node("split", TensorProto.INT64, (), [2])
+    split_to_seq_node = helper.make_node(
+        "SplitToSequence", ["data", "split"], ["output"], axis=axis, keepdims=1
+    )
+    graph = helper.make_graph(
+        [split_node, split_to_seq_node],
+        f"test_split_to_sequence_uneven_axis{axis}",
+        inputs=[helper.make_tensor_value_info("data", TensorProto.FLOAT, 
shape)],
+        outputs=[helper.make_tensor_sequence_value_info("output", 
TensorProto.FLOAT, None)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_split_to_sequence_uneven")
+    check_correctness(model)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to