erichkeane updated this revision to Diff 348108.
erichkeane added a comment.
Herald added subscribers: phosek, aheejin, dschuff.

Replace the DeviceLambdaManglingNumber mechanism with the callback mechanism.

Hopefully this is what you were thinking @rjmccall.


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D103112/new/

https://reviews.llvm.org/D103112

Files:
  clang/include/clang/AST/ASTContext.h
  clang/include/clang/AST/Mangle.h
  clang/lib/AST/ASTContext.cpp
  clang/lib/AST/Expr.cpp
  clang/lib/AST/ItaniumMangle.cpp
  clang/lib/CodeGen/CGCUDANV.cpp
  clang/lib/Sema/SemaSYCL.cpp

Index: clang/lib/Sema/SemaSYCL.cpp
===================================================================
--- clang/lib/Sema/SemaSYCL.cpp
+++ clang/lib/Sema/SemaSYCL.cpp
@@ -65,16 +65,17 @@
 }
 
 void Sema::AddSYCLKernelLambda(const FunctionDecl *FD) {
-  auto ShouldMangleCallback = [](ASTContext &Ctx, const CXXRecordDecl *RD) {
-    // We ALWAYS want to descend into the lambda mangling for these.
-    return true;
+  auto MangleCallback =
+      [](ASTContext &Ctx, const CXXRecordDecl *RD) -> llvm::Optional<unsigned> {
+    Ctx.AddSYCLKernelNamingDecl(RD);
+    // We always want to go into the lambda mangling (skipping the unnamed
+    // struct version), so make sure we return a value here.
+    return 1;
   };
-  auto MangleCallback = [](ASTContext &Ctx, const CXXRecordDecl *RD,
-                           raw_ostream &) { Ctx.AddSYCLKernelNamingDecl(RD); };
 
   QualType Ty = GetSYCLKernelObjectType(FD);
   std::unique_ptr<MangleContext> Ctx{ItaniumMangleContext::create(
-      Context, Context.getDiagnostics(), ShouldMangleCallback, MangleCallback)};
+      Context, Context.getDiagnostics(), MangleCallback)};
   llvm::raw_null_ostream Out;
   Ctx->mangleTypeName(Ty, Out);
 }
Index: clang/lib/CodeGen/CGCUDANV.cpp
===================================================================
--- clang/lib/CodeGen/CGCUDANV.cpp
+++ clang/lib/CodeGen/CGCUDANV.cpp
@@ -191,12 +191,27 @@
   return ((Twine("__cuda") + Twine(FuncName)).str());
 }
 
+static std::unique_ptr<MangleContext> InitDeviceMC(CodeGenModule &CGM) {
+  // If the host and device have different C++ ABIs, mark it as the device
+  // mangle context so that the mangling needs to retrieve the additonal
+  // device lambda mangling number instead of the regular host one.
+  if (CGM.getContext().getAuxTargetInfo() &&
+      CGM.getContext().getTargetInfo().getCXXABI().isMicrosoft() &&
+      CGM.getContext().getAuxTargetInfo()->getCXXABI().isItaniumFamily()) {
+    return std::unique_ptr<MangleContext>(
+        CGM.getContext().createDeviceMangleContext(
+            *CGM.getContext().getAuxTargetInfo()));
+  }
+
+  return std::unique_ptr<MangleContext>(CGM.getContext().createMangleContext(
+      CGM.getContext().getAuxTargetInfo()));
+}
+
 CGNVCUDARuntime::CGNVCUDARuntime(CodeGenModule &CGM)
     : CGCUDARuntime(CGM), Context(CGM.getLLVMContext()),
       TheModule(CGM.getModule()),
       RelocatableDeviceCode(CGM.getLangOpts().GPURelocatableDeviceCode),
