================ @@ -1860,25 +1866,54 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, auto destSize = unpackOp.getDestRank(); - if (!inputVectorSizes.empty()) - assert(inputVectorSizes.size() == destSize && + if (!inputVectorSizes.empty()) { + assert(inputVectorSizes.size() == destSize + sourceShape.size() && "Incorrect number of input vector sizes"); + } + + SmallVector<bool> readScalableVectorFlags; + SmallVector<bool> writeScalableVectorFlags; + SmallVector<int64_t> readVectorSizes; + SmallVector<int64_t> writeVectorSizes; - // vectorSizes is the shape of the vector that will be used to do final + // Split input-vector-sizes into vector sizes for the read and write + // operations. + if (!inputVectorSizes.empty()) { + readVectorSizes.append(inputVectorSizes.begin(), + inputVectorSizes.begin() + sourceShape.size()); + writeVectorSizes.append(inputVectorSizes.begin() + sourceShape.size(), + inputVectorSizes.end()); + } + if (!inputScalableVecDims.empty()) { + readScalableVectorFlags.append(inputScalableVecDims.begin(), + inputScalableVecDims.begin() + + sourceShape.size()); + writeScalableVectorFlags.append(inputScalableVecDims.begin() + + sourceShape.size(), + inputScalableVecDims.end()); + } else { + readScalableVectorFlags = SmallVector<bool>(sourceShape.size(), false); + writeScalableVectorFlags = SmallVector<bool>(destSize, false); + } + + // writeVectorSizes is the shape of the vector that will be used to do final // write on the destination tensor. It is set like this: Let's say the // source tensor is rank 'M' and the dest tensor rank 'N', where N <= M. // Thus: - // 1. vectorSizes = sourceShape.take_front(N) - // 2. if outer_dims_perms is present: do that permutation on vectorSizes. + // 1. writeVectorSizes = sourceShape.take_front(N) + // 2. if outer_dims_perms is present: do that permutation on writeVectorSizes. // 3. multiply all the locations in vectorSize pointed by innerDimPos by the // innerTiles attribute value. - SmallVector<int64_t> vectorSizes(inputVectorSizes); - if (vectorSizes.empty()) { - llvm::append_range(vectorSizes, sourceShape.take_front(destSize)); + // SmallVector<int64_t> writeVectorSizes(inputVectorSizes); ---------------- banach-space wrote:
> Also, can we add a comment here saying that this is the case that we would be > inferring the write vector sizes from the IR? Thanks! Sure. In fact, I've re-written that comment (and the related logic) in [this commit](https://github.com/llvm/llvm-project/pull/149293/commits/b3894996379080b4a24bebe5cc8c938d42ce4243) (it wasn't clear me how to expand it cleanly). https://github.com/llvm/llvm-project/pull/149293 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits