simoll updated this revision to Diff 296135.
simoll added a comment.
Herald added subscribers: llvm-commits, nikic, pengfei, hiraditya, mgorny.
Herald added a project: LLVM.

fixed for privatized ElementCount members.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D81083/new/

https://reviews.llvm.org/D81083

Files:
  llvm/include/llvm/Analysis/TargetTransformInfo.h
  llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
  llvm/include/llvm/CodeGen/ExpandVectorPredication.h
  llvm/include/llvm/CodeGen/Passes.h
  llvm/include/llvm/IR/IntrinsicInst.h
  llvm/include/llvm/InitializePasses.h
  llvm/lib/Analysis/TargetTransformInfo.cpp
  llvm/lib/CodeGen/CMakeLists.txt
  llvm/lib/CodeGen/ExpandVectorPredication.cpp
  llvm/lib/CodeGen/TargetPassConfig.cpp
  llvm/lib/IR/IntrinsicInst.cpp
  llvm/test/CodeGen/AArch64/O0-pipeline.ll
  llvm/test/CodeGen/AArch64/O3-pipeline.ll
  llvm/test/CodeGen/ARM/O3-pipeline.ll
  llvm/test/CodeGen/Generic/expand-vp.ll
  llvm/test/CodeGen/X86/O0-pipeline.ll
  llvm/tools/llc/llc.cpp
  llvm/tools/opt/opt.cpp

Index: llvm/tools/opt/opt.cpp
===================================================================
--- llvm/tools/opt/opt.cpp
+++ llvm/tools/opt/opt.cpp
@@ -578,6 +578,7 @@
   initializePostInlineEntryExitInstrumenterPass(Registry);
   initializeUnreachableBlockElimLegacyPassPass(Registry);
   initializeExpandReductionsPass(Registry);
+  initializeExpandVectorPredicationPass(Registry);
   initializeWasmEHPreparePass(Registry);
   initializeWriteBitcodePassPass(Registry);
   initializeHardwareLoopsPass(Registry);
Index: llvm/tools/llc/llc.cpp
===================================================================
--- llvm/tools/llc/llc.cpp
+++ llvm/tools/llc/llc.cpp
@@ -318,6 +318,7 @@
   initializeVectorization(*Registry);
   initializeScalarizeMaskedMemIntrinPass(*Registry);
   initializeExpandReductionsPass(*Registry);
+  initializeExpandVectorPredicationPass(*Registry);
   initializeHardwareLoopsPass(*Registry);
   initializeTransformUtils(*Registry);
 
Index: llvm/test/CodeGen/X86/O0-pipeline.ll
===================================================================
--- llvm/test/CodeGen/X86/O0-pipeline.ll
+++ llvm/test/CodeGen/X86/O0-pipeline.ll
@@ -24,6 +24,7 @@
 ; CHECK-NEXT:       Lower constant intrinsics
 ; CHECK-NEXT:       Remove unreachable blocks from the CFG
 ; CHECK-NEXT:       Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT:       Expand vector predication intrinsics
 ; CHECK-NEXT:       Scalarize Masked Memory Intrinsics
 ; CHECK-NEXT:       Expand reduction intrinsics
 ; CHECK-NEXT:       Expand indirectbr instructions