-      DeviceMC(CGM.getContext().createMangleContext(
-          CGM.getContext().getAuxTargetInfo())) {
+      DeviceMC(InitDeviceMC(CGM)) {
   CodeGen::CodeGenTypes &Types = CGM.getTypes();
   ASTContext &Ctx = CGM.getContext();
 
@@ -207,14 +222,6 @@
   CharPtrTy = llvm::PointerType::getUnqual(Types.ConvertType(Ctx.CharTy));
   VoidPtrTy = cast<llvm::PointerType>(Types.ConvertType(Ctx.VoidPtrTy));
   VoidPtrPtrTy = VoidPtrTy->getPointerTo();
-  if (CGM.getContext().getAuxTargetInfo()) {
-    // If the host and device have different C++ ABIs, mark it as the device
-    // mangle context so that the mangling needs to retrieve the additonal
-    // device lambda mangling number instead of the regular host one.
-    DeviceMC->setDeviceMangleContext(
-        CGM.getContext().getTargetInfo().getCXXABI().isMicrosoft() &&
-        CGM.getContext().getAuxTargetInfo()->getCXXABI().isItaniumFamily());
-  }
 }
 
 llvm::FunctionCallee CGNVCUDARuntime::getSetupArgumentFn() const {
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -125,20 +125,15 @@
   typedef std::pair<const DeclContext*, IdentifierInfo*> DiscriminatorKeyTy;
   llvm::DenseMap<DiscriminatorKeyTy, unsigned> Discriminator;
   llvm::DenseMap<const NamedDecl*, unsigned> Uniquifier;
-  const ShouldCallKernelCallbackTy ShouldCallKernelCallback = nullptr;
   const KernelMangleCallbackTy KernelMangleCallback = nullptr;
 
-  bool IsDevCtx = false;
   bool NeedsUniqueInternalLinkageNames = false;
 
 public:
-  explicit ItaniumMangleContextImpl(
-      ASTContext &Context, DiagnosticsEngine &Diags,
-      ShouldCallKernelCallbackTy ShouldCallKernelCB,
-      KernelMangleCallbackTy KernelCB)
-      : ItaniumMangleContext(Context, Diags),
-        ShouldCallKernelCallback(ShouldCallKernelCB),
-        KernelMangleCallback(KernelCB) {}
+  explicit ItaniumMangleContextImpl(ASTContext &Context,
+                                    DiagnosticsEngine &Diags,
+                                    KernelMangleCallbackTy KernelCB)
+      : ItaniumMangleContext(Context, Diags), KernelMangleCallback(KernelCB) {}
 
   /// @name Mangler Entry Points
   /// @{
@@ -153,9 +148,6 @@
     NeedsUniqueInternalLinkageNames = true;
   }
 
-  bool isDeviceMangleContext() const override { return IsDevCtx; }
-  void setDeviceMangleContext(bool IsDev) override { IsDevCtx = IsDev; }
-
   void mangleCXXName(GlobalDecl GD, raw_ostream &) override;
   void mangleThunk(const CXXMethodDecl *MD, const ThunkInfo &Thunk,
                    raw_ostream &) override;
@@ -252,9 +244,6 @@
     return Name;
   }
 
-  ShouldCallKernelCallbackTy getShouldCallKernelCallback() const override {
-    return ShouldCallKernelCallback;
-  }
   KernelMangleCallbackTy getKernelMangleCallback() const override {
     return KernelMangleCallback;
   }
