My colleague was working on the IF node and this should be fixed already. Have
you tried the main branch with the latest commit?
This is the script I used:
```python
import tvm
from tvm import relay
def _register_external_op_helper(op_name, supported=True):
@tvm.ir.register_op_attr(op_name, "target.special")
def _func_wrapper(expr):
return supported
return _func_wrapper
_register_external_op_helper("add")
_register_external_op_helper("subtract")
# graph 1: true branch
x1 = relay.var('x1', shape=(10, 1))
y1 = relay.var('y1', shape=(10, 1))
f1 = relay.op.multiply(x1, y1)
x3 = relay.var('x3', shape=(10, 1))
y3 = relay.var('y3', shape=(10, 1))
f3 = relay.op.multiply(x3, y3)
true_branch = relay.op.add(f1, f3)
# graph 2: false branch
x2 = relay.var('x2', shape=(10, 1))
y2 = relay.var('y2', shape=(10, 1))
f2 = relay.op.add(x2, y2)
x4 = relay.var('x4', shape=(10, 1))
y4 = relay.var('y4', shape=(10, 1))
f4 = relay.op.add(x4, y4)
false_branch = relay.op.add(f2, f4)
cond = relay.var('c')
result = relay.If(cond, true_branch=true_branch, false_branch=false_branch)
f = relay.Function(relay.analysis.free_vars(result), result)
mod = tvm.IRModule({"main": f})
mod = relay.transform.AnnotateTarget(["special"])(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.PartitionGraph()(mod)
print(mod)
```
And this is the output, which looks good to me.
```
def @main(%c: bool, %x1: Tensor[(10, 1), float32], %y1: Tensor[(10, 1),
float32], %x3: Tensor[(10, 1), float32], %y3: Tensor[(10, 1), float32], %x2:
Tensor[(10, 1), float32], %y2: Tensor[(10, 1), float32], %x4: Tensor[(10, 1),
float32], %y4: Tensor[(10, 1), float32]) -> Tensor[(10, 1), float32] {
if (%c) {
%0 = multiply(%x1, %y1) /* ty=Tensor[(10, 1), float32] */;
%1 = multiply(%x3, %y3) /* ty=Tensor[(10, 1), float32] */;
@special_0(%0, %1) /* ty=Tensor[(10, 1), float32] */
} else {
@special_2(%x2, %y2, %x4, %y4) /* ty=Tensor[(10, 1), float32] */
}
}
def @special_0(%special_0_i0: Tensor[(10, 1), float32], %special_0_i1:
Tensor[(10, 1), float32], global_symbol="special_0", Primitive=1,
Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
add(%special_0_i0, %special_0_i1) /* ty=Tensor[(10, 1), float32] */
}
def @special_2(%special_2_i0: Tensor[(10, 1), float32], %special_2_i1:
Tensor[(10, 1), float32], %special_2_i2: Tensor[(10, 1), float32],
%special_2_i3: Tensor[(10, 1), float32], global_symbol="special_2",
Primitive=1, Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
%2 = add(%special_2_i0, %special_2_i1) /* ty=Tensor[(10, 1), float32] */;
%3 = add(%special_2_i2, %special_2_i3) /* ty=Tensor[(10, 1), float32] */;
add(%2, %3) /* ty=Tensor[(10, 1), float32] */
}
```
---
[Visit
Topic](https://discuss.tvm.apache.org/t/understanding-tvm-relays-partitiongraph-mod-function/8290/10)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/76a742fa681e7f65b46d9419e074f2983526f094a531c8b937629ea1fd09edcf).