This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new cb5e290931 [Relax][ONNX] Add Optional and MatMulInteger16 frontend
support (#18950)
cb5e290931 is described below
commit cb5e290931fa403110c047618b3aad0e9df60607
Author: HoYi <[email protected]>
AuthorDate: Mon Mar 30 19:33:35 2026 +0800
[Relax][ONNX] Add Optional and MatMulInteger16 frontend support (#18950)
## Summary
This PR adds Relax ONNX frontend support for:
- `Optional`
- `OptionalHasElement`
- `OptionalGetElement`
- `MatMulInteger16` from the `com.microsoft` domain
The implementation follows existing TVM ONNX frontend patterns and keeps
Optional handling explicit through an empty-Optional sentinel during
import.
## Changes
- add ONNX frontend converters for `Optional`, `OptionalHasElement`, and
`OptionalGetElement`
- add ONNX frontend converter for `MatMulInteger16`
- extend ONNX attribute parsing to handle `TYPE_PROTO`
- preserve empty Optional values during import and unwrap them
consistently
- register Optional-related ops and `MatMulInteger16` in the ONNX
converter map
- handle Optional outputs correctly in importer output counting and
normalization
- tighten converter docstrings and input validation for better
consistency with nearby TVM code
## Tests
Added or updated tests in `tests/python/relax/test_frontend_onnx.py` to
cover:
- numerical correctness for `MatMulInteger16`
- structural IR checks for `MatMulInteger16`
- invalid dtype rejection for `MatMulInteger16`
- tensor and sequence Optional round-trips
- empty Optional behavior for `OptionalHasElement`
- structural IR checks ensuring Optional ops are erased as expected
- missing `type` attribute rejection for empty `Optional`
- empty `OptionalGetElement` rejection
## Validation
Validated with:
- `python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py
tests/python/relax/test_frontend_onnx.py`
- `python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k
"optional or matmulinteger16" -v`
Result:
- `13 passed`
This PR completes the ONNX `MatMulInteger16` and `Optional` work tracked
in https://github.com/apache/tvm/issues/18945.
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 122 ++++++++-
tests/python/relax/test_frontend_onnx.py | 343 ++++++++++++++++++++++++
2 files changed, 455 insertions(+), 10 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 4cc4e99b7b..4af7115e5c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -318,6 +318,37 @@ class MatMul(OnnxOpConverter):
return relax.op.matmul(inputs[0], inputs[1])
+class MatMulInteger16(OnnxOpConverter):
+ """Converts an ONNX MatMulInteger16 node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
+ if len(inputs) != 2:
+ raise ValueError(f"MatMulInteger16 expects two inputs, but got
{len(inputs)}")
+ a, b = inputs
+ valid_types = ["int16", "uint16"]
+ if a.struct_info.dtype not in valid_types:
+ raise ValueError(
+ "MatMulInteger16 expects input A to have int16 or uint16
dtype, "
+ f"but got {a.struct_info.dtype}"
+ )
+ if b.struct_info.dtype not in valid_types:
+ raise ValueError(
+ "MatMulInteger16 expects input B to have int16 or uint16
dtype, "
+ f"but got {b.struct_info.dtype}"
+ )
+
+ out_dtype = (
+ "uint32"
+ if a.struct_info.dtype == "uint16" and b.struct_info.dtype ==
"uint16"
+ else "int32"
+ )
+ return relax.op.matmul(
+ relax.op.astype(a, out_dtype),
+ relax.op.astype(b, out_dtype),
+ )
+
+
def _to_numpy(x):
if isinstance(x, relax.PrimValue):
x = x.value
@@ -328,6 +359,19 @@ def _to_numpy(x):
return x.data.numpy()
+class _EmptyOptional:
+ """Sentinel object that preserves an empty ONNX Optional during import."""
+
+ def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto):
+ self.type_proto = type_proto
+
+
+def _is_empty_optional(value: Any) -> bool:
+ """Returns whether the given value represents an empty ONNX Optional."""
+
+ return isinstance(value, _EmptyOptional)
+
+
class BinaryBase(OnnxOpConverter):
"""Converts an onnx BinaryBase node into an equivalent Relax expression."""
@@ -3686,6 +3730,50 @@ class SpaceToDepth(OnnxOpConverter):
)
+class Optional_(OnnxOpConverter):
+ """Converts an ONNX Optional node into an erased or empty Optional
representation."""
+
+ @classmethod
+ def _impl_v15(cls, bb, inputs, attr, params):
+ if len(inputs) > 1:
+ raise ValueError(f"Optional accepts at most one input, but got
{len(inputs)}")
+ if len(inputs) == 0 or inputs[0] is None:
+ if "type" not in attr:
+ raise ValueError("Optional without an input must specify the
type attribute.")
+ return _EmptyOptional(attr["type"])
+ return inputs[0]
+
+ _impl_v18 = _impl_v15
+
+
+class OptionalHasElement(OnnxOpConverter):
+ """Converts an ONNX OptionalHasElement node into a boolean constant."""
+
+ @classmethod
+ def _impl_v15(cls, bb, inputs, attr, params):
+ if len(inputs) != 1:
+ raise ValueError(f"OptionalHasElement expects one input, but got
{len(inputs)}")
+ if inputs[0] is None or _is_empty_optional(inputs[0]):
+ return relax.const(False, dtype="bool")
+ return relax.const(True, dtype="bool")
+
+ _impl_v18 = _impl_v15
+
+
+class OptionalGetElement(OnnxOpConverter):
+ """Converts an ONNX OptionalGetElement node by unwrapping a non-empty
Optional."""
+
+ @classmethod
+ def _impl_v15(cls, bb, inputs, attr, params):
+ if len(inputs) != 1:
+ raise ValueError(f"OptionalGetElement expects one input, but got
{len(inputs)}")
+ if inputs[0] is None or _is_empty_optional(inputs[0]):
+ raise ValueError("OptionalGetElement cannot access an empty
optional.")
+ return inputs[0]
+
+ _impl_v18 = _impl_v15
+
+
class SequenceConstruct(OnnxOpConverter):
"""Operator converter for sequence construction op."""
@@ -4111,9 +4199,9 @@ class MatMulInteger(OnnxOpConverter):
def _get_convert_map():
return {
# defs/experimental
- # "Optional": Optional_,
- # "OptionalHasElement": OptionalHasElement,
- # "OptionalGetElement": OptionalGetElement,
+ "Optional": Optional_,
+ "OptionalHasElement": OptionalHasElement,
+ "OptionalGetElement": OptionalGetElement,
# Binary operators
"Add": Add,
"Sub": Sub,
@@ -4184,7 +4272,7 @@ def _get_convert_map():
"Gemm": Gemm,
"MatMul": MatMul,
"MatMulInteger": MatMulInteger,
- # "MatMulInteger16": MatMulInteger16,
+ "MatMulInteger16": MatMulInteger16,
"Reshape": Reshape,
"Sigmoid": Sigmoid,
"Softmax": Softmax,
@@ -4343,7 +4431,18 @@ class ONNXGraphImporter:
self._check_for_unsupported_ops(graph)
self._construct_nodes(graph)
- outputs = [self._nodes[self._parse_value_proto(i)] for i in
graph.output]
+ # now return the outputs
+ output_names = [self._parse_value_proto(output) for output in
graph.output]
+ outputs = []
+ for output_name in output_names:
+ output_value = self._nodes[output_name]
+ if _is_empty_optional(output_value):
+ raise ValueError(
+ "ONNX graph output "
+ f"{output_name} is an empty optional. Empty
optional graph outputs "
+ "are not supported by the Relax ONNX frontend."
+ )
+ outputs.append(output_value)
outputs = outputs[0] if len(outputs) == 1 else
relax.Tuple(outputs)
if has_if:
@@ -4515,6 +4614,8 @@ class ONNXGraphImporter:
"Squeeze",
]
return_tuple_ops = [
+ "Optional",
+ "OptionalGetElement",
"SequenceConstruct",
"SequenceEmpty",
"SequenceErase",
@@ -4533,7 +4634,8 @@ class ONNXGraphImporter:
try:
op = self._convert_operator(op_name, inputs, attr, self.opset)
# Create struct information for the new operator.
- op = self.bb.normalize(op)
+ if isinstance(op, relax.Expr):
+ op = self.bb.normalize(op)
except TVMError as err:
print(f"Error converting operator {op_name}, with inputs:
{inputs}")
raise err
@@ -4585,11 +4687,11 @@ class ONNXGraphImporter:
if list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is
allowed"
attrs[a.name] = tuple(getattr(a, f))
- for f in ["t"]:
- if a.HasField(f):
+ for f in ["t", "tp"]:
+ if hasattr(a, f) and a.HasField(f):
attrs[a.name] = getattr(a, f)
- for f in ["tensors"]:
- if list(getattr(a, f)):
+ for f in ["tensors", "type_protos"]:
+ if hasattr(a, f) and list(getattr(a, f)):
assert a.name not in attrs, "Only one type of attr is
allowed"
attrs[a.name] = tuple(getattr(a, f))
for f in ["graphs"]:
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index d04b0c2f33..c848ef91d6 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -197,6 +197,65 @@ def check_correctness(
_check_output(tvm_out, ort_out)
+def run_in_tvm(
+ model: ModelProto,
+ inputs: dict[str, np.ndarray] | None = None,
+ ir_version: int = 8,
+ opset: int = 14,
+):
+ if ir_version is not None:
+ model.ir_version = ir_version
+ if opset is not None:
+ for opset_import in model.opset_import:
+ if opset_import.domain in ["", "ai.onnx"]:
+ opset_import.version = opset
+ break
+
+ inputs = generate_random_inputs(model, inputs)
+ tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
+ tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+ tvm_model = relax.transform.LegalizeOps()(tvm_model)
+ tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+ with tvm.transform.PassContext(opt_level=3):
+ ex = tvm.compile(tvm_model, target="llvm")
+ vm = relax.VirtualMachine(ex, tvm.cpu())
+
+ input_list = [
+ inputs[key.name_hint] for key in tvm_model["main"].params if
key.name_hint in inputs
+ ]
+ if params:
+ input_list += params["main"]
+
+ vm.set_input("main", *input_list)
+ vm.invoke_stateful("main")
+ return vm.get_outputs("main")
+
+
+def collect_relax_call_ops(func: relax.Function) -> list[str]:
+ op_names = []
+
+ def fvisit(expr):
+ if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
+ op_names.append(expr.op.name)
+
+ relax.analysis.post_order_visit(func.body, fvisit)
+ return op_names
+
+
+def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]:
+ values = []
+
+ def fvisit(expr):
+ if isinstance(expr, relax.Constant):
+ value = expr.data.numpy()
+ if value.shape == ():
+ values.append(value.item())
+
+ relax.analysis.post_order_visit(func.body, fvisit)
+ return values
+
+
@pytest.mark.parametrize(
"input_names, expected_names",
[
@@ -374,6 +433,101 @@ def test_matmul(dynamic):
check_correctness(model, inputs)
[email protected](
+ ("a_dtype", "b_dtype", "a_shape", "b_shape"),
+ [
+ (np.int16, np.int16, [2, 3], [3, 4]),
+ (np.uint16, np.uint16, [2, 3], [3, 4]),
+ (np.int16, np.uint16, [2, 1, 3, 5], [1, 2, 5, 4]),
+ ],
+)
+def test_matmulinteger16(a_dtype, b_dtype, a_shape, b_shape):
+ a = np.arange(np.prod(a_shape), dtype=np.int64).reshape(a_shape)
+ b = np.arange(np.prod(b_shape), dtype=np.int64).reshape(b_shape)
+ if np.issubdtype(a_dtype, np.signedinteger):
+ a -= a.size // 2
+ if np.issubdtype(b_dtype, np.signedinteger):
+ b -= b.size // 2
+ a = a.astype(a_dtype)
+ b = b.astype(b_dtype)
+
+ out_dtype = np.uint32 if a_dtype == np.uint16 and b_dtype == np.uint16
else np.int32
+ expected = np.matmul(a.astype(out_dtype), b.astype(out_dtype))
+
+ node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"],
domain="com.microsoft")
+ graph = helper.make_graph(
+ [node],
+ "matmulinteger16_test",
+ inputs=[
+ helper.make_tensor_value_info("a",
helper.np_dtype_to_tensor_dtype(a.dtype), a_shape),
+ helper.make_tensor_value_info("b",
helper.np_dtype_to_tensor_dtype(b.dtype), b_shape),
+ ],
+ outputs=[
+ helper.make_tensor_value_info(
+ "y", helper.np_dtype_to_tensor_dtype(np.dtype(out_dtype)),
expected.shape
+ )
+ ],
+ )
+ model = helper.make_model(
+ graph,
+ producer_name="matmulinteger16_test",
+ opset_imports=[helper.make_opsetid("", 18),
helper.make_opsetid("com.microsoft", 1)],
+ )
+ model.ir_version = 11
+
+ tvm_output = run_in_tvm(model, inputs={"a": a, "b": b}, ir_version=11,
opset=18)
+ assert isinstance(tvm_output, tvm.runtime.Tensor)
+ assert tvm_output.numpy().dtype == out_dtype
+ tvm.testing.assert_allclose(tvm_output.numpy(), expected)
+
+
+def test_matmulinteger16_ir():
+ node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"],
domain="com.microsoft")
+ graph = helper.make_graph(
+ [node],
+ "matmulinteger16_ir_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.UINT16, [2, 3]),
+ helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.UINT32, [2,
4])],
+ )
+ model = helper.make_model(
+ graph,
+ producer_name="matmulinteger16_ir_test",
+ opset_imports=[helper.make_opsetid("", 18),
helper.make_opsetid("com.microsoft", 1)],
+ )
+ model.ir_version = 11
+
+ tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+ assert call_ops.count("relax.astype") == 2
+ assert "relax.matmul" in call_ops
+ assert tvm_model["main"].ret_struct_info.dtype == "uint32"
+
+
+def test_matmulinteger16_invalid_dtype_raises():
+ node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"],
domain="com.microsoft")
+ graph = helper.make_graph(
+ [node],
+ "matmulinteger16_invalid_dtype_test",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.INT8, [2, 3]),
+ helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, [2,
4])],
+ )
+ model = helper.make_model(
+ graph,
+ producer_name="matmulinteger16_invalid_dtype_test",
+ opset_imports=[helper.make_opsetid("", 18),
helper.make_opsetid("com.microsoft", 1)],
+ )
+ model.ir_version = 11
+
+ with pytest.raises(ValueError, match="input A"):
+ from_onnx(model, opset=18, keep_params_in_input=True)
+
+
def test_concat():
verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})
@@ -3176,6 +3330,21 @@ def make_constant_node(name: str, data_type: int, dims:
list[int], vals: list[in
)
+def make_optional_tensor_value_info(name: str, elem_type: int, shape:
list[int]):
+ return helper.make_value_info(
+ name,
helper.make_optional_type_proto(helper.make_tensor_type_proto(elem_type, shape))
+ )
+
+
+def make_optional_sequence_value_info(name: str, elem_type: int, shape:
list[int]):
+ return helper.make_value_info(
+ name,
+ helper.make_optional_type_proto(
+
helper.make_sequence_type_proto(helper.make_tensor_type_proto(elem_type, shape))
+ ),
+ )
+
+
def test_sequence_construct():
node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=2)
graph = helper.make_graph(
@@ -3287,6 +3456,180 @@ def test_sequence_at():
check_correctness(model)
+def test_optional_get_element_tensor():
+ x_shape = [2, 3]
+ optional_node = helper.make_node("Optional", ["x"], ["optional"])
+ get_element_node = helper.make_node("OptionalGetElement", ["optional"],
["output"])
+ graph = helper.make_graph(
+ [optional_node, get_element_node],
+ "test_optional_get_element_tensor",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
x_shape)],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
x_shape)],
+ value_info=[make_optional_tensor_value_info("optional",
TensorProto.FLOAT, x_shape)],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_get_element_tensor")
+ check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_has_element_tensor():
+ x_shape = [2, 3]
+ optional_node = helper.make_node("Optional", ["x"], ["optional"])
+ has_element_node = helper.make_node("OptionalHasElement", ["optional"],
["output"])
+ graph = helper.make_graph(
+ [optional_node, has_element_node],
+ "test_optional_has_element_tensor",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
x_shape)],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL,
[])],
+ value_info=[make_optional_tensor_value_info("optional",
TensorProto.FLOAT, x_shape)],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_has_element_tensor")
+ check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_has_element_empty():
+ x_shape = [2, 3]
+ tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape)
+ optional_type = helper.make_optional_type_proto(tensor_type)
+ optional_node = helper.make_node("Optional", [], ["optional"],
type=tensor_type)
+ has_element_node = helper.make_node("OptionalHasElement", ["optional"],
["output"])
+ graph = helper.make_graph(
+ [optional_node, has_element_node],
+ "test_optional_has_element_empty",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL,
[])],
+ value_info=[helper.make_value_info("optional", optional_type)],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_has_element_empty")
+ check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_has_element_empty_ir():
+ x_shape = [2, 3]
+ tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape)
+ optional_type = helper.make_optional_type_proto(tensor_type)
+ optional_node = helper.make_node("Optional", [], ["optional"],
type=tensor_type)
+ has_element_node = helper.make_node("OptionalHasElement", ["optional"],
["output"])
+ graph = helper.make_graph(
+ [optional_node, has_element_node],
+ "test_optional_has_element_empty_ir",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL,
[])],
+ value_info=[helper.make_value_info("optional", optional_type)],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_has_element_empty_ir")
+ model.ir_version = 11
+ model.opset_import[0].version = 18
+ tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+
+ assert collect_relax_call_ops(tvm_model["main"]) == []
+ assert False in collect_scalar_constants(tvm_model["main"])
+
+
+def test_optional_get_element_tensor_ir():
+ x_shape = [2, 3]
+ optional_node = helper.make_node("Optional", ["x"], ["optional"])
+ get_element_node = helper.make_node("OptionalGetElement", ["optional"],
["output"])
+ graph = helper.make_graph(
+ [optional_node, get_element_node],
+ "test_optional_get_element_tensor_ir",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT,
x_shape)],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
x_shape)],
+ value_info=[make_optional_tensor_value_info("optional",
TensorProto.FLOAT, x_shape)],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_get_element_tensor_ir")
+ model.ir_version = 11
+ model.opset_import[0].version = 18
+ tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+
+ assert collect_relax_call_ops(tvm_model["main"]) == []
+ assert tvm_model["main"].ret_struct_info.dtype == "float32"
+
+
+def test_optional_get_element_sequence():
+ seq_node, graph_inputs = construct_sequence(input_shape=[32, 32],
num_tensors=4)
+ index = make_constant_node("index", TensorProto.INT64, (), [1])
+ optional_node = helper.make_node("Optional", ["sequence"], ["optional"])
+ get_element_node = helper.make_node("OptionalGetElement", ["optional"],
["unwrapped"])
+ sequence_at_node = helper.make_node("SequenceAt", ["unwrapped", "index"],
["output"])
+ graph = helper.make_graph(
+ [index, seq_node, optional_node, get_element_node, sequence_at_node],
+ "test_optional_get_element_sequence",
+ inputs=graph_inputs,
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
[32, 32])],
+ value_info=[make_optional_sequence_value_info("optional",
TensorProto.FLOAT, [32, 32])],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_get_element_sequence")
+ check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_without_input_requires_type_attr():
+ tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3])
+ optional_type = helper.make_optional_type_proto(tensor_type)
+ optional_node = helper.make_node("Optional", [], ["optional"])
+ graph = helper.make_graph(
+ [optional_node],
+ "test_optional_without_input_requires_type_attr",
+ inputs=[],
+ outputs=[helper.make_value_info("optional", optional_type)],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_without_input_requires_type_attr")
+ model.opset_import[0].version = 18
+
+ with pytest.raises(ValueError, match="type attribute"):
+ from_onnx(model, opset=18, keep_params_in_input=True)
+
+
+def test_empty_optional_graph_output_raises():
+ tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3])
+ optional_type = helper.make_optional_type_proto(tensor_type)
+ optional_node = helper.make_node("Optional", [], ["optional"],
type=tensor_type)
+ graph = helper.make_graph(
+ [optional_node],
+ "test_empty_optional_graph_output_raises",
+ inputs=[],
+ outputs=[helper.make_value_info("optional", optional_type)],
+ )
+ model = helper.make_model(graph,
producer_name="test_empty_optional_graph_output_raises")
+ model.opset_import[0].version = 18
+
+ with pytest.raises(ValueError, match="Empty optional graph outputs are not
supported"):
+ from_onnx(model, opset=18, keep_params_in_input=True)
+
+
+def test_optional_has_element_requires_one_input():
+ has_element_node = helper.make_node("OptionalHasElement", [], ["output"])
+ graph = helper.make_graph(
+ [has_element_node],
+ "test_optional_has_element_requires_one_input",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL,
[])],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_has_element_requires_one_input")
+ model.opset_import[0].version = 18
+
+ with pytest.raises(ValueError, match="expects one input"):
+ from_onnx(model, opset=18, keep_params_in_input=True)
+
+
+def test_optional_get_element_empty_raises():
+ x_shape = [2, 3]
+ tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape)
+ optional_type = helper.make_optional_type_proto(tensor_type)
+ optional_node = helper.make_node("Optional", [], ["optional"],
type=tensor_type)
+ get_element_node = helper.make_node("OptionalGetElement", ["optional"],
["output"])
+ graph = helper.make_graph(
+ [optional_node, get_element_node],
+ "test_optional_get_element_empty_raises",
+ inputs=[],
+ outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT,
x_shape)],
+ value_info=[helper.make_value_info("optional", optional_type)],
+ )
+ model = helper.make_model(graph,
producer_name="test_optional_get_element_empty_raises")
+ model.opset_import[0].version = 18
+ with pytest.raises(ValueError, match="empty optional"):
+ from_onnx(model, opset=18, keep_params_in_input=True)
+
+
def test_symbolic_shape_deduction():
index_node = helper.make_node(
"Constant",