So with the following rewrites and passes
```python
class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
def __init__(self):
self.zeros =
tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard())
self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern = (self.zeros + self.other_tensor) | (self.other_tensor +
self.zeros)
def callback(self, pre, post, node_map):
rt = node_map[self.pattern][0]
ot = node_map[self.other_tensor][0]
if (ot._checked_type_ == rt._checked_type_):
return ot
else:
return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
class ZeroZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
def __init__(self):
self.ones =
tvm.relay.dataflow_pattern.is_op("zeros")(tvm.relay.dataflow_pattern.wildcard())
| tvm.relay.dataflow_pattern.is_constant()
self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern = (self.ones + self.other_tensor) | (self.other_tensor +
self.ones)
def callback(self, pre, post, node_map):
rt = node_map[self.pattern][0]
ones = node_map[self.ones][0]
ot = node_map[self.other_tensor][0]
if isinstance(ot, tvm.relay.Constant):
if not all(ones.data.asnumpy() == 0):
return rt
# I don't know why I don't reliably get checked types here...
if (((rt._checked_type_ is not None) and (ot._checked_type_ ==
rt._checked_type_))
or (rt.type_args[0] == rt.type_args[1])):
return ot
elif (rt._checked_type_ is not None):
return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
return rt
class OneZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
def __init__(self):
self.ones =
tvm.relay.dataflow_pattern.is_op("ones")(tvm.relay.dataflow_pattern.wildcard())
| tvm.relay.dataflow_pattern.is_constant()
self.other_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern = (self.ones * self.other_tensor) | (self.other_tensor *
self.ones)
def callback(self, pre, post, node_map):
rt = node_map[self.pattern][0]
ones = node_map[self.ones][0]
ot = node_map[self.other_tensor][0]
if isinstance(ot, tvm.relay.Constant):
if not all(ones.data.asnumpy() == 1):
return rt
if (ot._checked_type_ == rt._checked_type_):
return ot
else:
return tvm.relay.broadcast_to(ot, list(rt._checked_type_.shape))
class LikeZapp(tvm.relay.dataflow_pattern.DFPatternCallback):
def __init__(self):
self.translations_with_dt = {'zeros_like': tvm.relay.zeros,
'ones_like': tvm.relay.ones}
self.data_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern_tensor = tvm.relay.dataflow_pattern.wildcard()
self.pattern = ((tvm.relay.dataflow_pattern.is_op("zeros_like")
| tvm.relay.dataflow_pattern.is_op("ones_like")
)(self.data_tensor)
) | ((
tvm.relay.dataflow_pattern.is_op("collapse_sum_like")
| tvm.relay.dataflow_pattern.is_op("broadcast_to_like")
)(self.data_tensor, self.pattern_tensor))
def callback(self, pre, post, node_map):
data = node_map[self.data_tensor][0]
res = node_map[self.pattern][0]
if res.op.name in self.translations_with_dt:
return
self.translations_with_dt[res.op.name](list(res._checked_type_.shape),
res._checked_type_.dtype)
if (data._checked_type_ == res._checked_type_):
return data
else:
if res.op.name == 'broadcast_to_like':
return tvm.relay.broadcast_to(data,
list(res._checked_type_.shape))
return res
grmod["main"] = tvm.relay.dataflow_pattern.rewrite(LikeZapp(),
grmod["main"])
grmod = tvm.relay.transform.FoldConstant()(grmod)
grmod = tvm.relay.transform.InferType()(grmod)
grmod["main"] = tvm.relay.dataflow_pattern.rewrite(ZeroZapp(),
grmod["main"])
grmod["main"] = tvm.relay.dataflow_pattern.rewrite(OneZapp(), grmod["main"])
```
I get what looks realistic:

But this is just a trivial case and if you had a hint whether some of these
patterns are readily available, I would be most grateful.
Also I don't have an idea why I don't reliably get `_checked_shape_` attributes
in the ZeroZapp... If you have an idea...
Best regards
Thomas
---
[Visit Topic](https://discuss.tvm.ai/t/same-shape-pattern/7012/4) to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.ai/email/unsubscribe/edf1a3df04c4e401b4a9bae6d2453278776471a7e08ef58e6b9932a4d4cbb1f8).