https://github.com/bviyer updated
https://github.com/llvm/llvm-project/pull/76087
>From d97a5729f496fb603f4fb9cf2977b015d8e37ed6 Mon Sep 17 00:00:00 2001
From: "Balaji V. Iyer"
Date: Thu, 30 Nov 2023 20:39:55 +
Subject: [PATCH 1/3] [mlir][Vectorizer] Vectorize `tensor.unpack`
This patch allows vectorization of a `tensor.unpack` operation.
---
.../Linalg/Transforms/Vectorization.cpp | 96 +++
1 file changed, 96 insertions(+)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index f9a53a8451a60..7a9846154bf34 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
@@ -1385,6 +1386,88 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter,
VectorizationState &state,
return success();
}
+// Vectorize an `tensor::UnPackOp` without OuterDimsPerms to these 4 Ops:
+// Vector::TransferReadOp - Reads the Vector Array of Source data
+// vector::TransposeOp - Transpose the Source
+// ShapeCastOp - Reshapes the data based on the target.
+// vector::TransferWriteOp. - Write the result vector back.
+
+static LogicalResult vectorizeAsUnpackOp(RewriterBase &rewriter,
+ tensor::UnPackOp unpackOp,
+ ArrayRef inputVectorSizes,
+ SmallVectorImpl &newResults) {
+
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+LDBG("outer dimensions perms NYI for: " << unpackOp);
+return failure();
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(unpackOp);
+
+ RankedTensorType packTensorType = unpackOp.getSourceType();
+ auto maskType =
+ VectorType::get(packTensorType.getShape(), rewriter.getI1Type());
+ auto vectorType = VectorType::get(packTensorType.getShape(),
+packTensorType.getElementType());
+ ReifiedRankedShapedTypeDims reifiedRetShapes;
+ LogicalResult status =
+ cast(unpackOp.getOperation())
+ .reifyResultShapes(rewriter, reifiedRetShapes);
+ if (status.failed()) {
+LDBG("Unable to reify result shapes of " << unpackOp);
+return failure();
+ }
+
+ arith::ConstantIndexOp zeroOp =
+ rewriter.create(unpackOp->getLoc(), 0);
+ Value mask = rewriter.create(
+ unpackOp.getLoc(), maskType,
+ tensor::getMixedSizes(rewriter, unpackOp.getLoc(),
unpackOp.getSource()));
+
+ vector::TransferReadOp readOp = rewriter.create(
+ unpackOp.getLoc(), vectorType, unpackOp.getSource(),
+ SmallVector(packTensorType.getRank(), zeroOp),
+ rewriter.getMultiDimIdentityMap(packTensorType.getRank()));
+
+ vector::MaskOp maskedOp =
+ cast(mlir::vector::maskOperation(rewriter, readOp,
mask));
+
+ int64_t numPackedDim = unpackOp.getInnerDimsPos().size();
+ int64_t packRank = packTensorType.getRank();
+ auto lastDims =
+ llvm::to_vector(llvm::seq(packRank - numPackedDim, packRank));
+ PackingMetadata packMetadata =
+ computePackingMetadata(packRank, unpackOp.getInnerDimsPos());
+ SmallVector lastDimToInsertPosPerm = computePermutationVector(
+ packRank, lastDims, packMetadata.insertPositions);
+ SmallVector stripMineShape(packTensorType.getShape());
+ applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
+
+ RankedTensorType stripMineTensorType =
+ RankedTensorType::Builder(packTensorType).setShape(stripMineShape);
+
+ RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
+ stripMineTensorType, packMetadata.reassociations);
+ auto vecCollapsedType =
+ VectorType::get(collapsedType.getShape(),
collapsedType.getElementType());
+
+ vector::TransposeOp transposeOp = rewriter.create(
+ unpackOp.getLoc(), maskedOp.getResult(0), lastDimToInsertPosPerm);
+
+ vector::ShapeCastOp shapeCastOp = rewriter.create(
+ unpackOp.getLoc(), vecCollapsedType, transposeOp->getResult(0));
+ tensor::EmptyOp emptyOp = rewriter.create(
+ unpackOp.getLoc(), reifiedRetShapes[0], packTensorType.getElementType());
+
+ vector::TransferWriteOp writeOp = rewriter.create(
+ unpackOp.getLoc(), shapeCastOp->getResult(0), emptyOp,
+ SmallVector(lastDims.size(), zeroOp),
+ SmallVector(lastDims.size(), true));
+
+ newResults.push_back(writeOp->getResult(0));
+ return success();
+}
/// Vectorize a `padOp` with (1) static result type, (2) constant padding value
/// and (3) all-zero lowPad to
@@ -1578,6 +1661,12 @@ vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
return success