Index: llvm/test/CodeGen/Generic/expand-vp.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/Generic/expand-vp.ll
@@ -0,0 +1,84 @@
+; RUN: opt --expand-vec-pred -S < %s | FileCheck %s
+
+; All VP intrinsics have to be lowered into non-VP ops
+; CHECK-NOT: {{call.* @llvm.vp.add}}
+; CHECK-NOT: {{call.* @llvm.vp.sub}}
+; CHECK-NOT: {{call.* @llvm.vp.mul}}
+; CHECK-NOT: {{call.* @llvm.vp.sdiv}}
+; CHECK-NOT: {{call.* @llvm.vp.srem}}
+; CHECK-NOT: {{call.* @llvm.vp.udiv}}
+; CHECK-NOT: {{call.* @llvm.vp.urem}}
+; CHECK-NOT: {{call.* @llvm.vp.and}}
+; CHECK-NOT: {{call.* @llvm.vp.or}}
+; CHECK-NOT: {{call.* @llvm.vp.xor}}
+; CHECK-NOT: {{call.* @llvm.vp.ashr}}
+; CHECK-NOT: {{call.* @llvm.vp.lshr}}
+; CHECK-NOT: {{call.* @llvm.vp.shl}}
+
+define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) {
+  %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r7 = call <8 x i32> @llvm.vp.and.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r8 = call <8 x i32> @llvm.vp.or.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %r9 = call <8 x i32> @llvm.vp.xor.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %rA = call <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %rB = call <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  %rC = call <8 x i32> @llvm.vp.shl.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n)
+  ret void
+}
+
+; fixed-width vectors
+; integer arith
+declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+; bit arith
+declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32)
+
+define void @test_vp_int_vscale(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i32> %i2, <vscale x 4 x i32> %f3, <vscale x 4 x i1> %m, i32 %n) {
+  %r0 = call <vscale x 4 x i32> @llvm.vp.add.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r1 = call <vscale x 4 x i32> @llvm.vp.sub.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r2 = call <vscale x 4 x i32> @llvm.vp.mul.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r3 = call <vscale x 4 x i32> @llvm.vp.sdiv.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r4 = call <vscale x 4 x i32> @llvm.vp.srem.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r5 = call <vscale x 4 x i32> @llvm.vp.udiv.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r6 = call <vscale x 4 x i32> @llvm.vp.urem.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r7 = call <vscale x 4 x i32> @llvm.vp.and.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r8 = call <vscale x 4 x i32> @llvm.vp.or.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %r9 = call <vscale x 4 x i32> @llvm.vp.xor.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %rA = call <vscale x 4 x i32> @llvm.vp.ashr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %rB = call <vscale x 4 x i32> @llvm.vp.lshr.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  %rC = call <vscale x 4 x i32> @llvm.vp.shl.nxv4i32(<vscale x 4 x i32> %i0, <vscale x 4 x i32> %i1, <vscale x 4 x i1> %m, i32 %n)
+  ret void
+}
+
+; scalable-width vectors
+; integer arith
+declare <vscale x 4 x i32> @llvm.vp.add.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.sub.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.mul.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.sdiv.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.srem.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.udiv.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.urem.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+; bit arith
+declare <vscale x 4 x i32> @llvm.vp.and.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.xor.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.or.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.ashr.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.lshr.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
+declare <vscale x 4 x i32> @llvm.vp.shl.nxv4i32(<vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i1>, i32)
Index: llvm/test/CodeGen/ARM/O3-pipeline.ll
===================================================================
--- llvm/test/CodeGen/ARM/O3-pipeline.ll
+++ llvm/test/CodeGen/ARM/O3-pipeline.ll
@@ -37,6 +37,7 @@
 ; CHECK-NEXT:      Constant Hoisting
 ; CHECK-NEXT:      Partially inline calls to library functions
 ; CHECK-NEXT:      Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT:      Expand vector predication intrinsics
 ; CHECK-NEXT:      Scalarize Masked Memory Intrinsics
 ; CHECK-NEXT:      Expand reduction intrinsics
 ; CHECK-NEXT:      Dominator Tree Construction
Index: llvm/test/CodeGen/AArch64/O3-pipeline.ll
===================================================================
--- llvm/test/CodeGen/AArch64/O3-pipeline.ll
+++ llvm/test/CodeGen/AArch64/O3-pipeline.ll
@@ -57,6 +57,7 @@
 ; CHECK-NEXT:       Constant Hoisting
 ; CHECK-NEXT:       Partially inline calls to library functions
 ; CHECK-NEXT:       Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT:       Expand vector predication intrinsics
 ; CHECK-NEXT:       Scalarize Masked Memory Intrinsics
 ; CHECK-NEXT:       Expand reduction intrinsics
 ; CHECK-NEXT:     Stack Safety Analysis
Index: llvm/test/CodeGen/AArch64/O0-pipeline.ll
===================================================================
--- llvm/test/CodeGen/AArch64/O0-pipeline.ll
+++ llvm/test/CodeGen/AArch64/O0-pipeline.ll
@@ -22,6 +22,7 @@
 ; CHECK-NEXT:       Lower constant intrinsics
 ; CHECK-NEXT:       Remove unreachable blocks from the CFG
 ; CHECK-NEXT:       Instrument function entry/exit with calls to e.g. mcount() (post inlining)
+; CHECK-NEXT:       Expand vector predication intrinsics
 ; CHECK-NEXT:       Scalarize Masked Memory Intrinsics
 ; CHECK-NEXT:       Expand reduction intrinsics
 ; CHECK-NEXT:       AArch64 Stack Tagging
Index: llvm/lib/IR/IntrinsicInst.cpp
===================================================================
--- llvm/lib/IR/IntrinsicInst.cpp
+++ llvm/lib/IR/IntrinsicInst.cpp
@@ -196,6 +196,12 @@
   return nullptr;
 }
 
