================
@@ -3483,6 +3488,125 @@ bool 
VectorCombine::foldInterleaveIntrinsics(Instruction &I) {
   return true;
 }
 
+// Attempt to shrink loads that are only used by shufflevector instructions.
+bool VectorCombine::shrinkLoadForShuffles(Instruction &I) {
+  auto *OldLoad = dyn_cast<LoadInst>(&I);
+  if (!OldLoad || !OldLoad->isSimple())
+    return false;
+
+  auto *OldLoadTy = dyn_cast<FixedVectorType>(OldLoad->getType());
+  if (!OldLoadTy)
+    return false;
+
+  unsigned const OldNumElements = OldLoadTy->getNumElements();
+
+  // Search all uses of load. If all uses are shufflevector instructions, and
+  // the second operands are all poison values, find the minimum and maximum
+  // indices of the vector elements referenced by all shuffle masks.
+  // Otherwise return `std::nullopt`.
+  using IndexRange = std::pair<int, int>;
+  auto GetIndexRangeInShuffles = [&]() -> std::optional<IndexRange> {
+    IndexRange OutputRange = IndexRange(OldNumElements, -1);
+    for (llvm::Use &Use : I.uses()) {
+      // Ensure all uses match the required pattern.
+      User *Shuffle = Use.getUser();
+      ArrayRef<int> Mask;
+
+      if (!match(Shuffle,
+                 m_Shuffle(m_Specific(OldLoad), m_Undef(), m_Mask(Mask))))
+        return std::nullopt;
+
+      // Ignore shufflevector instructions that have no uses.
+      if (Shuffle->use_empty())
+        continue;
+
+      // Find the min and max indices used by the shufflevector instruction.
+      for (int Index : Mask) {
+        if (Index >= 0 && Index < static_cast<int>(OldNumElements)) {
+          OutputRange.first = std::min(Index, OutputRange.first);
+          OutputRange.second = std::max(Index, OutputRange.second);
+        }
+      }
+    }
+
+    if (OutputRange.second < OutputRange.first)
+      return std::nullopt;
+
+    return OutputRange;
+  };
+
+  // Get the range of vector elements used by shufflevector instructions.
+  if (std::optional<IndexRange> Indices = GetIndexRangeInShuffles()) {
+    unsigned const NewNumElements = Indices->second + 1u;
+
+    // If the range of vector elements is smaller than the full load, attempt
+    // to create a smaller load.
+    if (NewNumElements < OldNumElements) {
+      IRBuilder Builder(&I);
+      Builder.SetCurrentDebugLocation(I.getDebugLoc());
+
+      // Calculate costs of old and new ops.
+      Type *ElemTy = OldLoadTy->getElementType();
+      FixedVectorType *NewLoadTy = FixedVectorType::get(ElemTy, 
NewNumElements);
+      Value *PtrOp = OldLoad->getPointerOperand();
+
+      InstructionCost OldCost = TTI.getMemoryOpCost(
+          Instruction::Load, OldLoad->getType(), OldLoad->getAlign(),
+          OldLoad->getPointerAddressSpace(), CostKind);
+      InstructionCost NewCost =
+          TTI.getMemoryOpCost(Instruction::Load, NewLoadTy, 
OldLoad->getAlign(),
+                              OldLoad->getPointerAddressSpace(), CostKind);
+
+      using UseEntry = std::pair<ShuffleVectorInst *, std::vector<int>>;
+      SmallVector<UseEntry, 4u> NewUses;
+      unsigned const SizeDiff = OldNumElements - NewNumElements;
+
+      for (llvm::Use &Use : I.uses()) {
+        auto *Shuffle = cast<ShuffleVectorInst>(Use.getUser());
+        ArrayRef<int> OldMask = Shuffle->getShuffleMask();
+
+        // Create entry for new use.
+        NewUses.push_back({Shuffle, {}});
+        std::vector<int> &NewMask = NewUses.back().second;
+        for (int Index : OldMask)
+          NewMask.push_back(Index >= static_cast<int>(OldNumElements)
+                                ? Index - SizeDiff
+                                : Index);
+
+        // Update costs.
+        OldCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, OldLoadTy,
+                                      OldMask, CostKind);
+        NewCost += TTI.getShuffleCost(TTI::SK_PermuteSingleSrc, NewLoadTy,
----------------
RKSimon wrote:

getShuffleCost has now changed to take a Dst/Src type - you will need to merge 
against trunk

https://github.com/llvm/llvm-project/pull/128938
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to