This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 52b5d55fc4 [Frontend][ONNX] Add If operator support to Relax ONNX 
frontend (#18946)
52b5d55fc4 is described below

commit 52b5d55fc4e629bed002b6e4e0383088a52ad17b
Author: Kryptonite <[email protected]>
AuthorDate: Sat Mar 28 20:25:47 2026 +0300

    [Frontend][ONNX] Add If operator support to Relax ONNX frontend (#18946)
    
    ### Summary
    
    This PR implements the ONNX `If` operator in the Relax ONNX frontend.
    The `If` operator enables conditional branching in ONNX models, where a
    boolean condition selects between two subgraph branches (`then_branch`
    and `else_branch`) at runtime. This is required for any model with
    runtime-dependent execution paths.
    
    Closes #18945 (Tier 1 — `If` operator)
    
    ### Implementation Notes
    
    - The main challenge is that `relax.If` cannot be emitted inside a
    dataflow block, which is how the ONNX frontend normally builds the
    entire graph. To handle this, when the graph contains an `If` node, the
    function body is built as a regular binding block instead — matching the
    approach used by the PyTorch Relax frontend for `torch.cond`.
    
    - Each branch is an ONNX subgraph that can reference values from the
    outer graph. A new `_convert_subgraph` method handles converting these
    subgraphs into Relax expressions, making outer-scope values available to
    the branch while ensuring branch-local bindings don't leak back to the
    parent graph.
    
    ### Why `relax.If` cannot live inside a dataflow block
    
    Dataflow blocks in Relax carry a semantic guarantee: every operation
    inside them must be pure and side-effect-free with no control flow. This
    allows the compiler to treat the entire block as a static computational
    graph for optimizations like operator fusion and constant folding. An
    `If` node breaks this guarantee by introducing runtime-dependent
    branching, so Relax's well-formedness checker explicitly forbids it. I
    discovered this when the checker raised:
    ```
    This IR is not well-formed: If nodes are not allowed to appear in dataflow 
blocks.
    ```
    
    The fix — skipping the dataflow block when the graph contains an `If`
    node — mirrors exactly how the PyTorch Relax frontend handles
    `torch.cond`.
    
    ### Known Limitations
    
    **Dataflow block**: Models whose top-level graph contains an `If` node
    are built without a dataflow block, which may affect downstream
    optimisation passes that rely on dataflow block structure.
    
    ### Tests
    
    Four new tests covering: scalar and tensor conditions, condition
    computed from another op, and multiple branch outputs. All verified
    against onnxruntime via `check_correctness`.
    
    ---------
    
    Signed-off-by: OmarAzizi <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 128 ++++++++++++--
 tests/python/relax/test_frontend_onnx.py        | 213 ++++++++++++++++++++++++
 2 files changed, 325 insertions(+), 16 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index a117317125..e56f975c62 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# ruff: noqa: E501, E731, E741, F841, RUF005
+# ruff: noqa: E501, E731, E741, RUF005
 """ONNX: Open Neural Network Exchange importer for Relax.
 
 This module implements the required functionality to read ONNX models
@@ -36,6 +36,7 @@ Not all TVM kernels currently support dynamic shapes, please 
file an issue on
 github.com/apache/tvm/issues if you hit an error with dynamic kernels.
 """
 
+import contextlib
 import functools
 import math
 import operator
@@ -2421,9 +2422,7 @@ class AffineGrid(OnnxOpConverter):
         align_corners = attr.get("align_corners", 0)
 
         if align_corners != 1:
-            raise NotImplementedError(
-                "AffineGrid with align_corners=0 is not yet supported in TVM"
-            )
+            raise NotImplementedError("AffineGrid with align_corners=0 is not 
yet supported in TVM")
 
         # Extract size values
         if isinstance(size, relax.Constant):
@@ -4178,7 +4177,6 @@ def _get_convert_map():
         "OneHot": OneHot,
         "Unique": Unique,
         "NonZero": NonZero,
-        # "If": If,
         # "MaxRoiPool": MaxRoiPool,
         "RoiAlign": RoiAlign,
         "NonMaxSuppression": NonMaxSuppression,
@@ -4254,25 +4252,30 @@ class ONNXGraphImporter:
         mod : tvm.IRModule
             The returned relax module
         """
+        has_if = any(node.op_type == "If" for node in graph.node)
+
         with self.bb.function("main"):
-            with self.bb.dataflow() as df:  # pylint: disable=invalid-name, 
unused-variable
+            with contextlib.ExitStack() as stack:
+                if not has_if:
+                    stack.enter_context(self.bb.dataflow())
+
                 self.opset = opset
                 self._parse_graph_initializers(graph)
                 self._parse_graph_input(graph)
                 self._check_for_unsupported_ops(graph)
                 self._construct_nodes(graph)
 
-                # now return the outputs
                 outputs = [self._nodes[self._parse_value_proto(i)] for i in 
graph.output]
                 outputs = outputs[0] if len(outputs) == 1 else 
relax.Tuple(outputs)
 
-                output_var = self.bb.emit_output(outputs)
+                if has_if:
+                    output_var = outputs
+                else:
+                    output_var = self.bb.emit_output(outputs)
 
-            # Create function attributes for this module
+            # ExitStack closes here — dataflow block is now closed
             func_attrs = {"num_input": self._num_input}
-            # Create a function from our output expression and all input 
variables.
             input_list = [value for value in self._inputs.values() if 
isinstance(value, relax.Var)]
-            # Attach params if they are available.
             if self._keep_params_in_input and self._params:
                 param_var_list, param_value_list = map(list, 
zip(*self._params.values()))
                 input_list = input_list + param_var_list
@@ -4281,7 +4284,6 @@ class ONNXGraphImporter:
             self.bb.emit_func_output(output_var, params=input_list)
 
         relax_mod = self.bb.get()
-        # Attach attributes.
         relax_mod["main"] = relax_mod["main"].with_attrs(func_attrs)
         return relax_mod
 
@@ -4369,12 +4371,15 @@ class ONNXGraphImporter:
 
     def _check_for_unsupported_ops(self, graph: onnx.onnx_ml_pb2.GraphProto):
         convert_map = _get_convert_map()
+        # Ops handled directly in _construct_nodes rather than via the 
converter map.
+        directly_handled_ops = {"If"}
         unsupported_ops = set()
         for node in graph.node:
             op_name = node.op_type
             if (
-                op_name not in convert_map and op_name != "Constant"
-                # and op_name not in _identity_list
+                op_name not in convert_map
+                and op_name not in directly_handled_ops
+                and op_name != "Constant"
             ):
                 unsupported_ops.add(op_name)
         if unsupported_ops:
@@ -4400,6 +4405,20 @@ class ONNXGraphImporter:
             attr["tvm_custom"]["name"] = i_name
             attr["tvm_custom"]["num_outputs"] = len(outputs)
 
+            if op_name == "If":
+                cond = inputs[0]
+                then_expr = self._convert_subgraph(self.bb, 
attr["then_branch"])
+                else_expr = self._convert_subgraph(self.bb, 
attr["else_branch"])
+                then_seq = relax.SeqExpr(blocks=[], body=then_expr)
+                else_seq = relax.SeqExpr(blocks=[], body=else_expr)
+                if_result = self.bb.emit(relax.If(cond, then_seq, else_seq))
+                if len(outputs) == 1:
+                    self._nodes[outputs[0]] = if_result
+                else:
+                    for i, k in enumerate(outputs):
+                        self._nodes[k] = 
self.bb.emit(relax.TupleGetItem(if_result, i))
+                continue
+
             # Perform special handling for shape expressions. If an input is a
             # shape expr, make sure the current op can handle it, otherwise
             # convert it to a tensor.
@@ -4462,7 +4481,7 @@ class ONNXGraphImporter:
             if outputs_num == 1:
                 self._nodes[outputs[0]] = op
             else:
-                for k, i in zip(list(outputs), range(len(outputs))):
+                for i, k in enumerate(outputs):
                     self._nodes[k] = op[i]
 
     def _parse_value_proto(self, value_proto: onnx.onnx_ml_pb2.GraphProto):
@@ -4497,7 +4516,8 @@ class ONNXGraphImporter:
                     attrs[a.name] = tuple(getattr(a, f))
             for f in ["graphs"]:
                 if list(getattr(a, f)):
-                    raise NotImplementedError(f"Field {f} is not supported in 
relax.")
+                    assert a.name not in attrs, "Only one type of attr is 
allowed"
+                    attrs[a.name] = tuple(getattr(a, f))
             if a.name not in attrs:
                 raise ValueError(f"Cannot parse attribute: \n{a}\n.")
         return attrs
@@ -4537,6 +4557,82 @@ class ONNXGraphImporter:
             raise NotImplementedError(f"Operator {op_name} not implemented.")
         return sym
 
+    def _convert_subgraph(self, bb, graph):
+        """
+        Walk an ONNX GraphProto (a branch body) and return a Relax SeqExpr.
+        Outer-scope nodes are visible because we copy self._nodes into the
+        local lookup table before processing.
+        """
+        outer_nodes = dict(self._nodes)
+
+        try:
+            for init_tensor in graph.initializer:
+                array = self._parse_array(init_tensor)
+                self._nodes[init_tensor.name] = relax.const(array)
+
+            for node in graph.node:
+                op_name = node.op_type
+                attr = self._parse_attr(node.attribute)
+
+                inputs = onnx_input()
+                for i in node.input:
+                    if i != "":
+                        inputs.append(self._nodes.get(i, outer_nodes.get(i)))
+                    else:
+                        inputs.append(None)
+
+                attr["tvm_custom"] = {}
+                attr["tvm_custom"]["name"] = node.name
+                attr["tvm_custom"]["num_outputs"] = len(node.output)
+
+                # Handle nested If recursively.
+                if op_name == "If":
+                    cond = inputs[0]
+                    then_expr = self._convert_subgraph(bb, attr["then_branch"])
+                    else_expr = self._convert_subgraph(bb, attr["else_branch"])
+                    then_seq = relax.SeqExpr(blocks=[], body=then_expr)
+                    else_seq = relax.SeqExpr(blocks=[], body=else_expr)
+                    op = bb.emit(relax.If(cond, then_seq, else_seq))
+                    outputs = node.output
+                    if len(outputs) == 1:
+                        self._nodes[outputs[0]] = op
+                    else:
+                        for i, k in enumerate(outputs):
+                            self._nodes[k] = bb.emit(relax.TupleGetItem(op, i))
+                    continue
+
+                op = self._convert_operator(op_name, inputs, attr, self.opset)
+                try:
+                    _ = op.struct_info
+                    has_struct_info = True
+                except tvm.error.InternalError:
+                    has_struct_info = False
+
+                if not has_struct_info:
+                    op = bb.normalize(op)
+
+                if not isinstance(op, relax.Tuple):
+                    if isinstance(op.struct_info, relax.TupleStructInfo):
+                        tuple_items = [
+                            relax.TupleGetItem(op, i) for i in 
range(len(op.struct_info.fields))
+                        ]
+                        op = relax.Tuple(tuple_items)
+
+                outputs = node.output
+                if len(outputs) == 1:
+                    self._nodes[outputs[0]] = op
+                else:
+                    for i, k in enumerate(outputs):
+                        self._nodes[k] = op[i]
+
+            branch_outputs = [self._nodes[o.name] for o in graph.output]
+            result = branch_outputs[0] if len(branch_outputs) == 1 else 
relax.Tuple(branch_outputs)
+
+            self._nodes = outer_nodes
+            return result
+        finally:
+            self._nodes = outer_nodes
+
 
 def from_onnx(
     model: onnx.onnx_ml_pb2.GraphProto,
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 887533f261..e2067bad23 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4203,5 +4203,218 @@ def test_roi_align(coordinate_transformation_mode, 
rois):
     check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
 
 
[email protected](
+    "cond_info, cond_true, cond_false",
+    [
+        (
+            helper.make_tensor_value_info("cond", TensorProto.BOOL, []),
+            np.array(True),
+            np.array(False),
+        ),
+        (
+            helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]),
+            np.array([True]),
+            np.array([False]),
+        ),
+    ],
+    ids=["scalar_condition", "tensor_condition"],
+)
+def test_if(cond_info, cond_true, cond_false):
+    """Test ONNX If operator with scalar and tensor bool conditions."""
+
+    x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+    result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, 
[3])
+
+    # then branch: x * 2.0
+    two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+    then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
+    then_out_info = helper.make_tensor_value_info("then_out", 
TensorProto.FLOAT, [3])
+    then_graph = helper.make_graph([then_mul], "then_graph", [], 
[then_out_info], initializer=[two])
+
+    # else branch: x * 3.0
+    three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+    else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
+    else_out_info = helper.make_tensor_value_info("else_out", 
TensorProto.FLOAT, [3])
+    else_graph = helper.make_graph(
+        [else_mul], "else_graph", [], [else_out_info], initializer=[three]
+    )
+
+    if_node = helper.make_node(
+        "If",
+        inputs=["cond"],
+        outputs=["result"],
+        then_branch=then_graph,
+        else_branch=else_graph,
+    )
+    main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info], 
[result_info])
+    model = helper.make_model(main_graph, 
opset_imports=[helper.make_opsetid("", 13)])
+
+    x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+
+    check_correctness(model, inputs={"cond": cond_true, "x": x_data})
+    check_correctness(model, inputs={"cond": cond_false, "x": x_data})
+
+
+def test_if_computed_condition():
+    """Test If where condition is computed from another op in the main 
graph."""
+    import numpy as np
+    from onnx import TensorProto, helper
+
+    x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+    result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, 
[3])
+
+    zero = helper.make_tensor("zero", TensorProto.FLOAT, [], [0.0])
+    reduce_node = helper.make_node(
+        "ReduceSum", ["x"], ["x_sum"], keepdims=0, noop_with_empty_axes=0
+    )
+    greater_node = helper.make_node("Greater", ["x_sum", "zero"], ["cond"])
+
+    two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+    then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
+    then_out_info = helper.make_tensor_value_info("then_out", 
TensorProto.FLOAT, [3])
+    then_graph = helper.make_graph([then_mul], "then_graph", [], 
[then_out_info], initializer=[two])
+
+    three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+    else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
+    else_out_info = helper.make_tensor_value_info("else_out", 
TensorProto.FLOAT, [3])
+    else_graph = helper.make_graph(
+        [else_mul], "else_graph", [], [else_out_info], initializer=[three]
+    )
+
+    if_node = helper.make_node(
+        "If", inputs=["cond"], outputs=["result"], then_branch=then_graph, 
else_branch=else_graph
+    )
+
+    main_graph = helper.make_graph(
+        [reduce_node, greater_node, if_node],
+        "if_computed_cond",
+        [x_info],
+        [result_info],
+        initializer=[zero],
+    )
+    model = helper.make_model(main_graph, 
opset_imports=[helper.make_opsetid("", 13)])
+
+    check_correctness(model, inputs={"x": np.array([1.0, 2.0, 3.0], 
dtype=np.float32)})
+    check_correctness(model, inputs={"x": np.array([-1.0, -2.0, -3.0], 
dtype=np.float32)})
+
+
+def test_if_multiple_outputs():
+    """Test If operator where branches return multiple outputs."""
+    import numpy as np
+    from onnx import TensorProto, helper
+
+    cond_info = helper.make_tensor_value_info("cond", TensorProto.BOOL, [])
+    x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+    out1_info = helper.make_tensor_value_info("out1", TensorProto.FLOAT, [3])
+    out2_info = helper.make_tensor_value_info("out2", TensorProto.FLOAT, [3])
+
+    two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+    three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+
+    then_mul1 = helper.make_node("Mul", ["x", "two"], ["then_out1"])
+    then_mul2 = helper.make_node("Mul", ["x", "three"], ["then_out2"])
+    then_o1 = helper.make_tensor_value_info("then_out1", TensorProto.FLOAT, 
[3])
+    then_o2 = helper.make_tensor_value_info("then_out2", TensorProto.FLOAT, 
[3])
+    then_graph = helper.make_graph(
+        [then_mul1, then_mul2], "then_graph", [], [then_o1, then_o2], 
initializer=[two, three]
+    )
+
+    four = helper.make_tensor("four", TensorProto.FLOAT, [1], [4.0])
+    five = helper.make_tensor("five", TensorProto.FLOAT, [1], [5.0])
+    else_mul1 = helper.make_node("Mul", ["x", "four"], ["else_out1"])
+    else_mul2 = helper.make_node("Mul", ["x", "five"], ["else_out2"])
+    else_o1 = helper.make_tensor_value_info("else_out1", TensorProto.FLOAT, 
[3])
+    else_o2 = helper.make_tensor_value_info("else_out2", TensorProto.FLOAT, 
[3])
+    else_graph = helper.make_graph(
+        [else_mul1, else_mul2], "else_graph", [], [else_o1, else_o2], 
initializer=[four, five]
+    )
+
+    if_node = helper.make_node(
+        "If",
+        inputs=["cond"],
+        outputs=["out1", "out2"],
+        then_branch=then_graph,
+        else_branch=else_graph,
+    )
+    main_graph = helper.make_graph(
+        [if_node], "if_multi_out", [cond_info, x_info], [out1_info, out2_info]
+    )
+    model = helper.make_model(main_graph, 
opset_imports=[helper.make_opsetid("", 13)])
+
+    x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+    check_correctness(model, inputs={"cond": np.array(True), "x": x_data})
+    check_correctness(model, inputs={"cond": np.array(False), "x": x_data})
+
+
+def test_if_nested():
+    """Test nested If operator inside a branch."""
+    import numpy as np
+    from onnx import TensorProto, helper
+
+    cond1_info = helper.make_tensor_value_info("cond1", TensorProto.BOOL, [])
+    cond2_info = helper.make_tensor_value_info("cond2", TensorProto.BOOL, [])
+    x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+    result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT, 
[3])
+
+    # Inner then: x * 2
+    two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+    inner_then_mul = helper.make_node("Mul", ["x", "two"], ["inner_then_out"])
+    inner_then_out_info = helper.make_tensor_value_info("inner_then_out", 
TensorProto.FLOAT, [3])
+    inner_then_graph = helper.make_graph(
+        [inner_then_mul], "inner_then", [], [inner_then_out_info], 
initializer=[two]
+    )
+
+    # Inner else: x * 3
+    three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+    inner_else_mul = helper.make_node("Mul", ["x", "three"], 
["inner_else_out"])
+    inner_else_out_info = helper.make_tensor_value_info("inner_else_out", 
TensorProto.FLOAT, [3])
+    inner_else_graph = helper.make_graph(
+        [inner_else_mul], "inner_else", [], [inner_else_out_info], 
initializer=[three]
+    )
+
+    # Outer then: nested If(cond2, x*2, x*3)
+    inner_if = helper.make_node(
+        "If",
+        inputs=["cond2"],
+        outputs=["outer_then_out"],
+        then_branch=inner_then_graph,
+        else_branch=inner_else_graph,
+    )
+    outer_then_out_info = helper.make_tensor_value_info("outer_then_out", 
TensorProto.FLOAT, [3])
+    outer_then_graph = helper.make_graph([inner_if], "outer_then", [], 
[outer_then_out_info])
+
+    # Outer else: x * 4
+    four = helper.make_tensor("four", TensorProto.FLOAT, [1], [4.0])
+    outer_else_mul = helper.make_node("Mul", ["x", "four"], ["outer_else_out"])
+    outer_else_out_info = helper.make_tensor_value_info("outer_else_out", 
TensorProto.FLOAT, [3])
+    outer_else_graph = helper.make_graph(
+        [outer_else_mul], "outer_else", [], [outer_else_out_info], 
initializer=[four]
+    )
+
+    outer_if = helper.make_node(
+        "If",
+        inputs=["cond1"],
+        outputs=["result"],
+        then_branch=outer_then_graph,
+        else_branch=outer_else_graph,
+    )
+    main_graph = helper.make_graph(
+        [outer_if], "nested_if", [cond1_info, cond2_info, x_info], 
[result_info]
+    )
+    model = helper.make_model(main_graph, 
opset_imports=[helper.make_opsetid("", 13)])
+
+    x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+    # cond1=True, cond2=True → x * 2
+    check_correctness(model, inputs={"cond1": np.array(True), "cond2": 
np.array(True), "x": x_data})
+    # cond1=True, cond2=False → x * 3
+    check_correctness(
+        model, inputs={"cond1": np.array(True), "cond2": np.array(False), "x": 
x_data}
+    )
+    # cond1=False → x * 4
+    check_correctness(
+        model, inputs={"cond1": np.array(False), "cond2": np.array(True), "x": 
x_data}
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to