This revision was automatically updated to reflect the committed changes.
Closed by commit rG243ebfba17da: [hip][cuda] Fix the extended lambda name 
mangling issue. (authored by hliao).

Changed prior to commit:
  https://reviews.llvm.org/D68818?vs=225678&id=225720#toc

Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D68818/new/

https://reviews.llvm.org/D68818

Files:
  clang/include/clang/AST/DeclCXX.h
  clang/include/clang/Sema/Sema.h
  clang/lib/AST/ASTImporter.cpp
  clang/lib/AST/Decl.cpp
  clang/lib/Sema/SemaLambda.cpp
  clang/lib/Sema/TreeTransform.h
  clang/lib/Serialization/ASTReaderDecl.cpp
  clang/lib/Serialization/ASTWriter.cpp
  clang/test/CodeGenCUDA/unnamed-types.cu

Index: clang/test/CodeGenCUDA/unnamed-types.cu
===================================================================
--- /dev/null
+++ clang/test/CodeGenCUDA/unnamed-types.cu
@@ -0,0 +1,39 @@
+// RUN: %clang_cc1 -std=c++11 -x hip -triple x86_64-linux-gnu -aux-triple amdgcn-amd-amdhsa -emit-llvm %s -o - | FileCheck %s --check-prefix=HOST
+// RUN: %clang_cc1 -std=c++11 -x hip -triple amdgcn-amd-amdhsa -fcuda-is-device -emit-llvm %s -o - | FileCheck %s --check-prefix=DEVICE
+
+#include "Inputs/cuda.h"
+
+// HOST: @0 = private unnamed_addr constant [43 x i8] c"_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_\00", align 1
+
+__device__ float d0(float x) {
+  return [](float x) { return x + 2.f; }(x);
+}
+
+__device__ float d1(float x) {
+  return [](float x) { return x * 2.f; }(x);
+}
+
+// DEVICE: amdgpu_kernel void @_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_(
+template <typename F>
+__global__ void k0(float *p, F f) {
+  p[0] = f(p[0]) + d0(p[1]) + d1(p[2]);
+}
+
+void f0(float *p) {
+  [](float *p) {
+    *p = 1.f;
+  }(p);
+}
+
+// The inner/outer lambdas are required to be mangled following ODR but their
+// linkages are still required to keep the original `internal` linkage.
+
+// HOST: define internal void @_ZZ2f1PfENKUlS_E_clES_(
+// DEVICE: define internal float @_ZZZ2f1PfENKUlS_E_clES_ENKUlfE_clEf(
+void f1(float *p) {
+  [](float *p) {
+    k0<<<1,1>>>(p, [] __device__ (float x) { return x + 1.f; });
+  }(p);
+}
+// HOST: @__hip_register_globals
+// HOST: __hipRegisterFunction{{.*}}@_Z2k0IZZ2f1PfENKUlS0_E_clES0_EUlfE_EvS0_T_{{.*}}@0
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -6224,6 +6224,7 @@
     Record->push_back(Lambda.CaptureDefault);
     Record->push_back(Lambda.NumCaptures);
     Record->push_back(Lambda.NumExplicitCaptures);
+    Record->push_back(Lambda.HasKnownInternalLinkage);
     Record->push_back(Lambda.ManglingNumber);
     AddDeclRef(D->getLambdaContextDecl());
     AddTypeSourceInfo(Lambda.MethodTyInfo);
Index: clang/lib/Serialization/ASTReaderDecl.cpp
===================================================================
--- clang/lib/Serialization/ASTReaderDecl.cpp
+++ clang/lib/Serialization/ASTReaderDecl.cpp
@@ -1690,6 +1690,7 @@
     Lambda.CaptureDefault = Record.readInt();
     Lambda.NumCaptures = Record.readInt();
     Lambda.NumExplicitCaptures = Record.readInt();
+    Lambda.HasKnownInternalLinkage = Record.readInt();
     Lambda.ManglingNumber = Record.readInt();
     Lambda.ContextDecl = ReadDeclID();
     Lambda.Captures = (Capture *)Reader.getContext().Allocate(
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -11497,17 +11497,18 @@
                                         E->getCaptureDefault());
   getDerived().transformedLocalDecl(OldClass, {Class});
 
-  Optional<std::pair<unsigned, Decl*>> Mangling;
+  Optional<std::tuple<unsigned, bool, Decl *>> Mangling;
   if (getDerived().ReplacingOriginal())
