This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new cb5e290931 [Relax][ONNX] Add Optional and MatMulInteger16 frontend 
support (#18950)
cb5e290931 is described below

commit cb5e290931fa403110c047618b3aad0e9df60607
Author: HoYi <[email protected]>
AuthorDate: Mon Mar 30 19:33:35 2026 +0800

    [Relax][ONNX] Add Optional and MatMulInteger16 frontend support (#18950)
    
    ## Summary
    
    This PR adds Relax ONNX frontend support for:
    - `Optional`
    - `OptionalHasElement`
    - `OptionalGetElement`
    - `MatMulInteger16` from the `com.microsoft` domain
    
    The implementation follows existing TVM ONNX frontend patterns and keeps
    Optional handling explicit through an empty-Optional sentinel during
    import.
    
    ## Changes
    
    - add ONNX frontend converters for `Optional`, `OptionalHasElement`, and
    `OptionalGetElement`
    - add ONNX frontend converter for `MatMulInteger16`
    - extend ONNX attribute parsing to handle `TYPE_PROTO`
    - preserve empty Optional values during import and unwrap them
    consistently
    - register Optional-related ops and `MatMulInteger16` in the ONNX
    converter map
    - handle Optional outputs correctly in importer output counting and
    normalization
    - tighten converter docstrings and input validation for better
    consistency with nearby TVM code
    
    ## Tests
    
    Added or updated tests in `tests/python/relax/test_frontend_onnx.py` to
    cover:
    - numerical correctness for `MatMulInteger16`
    - structural IR checks for `MatMulInteger16`
    - invalid dtype rejection for `MatMulInteger16`
    - tensor and sequence Optional round-trips
    - empty Optional behavior for `OptionalHasElement`
    - structural IR checks ensuring Optional ops are erased as expected
    - missing `type` attribute rejection for empty `Optional`
    - empty `OptionalGetElement` rejection
    
    ## Validation
    
    Validated with:
    - `python -m ruff check python/tvm/relax/frontend/onnx/onnx_frontend.py
    tests/python/relax/test_frontend_onnx.py`
    - `python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k
    "optional or matmulinteger16" -v`
    
    Result:
    - `13 passed`
    
    This PR completes the ONNX `MatMulInteger16` and `Optional` work tracked
    in https://github.com/apache/tvm/issues/18945.
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 122 ++++++++-
 tests/python/relax/test_frontend_onnx.py        | 343 ++++++++++++++++++++++++
 2 files changed, 455 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 4cc4e99b7b..4af7115e5c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -318,6 +318,37 @@ class MatMul(OnnxOpConverter):
         return relax.op.matmul(inputs[0], inputs[1])
 
 
+class MatMulInteger16(OnnxOpConverter):
+    """Converts an ONNX MatMulInteger16 node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
+        if len(inputs) != 2:
+            raise ValueError(f"MatMulInteger16 expects two inputs, but got 
{len(inputs)}")
+        a, b = inputs
+        valid_types = ["int16", "uint16"]
+        if a.struct_info.dtype not in valid_types:
+            raise ValueError(
+                "MatMulInteger16 expects input A to have int16 or uint16 
dtype, "
+                f"but got {a.struct_info.dtype}"
+            )
+        if b.struct_info.dtype not in valid_types:
+            raise ValueError(
+                "MatMulInteger16 expects input B to have int16 or uint16 
dtype, "
+                f"but got {b.struct_info.dtype}"
+            )
+
+        out_dtype = (
+            "uint32"
+            if a.struct_info.dtype == "uint16" and b.struct_info.dtype == 
"uint16"
+            else "int32"
+        )
+        return relax.op.matmul(
+            relax.op.astype(a, out_dtype),
+            relax.op.astype(b, out_dtype),
+        )
+
+
 def _to_numpy(x):
     if isinstance(x, relax.PrimValue):
         x = x.value
@@ -328,6 +359,19 @@ def _to_numpy(x):
         return x.data.numpy()
 
 
+class _EmptyOptional:
+    """Sentinel object that preserves an empty ONNX Optional during import."""
+
+    def __init__(self, type_proto: onnx.onnx_ml_pb2.TypeProto):
+        self.type_proto = type_proto
+
+
+def _is_empty_optional(value: Any) -> bool:
+    """Returns whether the given value represents an empty ONNX Optional."""
+
+    return isinstance(value, _EmptyOptional)
+
+
 class BinaryBase(OnnxOpConverter):
     """Converts an onnx BinaryBase node into an equivalent Relax expression."""
 
@@ -3686,6 +3730,50 @@ class SpaceToDepth(OnnxOpConverter):
         )
 
 
+class Optional_(OnnxOpConverter):
+    """Converts an ONNX Optional node into an erased or empty Optional 
representation."""
+
+    @classmethod
+    def _impl_v15(cls, bb, inputs, attr, params):
+        if len(inputs) > 1:
+            raise ValueError(f"Optional accepts at most one input, but got 
{len(inputs)}")
+        if len(inputs) == 0 or inputs[0] is None:
+            if "type" not in attr:
+                raise ValueError("Optional without an input must specify the 
type attribute.")
+            return _EmptyOptional(attr["type"])
+        return inputs[0]
+
+    _impl_v18 = _impl_v15
+
+
+class OptionalHasElement(OnnxOpConverter):
+    """Converts an ONNX OptionalHasElement node into a boolean constant."""
+
+    @classmethod
+    def _impl_v15(cls, bb, inputs, attr, params):
+        if len(inputs) != 1:
+            raise ValueError(f"OptionalHasElement expects one input, but got 
{len(inputs)}")
+        if inputs[0] is None or _is_empty_optional(inputs[0]):
+            return relax.const(False, dtype="bool")
+        return relax.const(True, dtype="bool")
+
+    _impl_v18 = _impl_v15
+
+
+class OptionalGetElement(OnnxOpConverter):
+    """Converts an ONNX OptionalGetElement node by unwrapping a non-empty 
Optional."""
+
+    @classmethod
+    def _impl_v15(cls, bb, inputs, attr, params):
+        if len(inputs) != 1:
+            raise ValueError(f"OptionalGetElement expects one input, but got 
{len(inputs)}")
+        if inputs[0] is None or _is_empty_optional(inputs[0]):
+            raise ValueError("OptionalGetElement cannot access an empty 
optional.")
+        return inputs[0]
+
+    _impl_v18 = _impl_v15
+
+
 class SequenceConstruct(OnnxOpConverter):
     """Operator converter for sequence construction op."""
 
@@ -4111,9 +4199,9 @@ class MatMulInteger(OnnxOpConverter):
 def _get_convert_map():
     return {
         # defs/experimental
-        # "Optional": Optional_,
-        # "OptionalHasElement": OptionalHasElement,
-        # "OptionalGetElement": OptionalGetElement,
+        "Optional": Optional_,
+        "OptionalHasElement": OptionalHasElement,
+        "OptionalGetElement": OptionalGetElement,
         # Binary operators
         "Add": Add,
         "Sub": Sub,
@@ -4184,7 +4272,7 @@ def _get_convert_map():
         "Gemm": Gemm,
         "MatMul": MatMul,
         "MatMulInteger": MatMulInteger,
-        # "MatMulInteger16": MatMulInteger16,
+        "MatMulInteger16": MatMulInteger16,
         "Reshape": Reshape,
         "Sigmoid": Sigmoid,
         "Softmax": Softmax,
@@ -4343,7 +4431,18 @@ class ONNXGraphImporter:
                 self._check_for_unsupported_ops(graph)
                 self._construct_nodes(graph)
 
-                outputs = [self._nodes[self._parse_value_proto(i)] for i in 
graph.output]
+                # now return the outputs
+                output_names = [self._parse_value_proto(output) for output in 
graph.output]
+                outputs = []
+                for output_name in output_names:
+                    output_value = self._nodes[output_name]
+                    if _is_empty_optional(output_value):
+                        raise ValueError(
+                            "ONNX graph output "
+                            f"{output_name} is an empty optional. Empty 
optional graph outputs "
+                            "are not supported by the Relax ONNX frontend."
+                        )
+                    outputs.append(output_value)
                 outputs = outputs[0] if len(outputs) == 1 else 
relax.Tuple(outputs)
 
                 if has_if:
@@ -4515,6 +4614,8 @@ class ONNXGraphImporter:
                 "Squeeze",
             ]
             return_tuple_ops = [
+                "Optional",
+                "OptionalGetElement",
                 "SequenceConstruct",
                 "SequenceEmpty",
                 "SequenceErase",
@@ -4533,7 +4634,8 @@ class ONNXGraphImporter:
             try:
                 op = self._convert_operator(op_name, inputs, attr, self.opset)
                 # Create struct information for the new operator.
-                op = self.bb.normalize(op)
+                if isinstance(op, relax.Expr):
+                    op = self.bb.normalize(op)
             except TVMError as err:
                 print(f"Error converting operator {op_name}, with inputs: 
{inputs}")
                 raise err
@@ -4585,11 +4687,11 @@ class ONNXGraphImporter:
                 if list(getattr(a, f)):
                     assert a.name not in attrs, "Only one type of attr is 
allowed"
                     attrs[a.name] = tuple(getattr(a, f))
-            for f in ["t"]:
-                if a.HasField(f):
+            for f in ["t", "tp"]:
+                if hasattr(a, f) and a.HasField(f):
                     attrs[a.name] = getattr(a, f)
-            for f in ["tensors"]:
-                if list(getattr(a, f)):
+            for f in ["tensors", "type_protos"]:
+                if hasattr(a, f) and list(getattr(a, f)):
                     assert a.name not in attrs, "Only one type of attr is 
allowed"
                     attrs[a.name] = tuple(getattr(a, f))
             for f in ["graphs"]:
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index d04b0c2f33..c848ef91d6 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -197,6 +197,65 @@ def check_correctness(
             _check_output(tvm_out, ort_out)
 
 
+def run_in_tvm(
+    model: ModelProto,
+    inputs: dict[str, np.ndarray] | None = None,
+    ir_version: int = 8,
+    opset: int = 14,
+):
+    if ir_version is not None:
+        model.ir_version = ir_version
+    if opset is not None:
+        for opset_import in model.opset_import:
+            if opset_import.domain in ["", "ai.onnx"]:
+                opset_import.version = opset
+                break
+
+    inputs = generate_random_inputs(model, inputs)
+    tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
+    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
+    tvm_model = relax.transform.LegalizeOps()(tvm_model)
+    tvm_model, params = relax.frontend.detach_params(tvm_model)
+
+    with tvm.transform.PassContext(opt_level=3):
+        ex = tvm.compile(tvm_model, target="llvm")
+        vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    input_list = [
+        inputs[key.name_hint] for key in tvm_model["main"].params if 
key.name_hint in inputs
+    ]
+    if params:
+        input_list += params["main"]
+
+    vm.set_input("main", *input_list)
+    vm.invoke_stateful("main")
+    return vm.get_outputs("main")
+
+
+def collect_relax_call_ops(func: relax.Function) -> list[str]:
+    op_names = []
+
+    def fvisit(expr):
+        if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
+            op_names.append(expr.op.name)
+
+    relax.analysis.post_order_visit(func.body, fvisit)
+    return op_names
+
+
+def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]:
+    values = []
+
+    def fvisit(expr):
+        if isinstance(expr, relax.Constant):
+            value = expr.data.numpy()
+            if value.shape == ():
+                values.append(value.item())
+
+    relax.analysis.post_order_visit(func.body, fvisit)
+    return values
+
+
 @pytest.mark.parametrize(
     "input_names, expected_names",
     [
@@ -374,6 +433,101 @@ def test_matmul(dynamic):
     check_correctness(model, inputs)
 
 
[email protected](
+    ("a_dtype", "b_dtype", "a_shape", "b_shape"),
+    [
+        (np.int16, np.int16, [2, 3], [3, 4]),
+        (np.uint16, np.uint16, [2, 3], [3, 4]),
+        (np.int16, np.uint16, [2, 1, 3, 5], [1, 2, 5, 4]),
+    ],
+)
+def test_matmulinteger16(a_dtype, b_dtype, a_shape, b_shape):
+    a = np.arange(np.prod(a_shape), dtype=np.int64).reshape(a_shape)
+    b = np.arange(np.prod(b_shape), dtype=np.int64).reshape(b_shape)
+    if np.issubdtype(a_dtype, np.signedinteger):
+        a -= a.size // 2
+    if np.issubdtype(b_dtype, np.signedinteger):
+        b -= b.size // 2
+    a = a.astype(a_dtype)
+    b = b.astype(b_dtype)
+
+    out_dtype = np.uint32 if a_dtype == np.uint16 and b_dtype == np.uint16 
else np.int32
+    expected = np.matmul(a.astype(out_dtype), b.astype(out_dtype))
+
+    node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], 
domain="com.microsoft")
+    graph = helper.make_graph(
+        [node],
+        "matmulinteger16_test",
+        inputs=[
+            helper.make_tensor_value_info("a", 
helper.np_dtype_to_tensor_dtype(a.dtype), a_shape),
+            helper.make_tensor_value_info("b", 
helper.np_dtype_to_tensor_dtype(b.dtype), b_shape),
+        ],
+        outputs=[
+            helper.make_tensor_value_info(
+                "y", helper.np_dtype_to_tensor_dtype(np.dtype(out_dtype)), 
expected.shape
+            )
+        ],
+    )
+    model = helper.make_model(
+        graph,
+        producer_name="matmulinteger16_test",
+        opset_imports=[helper.make_opsetid("", 18), 
helper.make_opsetid("com.microsoft", 1)],
+    )
+    model.ir_version = 11
+
+    tvm_output = run_in_tvm(model, inputs={"a": a, "b": b}, ir_version=11, 
opset=18)
+    assert isinstance(tvm_output, tvm.runtime.Tensor)
+    assert tvm_output.numpy().dtype == out_dtype
+    tvm.testing.assert_allclose(tvm_output.numpy(), expected)
+
+
+def test_matmulinteger16_ir():
+    node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], 
domain="com.microsoft")
+    graph = helper.make_graph(
+        [node],
+        "matmulinteger16_ir_test",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.UINT16, [2, 3]),
+            helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.UINT32, [2, 
4])],
+    )
+    model = helper.make_model(
+        graph,
+        producer_name="matmulinteger16_ir_test",
+        opset_imports=[helper.make_opsetid("", 18), 
helper.make_opsetid("com.microsoft", 1)],
+    )
+    model.ir_version = 11
+
+    tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+    assert call_ops.count("relax.astype") == 2
+    assert "relax.matmul" in call_ops
+    assert tvm_model["main"].ret_struct_info.dtype == "uint32"
+
+
+def test_matmulinteger16_invalid_dtype_raises():
+    node = helper.make_node("MatMulInteger16", ["a", "b"], ["y"], 
domain="com.microsoft")
+    graph = helper.make_graph(
+        [node],
+        "matmulinteger16_invalid_dtype_test",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.INT8, [2, 3]),
+            helper.make_tensor_value_info("b", TensorProto.UINT16, [3, 4]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.INT32, [2, 
4])],
+    )
+    model = helper.make_model(
+        graph,
+        producer_name="matmulinteger16_invalid_dtype_test",
+        opset_imports=[helper.make_opsetid("", 18), 
helper.make_opsetid("com.microsoft", 1)],
+    )
+    model.ir_version = 11
+
+    with pytest.raises(ValueError, match="input A"):
+        from_onnx(model, opset=18, keep_params_in_input=True)
+
+
 def test_concat():
     verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})
 
@@ -3176,6 +3330,21 @@ def make_constant_node(name: str, data_type: int, dims: 
list[int], vals: list[in
     )
 
 
+def make_optional_tensor_value_info(name: str, elem_type: int, shape: 
list[int]):
+    return helper.make_value_info(
+        name, 
helper.make_optional_type_proto(helper.make_tensor_type_proto(elem_type, shape))
+    )
+
+
+def make_optional_sequence_value_info(name: str, elem_type: int, shape: 
list[int]):
+    return helper.make_value_info(
+        name,
+        helper.make_optional_type_proto(
+            
helper.make_sequence_type_proto(helper.make_tensor_type_proto(elem_type, shape))
+        ),
+    )
+
+
 def test_sequence_construct():
     node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=2)
     graph = helper.make_graph(
@@ -3287,6 +3456,180 @@ def test_sequence_at():
     check_correctness(model)
 
 
+def test_optional_get_element_tensor():
+    x_shape = [2, 3]
+    optional_node = helper.make_node("Optional", ["x"], ["optional"])
+    get_element_node = helper.make_node("OptionalGetElement", ["optional"], 
["output"])
+    graph = helper.make_graph(
+        [optional_node, get_element_node],
+        "test_optional_get_element_tensor",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
x_shape)],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
x_shape)],
+        value_info=[make_optional_tensor_value_info("optional", 
TensorProto.FLOAT, x_shape)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_get_element_tensor")
+    check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_has_element_tensor():
+    x_shape = [2, 3]
+    optional_node = helper.make_node("Optional", ["x"], ["optional"])
+    has_element_node = helper.make_node("OptionalHasElement", ["optional"], 
["output"])
+    graph = helper.make_graph(
+        [optional_node, has_element_node],
+        "test_optional_has_element_tensor",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
x_shape)],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, 
[])],
+        value_info=[make_optional_tensor_value_info("optional", 
TensorProto.FLOAT, x_shape)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_has_element_tensor")
+    check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_has_element_empty():
+    x_shape = [2, 3]
+    tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape)
+    optional_type = helper.make_optional_type_proto(tensor_type)
+    optional_node = helper.make_node("Optional", [], ["optional"], 
type=tensor_type)
+    has_element_node = helper.make_node("OptionalHasElement", ["optional"], 
["output"])
+    graph = helper.make_graph(
+        [optional_node, has_element_node],
+        "test_optional_has_element_empty",
+        inputs=[],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, 
[])],
+        value_info=[helper.make_value_info("optional", optional_type)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_has_element_empty")
+    check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_has_element_empty_ir():
+    x_shape = [2, 3]
+    tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape)
+    optional_type = helper.make_optional_type_proto(tensor_type)
+    optional_node = helper.make_node("Optional", [], ["optional"], 
type=tensor_type)
+    has_element_node = helper.make_node("OptionalHasElement", ["optional"], 
["output"])
+    graph = helper.make_graph(
+        [optional_node, has_element_node],
+        "test_optional_has_element_empty_ir",
+        inputs=[],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, 
[])],
+        value_info=[helper.make_value_info("optional", optional_type)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_has_element_empty_ir")
+    model.ir_version = 11
+    model.opset_import[0].version = 18
+    tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+
+    assert collect_relax_call_ops(tvm_model["main"]) == []
+    assert False in collect_scalar_constants(tvm_model["main"])
+
+
+def test_optional_get_element_tensor_ir():
+    x_shape = [2, 3]
+    optional_node = helper.make_node("Optional", ["x"], ["optional"])
+    get_element_node = helper.make_node("OptionalGetElement", ["optional"], 
["output"])
+    graph = helper.make_graph(
+        [optional_node, get_element_node],
+        "test_optional_get_element_tensor_ir",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, 
x_shape)],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
x_shape)],
+        value_info=[make_optional_tensor_value_info("optional", 
TensorProto.FLOAT, x_shape)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_get_element_tensor_ir")
+    model.ir_version = 11
+    model.opset_import[0].version = 18
+    tvm_model = from_onnx(model, opset=18, keep_params_in_input=True)
+
+    assert collect_relax_call_ops(tvm_model["main"]) == []
+    assert tvm_model["main"].ret_struct_info.dtype == "float32"
+
+
+def test_optional_get_element_sequence():
+    seq_node, graph_inputs = construct_sequence(input_shape=[32, 32], 
num_tensors=4)
+    index = make_constant_node("index", TensorProto.INT64, (), [1])
+    optional_node = helper.make_node("Optional", ["sequence"], ["optional"])
+    get_element_node = helper.make_node("OptionalGetElement", ["optional"], 
["unwrapped"])
+    sequence_at_node = helper.make_node("SequenceAt", ["unwrapped", "index"], 
["output"])
+    graph = helper.make_graph(
+        [index, seq_node, optional_node, get_element_node, sequence_at_node],
+        "test_optional_get_element_sequence",
+        inputs=graph_inputs,
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[32, 32])],
+        value_info=[make_optional_sequence_value_info("optional", 
TensorProto.FLOAT, [32, 32])],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_get_element_sequence")
+    check_correctness(model, opset=18, ir_version=11)
+
+
+def test_optional_without_input_requires_type_attr():
+    tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3])
+    optional_type = helper.make_optional_type_proto(tensor_type)
+    optional_node = helper.make_node("Optional", [], ["optional"])
+    graph = helper.make_graph(
+        [optional_node],
+        "test_optional_without_input_requires_type_attr",
+        inputs=[],
+        outputs=[helper.make_value_info("optional", optional_type)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_without_input_requires_type_attr")
+    model.opset_import[0].version = 18
+
+    with pytest.raises(ValueError, match="type attribute"):
+        from_onnx(model, opset=18, keep_params_in_input=True)
+
+
+def test_empty_optional_graph_output_raises():
+    tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, [2, 3])
+    optional_type = helper.make_optional_type_proto(tensor_type)
+    optional_node = helper.make_node("Optional", [], ["optional"], 
type=tensor_type)
+    graph = helper.make_graph(
+        [optional_node],
+        "test_empty_optional_graph_output_raises",
+        inputs=[],
+        outputs=[helper.make_value_info("optional", optional_type)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_empty_optional_graph_output_raises")
+    model.opset_import[0].version = 18
+
+    with pytest.raises(ValueError, match="Empty optional graph outputs are not 
supported"):
+        from_onnx(model, opset=18, keep_params_in_input=True)
+
+
+def test_optional_has_element_requires_one_input():
+    has_element_node = helper.make_node("OptionalHasElement", [], ["output"])
+    graph = helper.make_graph(
+        [has_element_node],
+        "test_optional_has_element_requires_one_input",
+        inputs=[],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.BOOL, 
[])],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_has_element_requires_one_input")
+    model.opset_import[0].version = 18
+
+    with pytest.raises(ValueError, match="expects one input"):
+        from_onnx(model, opset=18, keep_params_in_input=True)
+
+
+def test_optional_get_element_empty_raises():
+    x_shape = [2, 3]
+    tensor_type = helper.make_tensor_type_proto(TensorProto.FLOAT, x_shape)
+    optional_type = helper.make_optional_type_proto(tensor_type)
+    optional_node = helper.make_node("Optional", [], ["optional"], 
type=tensor_type)
+    get_element_node = helper.make_node("OptionalGetElement", ["optional"], 
["output"])
+    graph = helper.make_graph(
+        [optional_node, get_element_node],
+        "test_optional_get_element_empty_raises",
+        inputs=[],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
x_shape)],
+        value_info=[helper.make_value_info("optional", optional_type)],
+    )
+    model = helper.make_model(graph, 
producer_name="test_optional_get_element_empty_raises")
+    model.opset_import[0].version = 18
+    with pytest.raises(ValueError, match="empty optional"):
+        from_onnx(model, opset=18, keep_params_in_input=True)
+
+
 def test_symbolic_shape_deduction():
     index_node = helper.make_node(
         "Constant",

Reply via email to