Hello,

I recently stumbled over the fact that `reshape` is typically hard for TVM's 
common subexpression elimination pass to work with. This is because the target 
shape (which also comes in the attrs) can be a distinct (even if equal) tensor. 
In particular, converting reshape from, say, PyTorch, we have that all shape 
tensors are separate.
Fusing these (second inputs to reshape, provided they're constant, same 
device(?), same shape, same values) helps eliminate those.

My main use case is self-attention in transformers.

The pass I came up with in Python looks like this:

```python
class ShapeConstDedupMutator(tvm.relay.ExprMutator):
    def __init__(self):
        super().__init__()
        self.shape_consts = {}

    def visit_call(self, call):
        if (isinstance(call.op, tvm.ir.Op) and call.op.name == "reshape"
            and isinstance(call.args[1], tvm.relay.Constant)):
            assert list(call.attrs.newshape) == 
list(call.args[1].data.asnumpy())
            new_fn = self.visit(call.op)
            new_args = [self.visit(arg) for arg in call.args]
            const = new_args[1]
            assert const.data.dtype.startswith('int') and 
len(const.data.shape)==1
            key = tuple(const.data.asnumpy())
            if key in self.shape_consts:
                new_args[1] = self.shape_consts[key]
            else:
                self.shape_consts[key] = new_args[1]
            return tvm.relay.Call(new_fn, new_args, call.attrs)
        return super().visit_call(call)

@tvm.relay.transform.function_pass(opt_level=1)
def ShapeConstDedup(fn, mod, ctx):
    return ShapeConstDedupMutator().visit(fn)

new_mod = ShapeConstDedup(new_mod)
new_mod = tvm.relay.transform.EliminateCommonSubexpr()(new_mod)
```

Before I convert this to C++ submit a PR would this be of enough general 
interest to add to the TVM standard passes?

An alternative to doing this separately can be to adjust the eliminate common 
subexpression logic to allow same const-input shapes to be merged. (Maybe this 
is even preferable, I would love to hear your input on it.)

One of the reasons I'm not proposing to merge all same-value consts is that it 
my experience is that can be touchy in other parts if suddenly all consts "1" 
are the same thing.

Best regards

Thomas





---
[Visit 
Topic](https://discuss.tvm.ai/t/discuss-pass-for-merging-shape-tensors/6955/1) 
to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.ai/email/unsubscribe/e2c5ee05d35dc1730f60a963049fadf59d1727cb3cd682fc8d20f69b7d807f97).

Reply via email to