+void VPIntrinsic::setMaskParam(Value *NewMask) {
+  auto MaskPos = GetMaskParamPos(getIntrinsicID());
+  assert(MaskPos.hasValue());
+  setArgOperand(MaskPos.getValue(), NewMask);
+}
+
 Value *VPIntrinsic::getVectorLengthParam() const {
   auto vlenPos = GetVectorLengthParamPos(getIntrinsicID());
   if (vlenPos)
@@ -203,6 +209,12 @@
   return nullptr;
 }
 
+void VPIntrinsic::setVectorLengthParam(Value *NewEVL) {
+  auto EVLPos = GetVectorLengthParamPos(getIntrinsicID());
+  assert(EVLPos.hasValue());
+  setArgOperand(EVLPos.getValue(), NewEVL);
+}
+
 Optional<int> VPIntrinsic::GetMaskParamPos(Intrinsic::ID IntrinsicID) {
   switch (IntrinsicID) {
   default:
Index: llvm/lib/CodeGen/TargetPassConfig.cpp
===================================================================
--- llvm/lib/CodeGen/TargetPassConfig.cpp
+++ llvm/lib/CodeGen/TargetPassConfig.cpp
@@ -702,6 +702,11 @@
   // Instrument function entry and exit, e.g. with calls to mcount().
   addPass(createPostInlineEntryExitInstrumenterPass());
 
+  // Expand vector predication intrinsics into standard IR instructions.
+  // This pass has to run before ScalarizeMaskedMemIntrin and ExpandReduction
+  // passes since it emits those kinds of intrinsics.
+  addPass(createExpandVectorPredicationPass());
+
   // Add scalarization of target's unsupported masked memory intrinsics pass.
   // the unsupported intrinsic will be replaced with a chain of basic blocks,
   // that stores/loads element one-by-one if the appropriate mask bit is set.
Index: llvm/lib/CodeGen/ExpandVectorPredication.cpp
===================================================================
--- /dev/null
+++ llvm/lib/CodeGen/ExpandVectorPredication.cpp
@@ -0,0 +1,456 @@
+//===--- CodeGen/ExpandVectorPredication.cpp - Expand VP intrinsics -===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass implements IR expansion for vector predication intrinsics, allowing
+// targets to enable vector predication until just before codegen.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/CodeGen/ExpandVectorPredication.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/InstIterator.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+
+using namespace llvm;
+
+using VPLegalization = TargetTransformInfo::VPLegalization;
+
+#define DEBUG_TYPE "expand-vec-pred"
+
+STATISTIC(NumFoldedVL, "Number of folded vector length params");
+STATISTIC(NumLoweredVPOps, "Number of folded vector predication operations");
+
+///// Helpers {
+
+/// \returns Whether the vector mask \p MaskVal has all lane bits set.
+static bool isAllTrueMask(Value *MaskVal) {
+  auto *ConstVec = dyn_cast<ConstantVector>(MaskVal);
+  if (!ConstVec)
+    return false;
+  return ConstVec->isAllOnesValue();
+}
+
+/// Computes the smallest integer bit width to hold the step vector <0, ..,
+/// NumVectorElements - 1>
+static unsigned getLeastLaneBitsForStepVector(unsigned NumVectorElements) {
+  unsigned MostSignificantOne =
+      llvm::countLeadingZeros<uint64_t>(NumVectorElements, ZB_Undefined);
+  return std::max<unsigned>(IntegerType::MIN_INT_BITS, 64 - MostSignificantOne);
+}
+
+/// \returns A non-excepting divisor constant for this type.
+static Constant *getSafeDivisor(Type *DivTy) {
+  assert(DivTy->isIntOrIntVectorTy());
+  return ConstantInt::get(DivTy, 1u, false);
+}
+
+/// Transfer operation properties from \p OldVPI to \p NewVal.
+static void transferDecorations(Value &NewVal, VPIntrinsic &VPI) {
+  auto *NewInst = dyn_cast<Instruction>(&NewVal);
+  if (!NewInst || !isa<FPMathOperator>(NewVal))
+    return;
+
+  auto *OldFMOp = dyn_cast<FPMathOperator>(&VPI);
+  if (!OldFMOp)
+    return;
+
+  NewInst->setFastMathFlags(OldFMOp->getFastMathFlags());
+}
+
+/// Transfer all properties from \p OldOp to \p NewOp and replace all uses.
+/// OldVP gets erased.
+static void replaceOperation(Value &NewOp, VPIntrinsic &OldOp) {
+  transferDecorations(NewOp, OldOp);
+  OldOp.replaceAllUsesWith(&NewOp);
+  OldOp.eraseFromParent();
+}
+
+//// } Helpers
+
+namespace {
+
+// Expansion pass state at function scope.
+struct CachingVPExpander {
+  Function &F;
+  const TargetTransformInfo &TTI;
+
+  /// \returns A (fixed length) vector with ascending integer indices
+  /// (<0, 1, ..., NumElems-1>).
+  Value *createStepVector(IRBuilder<> &Builder, int32_t ElemBits,
+                          int32_t NumElems);
+
+  /// \returns A bitmask that is true where the lane position is less-than \p
+  /// EVLParam
+  ///
+  /// \p Builder
+  ///    Used for instruction creation.
+  /// \p VLParam
+  ///    The explicit vector length parameter to test against the lane
+  ///    positions.
+  /// \p ElemCount
+  ///    Static (potentially scalable) number of vector elements
+  Value *convertEVLToMask(IRBuilder<> &Builder, Value *EVLParam,
+                          ElementCount ElemCount);
+
+  Value *foldEVLIntoMask(VPIntrinsic &VPI);
+
+  /// "Remove" the %evl parameter of \p PI by setting it to the static vector
+  /// length of the operation.
+  void discardEVLParameter(VPIntrinsic &PI);
+
+  /// \brief Lower this VP binary operator to a non-VP binary operator.
+  Value *expandPredicationInBinaryOperator(IRBuilder<> &Builder,
+                                           VPIntrinsic &PI);
+
+  /// \brief query TTI and expand the vector predication in \p P accordingly.
+  Value *expandPredication(VPIntrinsic &PI);
+
+  /// \brief return a good (for fast icmp) integer bit width to expand
+  /// the EVL comparison against the stepvector in.
+  std::map<unsigned, unsigned> StaticVLToBitsCache; // TODO 'SmallMap' class
+  unsigned getLaneBitsForEVLCompare(unsigned StaticVL);
+
+public:
+  CachingVPExpander(Function &F, const TargetTransformInfo &TTI)
+      : F(F), TTI(TTI) {}
+
+  // expand VP ops in \p F according to \p TTI.
+  bool expandVectorPredication();
+};
+
+//// CachingVPExpander {
+
+unsigned CachingVPExpander::getLaneBitsForEVLCompare(unsigned StaticVL) {
+  auto ItCached = StaticVLToBitsCache.find(StaticVL);
+  if (ItCached != StaticVLToBitsCache.end())
+    return ItCached->second;
+
+  // The smallest integer to hold <0, .., ElemCount.Min -1>
+  // Cannot choose less bits than this or the expansion will be invalid.
+  unsigned MinLaneBits = getLeastLaneBitsForStepVector(StaticVL);
+  LLVM_DEBUG(dbgs() << "Least lane bits for " << StaticVL << " is "
+                    << MinLaneBits << "\n";);
+
+  // If  the EVL compare will be expanded into scalar code, choose the
+  // smallest integer type.
+  if (TTI.getRegisterBitWidth(/* Vector */ true) == 0)
+    return MinLaneBits;
+
+  // Otw, the generated vector operation will likely map to vector instructions.
+  // The largest bit width to fit the EVL expansion in one vector register.
+  unsigned MaxLaneBits = std::min<unsigned>(
+      IntegerType::MAX_INT_BITS, TTI.getRegisterBitWidth(true) / StaticVL);
+
+  // Many SIMD instruction are restricted in their supported lane bit widths.
+  // We choose the bit width that gives us the cheapest vector compare.
+  int Cheapest = std::numeric_limits<int>::max();
+  auto &Ctx = F.getContext();
+  unsigned CheapestLaneBits = MinLaneBits;
+  for (auto LaneBits = MinLaneBits; LaneBits < MaxLaneBits; ++LaneBits) {
+    int VecCmpCost = TTI.getCmpSelInstrCost(
+        Instruction::ICmp, VectorType::get(Type::getIntNTy(Ctx, LaneBits),
+                                           StaticVL, /* Scalable */ false));
+    if (VecCmpCost < Cheapest) {
+      Cheapest = VecCmpCost;
+      CheapestLaneBits = LaneBits;
+    }
+  }
+
+  StaticVLToBitsCache[StaticVL] = CheapestLaneBits;
+  return CheapestLaneBits;
+}
+
+Value *CachingVPExpander::createStepVector(IRBuilder<> &Builder,
+                                           int32_t ElemBits, int32_t NumElems) {
+  // TODO add caching
+  SmallVector<Constant *, 16> ConstElems;
+
+  Type *LaneTy = Builder.getIntNTy(ElemBits);
+
+  for (int32_t Idx = 0; Idx < NumElems; ++Idx) {
+    ConstElems.push_back(ConstantInt::get(LaneTy, Idx, false));
+  }
+
+  return ConstantVector::get(ConstElems);
+}
+
+Value *CachingVPExpander::convertEVLToMask(IRBuilder<> &Builder,
+                                           Value *EVLParam,
+                                           ElementCount ElemCount) {
+  // TODO add caching
+  if (ElemCount.isScalable()) {
+    auto *M = Builder.GetInsertBlock()->getModule();
+    auto *BoolVecTy = VectorType::get(Builder.getInt1Ty(), ElemCount);
+    auto *ActiveMaskFunc = Intrinsic::getDeclaration(
+        M, Intrinsic::get_active_lane_mask, {BoolVecTy, EVLParam->getType()});
+    // `get_active_lane_mask` performs an implicit less-than comparison.
+    auto *ConstZero = Builder.getInt32(0);
+    return Builder.CreateCall(ActiveMaskFunc, {ConstZero, EVLParam});
+  }
+
+  unsigned NumElems = ElemCount.getFixedValue();
+  unsigned ElemBits = getLaneBitsForEVLCompare(NumElems);
+
+  Type *LaneTy = Builder.getIntNTy(ElemBits);
+
+  auto *ExtVLParam = Builder.CreateZExtOrTrunc(EVLParam, LaneTy);
+  auto *VLSplat = Builder.CreateVectorSplat(NumElems, ExtVLParam);
+
+  auto *IdxVec = createStepVector(Builder, ElemBits, NumElems);
+
+  return Builder.CreateICmp(CmpInst::ICMP_ULT, IdxVec, VLSplat);
+}
+
+Value *
+CachingVPExpander::expandPredicationInBinaryOperator(IRBuilder<> &Builder,
+                                                     VPIntrinsic &VPI) {
+  assert(VPI.canIgnoreVectorLengthParam());
+
+  auto OC = static_cast<Instruction::BinaryOps>(VPI.getFunctionalOpcode());
+  assert(Instruction::isBinaryOp(OC));
+
+  auto *FirstOp = VPI.getOperand(0);
+  auto *SndOp = VPI.getOperand(1);
+
+  auto *Mask = VPI.getMaskParam();
+
+  // Blend in safe operands
+  if (Mask && !isAllTrueMask(Mask)) {
+    switch (OC) {
+    default:
+      // can safely ignore the predicate
+      break;
+
+    // Division operators need a safe divisor on masked-off lanes (1)
+    case Instruction::UDiv:
+    case Instruction::SDiv:
+    case Instruction::URem:
+    case Instruction::SRem:
+      // 2nd operand must not be zero
+      auto *SafeDivisor = getSafeDivisor(VPI.getType());
+      SndOp = Builder.CreateSelect(Mask, SndOp, SafeDivisor);
+    }
+  }
+
+  auto *NewBinOp = Builder.CreateBinOp(OC, FirstOp, SndOp, VPI.getName());
+
+  replaceOperation(*NewBinOp, VPI);
+  return NewBinOp;
+}
+
+void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) {
+  LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n");
+
+  if (VPI.canIgnoreVectorLengthParam())
+    return;
+
+  Value *EVLParam = VPI.getVectorLengthParam();
+  if (!EVLParam)
+    return;
+
+  ElementCount StaticElemCount = VPI.getStaticVectorLength();
+  Value *MaxEVL = nullptr;
+  auto *Int32Ty = Type::getInt32Ty(VPI.getContext());
+  if (StaticElemCount.isScalable()) {
+    // TODO add caching
+    auto *M = VPI.getModule();
+    auto *VScaleFunc = Intrinsic::getDeclaration(M, Intrinsic::vscale, Int32Ty);
+    IRBuilder<> Builder(VPI.getParent(), VPI.getIterator());
+    auto *FactorConst = Builder.getInt32(StaticElemCount.getKnownMinValue());
+    auto *VScale = Builder.CreateCall(VScaleFunc, {}, "vscale");
+    MaxEVL = Builder.CreateMul(VScale, FactorConst, "scalable_size",
+                               /*NUW*/ true, /*NSW*/ false);
+  } else {
+    MaxEVL = ConstantInt::get(Int32Ty, StaticElemCount.getFixedValue(), false);
+  }
+  VPI.setVectorLengthParam(MaxEVL);
+}
+
+Value *CachingVPExpander::foldEVLIntoMask(VPIntrinsic &VPI) {
+  LLVM_DEBUG(dbgs() << "Folding vlen for " << VPI << '\n');
+
+  IRBuilder<> Builder(&VPI);
+
+  // No %evl parameter and so nothing to do here
+  if (VPI.canIgnoreVectorLengthParam()) {
+    return &VPI;
+  }
+
+  // Only VP intrinsics can have a %evl parameter
+  Value *OldMaskParam = VPI.getMaskParam();
+  Value *OldEVLParam = VPI.getVectorLengthParam();
+  assert(OldMaskParam && "no mask param to fold the vl param into");
+  assert(OldEVLParam && "no EVL param to fold away");
+
+  LLVM_DEBUG(dbgs() << "OLD evl: " << *OldEVLParam << '\n');
+  LLVM_DEBUG(dbgs() << "OLD mask: " << *OldMaskParam << '\n');
+
+  // Convert the %evl predication into vector mask predication.
+  ElementCount ElemCount = VPI.getStaticVectorLength();
+  auto *VLMask = convertEVLToMask(Builder, OldEVLParam, ElemCount);
+  auto *NewMaskParam = Builder.CreateAnd(VLMask, OldMaskParam);
+  VPI.setMaskParam(NewMaskParam);
+
+  // Drop the EVl parameter
+  discardEVLParameter(VPI);
+  assert(VPI.canIgnoreVectorLengthParam() &&
+         "transformation did not render the evl param ineffective!");
+
+  // re-asses the modified instruction
+  return &VPI;
+}
+
+Value *CachingVPExpander::expandPredication(VPIntrinsic &VPI) {
+  LLVM_DEBUG(dbgs() << "Lowering to unpredicated op: " << VPI << '\n');
+
+  IRBuilder<> Builder(&VPI);
+
+  // Try lowering to a LLVM instruction first.
+  unsigned OC = VPI.getFunctionalOpcode();
+#define FIRST_BINARY_INST(X) unsigned FirstBinOp = X;
+#define LAST_BINARY_INST(X) unsigned LastBinOp = X;
+#include "llvm/IR/Instruction.def"
+
+  if (FirstBinOp <= OC && OC <= LastBinOp) {
+    return expandPredicationInBinaryOperator(Builder, VPI);
+  }
+
+  return &VPI;
+}
+
+//// } CachingVPExpander
+
+struct TransformJob {
+  VPIntrinsic *PI;
+  TargetTransformInfo::VPLegalization Strategy;
+  TransformJob(VPIntrinsic *PI, TargetTransformInfo::VPLegalization InitStrat)
+      : PI(PI), Strategy(InitStrat) {}
+
+  bool isDone() const { return Strategy.doNothing(); }
+};
+
+void sanitizeStrategy(Instruction &I, VPLegalization &LegalizeStrat) {
+  // Speculatable instructions do not strictle need predication.
+  if (isSafeToSpeculativelyExecute(&I))
+    return;
+
+  // Preserve the predication effect of the EVL parameter by folding
+  // it into the predicate.
+  if (LegalizeStrat.EVLParamStrategy == VPLegalization::Discard) {
+    LegalizeStrat.EVLParamStrategy = VPLegalization::Convert;
+  }
+}
+
+/// \brief Expand llvm.vp.* intrinsics as requested by \p TTI.
+bool CachingVPExpander::expandVectorPredication() {
+  // Holds all vector-predicated ops with an effective vector length param that
+  SmallVector<TransformJob, 16> Worklist;
+
+  for (auto &I : instructions(F)) {
+    auto *VPI = dyn_cast<VPIntrinsic>(&I);
+    if (!VPI)
+      continue;
+    auto VPStrat = TTI.getVPLegalizationStrategy(*VPI);
+    sanitizeStrategy(I, VPStrat);
+    if (!VPStrat.doNothing()) {
+      Worklist.emplace_back(VPI, VPStrat);
+    }
+  }
+  if (Worklist.empty())
+    return false;
+
+  LLVM_DEBUG(dbgs() << "\n:::: Transforming instructions. ::::\n");
+  for (TransformJob Job : Worklist) {
+    // Transform the EVL parameter
+    switch (Job.Strategy.EVLParamStrategy) {
+    case VPLegalization::Legal:
+      break;
+    case VPLegalization::Discard: {
+      discardEVLParameter(*Job.PI);
+    } break;
+    case VPLegalization::Convert: {
+      if (foldEVLIntoMask(*Job.PI)) {
+        ++NumFoldedVL;
+      }
+    } break;
+    }
+    Job.Strategy.EVLParamStrategy = VPLegalization::Legal;
+
+    // Replace the operator
+    switch (Job.Strategy.OpStrategy) {
+    case VPLegalization::Legal:
+      break;
+    case VPLegalization::Discard:
+      llvm_unreachable("Invalid strategy for operators.");
+    case VPLegalization::Convert: {
+      expandPredication(*Job.PI);
+      ++NumLoweredVPOps;
+    } break;
+    }
+    Job.Strategy.OpStrategy = VPLegalization::Legal;
+
+    assert(Job.isDone() && "incomplete transformation");
+  }
+
+  return true;
+}
+class ExpandVectorPredication : public FunctionPass {
+public:
+  static char ID;
+  ExpandVectorPredication() : FunctionPass(ID) {
+    initializeExpandVectorPredicationPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+    CachingVPExpander VPExpander(F, *TTI);
+    return VPExpander.expandVectorPredication();
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<TargetTransformInfoWrapperPass>();
+    AU.setPreservesCFG();
+  }
+};
+} // namespace
+
+char ExpandVectorPredication::ID;
+INITIALIZE_PASS_BEGIN(ExpandVectorPredication, "expand-vec-pred",
+                      "Expand vector predication intrinsics", false, false)
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_END(ExpandVectorPredication, "expand-vec-pred",
+                    "Expand vector predication intrinsics", false, false)
+
+FunctionPass *llvm::createExpandVectorPredicationPass() {
+  return new ExpandVectorPredication();
+}
+
+PreservedAnalyses
+ExpandVectorPredicationPass::run(Function &F, FunctionAnalysisManager &AM) {
+  const auto &TTI = AM.getResult<TargetIRAnalysis>(F);
+  CachingVPExpander VPExpander(F, TTI);
+  if (!VPExpander.expandVectorPredication())
+    return PreservedAnalyses::all();
+  PreservedAnalyses PA;
+  PA.preserveSet<CFGAnalyses>();
+  return PA;
+}
Index: llvm/lib/CodeGen/CMakeLists.txt
===================================================================
--- llvm/lib/CodeGen/CMakeLists.txt
+++ llvm/lib/CodeGen/CMakeLists.txt
@@ -27,6 +27,7 @@
   ExpandMemCmp.cpp
   ExpandPostRAPseudos.cpp
   ExpandReductions.cpp
+  ExpandVectorPredication.cpp
   FaultMaps.cpp
   FEntryInserter.cpp
   FinalizeISel.cpp
Index: llvm/lib/Analysis/TargetTransformInfo.cpp
===================================================================
--- llvm/lib/Analysis/TargetTransformInfo.cpp
+++ llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1024,6 +1024,11 @@
   return TTIImpl->preferPredicatedReductionSelect(Opcode, Ty, Flags);
 }
 
