This is the follow-up issue for
https://discuss.tvm.ai/t/rfc-functionality-of-alteroplayout-and-possible-refactoring/
To enhance the AlterOpLayout pass, I would like to propose 4 more passes to
replace current AlterOpLayout pass,
- [ ] Layout inference pass
To infer the layout of each layer.
example:
https://github.com/yzhliu/tvm-1/blob/refactor_alter_layout/src/relay/op/nn/convolution.cc#L144
https://github.com/yzhliu/tvm-1/blob/refactor_alter_layout/tests/python/relay/test_pass_infer_layout.py
- [ ] Rewrite operator pass
This pass is to rewrite the operator to another (set of) operator(s), while
the shape/dtype/layout need to remain the same for input and output. API,
```python
@conv2d_rewrite_op.register("cpu")
def _rewrite_conv2d(attrs, inputs, tinfo):
```
This can be used to convert
(NCHW) -> conv2d -> (NCHW)
to
(NCHW) -> LT(NCHW->NCHW16c) -> conv2d_NCHW16c -> LT(NCHW16c -> NCHW) -> (NCHW)
- [ ] Populate layout pass
This pass is to convert other operators to use the layout of its previous
operator, it can be used to convert
conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> add
to
conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> LT(NCHW->NCHW16c) -> add ->
LT(NCHW16c->NCHW)
The API looks like, this can be pre-defined rules
```python
@add_populate_layout()
populate_layout_add(origin=["NCHW", "CHW"], preferred=["NCHW16c", "CHW16c"])
```
- [ ] Peephole pass
Remove unnecessary layout transform operators, it can be used to convert
conv2d_NCHW16c -> LT(NCHW16c->NCHW) -> LT(NCHW->NCHW16c) -> add ->
LT(NCHW16c->NCHW)
to
conv2d_NCHW16c -> add -> LT(NCHW16c->NCHW)
@tqchen @merrymercy @anijain2305
--
You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub:
https://github.com/dmlc/tvm/issues/3670