This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f0863fd5c7 [Relax][ONNX] Support ConcatFromSequenc/SequenceInsert with 
new_axis=1 (#19361)
f0863fd5c7 is described below

commit f0863fd5c711e9fec08e2f3b2cb69b8e02933a31
Author: Bana <[email protected]>
AuthorDate: Tue Apr 7 04:05:00 2026 +0300

    [Relax][ONNX] Support ConcatFromSequenc/SequenceInsert with new_axis=1 
(#19361)
    
    ## Description
    
    This PR adds support for `new_axis=1` in the ONNX `ConcatFromSequence`
    operator, which was previously raising a `NotImplementedError`.
    
    *Note regarding the tracking issue:* The tracking issue listed this task
    as "SequenceInsert — Does not support inserting with new axis.", but i
    think it meant `ConcatFromSequence`. ONNX's `SequenceInsert` does not
    have a `new_axis` attribute, whereas `ConcatFromSequence` does and was
    throwing the "Insert new axis is not supported yet" error. This PR
    implements the missing feature.
    
    ## Changes:
    - Replaced the `NotImplementedError` in `ConcatFromSequence` with
    `relax.op.stack(inputs[0], axis=axis)`
    - Removed the `pytest.skip` from `test_concat_from_sequence`.
    - Parameterized the test to explicitly check both standard concatenation
    (`new_axis=0, axis=0` yielding `[64, 32]`) and stacking operations
    (`new_axis=1, axis=1` yielding `[32, 2, 32]`).
    
    ## Testing
    I tested the implementation via running:
    ```
    pytest tests/python/relax/test_frontend_onnx.py::test_concat_from_sequence
    ```
    and all tests passed:
    <img width="1044" height="89" alt="image"
    
src="https://github.com/user-attachments/assets/25d6ad26-ad4b-4437-9fa5-e29efc9e0c9f";
    />
    
    
    ## Reference
    https://onnx.ai/onnx/operators/onnx__ConcatFromSequence.html
    https://onnx.ai/onnx/operators/onnx__SequenceInsert.html
    
    partially addresses #18945
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 30 ++++++++------
 tests/python/relax/test_frontend_onnx.py        | 53 ++++++++++++++++++++++---
 2 files changed, 65 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d54877646c..64fbf94076 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4242,10 +4242,9 @@ class SequenceErase(OnnxOpConverter):
 
         if position < 0:
             position = seq_len + position
-        # Convert sequence to a list, insert tensors before erased, and 
repackage as Tuple.
-        tensor_list = [input_sequence[i] for i in range(seq_len) if i != 
position]
-        # Create new tuple and return.
-        return relax.Tuple(tensor_list)
+        seq_list = list(input_sequence)
+        items = [t for i, t in enumerate(seq_list) if i != position]
+        return relax.Tuple(items)
 
 
 class SequenceInsert(OnnxOpConverter):
@@ -4261,19 +4260,22 @@ class SequenceInsert(OnnxOpConverter):
             position = inputs[2]
             # Non constant position is not supported.
             if isinstance(position, relax.Constant):
-                position = position.data.numpy()
+                position = int(position.data.numpy())
             else:
                 raise NotImplementedError("Position must be a constant.")
         else:
             position = -1
 
+        seq_len = len(input_sequence)
         if position < 0:
-            position = len(input_sequence) + position + 1
-        # Convert sequence to a list, insert new tensor, and repackage as 
Tuple.
-        tensor_list = [input_sequence[i] for i in range(len(input_sequence))]
-        # Insert new tensor.
+            position = seq_len + position + 1
+        # Upper bound is inclusive: position == seq_len appends at the end.
+        if not 0 <= position <= seq_len:
+            raise ValueError(
+                f"SequenceInsert position out of bounds for length {seq_len}, 
got {position}"
+            )
+        tensor_list = list(input_sequence)
         tensor_list.insert(position, tensor_to_insert)
-        # Create new tuple and return.
         return relax.Tuple(tensor_list)
 
 
@@ -4294,10 +4296,14 @@ class ConcatFromSequence(OnnxOpConverter):
         axis = attr.get("axis", 0)
         new_axis = attr.get("new_axis", 0)
 
+        if new_axis not in (0, 1):
+            raise ValueError(f"ConcatFromSequence only supports new_axis in 
(0, 1), got {new_axis}")
+
+        tensors = list(inputs[0])
         if new_axis == 1:
-            raise NotImplementedError("Insert new axis is not supported yet.")
+            tensors = [relax.op.expand_dims(t, axis=axis) for t in tensors]
 
-        return relax.op.concat(inputs[0], axis=axis)
+        return relax.op.concat(tensors, axis=axis)
 
 
 class SplitToSequence(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 8111b95c4b..29e2d9499d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -3762,24 +3762,65 @@ def test_sequence_insert(explicit_position: bool):
     check_correctness(model)
 
 
[email protected]("new_axis", [0, 1])
-def test_concat_from_sequence(new_axis: Literal[0, 1]):
-    if new_axis == 1:
-        pytest.skip("ConcatFromSequence with new_axis=1 is not supported yet")
[email protected](
+    "new_axis,axis,expected_shape",
+    [
+        (0, 0, [64, 32]),
+        (0, 1, [32, 64]),
+        (1, 0, [2, 32, 32]),
+        (1, 1, [32, 2, 32]),
+        (1, -1, [32, 32, 2]),
+    ],
+)
+def test_concat_from_sequence(new_axis: int, axis: int, expected_shape: 
list[int]):
     seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=2)
     concat_from_sequence_node = helper.make_node(
-        "ConcatFromSequence", ["sequence"], ["output"], axis=1
+        "ConcatFromSequence", ["sequence"], ["output"], axis=axis, 
new_axis=new_axis
     )
     graph = helper.make_graph(
         [seq_node, concat_from_sequence_node],
         "test_concat_from_sequence",
         inputs=graph_inputs,
-        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[64, 32])],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
expected_shape)],
     )
     model = helper.make_model(graph, producer_name="test_concat_from_sequence")
     check_correctness(model)
 
 