+TargetTransformInfo::VPLegalization
+TargetTransformInfo::getVPLegalizationStrategy(const VPIntrinsic &VPI) const {
+  return TTIImpl->getVPLegalizationStrategy(VPI);
+}
+
 bool TargetTransformInfo::shouldExpandReduction(const IntrinsicInst *II) const {
   return TTIImpl->shouldExpandReduction(II);
 }
Index: llvm/include/llvm/InitializePasses.h
===================================================================
--- llvm/include/llvm/InitializePasses.h
+++ llvm/include/llvm/InitializePasses.h
@@ -150,6 +150,7 @@
 void initializeExpandMemCmpPassPass(PassRegistry&);
 void initializeExpandPostRAPass(PassRegistry&);
 void initializeExpandReductionsPass(PassRegistry&);
+void initializeExpandVectorPredicationPass(PassRegistry &);
 void initializeMakeGuardsExplicitLegacyPassPass(PassRegistry&);
 void initializeExternalAAWrapperPassPass(PassRegistry&);
 void initializeFEntryInserterPass(PassRegistry&);
Index: llvm/include/llvm/IR/IntrinsicInst.h
===================================================================
--- llvm/include/llvm/IR/IntrinsicInst.h
+++ llvm/include/llvm/IR/IntrinsicInst.h
@@ -255,9 +255,11 @@
 
   /// \return the mask parameter or nullptr.
   Value *getMaskParam() const;
