hliao updated this revision to Diff 217901.
hliao added a comment.
Herald added subscribers: cfe-commits, erik.pilkington, mgorny.
Herald added a project: clang.

Revise `isLifeTimeMarker` following review comment.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D66950/new/

https://reviews.llvm.org/D66950

Files:
  clang/include/clang/AST/DeclCXX.h
  clang/include/clang/AST/Mangle.h
  clang/include/clang/AST/MangleNumberingContext.h
  clang/include/clang/Basic/LangOptions.def
  clang/include/clang/Driver/CC1Options.td
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/Decl.cpp
  clang/lib/AST/ItaniumMangle.cpp
  clang/lib/AST/MicrosoftCXXABI.cpp
  clang/lib/CodeGen/CGCUDANV.cpp
  clang/lib/Frontend/CompilerInvocation.cpp
  clang/lib/Sema/SemaLambda.cpp
  clang/test/CodeGenCUDA/unnamed-types.cu
  llvm/CMakeLists.txt
  llvm/cmake/modules/AddLLVM.cmake
  llvm/cmake/modules/HandleLLVMOptions.cmake
  llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
  llvm/lib/Transforms/Utils/SimplifyCFG.cpp
  llvm/test/Transforms/SimplifyCFG/sink-common-code.ll

Index: llvm/test/Transforms/SimplifyCFG/sink-common-code.ll
===================================================================
--- llvm/test/Transforms/SimplifyCFG/sink-common-code.ll
+++ llvm/test/Transforms/SimplifyCFG/sink-common-code.ll
@@ -886,6 +886,33 @@
 ; CHECK: ret
 }
 
+; CHECK-LABEL: @test_not_sink_lifetime_marker
+; CHECK-NOT: select
+; CHECK: call void @llvm.lifetime.end
+; CHECK: call void @llvm.lifetime.end
+define i32 @test_not_sink_lifetime_marker(i1 zeroext %flag, i32 %x) {
+entry:
+  %y = alloca i32
+  %z = alloca i32
+  br i1 %flag, label %if.then, label %if.else
+
+if.then:
+  %y.cast = bitcast i32* %y to i8*
+  call void @llvm.lifetime.end.p0i8(i64 4, i8* %y.cast)
+  br label %if.end
+
+if.else:
+  %z.cast = bitcast i32* %z to i8*
+  call void @llvm.lifetime.end.p0i8(i64 4, i8* %z.cast)
+  br label %if.end
+
+if.end:
+  ret i32 1
+}
+
+declare void @llvm.lifetime.start.p0i8(i64, i8* nocapture)
+declare void @llvm.lifetime.end.p0i8(i64, i8* nocapture)
+
 
 ; CHECK: ![[$TBAA]] = !{![[TYPE:[0-9]]], ![[TYPE]], i64 0}
 ; CHECK: ![[TYPE]] = !{!"float", ![[TEXT:[0-9]]]}
Index: llvm/lib/Transforms/Utils/SimplifyCFG.cpp
===================================================================
--- llvm/lib/Transforms/Utils/SimplifyCFG.cpp
+++ llvm/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -1421,6 +1421,20 @@
   return true;
 }
 
+// Check lifetime markers.
+static bool isLifeTimeMarker(const Instruction *I) {
+  if (auto II = dyn_cast<IntrinsicInst>(I)) {
+    switch (II->getIntrinsicID()) {
+    default:
+      break;
+    case Intrinsic::lifetime_start:
+    case Intrinsic::lifetime_end:
+      return true;
+    }
+  }
+  return false;
+}
+
 // All instructions in Insts belong to different blocks that all unconditionally
 // branch to a common successor. Analyze each instruction and return true if it
 // would be possible to sink them into their successor, creating one common
@@ -1475,20 +1489,25 @@
       return false;
   }
 
