CaprYang updated this revision to Diff 520296.
CaprYang removed reviewers: bollu, ldionne, nicolasvasilache, rafauler, Amir, 
maksfb, NoQ, njames93, libc++, libc++abi, libunwind, rymiel, 
HazardyKnusperkeks, owenpan, MyDeveloperDay.
CaprYang removed projects: clang-format, Flang, clang-tools-extra, MLIR, 
libunwind, libc++abi, libc-project, OpenMP, libc++, LLDB, Sanitizers, clang.

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D150043/new/

https://reviews.llvm.org/D150043

Files:
  llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
  llvm/test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll
  llvm/test/Transforms/InferAddressSpaces/masked-gather-scatter.ll

Index: llvm/test/Transforms/InferAddressSpaces/masked-gather-scatter.ll
===================================================================
--- /dev/null
+++ llvm/test/Transforms/InferAddressSpaces/masked-gather-scatter.ll
@@ -0,0 +1,25 @@
+; RUN: opt -S -passes=infer-address-spaces -assume-default-is-flat-addrspace %s | FileCheck %s
+
+; CHECK-LABEL: @masked_gather_inferas(
+; CHECK: tail call <4 x i32> @llvm.masked.gather.v4i32.v4p1
+define <4 x i32> @masked_gather_inferas(ptr addrspace(1) %out, <4 x i64> %index) {
+entry:
+  %out.1 = addrspacecast ptr addrspace(1) %out to ptr
+  %ptrs = getelementptr inbounds i32, ptr %out.1, <4 x i64> %index
+  %value = tail call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> %ptrs, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>, <4 x i32> poison)
+  ret <4 x i32> %value
+}
+
+; CHECK-LABEL: @masked_scatter_inferas(
+; CHECK: tail call void @llvm.masked.scatter.v4i32.v4p1
+define void @masked_scatter_inferas(ptr addrspace(1) %out, <4 x i64> %index, <4 x i32> %value) {
+entry:
+  %out.1 = addrspacecast ptr addrspace(1) %out to ptr
+  %ptrs = getelementptr inbounds i32, ptr %out.1, <4 x i64> %index
+  tail call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> %value, <4 x ptr> %ptrs, i32 4, <4 x i1> <i1 true, i1 true, i1 true, i1 true>)
+  ret void
+}
+
+declare <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr>, i32 immarg, <4 x i1>, <4 x i32>)
+
+declare void @llvm.masked.scatter.v4i32.v4p0(<4 x i32>, <4 x ptr>, i32 immarg, <4 x i1>)
\ No newline at end of file
Index: llvm/test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll
===================================================================
--- llvm/test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll
+++ llvm/test/Transforms/InferAddressSpaces/AMDGPU/icmp.ll
@@ -147,9 +147,8 @@
   ret i1 %cmp
 }
 