@@ -1529,7 +1518,7 @@
     //     # Parameter types or 'v' for 'void'.
     if (const CXXRecordDecl *Record = dyn_cast<CXXRecordDecl>(TD)) {
       if (Record->isLambda() && (Record->getLambdaManglingNumber() ||
-                                 Context.getShouldCallKernelCallback()(
+                                 Context.getKernelMangleCallback()(
                                      Context.getASTContext(), Record))) {
         assert(!AdditionalAbiTags &&
                "Lambda type cannot have additional abi tags");
@@ -1968,16 +1957,10 @@
   // if the host-side CXX ABI has different numbering for lambda. In such case,
   // if the mangle context is that device-side one, use the device-side lambda
   // mangling number for this lambda.
-
-  unsigned Number = Context.isDeviceMangleContext()
-                        ? Lambda->getDeviceLambdaManglingNumber()
-                        : Lambda->getLambdaManglingNumber();
-
-  if (Context.getShouldCallKernelCallback()(Context.getASTContext(), Lambda)) {
-    Context.getKernelMangleCallback()(Context.getASTContext(), Lambda, Out);
-    Out << '_';
-    return;
-  }
+  llvm::Optional<unsigned> DeviceNumber =
+      Context.getKernelMangleCallback()(Context.getASTContext(), Lambda);
+  unsigned Number = DeviceNumber.hasValue() ? *DeviceNumber
+                                            : Lambda->getLambdaManglingNumber();
 
   assert(Number > 0 && "Lambda should be mangled as an unnamed class");
   if (Number > 1)
@@ -6414,14 +6397,14 @@
 ItaniumMangleContext *ItaniumMangleContext::create(ASTContext &Context,
                                                    DiagnosticsEngine &Diags) {
   return new ItaniumMangleContextImpl(
-      Context, Diags, [](ASTContext &, const CXXRecordDecl *) { return false; },
-      [](ASTContext &, const CXXRecordDecl *, raw_ostream &) {});
+      Context, Diags,
+      [](ASTContext &, const CXXRecordDecl *) -> llvm::Optional<unsigned> {
+        return llvm::None;
+      });
 }
 
 ItaniumMangleContext *ItaniumMangleContext::create(
     ASTContext &Context, DiagnosticsEngine &Diags,
-    ShouldCallKernelCallbackTy ShouldCallKernelCallback,
     KernelMangleCallbackTy MangleCallback) {
-  return new ItaniumMangleContextImpl(Context, Diags, ShouldCallKernelCallback,
-                                      MangleCallback);
+  return new ItaniumMangleContextImpl(Context, Diags, MangleCallback);
 }
Index: clang/lib/AST/Expr.cpp
===================================================================
--- clang/lib/AST/Expr.cpp
+++ clang/lib/AST/Expr.cpp
@@ -546,12 +546,8 @@
 
 std::string SYCLUniqueStableNameExpr::ComputeName(ASTContext &Context,
                                                   QualType Ty) {
-  auto ShouldMangleCallback = [](ASTContext &Ctx, const CXXRecordDecl *RD) {
-    return Ctx.IsSYCLKernelNamingDecl(RD);
-  };
-  auto MangleCallback = [](ASTContext &Ctx, const CXXRecordDecl *RD,
-                           raw_ostream &OS) {
-    assert(Ctx.IsSYCLKernelNamingDecl(RD) && "Not a sycl kernel?");
+  auto MangleCallback =
+      [](ASTContext &Ctx, const CXXRecordDecl *RD) -> llvm::Optional<unsigned> {
     // This replaces the 'lambda number' in the mangling with a unique number
     // based on its order in the declaration.  To provide some level of visual
     // notability (actual uniqueness from normal lambdas isn't necessary, as
@@ -559,10 +555,15 @@
     // For example:
     // _ZTSZ3foovEUlvE10005_
     // Demangles to: typeinfo name for foo()::'lambda10005'()
-    OS << (10'000 + Ctx.GetSYCLKernelNamingIndex(RD));
+    // Note that the mangler subtracts 2, since with normal lambdas the lambda
+    // mangling number '0' is an anonymous struct mangle, and '1' is omitted.
+    // So 10,002 results in the first number being 10,000.
+    if (Ctx.IsSYCLKernelNamingDecl(RD))
+      return 10'002 + Ctx.GetSYCLKernelNamingIndex(RD);
+    return llvm::None;
   };
   std::unique_ptr<MangleContext> Ctx{ItaniumMangleContext::create(
-      Context, Context.getDiagnostics(), ShouldMangleCallback, MangleCallback)};
+      Context, Context.getDiagnostics(), MangleCallback)};
 
   std::string Buffer;
   Buffer.reserve(128);
Index: clang/lib/AST/ASTContext.cpp
===================================================================
--- clang/lib/AST/ASTContext.cpp
+++ clang/lib/AST/ASTContext.cpp
@@ -2458,7 +2458,7 @@
   // The preferred alignment of member pointers is that of a pointer.
   if (T->isMemberPointerType())
     return getPreferredTypeAlign(getPointerDiffType().getTypePtr());
- 
+
   if (!Target->allowsLargerPreferedTypeAlignment())
     return ABIAlign;
 
@@ -11075,6 +11075,31 @@
   llvm_unreachable("Unsupported ABI");
 }
 
+MangleContext *ASTContext::createDeviceMangleContext(const TargetInfo &T) {
+  assert(T.getCXXABI().getKind() != TargetCXXABI::Microsoft &&
+         "Device mangle context does not support Microsoft mangling.");
+  switch (T.getCXXABI().getKind()) {
+  case TargetCXXABI::AppleARM64:
+  case TargetCXXABI::Fuchsia:
+  case TargetCXXABI::GenericAArch64:
+  case TargetCXXABI::GenericItanium:
+  case TargetCXXABI::GenericARM:
+  case TargetCXXABI::GenericMIPS:
+  case TargetCXXABI::iOS:
+  case TargetCXXABI::WebAssembly:
+  case TargetCXXABI::WatchOS:
+  case TargetCXXABI::XL:
+    return ItaniumMangleContext::create(
+        *this, getDiagnostics(),
+        [](ASTContext &, const CXXRecordDecl *RD) -> llvm::Optional<unsigned> {
+          return RD->getDeviceLambdaManglingNumber();
+        });
+  case TargetCXXABI::Microsoft:
+    return MicrosoftMangleContext::create(*this, getDiagnostics());
+    }
+  llvm_unreachable("Unsupported ABI");
+}
+
 CXXABI::~CXXABI() = default;
 
 size_t ASTContext::getSideTableAllocatedMemory() const {
Index: clang/include/clang/AST/Mangle.h
===================================================================
--- clang/include/clang/AST/Mangle.h
+++ clang/include/clang/AST/Mangle.h
@@ -107,9 +107,6 @@
   virtual bool shouldMangleCXXName(const NamedDecl *D) = 0;
   virtual bool shouldMangleStringLiteral(const StringLiteral *SL) = 0;
 
-  virtual bool isDeviceMangleContext() const { return false; }
-  virtual void setDeviceMangleContext(bool) {}
-
   virtual bool isUniqueInternalLinkageDecl(const NamedDecl *ND) {
     return false;
   }
@@ -173,10 +170,8 @@
 
 class ItaniumMangleContext : public MangleContext {
 public:
-  using ShouldCallKernelCallbackTy = bool (*)(ASTContext &,
-                                              const CXXRecordDecl *);
-  using KernelMangleCallbackTy = void (*)(ASTContext &, const CXXRecordDecl *,
-                                          raw_ostream &);
+  using KernelMangleCallbackTy =
+      llvm::Optional<unsigned> (*)(ASTContext &, const CXXRecordDecl *);
   explicit ItaniumMangleContext(ASTContext &C, DiagnosticsEngine &D)
       : MangleContext(C, D, MK_Itanium) {}
 
@@ -199,21 +194,19 @@
 
   virtual void mangleDynamicStermFinalizer(const VarDecl *D, raw_ostream &) = 0;
 
-  // These have to live here, otherwise the CXXNameMangler won't have access to
-  // them.
-  virtual ShouldCallKernelCallbackTy getShouldCallKernelCallback() const = 0;
-  virtual KernelMangleCallbackTy getKernelMangleCallback() const = 0;
 
+  // This has to live here, otherwise the CXXNameMangler won't have access to
+  // it.
+  virtual KernelMangleCallbackTy getKernelMangleCallback() const = 0;
   static bool classof(const MangleContext *C) {
     return C->getKind() == MK_Itanium;
   }
 
   static ItaniumMangleContext *create(ASTContext &Context,
                                       DiagnosticsEngine &Diags);
-  static ItaniumMangleContext *
-  create(ASTContext &Context, DiagnosticsEngine &Diags,
-         ShouldCallKernelCallbackTy ShouldKernelMangleCB,
-         KernelMangleCallbackTy KernelMangleCB);
+  static ItaniumMangleContext *create(ASTContext &Context,
+                                      DiagnosticsEngine &Diags,
+                                      KernelMangleCallbackTy KernelMangleCB);
 };
 
 class MicrosoftMangleContext : public MangleContext {
Index: clang/include/clang/AST/ASTContext.h
===================================================================
--- clang/include/clang/AST/ASTContext.h
+++ clang/include/clang/AST/ASTContext.h
@@ -2355,6 +2355,12 @@
   /// If \p T is null pointer, assume the target in ASTContext.
   MangleContext *createMangleContext(const TargetInfo *T = nullptr);
 
+  /// Creates a device mangle context to correctly mangle lambdas in a mixed
+  /// architecture compile by setting the lambda mangling number source to the
+  /// DeviceLambdaManglingNumber. Currently this asserts that the TargetInfo
+  /// (from the AuxTargetInfo) is a an itanium target.
+  MangleContext *createDeviceMangleContext(const TargetInfo &T);
+
   void DeepCollectObjCIvars(const ObjCInterfaceDecl *OI, bool leafClass,
                             SmallVectorImpl<const ObjCIvarDecl*> &Ivars) const;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to