-  // Because SROA can't handle speculating stores of selects, try not
-  // to sink loads or stores of allocas when we'd have to create a PHI for
-  // the address operand. Also, because it is likely that loads or stores
-  // of allocas will disappear when Mem2Reg/SROA is run, don't sink them.
+  // Because SROA can't handle speculating stores of selects, try not to sink
+  // loads, stores or lifetime markers of allocas when we'd have to create a
+  // PHI for the address operand. Also, because it is likely that loads or
+  // stores of allocas will disappear when Mem2Reg/SROA is run, don't sink
+  // them.
   // This can cause code churn which can have unintended consequences down
   // the line - see https://llvm.org/bugs/show_bug.cgi?id=30244.
   // FIXME: This is a workaround for a deficiency in SROA - see
   // https://llvm.org/bugs/show_bug.cgi?id=30188
   if (isa<StoreInst>(I0) && any_of(Insts, [](const Instruction *I) {
-        return isa<AllocaInst>(I->getOperand(1));
+        return isa<AllocaInst>(I->getOperand(1)->stripPointerCasts());
       }))
     return false;
   if (isa<LoadInst>(I0) && any_of(Insts, [](const Instruction *I) {
-        return isa<AllocaInst>(I->getOperand(0));
+        return isa<AllocaInst>(I->getOperand(0)->stripPointerCasts());
+      }))
+    return false;
+  if (isLifeTimeMarker(I0) && any_of(Insts, [](const Instruction *I) {
+        return isa<AllocaInst>(I->getOperand(1)->stripPointerCasts());
       }))
     return false;
 
Index: llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
===================================================================
--- llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
+++ llvm/lib/CodeGen/SelectionDAG/InstrEmitter.cpp
@@ -384,6 +384,8 @@
                                   Op.getNode()->isDivergent() ||
                                       (IIRC && TRI->isDivergentRegClass(IIRC)))
             : nullptr;