+def test_concat_from_sequence_new_axis_three_tensors():
+    """new_axis=1 with three sequence elements (stack then concat along 
axis)."""
+    seq_node, graph_inputs = construct_sequence(input_shape=[16, 8], 
num_tensors=3)
+    concat_node = helper.make_node(
+        "ConcatFromSequence", ["sequence"], ["output"], axis=0, new_axis=1
+    )
+    graph = helper.make_graph(
+        [seq_node, concat_node],
+        "test_concat_from_sequence_new_axis_three",
+        inputs=graph_inputs,
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[3, 16, 8])],
+    )
+    model = helper.make_model(graph, 
producer_name="test_concat_from_sequence_new_axis_three")
+    check_correctness(model)
+
+
+def test_concat_from_sequence_invalid_new_axis():
+    """Verify that new_axis values other than 0 or 1 raise a ValueError."""
+    seq_node, graph_inputs = construct_sequence(input_shape=[16, 8], 
num_tensors=2)
+    concat_node = helper.make_node(
+        "ConcatFromSequence", ["sequence"], ["output"], axis=0, new_axis=2
+    )
+    graph = helper.make_graph(
+        [seq_node, concat_node],
+        "test_concat_from_sequence_invalid_new_axis",
+        inputs=graph_inputs,
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[32, 8])],
+    )
+    model = helper.make_model(graph, 
producer_name="test_concat_from_sequence_invalid_new_axis")
+    
+    with pytest.raises(ValueError, match="ConcatFromSequence only supports 
new_axis in"):
+        from_onnx(model, opset=11)
+
+
 @pytest.mark.parametrize("split", [2, [16, 48]])
 def test_split_to_sequence(split):
     split_to_sequence_node = helper.make_node(

Reply via email to