This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f0863fd5c7 [Relax][ONNX] Support ConcatFromSequenc/SequenceInsert with
new_axis=1 (#19361)
f0863fd5c7 is described below
commit f0863fd5c711e9fec08e2f3b2cb69b8e02933a31
Author: Bana <[email protected]>
AuthorDate: Tue Apr 7 04:05:00 2026 +0300
[Relax][ONNX] Support ConcatFromSequenc/SequenceInsert with new_axis=1
(#19361)
## Description
This PR adds support for `new_axis=1` in the ONNX `ConcatFromSequence`
operator, which was previously raising a `NotImplementedError`.
*Note regarding the tracking issue:* The tracking issue listed this task
as "SequenceInsert — Does not support inserting with new axis.", but i
think it meant `ConcatFromSequence`. ONNX's `SequenceInsert` does not
have a `new_axis` attribute, whereas `ConcatFromSequence` does and was
throwing the "Insert new axis is not supported yet" error. This PR
implements the missing feature.
## Changes:
- Replaced the `NotImplementedError` in `ConcatFromSequence` with
`relax.op.stack(inputs[0], axis=axis)`
- Removed the `pytest.skip` from `test_concat_from_sequence`.
- Parameterized the test to explicitly check both standard concatenation
(`new_axis=0, axis=0` yielding `[64, 32]`) and stacking operations
(`new_axis=1, axis=1` yielding `[32, 2, 32]`).
## Testing
I tested the implementation via running:
```
pytest tests/python/relax/test_frontend_onnx.py::test_concat_from_sequence
```
and all tests passed:
<img width="1044" height="89" alt="image"
src="https://github.com/user-attachments/assets/25d6ad26-ad4b-4437-9fa5-e29efc9e0c9f"
/>
## Reference
https://onnx.ai/onnx/operators/onnx__ConcatFromSequence.html
https://onnx.ai/onnx/operators/onnx__SequenceInsert.html
partially addresses #18945
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 30 ++++++++------
tests/python/relax/test_frontend_onnx.py | 53 ++++++++++++++++++++++---
2 files changed, 65 insertions(+), 18 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index d54877646c..64fbf94076 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4242,10 +4242,9 @@ class SequenceErase(OnnxOpConverter):
if position < 0:
position = seq_len + position
- # Convert sequence to a list, insert tensors before erased, and
repackage as Tuple.
- tensor_list = [input_sequence[i] for i in range(seq_len) if i !=
position]
- # Create new tuple and return.
- return relax.Tuple(tensor_list)
+ seq_list = list(input_sequence)
+ items = [t for i, t in enumerate(seq_list) if i != position]
+ return relax.Tuple(items)
class SequenceInsert(OnnxOpConverter):
@@ -4261,19 +4260,22 @@ class SequenceInsert(OnnxOpConverter):
position = inputs[2]
# Non constant position is not supported.
if isinstance(position, relax.Constant):
- position = position.data.numpy()
+ position = int(position.data.numpy())
else:
raise NotImplementedError("Position must be a constant.")
else:
position = -1
+ seq_len = len(input_sequence)
if position < 0:
- position = len(input_sequence) + position + 1
- # Convert sequence to a list, insert new tensor, and repackage as
Tuple.
- tensor_list = [input_sequence[i] for i in range(len(input_sequence))]
- # Insert new tensor.
+ position = seq_len + position + 1
+ # Upper bound is inclusive: position == seq_len appends at the end.
+ if not 0 <= position <= seq_len:
+ raise ValueError(
+ f"SequenceInsert position out of bounds for length {seq_len},
got {position}"
+ )
+ tensor_list = list(input_sequence)
tensor_list.insert(position, tensor_to_insert)
- # Create new tuple and return.
return relax.Tuple(tensor_list)
@@ -4294,10 +4296,14 @@ class ConcatFromSequence(OnnxOpConverter):
axis = attr.get("axis", 0)
new_axis = attr.get("new_axis", 0)
+ if new_axis not in (0, 1):
+ raise ValueError(f"ConcatFromSequence only supports new_axis in
(0, 1), got {new_axis}")
+
+ tensors = list(inputs[0])
if new_axis == 1:
- raise NotImplementedError("Insert new axis is not supported yet.")
+ tensors = [relax.op.expand_dims(t, axis=axis) for t in tensors]
- return relax.op.concat(inputs[0], axis=axis)
+ return relax.op.concat(tensors, axis=axis)
class SplitToSequence(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 8111b95c4b..29e2d9499d 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -3762,24 +3762,65 @@ def test_sequence_insert(explicit_position: bool):
check_correctness(model)
[email protected]("new_axis", [0, 1])
-def test_concat_from_sequence(new_axis: Literal[0, 1]):
- if new_axis == 1:
- pytest.skip("ConcatFromSequence with new_axis=1 is not supported yet")
[email protected](
+ "new_axis,axis,expected_shape",
+ [
+ (0, 0, [64, 32]),
+ (0, 1, [32, 64]),
+ (1, 0, [2, 32, 32]),
+ (1, 1, [32, 2, 32]),
+ (1, -1, [32, 32, 2]),
+ ],
+)
+def test_concat_from_sequence(new_axis: int, axis: int, expected_shape:
list[int]):
seq_node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=2)
concat_from_sequence_node = helper.make_node(
- "ConcatFromSequence", ["sequence"], ["output"], axis=1
+ "ConcatFromSequence", ["sequence"], ["output"], axis=axis,
new_axis=new_axis
)
graph = helper.make_graph(
[seq_node, concat_from_sequence_node],
"test_concat_from_sequence",
inputs=graph_inputs,
- outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[64, 32])],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
expected_shape)],
)
model = helper.make_model(graph, producer_name="test_concat_from_sequence")
check_correctness(model)
+def test_concat_from_sequence_new_axis_three_tensors():
+ """new_axis=1 with three sequence elements (stack then concat along
axis)."""
+ seq_node, graph_inputs = construct_sequence(input_shape=[16, 8],
num_tensors=3)
+ concat_node = helper.make_node(
+ "ConcatFromSequence", ["sequence"], ["output"], axis=0, new_axis=1
+ )
+ graph = helper.make_graph(
+ [seq_node, concat_node],
+ "test_concat_from_sequence_new_axis_three",
+ inputs=graph_inputs,
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[3, 16, 8])],
+ )
+ model = helper.make_model(graph,
producer_name="test_concat_from_sequence_new_axis_three")
+ check_correctness(model)
+
+
+def test_concat_from_sequence_invalid_new_axis():
+ """Verify that new_axis values other than 0 or 1 raise a ValueError."""
+ seq_node, graph_inputs = construct_sequence(input_shape=[16, 8],
num_tensors=2)
+ concat_node = helper.make_node(
+ "ConcatFromSequence", ["sequence"], ["output"], axis=0, new_axis=2
+ )
+ graph = helper.make_graph(
+ [seq_node, concat_node],
+ "test_concat_from_sequence_invalid_new_axis",
+ inputs=graph_inputs,
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[32, 8])],
+ )
+ model = helper.make_model(graph,
producer_name="test_concat_from_sequence_invalid_new_axis")
+
+ with pytest.raises(ValueError, match="ConcatFromSequence only supports
new_axis in"):
+ from_onnx(model, opset=11)
+
+
@pytest.mark.parametrize("split", [2, [16, 48]])
def test_split_to_sequence(split):
split_to_sequence_node = helper.make_node(