+    assert(!II || IIOpNum < II->getNumOperands() || !IIRC);
+    IIRC = TRI->getAllocatableClass(IIRC);
 
     if (OpRC && IIRC && OpRC != IIRC && Register::isVirtualRegister(VReg)) {
       Register NewVReg = MRI->createVirtualRegister(IIRC);
Index: llvm/cmake/modules/HandleLLVMOptions.cmake
===================================================================
--- llvm/cmake/modules/HandleLLVMOptions.cmake
+++ llvm/cmake/modules/HandleLLVMOptions.cmake
@@ -85,7 +85,7 @@
 
 if(LLVM_ENABLE_EXPENSIVE_CHECKS)
   add_definitions(-DEXPENSIVE_CHECKS)
-  add_definitions(-D_GLIBCXX_DEBUG)
+  #  add_definitions(-D_GLIBCXX_DEBUG)
 endif()
 
 string(TOUPPER "${LLVM_ABI_BREAKING_CHECKS}" uppercase_LLVM_ABI_BREAKING_CHECKS)
Index: llvm/cmake/modules/AddLLVM.cmake
===================================================================
--- llvm/cmake/modules/AddLLVM.cmake
+++ llvm/cmake/modules/AddLLVM.cmake
@@ -199,6 +199,10 @@
     set(LLVM_LINKER_DETECTED NO)
     message(STATUS "Linker detection: unknown")
   endif()
+  if(LLVM_LINKER_IS_GOLD AND LLVM_USE_GDBINDEX)
+    append("-Wl,--gdb-index" CMAKE_EXE_LINKER_FLAGS CMAKE_SHARED_LINKER_FLAGS
+      CMAKE_MODULE_LINKER_FLAGS)
+  endif()
 endif()
 
 function(add_link_opts target_name)
Index: llvm/CMakeLists.txt
===================================================================
--- llvm/CMakeLists.txt
+++ llvm/CMakeLists.txt
@@ -460,6 +460,8 @@
 
 option(LLVM_USE_SPLIT_DWARF
   "Use -gsplit-dwarf when compiling llvm." OFF)
+option(LLVM_USE_GDBINDEX
+  "Use -Wl,--gdb-index to index the debug info." OFF)
 
 option(LLVM_POLLY_LINK_INTO_TOOLS "Statically link Polly into tools (if available)" ON)
 option(LLVM_POLLY_BUILD "Build LLVM with Polly" ON)
Index: clang/test/CodeGenCUDA/unnamed-types.cu
===================================================================
--- /dev/null
+++ clang/test/CodeGenCUDA/unnamed-types.cu
@@ -0,0 +1,52 @@
+// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-linux-gnu -aux-triple amdgcn-amd-amdhsa -fcuda-force-lambda-odr -emit-llvm %s -o - | FileCheck %s --check-prefix=HOST
+// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-pc-windows-msvc -aux-triple amdgcn-amd-amdhsa -fcuda-force-lambda-odr -emit-llvm %s -o - | FileCheck %s --check-prefix=MSVC
+// RUN: %clang_cc1 -std=c++11 -x hip -triple amdgcn-amd-amdhsa -fcuda-force-lambda-odr -fcuda-is-device -emit-llvm %s -o - | FileCheck %s --check-prefix=DEVICE
+
+#include "Inputs/cuda.h"
+
+// HOST: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
+// HOST: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1
+// Check that, on MSVC, the same device kernel mangling name is generated.
+// MSVC: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
+// MSVC: @1 = private unnamed_addr constant [60 x i8] c"_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_\00", align 1
+
+__device__ float d0(float x) {
+  return [](float x) { return x + 1.f; }(x);
+}
+
+__device__ float d1(float x) {
+  return [](float x) { return x * 2.f; }(x);
+}
+
+// DEVICE: amdgpu_kernel void @_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_(
+template <typename F>
+__global__ void k0(float *p, F f) {
+  p[0] = f(p[0]) + d0(p[1]) + d1(p[2]);
+}
+
+// DEVICE: amdgpu_kernel void @_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_(
+template <typename F0, typename F1, typename F2>
+__global__ void k1(float *p, F0 f0, F1 f1, F2 f2) {
+  p[0] = f0(p[0]) + f1(p[1], p[2]) + f2(p[3]);
+}
+
+void f0(float *p) {
+  [](float *p) {
+    *p = 1.f;
+  }(p);
+}
+
+void f1(float *p) {
+  [](float *p) {
+    k0<<<1,1>>>(p, [] __device__ (float x) { return x + 3.f; });
+  }(p);
+  k1<<<1,1>>>(p,
+              [] __device__ (float x) { return x + 4.f; },
+              [] __device__ (float x, float y) { return x * y; },
+              [] __device__ (float x) { return x + 5.f; });
+}
+// HOST: @__hip_register_globals
+// HOST: __hipRegisterFunction{{.*}}@_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0
+// HOST: __hipRegisterFunction{{.*}}@_Z2k1IZ2f1PfEUlfE_Z2f1S0_EUlffE_Z2f1S0_EUlfE0_EvS0_T_T0_T1_{{.*}}@1
+// MSVC: __hipRegisterFunction{{.*}}@"??$k0@V<lambda_1>@?0???R1?0??f1@@YAXPEAM@Z@QEBA@0@Z@@@YAXPEAMV<lambda_1>@?0???R0?0??f1@@YAX0@Z@QEBA@0@Z@@Z{{.*}}@0
+// MSVC: __hipRegisterFunction{{.*}}@"??$k1@V<lambda_2>@?0??f1@@YAXPEAM@Z@V<lambda_3>@?0??2@YAX0@Z@V<lambda_4>@?0??2@YAX0@Z@@@YAXPEAMV<lambda_2>@?0??f1@@YAX0@Z@V<lambda_3>@?0??1@YAX0@Z@V<lambda_4>@?0??1@YAX0@Z@@Z{{.*}}@1
Index: clang/lib/Sema/SemaLambda.cpp
===================================================================
--- clang/lib/Sema/SemaLambda.cpp
+++ clang/lib/Sema/SemaLambda.cpp
@@ -274,7 +274,8 @@
 
 MangleNumberingContext *
 Sema::getCurrentMangleNumberContext(const DeclContext *DC,
-                                    Decl *&ManglingContextDecl) {
+                                    Decl *&ManglingContextDecl,
+                                    bool SkpNoODRChk, bool *Forced) {
   // Compute the context for allocating mangling numbers in the current
   // expression, if the ABI requires them.
   ManglingContextDecl = ExprEvalContexts.back().ManglingContextDecl;
@@ -322,9 +323,14 @@
   case Normal: {
     //  -- the bodies of non-exported nonspecialized template functions
     //  -- the bodies of inline functions
-    if ((IsInNonspecializedTemplate &&
+    bool NeedODR =
+        (IsInNonspecializedTemplate &&
          !(ManglingContextDecl && isa<ParmVarDecl>(ManglingContextDecl))) ||
-        isInInlineFunction(CurContext)) {
+        isInInlineFunction(CurContext);
+    if (NeedODR || SkpNoODRChk) {
+      // Set forced if it don't need to follow ODR originally.
+      if (SkpNoODRChk && Forced)
+        *Forced = !NeedODR;
       ManglingContextDecl = nullptr;
       while (auto *CD = dyn_cast<CapturedDecl>(DC))
         DC = CD->getParent();
@@ -337,10 +343,13 @@
 
   case StaticDataMember:
     //  -- the initializers of nonspecialized static members of template classes
-    if (!IsInNonspecializedTemplate) {
+    if (!SkpNoODRChk && !IsInNonspecializedTemplate) {
       ManglingContextDecl = nullptr;
       return nullptr;
     }
+    // Set forced if it don't need to follow ODR originally.
+    if (SkpNoODRChk && Forced)
+      *Forced = !IsInNonspecializedTemplate;
     // Fall through to get the current context.
     LLVM_FALLTHROUGH;
 
@@ -437,11 +446,16 @@
     Class->setLambdaMangling(Mangling->first, Mangling->second);
   } else {
     Decl *ManglingContextDecl;
-    if (MangleNumberingContext *MCtx =
-            getCurrentMangleNumberContext(Class->getDeclContext(),
-                                          ManglingContextDecl)) {
+    bool Forced = false;
+    if (MangleNumberingContext *MCtx = getCurrentMangleNumberContext(
+            Class->getDeclContext(), ManglingContextDecl,
+            getLangOpts().CUDAForceLambdaODR, &Forced)) {
       unsigned ManglingNumber = MCtx->getManglingNumber(Method);
-      Class->setLambdaMangling(ManglingNumber, ManglingContextDecl);
+      Class->setLambdaMangling(ManglingNumber, ManglingContextDecl, Forced);
+      if (MCtx->hasDeviceMangleNumberingContext()) {
+        Class->setDeviceLambdaManglingNumber(
+            MCtx->getDeviceManglingNumber(Method));
+      }
     }
   }
 
Index: clang/lib/Frontend/CompilerInvocation.cpp
===================================================================
--- clang/lib/Frontend/CompilerInvocation.cpp
+++ clang/lib/Frontend/CompilerInvocation.cpp
@@ -2513,6 +2513,9 @@
   if (Args.hasArg(OPT_fno_cuda_host_device_constexpr))
     Opts.CUDAHostDeviceConstexpr = 0;
 
+  if (Args.hasArg(OPT_fcuda_force_lambda_odr))
+    Opts.CUDAForceLambdaODR = 1;
+
   if (Opts.CUDAIsDevice && Args.hasArg(OPT_fcuda_approx_transcendentals))
     Opts.CUDADeviceApproxTranscendentals = 1;
 
Index: clang/lib/CodeGen/CGCUDANV.cpp
===================================================================
--- clang/lib/CodeGen/CGCUDANV.cpp
+++ clang/lib/CodeGen/CGCUDANV.cpp
@@ -166,6 +166,10 @@
   CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy));
   VoidPtrTy = cast<llvm::PointerType>(Types.ConvertType(Ctx.VoidPtrTy));
   VoidPtrPtrTy = VoidPtrTy->getPointerTo();
+
+  DeviceMC->setDeviceMangleContext(
+      CGM.getContext().getTargetInfo().getCXXABI().isMicrosoft() &&
+      CGM.getContext().getAuxTargetInfo()->getCXXABI().isItaniumFamily());
 }
 
 llvm::FunctionCallee CGNVCUDARuntime::getSetupArgumentFn() const {
Index: clang/lib/AST/MicrosoftCXXABI.cpp
===================================================================
--- clang/lib/AST/MicrosoftCXXABI.cpp
+++ clang/lib/AST/MicrosoftCXXABI.cpp
@@ -22,6 +22,138 @@
 
 using namespace clang;
 
+// Before revising the interface, clone of `ItaniumNumberingContext` from
+// `lib/AST/ItaniumCXXABI.cpp`.
+// {{{ BEGIN CLONE
+namespace {
+
+/// According to Itanium C++ ABI 5.1.2:
+/// the name of an anonymous union is considered to be
+/// the name of the first named data member found by a pre-order,
+/// depth-first, declaration-order walk of the data members of
+/// the anonymous union.
+/// If there is no such data member (i.e., if all of the data members
+/// in the union are unnamed), then there is no way for a program to
+/// refer to the anonymous union, and there is therefore no need to mangle its name.
+///
+/// Returns the name of anonymous union VarDecl or nullptr if it is not found.
+static const IdentifierInfo *findAnonymousUnionVarDeclName(const VarDecl& VD) {
+  const RecordType *RT = VD.getType()->getAs<RecordType>();
+  assert(RT && "type of VarDecl is expected to be RecordType.");
+  assert(RT->getDecl()->isUnion() && "RecordType is expected to be a union.");
+  if (const FieldDecl *FD = RT->getDecl()->findFirstNamedDataMember()) {
+    return FD->getIdentifier();
+  }
+
+  return nullptr;
+}
+
+/// The name of a decomposition declaration.
+struct DecompositionDeclName {
+  using BindingArray = ArrayRef<const BindingDecl*>;
+
+  /// Representative example of a set of bindings with these names.
+  BindingArray Bindings;
+
+  /// Iterators over the sequence of identifiers in the name.
+  struct Iterator
+      : llvm::iterator_adaptor_base<Iterator, BindingArray::const_iterator,
+                                    std::random_access_iterator_tag,
+                                    const IdentifierInfo *> {
+    Iterator(BindingArray::const_iterator It) : iterator_adaptor_base(It) {}
+    const IdentifierInfo *operator*() const {
+      return (*this->I)->getIdentifier();
+    }
+  };
+  Iterator begin() const { return Iterator(Bindings.begin()); }
+  Iterator end() const { return Iterator(Bindings.end()); }
+};
+}
+
+namespace llvm {
+template<>
+struct DenseMapInfo<DecompositionDeclName> {
+  using ArrayInfo = llvm::DenseMapInfo<ArrayRef<const BindingDecl*>>;
+  using IdentInfo = llvm::DenseMapInfo<const IdentifierInfo*>;
+  static DecompositionDeclName getEmptyKey() {
+    return {ArrayInfo::getEmptyKey()};
+  }
+  static DecompositionDeclName getTombstoneKey() {
+    return {ArrayInfo::getTombstoneKey()};
+  }
+  static unsigned getHashValue(DecompositionDeclName Key) {
+    assert(!isEqual(Key, getEmptyKey()) && !isEqual(Key, getTombstoneKey()));
+    return llvm::hash_combine_range(Key.begin(), Key.end());
+  }
+  static bool isEqual(DecompositionDeclName LHS, DecompositionDeclName RHS) {
+    if (ArrayInfo::isEqual(LHS.Bindings, ArrayInfo::getEmptyKey()))
+      return ArrayInfo::isEqual(RHS.Bindings, ArrayInfo::getEmptyKey());
+    if (ArrayInfo::isEqual(LHS.Bindings, ArrayInfo::getTombstoneKey()))
+      return ArrayInfo::isEqual(RHS.Bindings, ArrayInfo::getTombstoneKey());
+    return LHS.Bindings.size() == RHS.Bindings.size() &&
+           std::equal(LHS.begin(), LHS.end(), RHS.begin());
+  }
+};
+}
+
+namespace {
+
+/// Keeps track of the mangled names of lambda expressions and block
+/// literals within a particular context.
+class ItaniumNumberingContext : public MangleNumberingContext {
+  llvm::DenseMap<const Type *, unsigned> ManglingNumbers;
+  llvm::DenseMap<const IdentifierInfo *, unsigned> VarManglingNumbers;
+  llvm::DenseMap<const IdentifierInfo *, unsigned> TagManglingNumbers;
+  llvm::DenseMap<DecompositionDeclName, unsigned>
+      DecompsitionDeclManglingNumbers;
+
+public:
+  unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
+    const FunctionProtoType *Proto =
+        CallOperator->getType()->getAs<FunctionProtoType>();
+    ASTContext &Context = CallOperator->getASTContext();
+
+    FunctionProtoType::ExtProtoInfo EPI;
+    EPI.Variadic = Proto->isVariadic();
+    QualType Key =
+        Context.getFunctionType(Context.VoidTy, Proto->getParamTypes(), EPI);
+    Key = Context.getCanonicalType(Key);
+    return ++ManglingNumbers[Key->castAs<FunctionProtoType>()];
+  }
+
+  unsigned getManglingNumber(const BlockDecl *BD) override {
+    const Type *Ty = nullptr;
+    return ++ManglingNumbers[Ty];
+  }
+
+  unsigned getStaticLocalNumber(const VarDecl *VD) override {
+    return 0;
+  }
+
+  /// Variable decls are numbered by identifier.
+  unsigned getManglingNumber(const VarDecl *VD, unsigned) override {
+    if (auto *DD = dyn_cast<DecompositionDecl>(VD)) {
+      DecompositionDeclName Name{DD->bindings()};
+      return ++DecompsitionDeclManglingNumbers[Name];
+    }
+
+    const IdentifierInfo *Identifier = VD->getIdentifier();
+    if (!Identifier) {
+      // VarDecl without an identifier represents an anonymous union
+      // declaration.
+      Identifier = findAnonymousUnionVarDeclName(*VD);
+    }
+    return ++VarManglingNumbers[Identifier];
+  }
+
+  unsigned getManglingNumber(const TagDecl *TD, unsigned) override {
+    return ++TagManglingNumbers[TD->getIdentifier()];
+  }
+};
+
+} // End anonymous namesapce
+// END CLONE }}}
+
 namespace {
 
 /// Numbers things which need to correspond across multiple TUs.
@@ -63,6 +195,41 @@
   }
 };
 
+class MSHIPNumberingContext : public MangleNumberingContext {
+  MicrosoftNumberingContext HostCtx;
+  ItaniumNumberingContext DeviceCtx;
+
+public:
+
+  unsigned getManglingNumber(const CXXMethodDecl *CallOperator) override {
+    return HostCtx.getManglingNumber(CallOperator);
+  }
+
+  unsigned getManglingNumber(const BlockDecl *BD) override {
+    return HostCtx.getManglingNumber(BD);
+  }
+
+  unsigned getStaticLocalNumber(const VarDecl *VD) override {
+    return HostCtx.getStaticLocalNumber(VD);
+  }
+
+  unsigned getManglingNumber(const VarDecl *VD,
+                             unsigned MSLocalManglingNumber) override {
+    return HostCtx.getManglingNumber(VD, MSLocalManglingNumber);
+  }
+
+  unsigned getManglingNumber(const TagDecl *TD,
+                             unsigned MSLocalManglingNumber) override {
+    return HostCtx.getManglingNumber(TD, MSLocalManglingNumber);
+  }
+
+  bool hasDeviceMangleNumberingContext() override { return true; }
+
+  unsigned getDeviceManglingNumber(const CXXMethodDecl *CallOperator) override {
+    return DeviceCtx.getManglingNumber(CallOperator);
+  }
+};
+
 class MicrosoftCXXABI : public CXXABI {
   ASTContext &Context;
   llvm::SmallDenseMap<CXXRecordDecl *, CXXConstructorDecl *> RecordToCopyCtor;
@@ -132,6 +299,8 @@
 
   std::unique_ptr<MangleNumberingContext>
   createMangleNumberingContext() const override {
+    if (Context.getLangOpts().CUDA)
+      return std::make_unique<MSHIPNumberingContext>();
     return std::make_unique<MicrosoftNumberingContext>();
   }
 };
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -122,6 +122,8 @@
   llvm::DenseMap<DiscriminatorKeyTy, unsigned> Discriminator;
   llvm::DenseMap<const NamedDecl*, unsigned> Uniquifier;
 
+  bool IsDevCtx = false;
+
 public:
   explicit ItaniumMangleContextImpl(ASTContext &Context,
                                     DiagnosticsEngine &Diags)
@@ -134,6 +136,10 @@
   bool shouldMangleStringLiteral(const StringLiteral *) override {
     return false;
   }
+
+  bool isDeviceMangleContext() const override { return IsDevCtx; }
+  void setDeviceMangleContext(bool IsDev) override { IsDevCtx = IsDev;}
+
   void mangleCXXName(const NamedDecl *D, raw_ostream &) override;
   void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk,
                    raw_ostream &) override;
@@ -1739,7 +1745,9 @@
   // (in lexical order) with that same <lambda-sig> and context.
   //
   // The AST keeps track of the number for us.
-  unsigned Number = Lambda->getLambdaManglingNumber();
+  unsigned Number = Context.isDeviceMangleContext()
+                        ? Lambda->getDeviceLambdaManglingNumber()
+                        : Lambda->getLambdaManglingNumber();
   assert(Number > 0 && "Lambda should be mangled as an unnamed class");
   if (Number > 1)
     mangleNumber(Number - 2);
Index: clang/lib/AST/Decl.cpp
===================================================================
--- clang/lib/AST/Decl.cpp
+++ clang/lib/AST/Decl.cpp
@@ -1385,8 +1385,10 @@
     case Decl::CXXRecord: {
       const auto *Record = cast<CXXRecordDecl>(D);
       if (Record->isLambda()) {
-        if (!Record->getLambdaManglingNumber()) {
-          // This lambda has no mangling number, so it's internal.
+        if (!Record->getLambdaManglingNumber() ||
+            Record->hasForcedLambdaManglingNumber()) {
+          // This lambda has no mangling number or that number is forced, so
+          // it's internal.
           return getInternalLinkageFor(D);
         }
 
@@ -1402,7 +1404,8 @@
         //  };
         const CXXRecordDecl *OuterMostLambda =
             getOutermostEnclosingLambda(Record);
-        if (!OuterMostLambda->getLambdaManglingNumber())
+        if (!OuterMostLambda->getLambdaManglingNumber() ||
+            OuterMostLambda->hasForcedLambdaManglingNumber())
           return getInternalLinkageFor(D);
 
         return getLVForClosure(
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1105,10 +1105,11 @@
   /// block literal.
   /// \param[out] ManglingContextDecl - Returns the ManglingContextDecl
   /// associated with the context, if relevant.
-  MangleNumberingContext *getCurrentMangleNumberContext(
-    const DeclContext *DC,
-    Decl *&ManglingContextDecl);
-
+  MangleNumberingContext *
+  getCurrentMangleNumberContext(const DeclContext *DC,
+                                Decl *&ManglingContextDecl,
+                                bool SkpNoODRChk = false,
+                                bool *Forced = nullptr);
 
   /// SpecialMemberOverloadResult - The overloading result for a special member
   /// function.
Index: clang/include/clang/Driver/CC1Options.td
===================================================================
--- clang/include/clang/Driver/CC1Options.td
+++ clang/include/clang/Driver/CC1Options.td
@@ -866,6 +866,8 @@
   HelpText<"Allow variadic functions in CUDA device code.">;
 def fno_cuda_host_device_constexpr : Flag<["-"], "fno-cuda-host-device-constexpr">,
   HelpText<"Don't treat unattributed constexpr functions as __host__ __device__.">;
+def fcuda_force_lambda_odr : Flag<["-"], "fcuda-force-lambda-odr">,
+  HelpText<"Force lambda naming following one definition rule (ODR).">;
 
 //===----------------------------------------------------------------------===//
 // OpenMP Options
Index: clang/include/clang/Basic/LangOptions.def
===================================================================
--- clang/include/clang/Basic/LangOptions.def
+++ clang/include/clang/Basic/LangOptions.def
@@ -221,6 +221,7 @@
 LANGOPT(CUDAHostDeviceConstexpr, 1, 1, "treating unattributed constexpr functions as __host__ __device__")
 LANGOPT(CUDADeviceApproxTranscendentals, 1, 0, "using approximate transcendental functions")
 LANGOPT(GPURelocatableDeviceCode, 1, 0, "generate relocatable device code")
+LANGOPT(CUDAForceLambdaODR, 1, 0, "force lambda naming following one definition rule (ODR)")
 
 LANGOPT(SYCLIsDevice      , 1, 0, "Generate code for SYCL device")
 
Index: clang/include/clang/AST/MangleNumberingContext.h
===================================================================
--- clang/include/clang/AST/MangleNumberingContext.h
+++ clang/include/clang/AST/MangleNumberingContext.h
@@ -16,6 +16,7 @@
 
 #include "clang/Basic/LLVM.h"
 #include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/Support/ErrorHandling.h"
 
 namespace clang {
 
@@ -52,6 +53,15 @@
   /// this context.
   virtual unsigned getManglingNumber(const TagDecl *TD,
                                      unsigned MSLocalManglingNumber) = 0;
+
+  /// Has device mangle number context.
+  virtual bool hasDeviceMangleNumberingContext() { return false; }
+
+  /// Retrieve the mangling number of a new lambda expression with the
+  /// given call operator within the device context.
+  virtual unsigned getDeviceManglingNumber(const CXXMethodDecl *) {
+    llvm_unreachable("There's no device context associated!");
+  }
 };
 
 } // end namespace clang
Index: clang/include/clang/AST/Mangle.h
===================================================================
--- clang/include/clang/AST/Mangle.h
+++ clang/include/clang/AST/Mangle.h
@@ -95,6 +95,9 @@
   virtual bool shouldMangleCXXName(const NamedDecl *D) = 0;
   virtual bool shouldMangleStringLiteral(const StringLiteral *SL) = 0;
 
+  virtual bool isDeviceMangleContext() const { return false; }
+  virtual void setDeviceMangleContext(bool) {}
+
   // FIXME: consider replacing raw_ostream & with something like SmallString &.
   void mangleName(const NamedDecl *D, raw_ostream &);
   virtual void mangleCXXName(const NamedDecl *D, raw_ostream &) = 0;
Index: clang/include/clang/AST/DeclCXX.h
===================================================================
--- clang/include/clang/AST/DeclCXX.h
+++ clang/include/clang/AST/DeclCXX.h
@@ -586,6 +586,13 @@
     /// mangling in the Itanium C++ ABI.
     unsigned ManglingNumber = 0;
 
+    /// The device side name mangling number.
+    unsigned DeviceManglingNumber = 0;
+
+    /// The mangling number is enforced to ensure ODR naming.
+    // FIXME: Save bit from `NumCaptures` to minimize `LambdaDefinitionData`.
+    bool ForcedNumbering = false;
+
     /// The declaration that provides context for this lambda, if the
     /// actual DeclContext does not suffice. This is used for lambdas that
     /// occur within default arguments of function parameters within the class
@@ -1881,6 +1888,11 @@
     return getLambdaData().ManglingNumber;
   }
 
+  bool hasForcedLambdaManglingNumber() const {
+    assert(isLambda() && "Not a lambda closure type!");
+    return getLambdaData().ForcedNumbering;
+  }
+
   /// Retrieve the declaration that provides additional context for a
   /// lambda, when the normal declaration context is not specific enough.
   ///
@@ -1894,11 +1906,23 @@
 
   /// Set the mangling number and context declaration for a lambda
   /// class.
-  void setLambdaMangling(unsigned ManglingNumber, Decl *ContextDecl) {
+  void setLambdaMangling(unsigned ManglingNumber, Decl *ContextDecl,
+                         bool Forced = false) {
     getLambdaData().ManglingNumber = ManglingNumber;
+    getLambdaData().ForcedNumbering = Forced;
     getLambdaData().ContextDecl = ContextDecl;
   }
 
+  /// Set the device side mangling number.
+  void setDeviceLambdaManglingNumber(unsigned Num) {
+    getLambdaData().DeviceManglingNumber = Num;
+  }
+
+  unsigned getDeviceLambdaManglingNumber() const {
+    assert(isLambda() && "Not a lambda closure type!");
+    return getLambdaData().DeviceManglingNumber;
+  }
+
   /// Returns the inheritance model used for this record.
   MSInheritanceAttr::Spelling getMSInheritanceModel() const;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to