Hi All,

Hi @comaniac, I want to follow up with my above post. I removed the IF 
statement, and now it works. 
Is that mean there is some  MergeCompilerRegions does not fully support IF yet. 

This is the code that works.  
```
    # this is test case for graph type 1
    print("Graph type 1")

    # graph 1: true branch
    x1 = relay.var('x1', shape=(10, 1))
    y1 = relay.var('y1', shape=(10, 1))
    f1 = relay.op.multiply(x1, y1)

    x3 = relay.var('x3', shape=(10, 1))
    y3 = relay.var('y3', shape=(10, 1))
    f3 = relay.op.multiply(x3, y3)

    true_branch = relay.op.add(f1, f3)

    # graph 2: false branch
    x2 = relay.var('x2', shape=(10, 1))
    y2 = relay.var('y2', shape=(10, 1))
    f2 = relay.op.add(x2, y2)

    x4 = relay.var('x4', shape=(10, 1))
    y4 = relay.var('y4', shape=(10, 1))
    f4 = relay.op.add(x4, y4)

    false_branch = relay.op.add(f2, f4)

    cond = relay.var('c')
    #result = relay.If(cond, true_branch=true_branch, false_branch=false_branch)
    result = true_branch
    #f = relay.Function([], result)
    f = relay.Function(relay.analysis.free_vars(result), result)


    mod = tvm.IRModule({"main": f})
    mod = relay.transform.AnnotateTarget(["special"])(mod)  # Output: Figure 2
    mod = relay.transform.MergeCompilerRegions()(mod)
    mod = relay.transform.PartitionGraph()(mod)  # Output: Figure 4
```

This is the CODE that DOES NOT work.

```
    # this is test case for graph type 1
    print("Graph type 1")

    # graph 1: true branch
    x1 = relay.var('x1', shape=(10, 1))
    y1 = relay.var('y1', shape=(10, 1))
    f1 = relay.op.multiply(x1, y1)

    x3 = relay.var('x3', shape=(10, 1))
    y3 = relay.var('y3', shape=(10, 1))
    f3 = relay.op.multiply(x3, y3)

    true_branch = relay.op.add(f1, f3)

    # graph 2: false branch
    x2 = relay.var('x2', shape=(10, 1))
    y2 = relay.var('y2', shape=(10, 1))
    f2 = relay.op.add(x2, y2)

    x4 = relay.var('x4', shape=(10, 1))
    y4 = relay.var('y4', shape=(10, 1))
    f4 = relay.op.add(x4, y4)

    false_branch = relay.op.add(f2, f4)

    cond = relay.var('c')
    result = relay.If(cond, true_branch=true_branch, false_branch=false_branch)
    #result = true_branch
    #f = relay.Function([], result)
    f = relay.Function(relay.analysis.free_vars(result), result)


    mod = tvm.IRModule({"main": f})
    mod = relay.transform.AnnotateTarget(["special"])(mod)  # Output: Figure 2
    mod = relay.transform.MergeCompilerRegions()(mod)
    mod = relay.transform.PartitionGraph()(mod)  # Output: Figure 4
```





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/understanding-tvm-relays-partitiongraph-mod-function/8290/9)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/e91ea606f1229cc44e2856a71d7fe21d447cda90c8120b9be64fb6a036571b71).

Reply via email to