This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 52b5d55fc4 [Frontend][ONNX] Add If operator support to Relax ONNX
frontend (#18946)
52b5d55fc4 is described below
commit 52b5d55fc4e629bed002b6e4e0383088a52ad17b
Author: Kryptonite <[email protected]>
AuthorDate: Sat Mar 28 20:25:47 2026 +0300
[Frontend][ONNX] Add If operator support to Relax ONNX frontend (#18946)
### Summary
This PR implements the ONNX `If` operator in the Relax ONNX frontend.
The `If` operator enables conditional branching in ONNX models, where a
boolean condition selects between two subgraph branches (`then_branch`
and `else_branch`) at runtime. This is required for any model with
runtime-dependent execution paths.
Closes #18945 (Tier 1 — `If` operator)
### Implementation Notes
- The main challenge is that `relax.If` cannot be emitted inside a
dataflow block, which is how the ONNX frontend normally builds the
entire graph. To handle this, when the graph contains an `If` node, the
function body is built as a regular binding block instead — matching the
approach used by the PyTorch Relax frontend for `torch.cond`.
- Each branch is an ONNX subgraph that can reference values from the
outer graph. A new `_convert_subgraph` method handles converting these
subgraphs into Relax expressions, making outer-scope values available to
the branch while ensuring branch-local bindings don't leak back to the
parent graph.
### Why `relax.If` cannot live inside a dataflow block
Dataflow blocks in Relax carry a semantic guarantee: every operation
inside them must be pure and side-effect-free with no control flow. This
allows the compiler to treat the entire block as a static computational
graph for optimizations like operator fusion and constant folding. An
`If` node breaks this guarantee by introducing runtime-dependent
branching, so Relax's well-formedness checker explicitly forbids it. I
discovered this when the checker raised:
```
This IR is not well-formed: If nodes are not allowed to appear in dataflow
blocks.
```
The fix — skipping the dataflow block when the graph contains an `If`
node — mirrors exactly how the PyTorch Relax frontend handles
`torch.cond`.
### Known Limitations
**Dataflow block**: Models whose top-level graph contains an `If` node
are built without a dataflow block, which may affect downstream
optimisation passes that rely on dataflow block structure.
### Tests
Four new tests covering: scalar and tensor conditions, condition
computed from another op, and multiple branch outputs. All verified
against onnxruntime via `check_correctness`.
---------
Signed-off-by: OmarAzizi <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 128 ++++++++++++--
tests/python/relax/test_frontend_onnx.py | 213 ++++++++++++++++++++++++
2 files changed, 325 insertions(+), 16 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index a117317125..e56f975c62 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: E501, E731, E741, F841, RUF005
+# ruff: noqa: E501, E731, E741, RUF005
"""ONNX: Open Neural Network Exchange importer for Relax.
This module implements the required functionality to read ONNX models
@@ -36,6 +36,7 @@ Not all TVM kernels currently support dynamic shapes, please
file an issue on
github.com/apache/tvm/issues if you hit an error with dynamic kernels.
"""
+import contextlib
import functools
import math
import operator
@@ -2421,9 +2422,7 @@ class AffineGrid(OnnxOpConverter):
align_corners = attr.get("align_corners", 0)
if align_corners != 1:
- raise NotImplementedError(
- "AffineGrid with align_corners=0 is not yet supported in TVM"
- )
+ raise NotImplementedError("AffineGrid with align_corners=0 is not
yet supported in TVM")
# Extract size values
if isinstance(size, relax.Constant):
@@ -4178,7 +4177,6 @@ def _get_convert_map():
"OneHot": OneHot,
"Unique": Unique,
"NonZero": NonZero,
- # "If": If,
# "MaxRoiPool": MaxRoiPool,
"RoiAlign": RoiAlign,
"NonMaxSuppression": NonMaxSuppression,
@@ -4254,25 +4252,30 @@ class ONNXGraphImporter:
mod : tvm.IRModule
The returned relax module
"""
+ has_if = any(node.op_type == "If" for node in graph.node)
+
with self.bb.function("main"):
- with self.bb.dataflow() as df: # pylint: disable=invalid-name,
unused-variable
+ with contextlib.ExitStack() as stack:
+ if not has_if:
+ stack.enter_context(self.bb.dataflow())
+
self.opset = opset
self._parse_graph_initializers(graph)
self._parse_graph_input(graph)
self._check_for_unsupported_ops(graph)
self._construct_nodes(graph)
- # now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in
graph.output]
outputs = outputs[0] if len(outputs) == 1 else
relax.Tuple(outputs)
- output_var = self.bb.emit_output(outputs)
+ if has_if:
+ output_var = outputs
+ else:
+ output_var = self.bb.emit_output(outputs)
- # Create function attributes for this module
+ # ExitStack closes here — dataflow block is now closed
func_attrs = {"num_input": self._num_input}
- # Create a function from our output expression and all input
variables.
input_list = [value for value in self._inputs.values() if
isinstance(value, relax.Var)]
- # Attach params if they are available.
if self._keep_params_in_input and self._params:
param_var_list, param_value_list = map(list,
zip(*self._params.values()))
input_list = input_list + param_var_list
@@ -4281,7 +4284,6 @@ class ONNXGraphImporter:
self.bb.emit_func_output(output_var, params=input_list)
relax_mod = self.bb.get()
- # Attach attributes.
relax_mod["main"] = relax_mod["main"].with_attrs(func_attrs)
return relax_mod
@@ -4369,12 +4371,15 @@ class ONNXGraphImporter:
def _check_for_unsupported_ops(self, graph: onnx.onnx_ml_pb2.GraphProto):
convert_map = _get_convert_map()
+ # Ops handled directly in _construct_nodes rather than via the
converter map.
+ directly_handled_ops = {"If"}
unsupported_ops = set()
for node in graph.node:
op_name = node.op_type
if (
- op_name not in convert_map and op_name != "Constant"
- # and op_name not in _identity_list
+ op_name not in convert_map
+ and op_name not in directly_handled_ops
+ and op_name != "Constant"
):
unsupported_ops.add(op_name)
if unsupported_ops:
@@ -4400,6 +4405,20 @@ class ONNXGraphImporter:
attr["tvm_custom"]["name"] = i_name
attr["tvm_custom"]["num_outputs"] = len(outputs)
+ if op_name == "If":
+ cond = inputs[0]
+ then_expr = self._convert_subgraph(self.bb,
attr["then_branch"])
+ else_expr = self._convert_subgraph(self.bb,
attr["else_branch"])
+ then_seq = relax.SeqExpr(blocks=[], body=then_expr)
+ else_seq = relax.SeqExpr(blocks=[], body=else_expr)
+ if_result = self.bb.emit(relax.If(cond, then_seq, else_seq))
+ if len(outputs) == 1:
+ self._nodes[outputs[0]] = if_result
+ else:
+ for i, k in enumerate(outputs):
+ self._nodes[k] =
self.bb.emit(relax.TupleGetItem(if_result, i))
+ continue
+
# Perform special handling for shape expressions. If an input is a
# shape expr, make sure the current op can handle it, otherwise
# convert it to a tensor.
@@ -4462,7 +4481,7 @@ class ONNXGraphImporter:
if outputs_num == 1:
self._nodes[outputs[0]] = op
else:
- for k, i in zip(list(outputs), range(len(outputs))):
+ for i, k in enumerate(outputs):
self._nodes[k] = op[i]
def _parse_value_proto(self, value_proto: onnx.onnx_ml_pb2.GraphProto):
@@ -4497,7 +4516,8 @@ class ONNXGraphImporter:
attrs[a.name] = tuple(getattr(a, f))
for f in ["graphs"]:
if list(getattr(a, f)):
- raise NotImplementedError(f"Field {f} is not supported in
relax.")
+ assert a.name not in attrs, "Only one type of attr is
allowed"
+ attrs[a.name] = tuple(getattr(a, f))
if a.name not in attrs:
raise ValueError(f"Cannot parse attribute: \n{a}\n.")
return attrs
@@ -4537,6 +4557,82 @@ class ONNXGraphImporter:
raise NotImplementedError(f"Operator {op_name} not implemented.")
return sym
+ def _convert_subgraph(self, bb, graph):
+ """
+ Walk an ONNX GraphProto (a branch body) and return a Relax SeqExpr.
+ Outer-scope nodes are visible because we copy self._nodes into the
+ local lookup table before processing.
+ """
+ outer_nodes = dict(self._nodes)
+
+ try:
+ for init_tensor in graph.initializer:
+ array = self._parse_array(init_tensor)
+ self._nodes[init_tensor.name] = relax.const(array)
+
+ for node in graph.node:
+ op_name = node.op_type
+ attr = self._parse_attr(node.attribute)
+
+ inputs = onnx_input()
+ for i in node.input:
+ if i != "":
+ inputs.append(self._nodes.get(i, outer_nodes.get(i)))
+ else:
+ inputs.append(None)
+
+ attr["tvm_custom"] = {}
+ attr["tvm_custom"]["name"] = node.name
+ attr["tvm_custom"]["num_outputs"] = len(node.output)
+
+ # Handle nested If recursively.
+ if op_name == "If":
+ cond = inputs[0]
+ then_expr = self._convert_subgraph(bb, attr["then_branch"])
+ else_expr = self._convert_subgraph(bb, attr["else_branch"])
+ then_seq = relax.SeqExpr(blocks=[], body=then_expr)
+ else_seq = relax.SeqExpr(blocks=[], body=else_expr)
+ op = bb.emit(relax.If(cond, then_seq, else_seq))
+ outputs = node.output
+ if len(outputs) == 1:
+ self._nodes[outputs[0]] = op
+ else:
+ for i, k in enumerate(outputs):
+ self._nodes[k] = bb.emit(relax.TupleGetItem(op, i))
+ continue
+
+ op = self._convert_operator(op_name, inputs, attr, self.opset)
+ try:
+ _ = op.struct_info
+ has_struct_info = True
+ except tvm.error.InternalError:
+ has_struct_info = False
+
+ if not has_struct_info:
+ op = bb.normalize(op)
+
+ if not isinstance(op, relax.Tuple):
+ if isinstance(op.struct_info, relax.TupleStructInfo):
+ tuple_items = [
+ relax.TupleGetItem(op, i) for i in
range(len(op.struct_info.fields))
+ ]
+ op = relax.Tuple(tuple_items)
+
+ outputs = node.output
+ if len(outputs) == 1:
+ self._nodes[outputs[0]] = op
+ else:
+ for i, k in enumerate(outputs):
+ self._nodes[k] = op[i]
+
+ branch_outputs = [self._nodes[o.name] for o in graph.output]
+ result = branch_outputs[0] if len(branch_outputs) == 1 else
relax.Tuple(branch_outputs)
+
+ self._nodes = outer_nodes
+ return result
+ finally:
+ self._nodes = outer_nodes
+
def from_onnx(
model: onnx.onnx_ml_pb2.GraphProto,
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 887533f261..e2067bad23 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -4203,5 +4203,218 @@ def test_roi_align(coordinate_transformation_mode,
rois):
check_correctness(model, inputs=inputs, opset=16, rtol=1e-5, atol=1e-5)
[email protected](
+ "cond_info, cond_true, cond_false",
+ [
+ (
+ helper.make_tensor_value_info("cond", TensorProto.BOOL, []),
+ np.array(True),
+ np.array(False),
+ ),
+ (
+ helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]),
+ np.array([True]),
+ np.array([False]),
+ ),
+ ],
+ ids=["scalar_condition", "tensor_condition"],
+)
+def test_if(cond_info, cond_true, cond_false):
+ """Test ONNX If operator with scalar and tensor bool conditions."""
+
+ x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+ result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT,
[3])
+
+ # then branch: x * 2.0
+ two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+ then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
+ then_out_info = helper.make_tensor_value_info("then_out",
TensorProto.FLOAT, [3])
+ then_graph = helper.make_graph([then_mul], "then_graph", [],
[then_out_info], initializer=[two])
+
+ # else branch: x * 3.0
+ three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+ else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
+ else_out_info = helper.make_tensor_value_info("else_out",
TensorProto.FLOAT, [3])
+ else_graph = helper.make_graph(
+ [else_mul], "else_graph", [], [else_out_info], initializer=[three]
+ )
+
+ if_node = helper.make_node(
+ "If",
+ inputs=["cond"],
+ outputs=["result"],
+ then_branch=then_graph,
+ else_branch=else_graph,
+ )
+ main_graph = helper.make_graph([if_node], "if_test", [cond_info, x_info],
[result_info])
+ model = helper.make_model(main_graph,
opset_imports=[helper.make_opsetid("", 13)])
+
+ x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+
+ check_correctness(model, inputs={"cond": cond_true, "x": x_data})
+ check_correctness(model, inputs={"cond": cond_false, "x": x_data})
+
+
+def test_if_computed_condition():
+ """Test If where condition is computed from another op in the main
graph."""
+ import numpy as np
+ from onnx import TensorProto, helper
+
+ x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+ result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT,
[3])
+
+ zero = helper.make_tensor("zero", TensorProto.FLOAT, [], [0.0])
+ reduce_node = helper.make_node(
+ "ReduceSum", ["x"], ["x_sum"], keepdims=0, noop_with_empty_axes=0
+ )
+ greater_node = helper.make_node("Greater", ["x_sum", "zero"], ["cond"])
+
+ two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+ then_mul = helper.make_node("Mul", ["x", "two"], ["then_out"])
+ then_out_info = helper.make_tensor_value_info("then_out",
TensorProto.FLOAT, [3])
+ then_graph = helper.make_graph([then_mul], "then_graph", [],
[then_out_info], initializer=[two])
+
+ three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+ else_mul = helper.make_node("Mul", ["x", "three"], ["else_out"])
+ else_out_info = helper.make_tensor_value_info("else_out",
TensorProto.FLOAT, [3])
+ else_graph = helper.make_graph(
+ [else_mul], "else_graph", [], [else_out_info], initializer=[three]
+ )
+
+ if_node = helper.make_node(
+ "If", inputs=["cond"], outputs=["result"], then_branch=then_graph,
else_branch=else_graph
+ )
+
+ main_graph = helper.make_graph(
+ [reduce_node, greater_node, if_node],
+ "if_computed_cond",
+ [x_info],
+ [result_info],
+ initializer=[zero],
+ )
+ model = helper.make_model(main_graph,
opset_imports=[helper.make_opsetid("", 13)])
+
+ check_correctness(model, inputs={"x": np.array([1.0, 2.0, 3.0],
dtype=np.float32)})
+ check_correctness(model, inputs={"x": np.array([-1.0, -2.0, -3.0],
dtype=np.float32)})
+
+
+def test_if_multiple_outputs():
+ """Test If operator where branches return multiple outputs."""
+ import numpy as np
+ from onnx import TensorProto, helper
+
+ cond_info = helper.make_tensor_value_info("cond", TensorProto.BOOL, [])
+ x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+ out1_info = helper.make_tensor_value_info("out1", TensorProto.FLOAT, [3])
+ out2_info = helper.make_tensor_value_info("out2", TensorProto.FLOAT, [3])
+
+ two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+ three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+
+ then_mul1 = helper.make_node("Mul", ["x", "two"], ["then_out1"])
+ then_mul2 = helper.make_node("Mul", ["x", "three"], ["then_out2"])
+ then_o1 = helper.make_tensor_value_info("then_out1", TensorProto.FLOAT,
[3])
+ then_o2 = helper.make_tensor_value_info("then_out2", TensorProto.FLOAT,
[3])
+ then_graph = helper.make_graph(
+ [then_mul1, then_mul2], "then_graph", [], [then_o1, then_o2],
initializer=[two, three]
+ )
+
+ four = helper.make_tensor("four", TensorProto.FLOAT, [1], [4.0])
+ five = helper.make_tensor("five", TensorProto.FLOAT, [1], [5.0])
+ else_mul1 = helper.make_node("Mul", ["x", "four"], ["else_out1"])
+ else_mul2 = helper.make_node("Mul", ["x", "five"], ["else_out2"])
+ else_o1 = helper.make_tensor_value_info("else_out1", TensorProto.FLOAT,
[3])
+ else_o2 = helper.make_tensor_value_info("else_out2", TensorProto.FLOAT,
[3])
+ else_graph = helper.make_graph(
+ [else_mul1, else_mul2], "else_graph", [], [else_o1, else_o2],
initializer=[four, five]
+ )
+
+ if_node = helper.make_node(
+ "If",
+ inputs=["cond"],
+ outputs=["out1", "out2"],
+ then_branch=then_graph,
+ else_branch=else_graph,
+ )
+ main_graph = helper.make_graph(
+ [if_node], "if_multi_out", [cond_info, x_info], [out1_info, out2_info]
+ )
+ model = helper.make_model(main_graph,
opset_imports=[helper.make_opsetid("", 13)])
+
+ x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+ check_correctness(model, inputs={"cond": np.array(True), "x": x_data})
+ check_correctness(model, inputs={"cond": np.array(False), "x": x_data})
+
+
+def test_if_nested():
+ """Test nested If operator inside a branch."""
+ import numpy as np
+ from onnx import TensorProto, helper
+
+ cond1_info = helper.make_tensor_value_info("cond1", TensorProto.BOOL, [])
+ cond2_info = helper.make_tensor_value_info("cond2", TensorProto.BOOL, [])
+ x_info = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3])
+ result_info = helper.make_tensor_value_info("result", TensorProto.FLOAT,
[3])
+
+ # Inner then: x * 2
+ two = helper.make_tensor("two", TensorProto.FLOAT, [1], [2.0])
+ inner_then_mul = helper.make_node("Mul", ["x", "two"], ["inner_then_out"])
+ inner_then_out_info = helper.make_tensor_value_info("inner_then_out",
TensorProto.FLOAT, [3])
+ inner_then_graph = helper.make_graph(
+ [inner_then_mul], "inner_then", [], [inner_then_out_info],
initializer=[two]
+ )
+
+ # Inner else: x * 3
+ three = helper.make_tensor("three", TensorProto.FLOAT, [1], [3.0])
+ inner_else_mul = helper.make_node("Mul", ["x", "three"],
["inner_else_out"])
+ inner_else_out_info = helper.make_tensor_value_info("inner_else_out",
TensorProto.FLOAT, [3])
+ inner_else_graph = helper.make_graph(
+ [inner_else_mul], "inner_else", [], [inner_else_out_info],
initializer=[three]
+ )
+
+ # Outer then: nested If(cond2, x*2, x*3)
+ inner_if = helper.make_node(
+ "If",
+ inputs=["cond2"],
+ outputs=["outer_then_out"],
+ then_branch=inner_then_graph,
+ else_branch=inner_else_graph,
+ )
+ outer_then_out_info = helper.make_tensor_value_info("outer_then_out",
TensorProto.FLOAT, [3])
+ outer_then_graph = helper.make_graph([inner_if], "outer_then", [],
[outer_then_out_info])
+
+ # Outer else: x * 4
+ four = helper.make_tensor("four", TensorProto.FLOAT, [1], [4.0])
+ outer_else_mul = helper.make_node("Mul", ["x", "four"], ["outer_else_out"])
+ outer_else_out_info = helper.make_tensor_value_info("outer_else_out",
TensorProto.FLOAT, [3])
+ outer_else_graph = helper.make_graph(
+ [outer_else_mul], "outer_else", [], [outer_else_out_info],
initializer=[four]
+ )
+
+ outer_if = helper.make_node(
+ "If",
+ inputs=["cond1"],
+ outputs=["result"],
+ then_branch=outer_then_graph,
+ else_branch=outer_else_graph,
+ )
+ main_graph = helper.make_graph(
+ [outer_if], "nested_if", [cond1_info, cond2_info, x_info],
[result_info]
+ )
+ model = helper.make_model(main_graph,
opset_imports=[helper.make_opsetid("", 13)])
+
+ x_data = np.array([1.0, 2.0, 3.0], dtype=np.float32)
+ # cond1=True, cond2=True → x * 2
+ check_correctness(model, inputs={"cond1": np.array(True), "cond2":
np.array(True), "x": x_data})
+ # cond1=True, cond2=False → x * 3
+ check_correctness(
+ model, inputs={"cond1": np.array(True), "cond2": np.array(False), "x":
x_data}
+ )
+ # cond1=False → x * 4
+ check_correctness(
+ model, inputs={"cond1": np.array(False), "cond2": np.array(True), "x":
x_data}
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()