-    Mangling = std::make_pair(OldClass->getLambdaManglingNumber(),
-                              OldClass->getLambdaContextDecl());
+    Mangling = std::make_tuple(OldClass->getLambdaManglingNumber(),
+                               OldClass->hasKnownLambdaInternalLinkage(),
+                               OldClass->getLambdaContextDecl());
 
   // Build the call operator.
   CXXMethodDecl *NewCallOperator = getSema().startLambdaDefinition(
       Class, E->getIntroducerRange(), NewCallOpTSI,
       E->getCallOperator()->getEndLoc(),
       NewCallOpTSI->getTypeLoc().castAs<FunctionProtoTypeLoc>().getParams(),
-      E->getCallOperator()->getConstexprKind(), Mangling);
+      E->getCallOperator()->getConstexprKind());
 
   LSI->CallOperator = NewCallOperator;
 
@@ -11527,6 +11528,9 @@
   getDerived().transformAttrs(E->getCallOperator(), NewCallOperator);
   getDerived().transformedLocalDecl(E->getCallOperator(), {NewCallOperator});
 
+  // Number the lambda for linkage purposes if necessary.
+  getSema().handleLambdaNumbering(Class, NewCallOperator, Mangling);
+
   // Introduce the context of the call operator.
   Sema::ContextRAII SavedContext(getSema(), NewCallOperator,
                                  /*NewThisContext*/false);
Index: clang/lib/Sema/SemaLambda.cpp
===================================================================
--- clang/lib/Sema/SemaLambda.cpp
+++ clang/lib/Sema/SemaLambda.cpp
@@ -335,7 +335,7 @@
   case StaticDataMember:
     //  -- the initializers of nonspecialized static members of template classes
     if (!IsInNonspecializedTemplate)
-      return std::make_tuple(nullptr, nullptr);
+      return std::make_tuple(nullptr, ManglingContextDecl);
     // Fall through to get the current context.
     LLVM_FALLTHROUGH;
 
@@ -356,14 +356,15 @@
   llvm_unreachable("unexpected context");
 }
 
