Hello, I'm trying to partition a Relay graph into some functions and rewrite it 
but fail. Here's a minimum working example:

```python
import tvm
import tvm.relay as relay
from tvm.relay.dataflow_pattern import wildcard, is_op, rewrite, 
DFPatternCallback, FunctionPattern

class TestCallback(DFPatternCallback):
    def __init__(self):
        super(TestCallback, self).__init__()
        self.x = wildcard()
        self.y = wildcard()

        pattern = is_op('add')(self.x, self.y)
        pattern = FunctionPattern([wildcard(), wildcard()], pattern)
        self.pattern = pattern

    def callback(self, pre, post, node_map):
        print('here')
        x = node_map[self.x][0]
        y = node_map[self.y][0]
        return x - y

x = relay.var('x')
y = relay.var('y')
z = relay.var('z')
expr = (x + y) * z

p = wildcard() + wildcard()
fp = FunctionPattern([wildcard(), wildcard()], p)
print(expr)
expr_p = p.partition(expr)
print(expr_p)
expr_r = rewrite(TestCallback(), expr_p)
print(expr_r)
```

The third print statement print the same output with the second one as 
TestCallback fails to match the `add` op. Anyone can help?

Thanks in advance!





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/rewrite-a-function-in-a-relay-graph-failed/10173/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/0c21c702dfea80aa5e08bc4f36b29943d00f5c33760f16d10eb203a4969c9dec).

Reply via email to