Hi All,
Hi @comaniac, I want to follow up with my above post. I removed the IF statement, and now it works. Is that mean there is some MergeCompilerRegions does not fully support IF yet. This is the code that works. ``` # this is test case for graph type 1 print("Graph type 1") # graph 1: true branch x1 = relay.var('x1', shape=(10, 1)) y1 = relay.var('y1', shape=(10, 1)) f1 = relay.op.multiply(x1, y1) x3 = relay.var('x3', shape=(10, 1)) y3 = relay.var('y3', shape=(10, 1)) f3 = relay.op.multiply(x3, y3) true_branch = relay.op.add(f1, f3) # graph 2: false branch x2 = relay.var('x2', shape=(10, 1)) y2 = relay.var('y2', shape=(10, 1)) f2 = relay.op.add(x2, y2) x4 = relay.var('x4', shape=(10, 1)) y4 = relay.var('y4', shape=(10, 1)) f4 = relay.op.add(x4, y4) false_branch = relay.op.add(f2, f4) cond = relay.var('c') #result = relay.If(cond, true_branch=true_branch, false_branch=false_branch) result = true_branch #f = relay.Function([], result) f = relay.Function(relay.analysis.free_vars(result), result) mod = tvm.IRModule({"main": f}) mod = relay.transform.AnnotateTarget(["special"])(mod) # Output: Figure 2 mod = relay.transform.MergeCompilerRegions()(mod) mod = relay.transform.PartitionGraph()(mod) # Output: Figure 4 ``` This is the CODE that DOES NOT work. ``` # this is test case for graph type 1 print("Graph type 1") # graph 1: true branch x1 = relay.var('x1', shape=(10, 1)) y1 = relay.var('y1', shape=(10, 1)) f1 = relay.op.multiply(x1, y1) x3 = relay.var('x3', shape=(10, 1)) y3 = relay.var('y3', shape=(10, 1)) f3 = relay.op.multiply(x3, y3) true_branch = relay.op.add(f1, f3) # graph 2: false branch x2 = relay.var('x2', shape=(10, 1)) y2 = relay.var('y2', shape=(10, 1)) f2 = relay.op.add(x2, y2) x4 = relay.var('x4', shape=(10, 1)) y4 = relay.var('y4', shape=(10, 1)) f4 = relay.op.add(x4, y4) false_branch = relay.op.add(f2, f4) cond = relay.var('c') result = relay.If(cond, true_branch=true_branch, false_branch=false_branch) #result = true_branch #f = relay.Function([], result) f = relay.Function(relay.analysis.free_vars(result), result) mod = tvm.IRModule({"main": f}) mod = relay.transform.AnnotateTarget(["special"])(mod) # Output: Figure 2 mod = relay.transform.MergeCompilerRegions()(mod) mod = relay.transform.PartitionGraph()(mod) # Output: Figure 4 ``` --- [Visit Topic](https://discuss.tvm.apache.org/t/understanding-tvm-relays-partitiongraph-mod-function/8290/9) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/e91ea606f1229cc44e2856a71d7fe21d447cda90c8120b9be64fb6a036571b71).