+  void setMaskParam(Value *);
 
   /// \return the vector length parameter or nullptr.
   Value *getVectorLengthParam() const;
+  void setVectorLengthParam(Value *);
 
   /// \return whether the vector length param can be ignored.
   bool canIgnoreVectorLengthParam() const;
Index: llvm/include/llvm/CodeGen/Passes.h
===================================================================
--- llvm/include/llvm/CodeGen/Passes.h
+++ llvm/include/llvm/CodeGen/Passes.h
@@ -456,6 +456,11 @@
   /// shuffles.
   FunctionPass *createExpandReductionsPass();
 
+  /// This pass expands the vector predication intrinsics into unpredicated
+  /// instructions with selects or just the explicit vector length into the
+  /// predicate mask.
+  FunctionPass *createExpandVectorPredicationPass();
+
   // This pass expands memcmp() to load/stores.
   FunctionPass *createExpandMemCmpPass();
 
Index: llvm/include/llvm/CodeGen/ExpandVectorPredication.h
===================================================================
--- /dev/null
+++ llvm/include/llvm/CodeGen/ExpandVectorPredication.h
@@ -0,0 +1,23 @@
+//===-- ExpandVectorPredication.h - Expand vector predication ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_EXPANDVECTORPREDICATION_H
+#define LLVM_CODEGEN_EXPANDVECTORPREDICATION_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class ExpandVectorPredicationPass
+    : public PassInfoMixin<ExpandVectorPredicationPass> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+} // end namespace llvm
+
+#endif // LLVM_CODEGEN_EXPANDVECTORPREDICATION_H
Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
===================================================================
--- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -677,6 +677,13 @@
 
   bool hasActiveVectorLength() const { return false; }
 
+  TargetTransformInfo::VPLegalization
+  getVPLegalizationStrategy(const VPIntrinsic &PI) const {
+    return TargetTransformInfo::VPLegalization(
+        /* EVLParamStrategy */ TargetTransformInfo::VPLegalization::Discard,
+        /* OperatorStrategy */ TargetTransformInfo::VPLegalization::Convert);
+  }
+
 protected:
   // Obtain the minimum required size to hold the value (without the sign)
   // In case of a vector it returns the min required size for one element.
Index: llvm/include/llvm/Analysis/TargetTransformInfo.h
===================================================================
--- llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -57,6 +57,7 @@
 class Type;
 class User;
 class Value;
+class VPIntrinsic;
 struct KnownBits;
 template <typename T> class Optional;
 
@@ -1326,6 +1327,38 @@
   /// Intrinsics") Use of %evl is discouraged when that is not the case.
   bool hasActiveVectorLength() const;
 
+  struct VPLegalization {
+    enum VPTransform {
+      // keep the predicating parameter
+      Legal = 0,
+      // where legal, discard the predicate parameter
+      Discard = 1,
+      // transform into something else that is also predicating
+      Convert = 2
+    };
+
+    // How to transform the EVL parameter.
+    // Legal:   keep the EVL parameter as it is.
+    // Discard: Ignore the EVL parameter where it is safe to do so.
+    // Convert: Fold the EVL into the mask parameter.
+    VPTransform EVLParamStrategy;
+
+    // How to transform the operator.
+    // Legal:   The target supports this operator.
+    // Convert: Convert this to a non-VP operation.
+    // The 'Discard' strategy is invalid.
+    VPTransform OpStrategy;
+
+    bool doNothing() const {
+      return (EVLParamStrategy == Legal) && (OpStrategy == Legal);
+    }
+    VPLegalization(VPTransform EVLParamStrategy, VPTransform OpStrategy)
+        : EVLParamStrategy(EVLParamStrategy), OpStrategy(OpStrategy) {}
+  };
+
+  /// \returns How the target needs this vector-predicated operation to be
+  /// transformed.
+  VPLegalization getVPLegalizationStrategy(const VPIntrinsic &PI) const;
   /// @}
 
   /// @}
@@ -1609,6 +1642,8 @@
   virtual bool shouldExpandReduction(const IntrinsicInst *II) const = 0;
   virtual unsigned getGISelRematGlobalCost() const = 0;
   virtual bool hasActiveVectorLength() const = 0;
+  virtual VPLegalization
+  getVPLegalizationStrategy(const VPIntrinsic &PI) const = 0;
   virtual int getInstructionLatency(const Instruction *I) = 0;
 };
 
@@ -2127,6 +2162,11 @@
     return Impl.hasActiveVectorLength();
   }
 
+  VPLegalization
+  getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
+    return Impl.getVPLegalizationStrategy(PI);
+  }
+
   int getInstructionLatency(const Instruction *I) override {
     return Impl.getInstructionLatency(I);
   }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to