================
@@ -1860,25 +1866,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, 
linalg::UnPackOp unpackOp,
 
   auto destSize = unpackOp.getDestRank();
 
-  if (!inputVectorSizes.empty())
-    assert(inputVectorSizes.size() == destSize &&
+  if (!inputVectorSizes.empty()) {
+    assert(inputVectorSizes.size() == destSize + sourceShape.size() &&
            "Incorrect number of input vector sizes");
+  }
+
+  SmallVector<bool> readScalableVectorFlags;
+  SmallVector<bool> writeScalableVectorFlags;
+  SmallVector<int64_t> readVectorSizes;
+  SmallVector<int64_t> writeVectorSizes;
 
-  // vectorSizes is the shape of the vector that will be used to do final
+  // Split input-vector-sizes into vector sizes for the read and write
+  // operations.
+  if (!inputVectorSizes.empty()) {
+    readVectorSizes.append(inputVectorSizes.begin(),
+                           inputVectorSizes.begin() + sourceShape.size());
+    writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(),
+                            inputVectorSizes.end());
+  }
+  if (!inputScalableVecDims.empty()) {
+    readScalableVectorFlags.append(inputScalableVecDims.begin(),
+                                   inputScalableVecDims.begin() +
+                                       sourceShape.size());
+    writeScalableVectorFlags.append(inputScalableVecDims.begin() +
+                                        sourceShape.size(),
+                                    inputScalableVecDims.end());
+  } else {
+    readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false);
+    writeScalableVectorFlags = SmallVector<bool>(destSize, false);
+  }
+
+  // writeVectorSizes is the shape of the vector that will be used to do final
   // write on the destination tensor. It is set like this: Let's say the
   // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M.
   // Thus:
-  // 1. vectorSizes = sourceShape.take_front(N)
-  // 2. if outer_dims_perms is present: do that permutation on vectorSizes.
+  // 1. writeVectorSizes = sourceShape.take_front(N)
+  // 2. if outer_dims_perms is present: do that permutation on 
writeVectorSizes.
   // 3. multiply all the locations in vectorSize pointed by innerDimPos by the
   //    innerTiles attribute value.
-  SmallVector<int64_t> vectorSizes(inputVectorSizes);
-  if (vectorSizes.empty()) {
-    llvm::append_range(vectorSizes, sourceShape.take_front(destSize));
+  // SmallVector<int64_t> writeVectorSizes(inputVectorSizes);
----------------
banach-space wrote:

> Also, can we add a comment here saying that this is the case that we would be 
> inferring the write vector sizes from the IR? Thanks!

Sure. In fact, I've re-written that comment (and the related logic) in [this 
commit](https://github.com/llvm/llvm-project/pull/149293/commits/b3894996379080b4a24bebe5cc8c938d42ce4243)
 (it wasn't clear me how to expand it cleanly).

https://github.com/llvm/llvm-project/pull/149293
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to