-; TODO: Should be handled
 ; CHECK-LABEL: @icmp_flat_flat_from_group_vector(
-; CHECK: %cmp = icmp eq <2 x ptr> %cast0, %cast1
+; CHECK: %cmp = icmp eq <2 x ptr addrspace(3)> %group.ptr.0, %group.ptr.1
 define <2 x i1> @icmp_flat_flat_from_group_vector(<2 x ptr addrspace(3)> %group.ptr.0, <2 x ptr addrspace(3)> %group.ptr.1) #0 {
   %cast0 = addrspacecast <2 x ptr addrspace(3)> %group.ptr.0 to <2 x ptr>
   %cast1 = addrspacecast <2 x ptr addrspace(3)> %group.ptr.1 to <2 x ptr>
Index: llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
===================================================================
--- llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -256,6 +256,48 @@
 INITIALIZE_PASS_END(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces",
                     false, false)
 
+static unsigned getPtrOrVecOfPtrsAddressSpace(Type *Ty) {
+  if (Ty->isVectorTy()) {
+    Ty = cast<VectorType>(Ty)->getElementType();
+  }
+  assert(Ty->isPointerTy());
+  return Ty->getPointerAddressSpace();
+}
+
+static bool isPtrOrVecOfPtrsType(Type *Ty) {
+  if (Ty->isVectorTy()) {
+    Ty = cast<VectorType>(Ty)->getElementType();
+  }
+  return Ty->isPointerTy();
+}
+
+static Type *getPtrOrVecOfPtrsWithNewAS(Type *Ty, unsigned NewAddrSpace) {
+  if (!Ty->isVectorTy()) {
+    assert(Ty->isPointerTy());
+    return PointerType::getWithSamePointeeType(cast<PointerType>(Ty),
+                                               NewAddrSpace);
+  }
+
+  Type *PT = cast<VectorType>(Ty)->getElementType();
+  assert(PT->isPointerTy());
+
+  Type *NPT =
+      PointerType::getWithSamePointeeType(cast<PointerType>(PT), NewAddrSpace);
+  return VectorType::get(NPT, cast<VectorType>(Ty)->getElementCount());
+}
+
+static bool hasSameElementOfPtrOrVecPtrs(Type *Ty1, Type *Ty2) {
+  assert(isPtrOrVecOfPtrsType(Ty1) && isPtrOrVecOfPtrsType(Ty2));
+  assert(Ty1->isVectorTy() == Ty2->isVectorTy());
+  if (Ty1->isVectorTy()) {
+    Ty1 = cast<VectorType>(Ty1)->getElementType();
+    Ty2 = cast<VectorType>(Ty2)->getElementType();
+  }
+
+  assert(Ty1->isPointerTy() && Ty2->isPointerTy());
+  return cast<PointerType>(Ty1)->hasSameElementTypeAs(cast<PointerType>(Ty2));
+}
+
 // Check whether that's no-op pointer bicast using a pair of
 // `ptrtoint`/`inttoptr` due to the missing no-op pointer bitcast over
 // different address spaces.
@@ -279,8 +321,9 @@
   // arithmetic may also be undefined after invalid pointer reinterpret cast.
   // However, as we confirm through the target hooks that it's a no-op
   // addrspacecast, it doesn't matter since the bits should be the same.
-  unsigned P2IOp0AS = P2I->getOperand(0)->getType()->getPointerAddressSpace();
-  unsigned I2PAS = I2P->getType()->getPointerAddressSpace();
+  unsigned P2IOp0AS =
+      getPtrOrVecOfPtrsAddressSpace(P2I->getOperand(0)->getType());
+  unsigned I2PAS = getPtrOrVecOfPtrsAddressSpace(I2P->getType());
   return CastInst::isNoopCast(Instruction::CastOps(I2P->getOpcode()),
                               I2P->getOperand(0)->getType(), I2P->getType(),
                               DL) &&
@@ -301,14 +344,14 @@
 
   switch (Op->getOpcode()) {
   case Instruction::PHI:
-    assert(Op->getType()->isPointerTy());
+    assert(isPtrOrVecOfPtrsType(Op->getType()));
     return true;
   case Instruction::BitCast:
   case Instruction::AddrSpaceCast:
   case Instruction::GetElementPtr:
     return true;
   case Instruction::Select:
-    return Op->getType()->isPointerTy();
+    return isPtrOrVecOfPtrsType(Op->getType());
   case Instruction::Call: {
     const IntrinsicInst *II = dyn_cast<IntrinsicInst>(&V);
     return II && II->getIntrinsicID() == Intrinsic::ptrmask;
@@ -373,6 +416,24 @@
   case Intrinsic::ptrmask:
     // This is handled as an address expression, not as a use memory operation.
     return false;
+  case Intrinsic::masked_gather: {
+    Type *RetTy = II->getType();
+    Type *NewPtrTy = NewV->getType();
+    Function *NewDecl =
+        Intrinsic::getDeclaration(M, II->getIntrinsicID(), {RetTy, NewPtrTy});
+    II->setArgOperand(0, NewV);
+    II->setCalledFunction(NewDecl);
+    return true;
+  }
+  case Intrinsic::masked_scatter: {
+    Type *ValueTy = II->getOperand(0)->getType();
+    Type *NewPtrTy = NewV->getType();
+    Function *NewDecl =
+        Intrinsic::getDeclaration(M, II->getIntrinsicID(), {ValueTy, NewPtrTy});
+    II->setArgOperand(1, NewV);
+    II->setCalledFunction(NewDecl);
+    return true;
+  }
   default: {
     Value *Rewrite = TTI->rewriteIntrinsicWithAddressSpace(II, OldV, NewV);
     if (!Rewrite)
@@ -394,6 +455,14 @@
     appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
                                                  PostorderStack, Visited);
     break;
+  case Intrinsic::masked_gather:
+    appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0),
+                                                 PostorderStack, Visited);
+    break;
+  case Intrinsic::masked_scatter:
+    appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(1),
+                                                 PostorderStack, Visited);
+    break;
   default:
     SmallVector<int, 2> OpIndexes;
     if (TTI->collectFlatAddressOperands(OpIndexes, IID)) {
@@ -412,7 +481,7 @@
 void InferAddressSpacesImpl::appendsFlatAddressExpressionToPostorderStack(
     Value *V, PostorderStackTy &PostorderStack,
     DenseSet<Value *> &Visited) const {
-  assert(V->getType()->isPointerTy());
+  assert(isPtrOrVecOfPtrsType(V->getType()));
 
   // Generic addressing expressions may be hidden in nested constant
   // expressions.
@@ -424,7 +493,7 @@
     return;
   }
 
-  if (V->getType()->getPointerAddressSpace() == FlatAddrSpace &&
+  if (getPtrOrVecOfPtrsAddressSpace(V->getType()) == FlatAddrSpace &&
       isAddressExpression(*V, *DL, TTI)) {
     if (Visited.insert(V).second) {
       PostorderStack.emplace_back(V, false);
@@ -460,8 +529,7 @@
   // addressing calculations may also be faster.
   for (Instruction &I : instructions(F)) {
     if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) {
-      if (!GEP->getType()->isVectorTy())
-        PushPtrOperand(GEP->getPointerOperand());
+      PushPtrOperand(GEP->getPointerOperand());
     } else if (auto *LI = dyn_cast<LoadInst>(&I))
       PushPtrOperand(LI->getPointerOperand());
     else if (auto *SI = dyn_cast<StoreInst>(&I))
@@ -481,13 +549,12 @@
       collectRewritableIntrinsicOperands(II, PostorderStack, Visited);
     else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) {
       // FIXME: Handle vectors of pointers
-      if (Cmp->getOperand(0)->getType()->isPointerTy()) {
+      if (isPtrOrVecOfPtrsType(Cmp->getOperand(0)->getType())) {
         PushPtrOperand(Cmp->getOperand(0));
         PushPtrOperand(Cmp->getOperand(1));
       }
     } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) {
-      if (!ASC->getType()->isVectorTy())
-        PushPtrOperand(ASC->getPointerOperand());
+      PushPtrOperand(ASC->getPointerOperand());
     } else if (auto *I2P = dyn_cast<IntToPtrInst>(&I)) {
       if (isNoopPtrIntCastPair(cast<Operator>(I2P), *DL, TTI))
         PushPtrOperand(
@@ -501,7 +568,7 @@
     // If the operands of the expression on the top are already explored,
     // adds that expression to the resultant postorder.
     if (PostorderStack.back().getInt()) {
-      if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace)
+      if (getPtrOrVecOfPtrsAddressSpace(TopVal->getType()) == FlatAddrSpace)
         Postorder.push_back(TopVal);
       PostorderStack.pop_back();
       continue;
@@ -529,8 +596,7 @@
     SmallVectorImpl<const Use *> *UndefUsesToFix) {
   Value *Operand = OperandUse.get();
 
-  Type *NewPtrTy = PointerType::getWithSamePointeeType(
-      cast<PointerType>(Operand->getType()), NewAddrSpace);
+  Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAddrSpace);
 
   if (Constant *C = dyn_cast<Constant>(Operand))
     return ConstantExpr::getAddrSpaceCast(C, NewPtrTy);
@@ -543,8 +609,7 @@
   if (I != PredicatedAS.end()) {
     // Insert an addrspacecast on that operand before the user.
     unsigned NewAS = I->second;
-    Type *NewPtrTy = PointerType::getWithSamePointeeType(
-        cast<PointerType>(Operand->getType()), NewAS);
+    Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(Operand->getType(), NewAS);
     auto *NewI = new AddrSpaceCastInst(Operand, NewPtrTy);
     NewI->insertBefore(Inst);
     NewI->setDebugLoc(Inst->getDebugLoc());
@@ -572,15 +637,14 @@
     const ValueToValueMapTy &ValueWithNewAddrSpace,
     const PredicatedAddrSpaceMapTy &PredicatedAS,
     SmallVectorImpl<const Use *> *UndefUsesToFix) const {
-  Type *NewPtrType = PointerType::getWithSamePointeeType(
-      cast<PointerType>(I->getType()), NewAddrSpace);
+  Type *NewPtrType = getPtrOrVecOfPtrsWithNewAS(I->getType(), NewAddrSpace);
 
   if (I->getOpcode() == Instruction::AddrSpaceCast) {
     Value *Src = I->getOperand(0);
     // Because `I` is flat, the source address space must be specific.
     // Therefore, the inferred address space must be the source space, according
     // to our algorithm.
-    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
+    assert(getPtrOrVecOfPtrsAddressSpace(Src->getType()) == NewAddrSpace);
     if (Src->getType() != NewPtrType)
       return new BitCastInst(Src, NewPtrType);
     return Src;
@@ -607,8 +671,7 @@
   if (AS != UninitializedAddressSpace) {
     // For the assumed address space, insert an `addrspacecast` to make that
     // explicit.
-    Type *NewPtrTy = PointerType::getWithSamePointeeType(
-        cast<PointerType>(I->getType()), AS);
+    Type *NewPtrTy = getPtrOrVecOfPtrsWithNewAS(I->getType(), AS);
     auto *NewI = new AddrSpaceCastInst(I, NewPtrTy);
     NewI->insertAfter(I);
     return NewI;
@@ -617,7 +680,7 @@
   // Computes the converted pointer operands.
   SmallVector<Value *, 4> NewPointerOperands;
   for (const Use &OperandUse : I->operands()) {
-    if (!OperandUse.get()->getType()->isPointerTy())
+    if (!isPtrOrVecOfPtrsType(OperandUse.get()->getType()))
       NewPointerOperands.push_back(nullptr);
     else
       NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef(
@@ -629,7 +692,7 @@
   case Instruction::BitCast:
     return new BitCastInst(NewPointerOperands[0], NewPtrType);
   case Instruction::PHI: {
-    assert(I->getType()->isPointerTy());
+    assert(isPtrOrVecOfPtrsType(I->getType()));
     PHINode *PHI = cast<PHINode>(I);
     PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues());
     for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) {
@@ -648,7 +711,7 @@
     return NewGEP;
   }
   case Instruction::Select:
-    assert(I->getType()->isPointerTy());
+    assert(isPtrOrVecOfPtrsType(I->getType()));
     return SelectInst::Create(I->getOperand(0), NewPointerOperands[1],
                               NewPointerOperands[2], "", nullptr, I);
   case Instruction::IntToPtr: {
@@ -674,16 +737,16 @@
     ConstantExpr *CE, unsigned NewAddrSpace,
     const ValueToValueMapTy &ValueWithNewAddrSpace, const DataLayout *DL,
     const TargetTransformInfo *TTI) {
-  Type *TargetType = CE->getType()->isPointerTy()
-                         ? PointerType::getWithSamePointeeType(
-                               cast<PointerType>(CE->getType()), NewAddrSpace)
-                         : CE->getType();
+  Type *TargetType =
+      isPtrOrVecOfPtrsType(CE->getType())
+          ? getPtrOrVecOfPtrsWithNewAS(CE->getType(), NewAddrSpace)
+          : CE->getType();
 
   if (CE->getOpcode() == Instruction::AddrSpaceCast) {
     // Because CE is flat, the source address space must be specific.
     // Therefore, the inferred address space must be the source space according
     // to our algorithm.
-    assert(CE->getOperand(0)->getType()->getPointerAddressSpace() ==
+    assert(getPtrOrVecOfPtrsAddressSpace(CE->getOperand(0)->getType()) ==
            NewAddrSpace);
     return ConstantExpr::getBitCast(CE->getOperand(0), TargetType);
   }
@@ -697,7 +760,7 @@
   if (CE->getOpcode() == Instruction::IntToPtr) {
     assert(isNoopPtrIntCastPair(cast<Operator>(CE), *DL, TTI));
     Constant *Src = cast<ConstantExpr>(CE->getOperand(0))->getOperand(0);
-    assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace);
+    assert(getPtrOrVecOfPtrsAddressSpace(Src->getType()) == NewAddrSpace);
     return ConstantExpr::getBitCast(Src, TargetType);
   }
 
@@ -753,7 +816,7 @@
     const PredicatedAddrSpaceMapTy &PredicatedAS,
     SmallVectorImpl<const Use *> *UndefUsesToFix) const {
   // All values in Postorder are flat address expressions.
-  assert(V->getType()->getPointerAddressSpace() == FlatAddrSpace &&
+  assert(getPtrOrVecOfPtrsAddressSpace(V->getType()) == FlatAddrSpace &&
          isAddressExpression(*V, *DL, TTI));
 
   if (Instruction *I = dyn_cast<Instruction>(V)) {
@@ -898,12 +961,14 @@
     Value *Src1 = Op.getOperand(2);
 
     auto I = InferredAddrSpace.find(Src0);
-    unsigned Src0AS = (I != InferredAddrSpace.end()) ?
-      I->second : Src0->getType()->getPointerAddressSpace();
+    unsigned Src0AS = (I != InferredAddrSpace.end())
+                          ? I->second
+                          : getPtrOrVecOfPtrsAddressSpace(Src0->getType());
 
     auto J = InferredAddrSpace.find(Src1);
-    unsigned Src1AS = (J != InferredAddrSpace.end()) ?
-      J->second : Src1->getType()->getPointerAddressSpace();
+    unsigned Src1AS = (J != InferredAddrSpace.end())
+                          ? J->second
+                          : getPtrOrVecOfPtrsAddressSpace(Src1->getType());
 
     auto *C0 = dyn_cast<Constant>(Src0);
     auto *C1 = dyn_cast<Constant>(Src1);
@@ -932,7 +997,7 @@
         auto I = InferredAddrSpace.find(PtrOperand);
         unsigned OperandAS;
         if (I == InferredAddrSpace.end()) {
-          OperandAS = PtrOperand->getType()->getPointerAddressSpace();
+          OperandAS = getPtrOrVecOfPtrsAddressSpace(PtrOperand->getType());
           if (OperandAS == FlatAddrSpace) {
             // Check AC for assumption dominating V.
             unsigned AS = getPredicatedAddrSpace(V, PtrOperand);
@@ -1057,7 +1122,7 @@
                                                         unsigned NewAS) const {
   assert(NewAS != UninitializedAddressSpace);
 
-  unsigned SrcAS = C->getType()->getPointerAddressSpace();
+  unsigned SrcAS = getPtrOrVecOfPtrsAddressSpace(C->getType());
   if (SrcAS == NewAS || isa<UndefValue>(C))
     return true;
 
@@ -1075,7 +1140,7 @@
       return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), NewAS);
 
     if (Op->getOpcode() == Instruction::IntToPtr &&
-        Op->getType()->getPointerAddressSpace() == FlatAddrSpace)
+        getPtrOrVecOfPtrsAddressSpace(Op->getType()) == FlatAddrSpace)
       return true;
   }
 
@@ -1111,7 +1176,7 @@
     if (NewAddrSpace == UninitializedAddressSpace)
       continue;
 
-    if (V->getType()->getPointerAddressSpace() != NewAddrSpace) {
+    if (getPtrOrVecOfPtrsAddressSpace(V->getType()) != NewAddrSpace) {
       Value *New =
           cloneValueWithNewAddressSpace(V, NewAddrSpace, ValueWithNewAddrSpace,
                                         PredicatedAS, &UndefUsesToFix);
@@ -1168,7 +1233,7 @@
       I = skipToNextUser(I, E);
 
       if (isSimplePointerUseValidToReplace(
-              *TTI, U, V->getType()->getPointerAddressSpace())) {
+              *TTI, U, getPtrOrVecOfPtrsAddressSpace(V->getType()))) {
         // If V is used as the pointer operand of a compatible memory operation,
         // sets the pointer operand to NewV. This replacement does not change
         // the element type, so the resultant load/store is still valid.
@@ -1199,13 +1264,13 @@
           // into
           //   %cmp = icmp eq float addrspace(3)* %new_p, %new_q
 
-          unsigned NewAS = NewV->getType()->getPointerAddressSpace();
+          unsigned NewAS = getPtrOrVecOfPtrsAddressSpace(NewV->getType());
           int SrcIdx = U.getOperandNo();
           int OtherIdx = (SrcIdx == 0) ? 1 : 0;
           Value *OtherSrc = Cmp->getOperand(OtherIdx);
 
           if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) {
-            if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) {
+            if (getPtrOrVecOfPtrsAddressSpace(OtherNewV->getType()) == NewAS) {
               Cmp->setOperand(OtherIdx, OtherNewV);
               Cmp->setOperand(SrcIdx, NewV);
               continue;
@@ -1224,11 +1289,10 @@
         }
 
         if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) {
-          unsigned NewAS = NewV->getType()->getPointerAddressSpace();
-          if (ASC->getDestAddressSpace() == NewAS) {
-            if (!cast<PointerType>(ASC->getType())
-                    ->hasSameElementTypeAs(
-                        cast<PointerType>(NewV->getType()))) {
+          unsigned NewAS = getPtrOrVecOfPtrsAddressSpace(NewV->getType());
+          if (getPtrOrVecOfPtrsAddressSpace(ASC->getType()) == NewAS) {
+            if (!hasSameElementOfPtrOrVecPtrs(ASC->getType(),
+                                              NewV->getType())) {
               BasicBlock::iterator InsertPos;
               if (Instruction *NewVInst = dyn_cast<Instruction>(NewV))
                 InsertPos = std::next(NewVInst->getIterator());
_______________________________________________
lldb-commits mailing list
lldb-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits

Reply via email to