Issue |
150185
|
Summary |
[mlir] How to simplify this overly complex size calculation after tiling?
|
Labels |
mlir
|
Assignees |
|
Reporter |
banach-space
|
Hi folks,
_This is based on https://github.com/iree-org/iree/issues/21393 that was originally posted by @egebeysel , thanks! I've re-written it using "pure" MLIR (i.e. to not require IREE)._
**REPRO**
```mlir
func.func @unpack(%arg0: tensor<512x?x8x?xf32>, %arg1: tensor<4096x4096xf32>, %arg2: tensor<512x?x8x?xf32>) -> tensor<4096x4096xf32> {
%c8 = arith.constant 8 : index
%vscale = vector.vscale
%c8_vscale = arith.muli %vscale, %c8 : index
%0 = tensor.empty() : tensor<4096x4096xf32>
%unpack = linalg.unpack %arg2 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %0 : tensor<512x?x8x?xf32> -> tensor<4096x4096xf32>
return %unpack : tensor<4096x4096xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module : !transform.any_op {transform.readonly}) {
%unpack = transform.structured.match ops{["linalg.unpack"]} in %module
: (!transform.any_op) -> !transform.any_op
%tiled_unpack, %loops:2 = transform.structured.tile_using_for %unpack tile_sizes [8, [8]]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
transform.yield
}
}
```
After tiling (`bin/mlir-opt --transform-interpreter unpack.mlir -cse -test-transform-dialect-erase-schedule`):
```mlir
#map = affine_map<(d0)[s0] -> (-d0 + 4096, s0)>
#map1 = affine_map<(d0) -> (d0 floordiv 8)>
#map2 = affine_map<(d0)[s0] -> (d0 floordiv s0)>
#map3 = affine_map<(d0)[s0] -> (d0 mod s0)>
#map4 = affine_map<(d0, d1)[s0] -> ((d0 + d1 - 1) floordiv s0 - d0 floordiv s0 + 1)>
// ...
%1 = scf.for %arg3 = %c0 to %c4096 step %c8 iter_args(%arg4 = %0) -> (tensor<4096x4096xf32>) {
%2 = scf.for %arg5 = %c0 to %c4096 step %c8_vscale iter_args(%arg6 = %arg4) -> (tensor<4096x4096xf32>) {
%3 = affine.min #map(%arg5)[%c8_vscale]
%4 = affine.apply #map1(%arg3)
%5 = affine.apply #map2(%arg5)[%c8_vscale]
%7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
%extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>
%9 = tensor.empty(%8) : tensor<8x?xf32>
%unpack = linalg.unpack %extracted_slice outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [8, %c8_vscale] into %9 : tensor<1x?x8x?xf32> -> tensor<8x?xf32>
// ...
```
**ISSUE**
This _expression_ in the output is overly complex:
```mlir
%7 = affine.apply #map4(%arg5, %3)[%c8_vscale]
// ...
%extracted_slice = tensor.extract_slice %arg2[%4, %5, 0, 0] [1, %7, 8, %c8_vscale] [1, 1, 1, 1] : tensor<512x?x8x?xf32> to tensor<1x?x8x?xf32>
```
Specifically, `%7` should be trivially `1` (see below). As a result, we get the unpack source shape 1x?x8x? instead of 1x1x8x?, which makes `tensor.unpack` trickier to vectorise.
**WHY SHOULD %7 BE 1?**
For ease of interpretation:
* `%3`: tile/step size or remaining size (last iteration)
* `%4`: index of the loop over M (tile/step size = `8`)
* `%5`: index of the loop over N (tile/step size = `8 * vscale`)
* `%7`: from how many %5s are we reading from -> only != 1 when inner tile sizes are **not aligned with tile sizes**, which is not the case here (`8 * vscale` for both here).
Specifically, given that:`%7 = ((%arg5 + %3 - 1) floordiv %c8_vscale - %arg5 floordiv %c8_vscale + 1)`, and:
* `%3` <= `8 * vscale` --> `%3 - 1 < 8 * vscale`
* `(%arg5 + %3 - 1) / %c8_vscale` == `%arg5 / %c8_vscale`
* `((%arg5 + %3 - 1) / %c8_vscale - %arg5 / %c8_vscale` == `0`
we can safely conclude that:`%7 == 1`.
_______________________________________________
llvm-bugs mailing list
llvm-bugs@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs