================
@@ -298,16 +298,139 @@ struct LegalizeSVEMaskLoadConversion : public 
OpRewritePattern<memref::LoadOp> {
   }
 };
 
+/// Transforms a `transfer_read` operation so it reads vector of a type that
+/// can be mapped to an LLVM type. This is done by collapsing trailing
+/// dimensions so we obtain a vector type with a single scalable dimension in
+/// the rightmost position.
+///
+/// Example:
+/// ```
+/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8
+///   {in_bounds = [false, true, true, true]}
+///   : memref<?x?x2x8xi8>, vector<2x[4]x2x8xi8>
+/// ```
+/// is rewritten to
+/// ```
+/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]]
+///   : memref<?x?x2x8xi8> into memref<?x?xi8>
+/// %0 = vector.transfer_read  %collapse_shape[%i, %j], %c0_i8
+///   {in_bounds = [false, true]}
+///   : memref<?x?xi8>, vector<2x[64]xi8>
+/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8>
+/// ```
+struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
+                                PatternRewriter &rewriter) const override {
+
+    // Do not try to transform masked reads. For example, if we have a transfer
+    // to a `vector<[4]x4xi8>` we could have a mask like
+    //    1 1 1 0
+    //    1 1 1 0
+    //    1 1 1 0
+    //    0 0 0 0
+    // Flattening this mask would look like
+    //    1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0
+    // and we have not yet figured out an efficient way to build such a mask,
+    // neither from the mask operand, nor from the original 
`vector.create_mask`
+    // operation (if visible at all).
+    if (readOp.isMasked() || readOp.getMask())
+      return rewriter.notifyMatchFailure(readOp,
+                                         "masked transfers not-supported");
+
+    if (!readOp.getPermutationMap().isMinorIdentity())
+      return rewriter.notifyMatchFailure(readOp, "non-identity permutation");
+
+    // We handle transfers of vectors with rank >= 2 and a single scalable
+    // dimension.
+    VectorType origVT = readOp.getVectorType();
+    ArrayRef<bool> origScalableDims = origVT.getScalableDims();
+    const int64_t origVRank = origVT.getRank();
+    if (origVRank < 2 || llvm::count(origScalableDims, true) != 1)
+      return rewriter.notifyMatchFailure(readOp, "wrong dimensions");
+
+    // Number of trailing dimensions to collapse, including the scalable
+    // dimension.  Nothing to do if the single scalable dimension is already 
the
+    // last one.
+    const int64_t numCollapseDims = std::distance(
+        llvm::find(origScalableDims, true), origScalableDims.end());
+    if (numCollapseDims < 2)
+      return rewriter.notifyMatchFailure(readOp,
+                                         "scalable dimension is trailing");
+
+    // We want a simple memref (not a tensor) with contiguous elements for at
+    // least all the trailing dimensions up to and including the scalable one.
+    auto memTy = dyn_cast<MemRefType>(readOp.getBase().getType());
+    if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims)))
+      return rewriter.notifyMatchFailure(
+          readOp, "non-contiguous memref dimensions to collapse");
+
+    // The collapsed dimensions (excluding the scalable one) of the vector and
+    // the memref must match and the corresponding indices must be in-bounds 
(it
+    // follows these indices would be zero). This guarantees that the operation
+    // transfers a contiguous block.
+    if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1),
+                     origVT.getShape().take_back(numCollapseDims - 1)))
+      return rewriter.notifyMatchFailure(
+          readOp, "memref and vector dimensions do not match");
+
+    SmallVector<bool> origInBounds = readOp.getInBoundsValues();
+    if (!llvm::all_of(
+            ArrayRef<bool>(origInBounds).take_back(numCollapseDims - 1),
+            [](bool v) { return v; }))
+      return rewriter.notifyMatchFailure(readOp,
+                                         "out-if-bounds index to collapse");
----------------
banach-space wrote:

Note, it's not really index that's out-of-bounds, but the corresponding memory 
access. So, index could be in-bounds, but we might be reading "more" than 
there's available to read (starting at that index). For example:

```mlir
vector.transfer_read %mem[5] : memref<7xi8>, vector<7xi8>
```
```suggestion
                                         "out-of-bounds index to collapse");
```

https://github.com/llvm/llvm-project/pull/143146
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to