My colleague was working on the IF node and this should be fixed already. Have 
you tried the main branch with the latest commit?

This is the script I used:

```python
import tvm
from tvm import relay

def _register_external_op_helper(op_name, supported=True):

    @tvm.ir.register_op_attr(op_name, "target.special")
    def _func_wrapper(expr):
        return supported

    return _func_wrapper


_register_external_op_helper("add")
_register_external_op_helper("subtract")


# graph 1: true branch
x1 = relay.var('x1', shape=(10, 1))
y1 = relay.var('y1', shape=(10, 1))
f1 = relay.op.multiply(x1, y1)

x3 = relay.var('x3', shape=(10, 1))
y3 = relay.var('y3', shape=(10, 1))
f3 = relay.op.multiply(x3, y3)

true_branch = relay.op.add(f1, f3)

# graph 2: false branch
x2 = relay.var('x2', shape=(10, 1))
y2 = relay.var('y2', shape=(10, 1))
f2 = relay.op.add(x2, y2)

x4 = relay.var('x4', shape=(10, 1))
y4 = relay.var('y4', shape=(10, 1))
f4 = relay.op.add(x4, y4)

false_branch = relay.op.add(f2, f4)

cond = relay.var('c')
result = relay.If(cond, true_branch=true_branch, false_branch=false_branch)
f = relay.Function(relay.analysis.free_vars(result), result)


mod = tvm.IRModule({"main": f})
mod = relay.transform.AnnotateTarget(["special"])(mod)
mod = relay.transform.MergeCompilerRegions()(mod)
mod = relay.transform.PartitionGraph()(mod)
print(mod)
```

And this is the output, which looks good to me.

```
def @main(%c: bool, %x1: Tensor[(10, 1), float32], %y1: Tensor[(10, 1), 
float32], %x3: Tensor[(10, 1), float32], %y3: Tensor[(10, 1), float32], %x2: 
Tensor[(10, 1), float32], %y2: Tensor[(10, 1), float32], %x4: Tensor[(10, 1), 
float32], %y4: Tensor[(10, 1), float32]) -> Tensor[(10, 1), float32] {
  if (%c) {
    %0 = multiply(%x1, %y1) /* ty=Tensor[(10, 1), float32] */;
    %1 = multiply(%x3, %y3) /* ty=Tensor[(10, 1), float32] */;
    @special_0(%0, %1) /* ty=Tensor[(10, 1), float32] */
  } else {
    @special_2(%x2, %y2, %x4, %y4) /* ty=Tensor[(10, 1), float32] */
  }
}

def @special_0(%special_0_i0: Tensor[(10, 1), float32], %special_0_i1: 
Tensor[(10, 1), float32], global_symbol="special_0", Primitive=1, 
Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
  add(%special_0_i0, %special_0_i1) /* ty=Tensor[(10, 1), float32] */
}

def @special_2(%special_2_i0: Tensor[(10, 1), float32], %special_2_i1: 
Tensor[(10, 1), float32], %special_2_i2: Tensor[(10, 1), float32], 
%special_2_i3: Tensor[(10, 1), float32], global_symbol="special_2", 
Primitive=1, Compiler="special", Inline=1) -> Tensor[(10, 1), float32] {
  %2 = add(%special_2_i0, %special_2_i1) /* ty=Tensor[(10, 1), float32] */;
  %3 = add(%special_2_i2, %special_2_i3) /* ty=Tensor[(10, 1), float32] */;
  add(%2, %3) /* ty=Tensor[(10, 1), float32] */
}

```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/understanding-tvm-relays-partitiongraph-mod-function/8290/10)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/76a742fa681e7f65b46d9419e074f2983526f094a531c8b937629ea1fd09edcf).

Reply via email to