-CXXMethodDecl *Sema::startLambdaDefinition(
-    CXXRecordDecl *Class, SourceRange IntroducerRange,
-    TypeSourceInfo *MethodTypeInfo, SourceLocation EndLoc,
-    ArrayRef<ParmVarDecl *> Params, ConstexprSpecKind ConstexprKind,
-    Optional<std::pair<unsigned, Decl *>> Mangling) {
+CXXMethodDecl *Sema::startLambdaDefinition(CXXRecordDecl *Class,
+                                           SourceRange IntroducerRange,
+                                           TypeSourceInfo *MethodTypeInfo,
+                                           SourceLocation EndLoc,
+                                           ArrayRef<ParmVarDecl *> Params,
+                                           ConstexprSpecKind ConstexprKind) {
   QualType MethodType = MethodTypeInfo->getType();
   TemplateParameterList *TemplateParams =
-            getGenericLambdaTemplateParameterList(getCurLambda(), *this);
+      getGenericLambdaTemplateParameterList(getCurLambda(), *this);
   // If a lambda appears in a dependent context or is a generic lambda (has
   // template parameters) and has an 'auto' return type, deduce it to a
   // dependent type.
@@ -425,20 +426,55 @@
       P->setOwningFunction(Method);
   }
 
+  return Method;
+}
+
+void Sema::handleLambdaNumbering(
+    CXXRecordDecl *Class, CXXMethodDecl *Method,
+    Optional<std::tuple<unsigned, bool, Decl *>> Mangling) {
   if (Mangling) {
-    Class->setLambdaMangling(Mangling->first, Mangling->second);
-  } else {
-    MangleNumberingContext *MCtx;
+    unsigned ManglingNumber;
+    bool HasKnownInternalLinkage;
     Decl *ManglingContextDecl;
-    std::tie(MCtx, ManglingContextDecl) =
-        getCurrentMangleNumberContext(Class->getDeclContext());
-    if (MCtx) {
-      unsigned ManglingNumber = MCtx->getManglingNumber(Method);
-      Class->setLambdaMangling(ManglingNumber, ManglingContextDecl);
-    }
+    std::tie(ManglingNumber, HasKnownInternalLinkage, ManglingContextDecl) =
+        Mangling.getValue();
+    Class->setLambdaMangling(ManglingNumber, ManglingContextDecl,
+                             HasKnownInternalLinkage);
+    return;
   }
 
-  return Method;
+  auto getMangleNumberingContext =
+      [this](CXXRecordDecl *Class, Decl *ManglingContextDecl) -> MangleNumberingContext * {
+    // Get mangle numbering context if there's any extra decl context.
+    if (ManglingContextDecl)
+      return &Context.getManglingNumberContext(
+          ASTContext::NeedExtraManglingDecl, ManglingContextDecl);
+    // Otherwise, from that lambda's decl context.
+    auto DC = Class->getDeclContext();
+    while (auto *CD = dyn_cast<CapturedDecl>(DC))
+      DC = CD->getParent();
+    return &Context.getManglingNumberContext(DC);
+  };
+
+  MangleNumberingContext *MCtx;
+  Decl *ManglingContextDecl;
+  std::tie(MCtx, ManglingContextDecl) =
+      getCurrentMangleNumberContext(Class->getDeclContext());
+  bool HasKnownInternalLinkage = false;
+  if (!MCtx && getLangOpts().CUDA) {
+    // Force lambda numbering in CUDA/HIP as we need to name lambdas following
+    // ODR. Both device- and host-compilation need to have a consistent naming
+    // on kernel functions. As lambdas are potential part of these `__global__`
+    // function names, they needs numbering following ODR.
+    MCtx = getMangleNumberingContext(Class, ManglingContextDecl);
+    assert(MCtx && "Retrieving mangle numbering context failed!");
+    HasKnownInternalLinkage = true;
+  }
+  if (MCtx) {
+    unsigned ManglingNumber = MCtx->getManglingNumber(Method);
+    Class->setLambdaMangling(ManglingNumber, ManglingContextDecl,
+                             HasKnownInternalLinkage);
+  }
 }
 
 void Sema::buildLambdaScope(LambdaScopeInfo *LSI,
@@ -951,6 +987,9 @@
   if (getLangOpts().CUDA)
     CUDASetLambdaAttrs(Method);
 
+  // Number the lambda for linkage purposes if necessary.
+  handleLambdaNumbering(Class, Method);
+
   // Introduce the function call operator as the current declaration context.
   PushDeclContext(CurScope, Method);
 
Index: clang/lib/AST/Decl.cpp
===================================================================
--- clang/lib/AST/Decl.cpp
+++ clang/lib/AST/Decl.cpp
@@ -1385,7 +1385,8 @@
     case Decl::CXXRecord: {
       const auto *Record = cast<CXXRecordDecl>(D);
       if (Record->isLambda()) {
-        if (!Record->getLambdaManglingNumber()) {
+        if (Record->hasKnownLambdaInternalLinkage() ||
+            !Record->getLambdaManglingNumber()) {
           // This lambda has no mangling number, so it's internal.
           return getInternalLinkageFor(D);
         }
@@ -1402,7 +1403,8 @@
         //  };
         const CXXRecordDecl *OuterMostLambda =
             getOutermostEnclosingLambda(Record);
-        if (!OuterMostLambda->getLambdaManglingNumber())
+        if (OuterMostLambda->hasKnownLambdaInternalLinkage() ||
+            !OuterMostLambda->getLambdaManglingNumber())
           return getInternalLinkageFor(D);
 
         return getLVForClosure(
Index: clang/lib/AST/ASTImporter.cpp
===================================================================
--- clang/lib/AST/ASTImporter.cpp
+++ clang/lib/AST/ASTImporter.cpp
@@ -2694,7 +2694,8 @@
       ExpectedDecl CDeclOrErr = import(DCXX->getLambdaContextDecl());
       if (!CDeclOrErr)
         return CDeclOrErr.takeError();
-      D2CXX->setLambdaMangling(DCXX->getLambdaManglingNumber(), *CDeclOrErr);
+      D2CXX->setLambdaMangling(DCXX->getLambdaManglingNumber(), *CDeclOrErr,
+                               DCXX->hasKnownLambdaInternalLinkage());
     } else if (DCXX->isInjectedClassName()) {
       // We have to be careful to do a similar dance to the one in
       // Sema::ActOnStartCXXMemberDeclarations
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -5926,12 +5926,17 @@
                                          LambdaCaptureDefault CaptureDefault);
 
   /// Start the definition of a lambda expression.
-  CXXMethodDecl *
-  startLambdaDefinition(CXXRecordDecl *Class, SourceRange IntroducerRange,
-                        TypeSourceInfo *MethodType, SourceLocation EndLoc,
-                        ArrayRef<ParmVarDecl *> Params,
-                        ConstexprSpecKind ConstexprKind,
-                        Optional<std::pair<unsigned, Decl *>> Mangling = None);
+  CXXMethodDecl *startLambdaDefinition(CXXRecordDecl *Class,
+                                       SourceRange IntroducerRange,
+                                       TypeSourceInfo *MethodType,
+                                       SourceLocation EndLoc,
+                                       ArrayRef<ParmVarDecl *> Params,
+                                       ConstexprSpecKind ConstexprKind);
+
+  /// Number lambda for linkage purposes if necessary.
+  void handleLambdaNumbering(
+      CXXRecordDecl *Class, CXXMethodDecl *Method,
+      Optional<std::tuple<unsigned, bool, Decl *>> Mangling = None);
 
   /// Endow the lambda scope info with the relevant properties.
   void buildLambdaScope(sema::LambdaScopeInfo *LSI,
Index: clang/include/clang/AST/DeclCXX.h
===================================================================
--- clang/include/clang/AST/DeclCXX.h
+++ clang/include/clang/AST/DeclCXX.h
@@ -389,9 +389,12 @@
     /// The number of explicit captures in this lambda.
     unsigned NumExplicitCaptures : 13;
 
+    /// Has known `internal` linkage.
+    unsigned HasKnownInternalLinkage : 1;
+
     /// The number used to indicate this lambda expression for name
     /// mangling in the Itanium C++ ABI.
-    unsigned ManglingNumber = 0;
+    unsigned ManglingNumber : 31;
 
     /// The declaration that provides context for this lambda, if the
     /// actual DeclContext does not suffice. This is used for lambdas that
@@ -406,12 +409,12 @@
     /// The type of the call method.
     TypeSourceInfo *MethodTyInfo;
 
-    LambdaDefinitionData(CXXRecordDecl *D, TypeSourceInfo *Info,
-                         bool Dependent, bool IsGeneric,
-                         LambdaCaptureDefault CaptureDefault)
-      : DefinitionData(D), Dependent(Dependent), IsGenericLambda(IsGeneric),
-        CaptureDefault(CaptureDefault), NumCaptures(0), NumExplicitCaptures(0),
-        MethodTyInfo(Info) {
+    LambdaDefinitionData(CXXRecordDecl *D, TypeSourceInfo *Info, bool Dependent,
+                         bool IsGeneric, LambdaCaptureDefault CaptureDefault)
+        : DefinitionData(D), Dependent(Dependent), IsGenericLambda(IsGeneric),
+          CaptureDefault(CaptureDefault), NumCaptures(0),
+          NumExplicitCaptures(0), HasKnownInternalLinkage(0), ManglingNumber(0),
+          MethodTyInfo(Info) {
       IsLambda = true;
 
       // C++1z [expr.prim.lambda]p4:
@@ -1705,6 +1708,13 @@
     return getLambdaData().ManglingNumber;
   }
 
+  /// The lambda is known to has internal linkage no matter whether it has name
+  /// mangling number.
+  bool hasKnownLambdaInternalLinkage() const {
+    assert(isLambda() && "Not a lambda closure type!");
+    return getLambdaData().HasKnownInternalLinkage;
+  }
+
   /// Retrieve the declaration that provides additional context for a
   /// lambda, when the normal declaration context is not specific enough.
   ///
@@ -1718,9 +1728,12 @@
 
   /// Set the mangling number and context declaration for a lambda
   /// class.
-  void setLambdaMangling(unsigned ManglingNumber, Decl *ContextDecl) {
+  void setLambdaMangling(unsigned ManglingNumber, Decl *ContextDecl,
+                         bool HasKnownInternalLinkage = false) {
+    assert(isLambda() && "Not a lambda closure type!");
     getLambdaData().ManglingNumber = ManglingNumber;
     getLambdaData().ContextDecl = ContextDecl;
+    getLambdaData().HasKnownInternalLinkage = HasKnownInternalLinkage;
   }
 
   /// Returns the inheritance model used for this record.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to