Author: Abhinav Gaba
Date: 2025-07-07T23:47:02Z
New Revision: 02f60fda3cb28f14681f8a4252bc832392c91fef

URL: 
https://github.com/llvm/llvm-project/commit/02f60fda3cb28f14681f8a4252bc832392c91fef
DIFF: 
https://github.com/llvm/llvm-project/commit/02f60fda3cb28f14681f8a4252bc832392c91fef.diff

LOG: [NFC][Clang][OpenMP] Refactor mapinfo generation for captured vars 
(#146891)

The refactored code would allow creating multiple member-of maps for the
same captured var, which would be useful for changes like
https://github.com/llvm/llvm-project/pull/145454.

Added: 
    

Modified: 
    clang/lib/CodeGen/CGOpenMPRuntime.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp 
b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 8ccc37ef98a74..a5f2f0efa2c3b 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6801,6 +6801,11 @@ class MappableExprsHandler {
       llvm::OpenMPIRBuilder::MapNonContiguousArrayTy;
   using MapExprsArrayTy = SmallVector<MappingExprInfo, 4>;
   using MapValueDeclsArrayTy = SmallVector<const ValueDecl *, 4>;
+  using MapData =
+      std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
+                 OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>,
+                 bool /*IsImplicit*/, const ValueDecl *, const Expr *>;
+  using MapDataArrayTy = SmallVector<MapData, 4>;
 
   /// This structure contains combined information generated for mappable
   /// clauses, including base pointers, pointers, sizes, map types, 
user-defined
@@ -8496,6 +8501,7 @@ class MappableExprsHandler {
                          const StructRangeInfoTy &PartialStruct, bool 
IsMapThis,
                          llvm::OpenMPIRBuilder &OMPBuilder,
                          const ValueDecl *VD = nullptr,
+                         unsigned OffsetForMemberOfFlag = 0,
                          bool NotTargetParams = true) const {
     if (CurTypes.size() == 1 &&
         ((CurTypes.back() & OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
@@ -8583,8 +8589,8 @@ class MappableExprsHandler {
     // All other current entries will be MEMBER_OF the combined entry
     // (except for PTR_AND_OBJ entries which do not have a placeholder value
     // 0xFFFF in the MEMBER_OF field).
-    OpenMPOffloadMappingFlags MemberOfFlag =
-        OMPBuilder.getMemberOfFlag(CombinedInfo.BasePointers.size() - 1);
+    OpenMPOffloadMappingFlags MemberOfFlag = OMPBuilder.getMemberOfFlag(
+        OffsetForMemberOfFlag + CombinedInfo.BasePointers.size() - 1);
     for (auto &M : CurTypes)
       OMPBuilder.setCorrectMemberOfFlag(M, MemberOfFlag);
   }
@@ -8727,11 +8733,13 @@ class MappableExprsHandler {
     }
   }
 
-  /// Generate the base pointers, section pointers, sizes, map types, and
-  /// mappers associated to a given capture (all included in \a CombinedInfo).
-  void generateInfoForCapture(const CapturedStmt::Capture *Cap,
-                              llvm::Value *Arg, MapCombinedInfoTy 
&CombinedInfo,
-                              StructRangeInfoTy &PartialStruct) const {
+  /// For a capture that has an associated clause, generate the base pointers,
+  /// section pointers, sizes, map types, and mappers (all included in
+  /// \a CurCaptureVarInfo).
+  void generateInfoForCaptureFromClauseInfo(
+      const CapturedStmt::Capture *Cap, llvm::Value *Arg,
+      MapCombinedInfoTy &CurCaptureVarInfo, llvm::OpenMPIRBuilder &OMPBuilder,
+      unsigned OffsetForMemberOfFlag) const {
     assert(!Cap->capturesVariableArrayType() &&
            "Not expecting to generate map info for a variable array type!");
 
@@ -8749,26 +8757,22 @@ class MappableExprsHandler {
     // pass the pointer by value. If it is a reference to a declaration, we 
just
     // pass its value.
     if (VD && (DevPointersMap.count(VD) || HasDevAddrsMap.count(VD))) {
-      CombinedInfo.Exprs.push_back(VD);
-      CombinedInfo.BasePointers.emplace_back(Arg);
-      CombinedInfo.DevicePtrDecls.emplace_back(VD);
-      CombinedInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
-      CombinedInfo.Pointers.push_back(Arg);
-      CombinedInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
+      CurCaptureVarInfo.Exprs.push_back(VD);
+      CurCaptureVarInfo.BasePointers.emplace_back(Arg);
+      CurCaptureVarInfo.DevicePtrDecls.emplace_back(VD);
+      CurCaptureVarInfo.DevicePointers.emplace_back(DeviceInfoTy::Pointer);
+      CurCaptureVarInfo.Pointers.push_back(Arg);
+      CurCaptureVarInfo.Sizes.push_back(CGF.Builder.CreateIntCast(
           CGF.getTypeSize(CGF.getContext().VoidPtrTy), CGF.Int64Ty,
           /*isSigned=*/true));
-      CombinedInfo.Types.push_back(
+      CurCaptureVarInfo.Types.push_back(
           OpenMPOffloadMappingFlags::OMP_MAP_LITERAL |
           OpenMPOffloadMappingFlags::OMP_MAP_TARGET_PARAM);
-      CombinedInfo.Mappers.push_back(nullptr);
+      CurCaptureVarInfo.Mappers.push_back(nullptr);
       return;
     }
 
-    using MapData =
-        std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
-                   OpenMPMapClauseKind, ArrayRef<OpenMPMapModifierKind>, bool,
-                   const ValueDecl *, const Expr *>;
-    SmallVector<MapData, 4> DeclComponentLists;
+    MapDataArrayTy DeclComponentLists;
     // For member fields list in is_device_ptr, store it in
     // DeclComponentLists for generating components info.
     static const OpenMPMapModifierKind Unknown = OMPC_MAP_MODIFIER_unknown;
@@ -8826,6 +8830,51 @@ class MappableExprsHandler {
       return (HasPresent && !HasPresentR) || (HasAllocs && !HasAllocsR);
     });
 
+    auto GenerateInfoForComponentLists =
+        [&](ArrayRef<MapData> DeclComponentLists,
+            bool IsEligibleForTargetParamFlag) {
+          MapCombinedInfoTy CurInfoForComponentLists;
+          StructRangeInfoTy PartialStruct;
+
+          if (DeclComponentLists.empty())
+            return;
+
+          generateInfoForCaptureFromComponentLists(
+              VD, DeclComponentLists, CurInfoForComponentLists, PartialStruct,
+              IsEligibleForTargetParamFlag,
+              /*AreBothBasePtrAndPteeMapped=*/HasMapBasePtr && HasMapArraySec);
+
+          // If there is an entry in PartialStruct it means we have a
+          // struct with individual members mapped. Emit an extra combined
+          // entry.
+          if (PartialStruct.Base.isValid()) {
+            CurCaptureVarInfo.append(PartialStruct.PreliminaryMapData);
+            emitCombinedEntry(
+                CurCaptureVarInfo, CurInfoForComponentLists.Types,
+                PartialStruct, Cap->capturesThis(), OMPBuilder, nullptr,
+                OffsetForMemberOfFlag,
+                /*NotTargetParams*/ !IsEligibleForTargetParamFlag);
+          }
+
+          // Return if we didn't add any entries.
+          if (CurInfoForComponentLists.BasePointers.empty())
+            return;
+
+          CurCaptureVarInfo.append(CurInfoForComponentLists);
+        };
+
+    GenerateInfoForComponentLists(DeclComponentLists,
+                                  /*IsEligibleForTargetParamFlag=*/true);
+  }
+
+  /// Generate the base pointers, section pointers, sizes, map types, and
+  /// mappers associated to \a DeclComponentLists for a given capture
+  /// \a VD (all included in \a CurComponentListInfo).
+  void generateInfoForCaptureFromComponentLists(
+      const ValueDecl *VD, ArrayRef<MapData> DeclComponentLists,
+      MapCombinedInfoTy &CurComponentListInfo, StructRangeInfoTy 
&PartialStruct,
+      bool IsListEligibleForTargetParamFlag,
+      bool AreBothBasePtrAndPteeMapped = false) const {
     // Find overlapping elements (including the offset from the base element).
     llvm::SmallDenseMap<
         const MapData *,
@@ -8949,7 +8998,7 @@ class MappableExprsHandler {
 
     // Associated with a capture, because the mapping flags depend on it.
     // Go through all of the elements with the overlapped elements.
-    bool IsFirstComponentList = true;
+    bool AddTargetParamFlag = IsListEligibleForTargetParamFlag;
     MapCombinedInfoTy StructBaseCombinedInfo;
     for (const auto &Pair : OverlappedData) {
       const MapData &L = *Pair.getFirst();
@@ -8964,11 +9013,11 @@ class MappableExprsHandler {
       ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
           OverlappedComponents = Pair.getSecond();
       generateInfoForComponentList(
-          MapType, MapModifiers, {}, Components, CombinedInfo,
-          StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
-          IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
+          MapType, MapModifiers, {}, Components, CurComponentListInfo,
+          StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag, 
IsImplicit,
+          /*GenerateAllInfoForClauses*/ false, Mapper,
           /*ForDeviceAddr=*/false, VD, VarRef, OverlappedComponents);
-      IsFirstComponentList = false;
+      AddTargetParamFlag = false;
     }
     // Go through other elements without overlapped elements.
     for (const MapData &L : DeclComponentLists) {
@@ -8983,12 +9032,12 @@ class MappableExprsHandler {
       auto It = OverlappedData.find(&L);
       if (It == OverlappedData.end())
         generateInfoForComponentList(
-            MapType, MapModifiers, {}, Components, CombinedInfo,
-            StructBaseCombinedInfo, PartialStruct, IsFirstComponentList,
+            MapType, MapModifiers, {}, Components, CurComponentListInfo,
+            StructBaseCombinedInfo, PartialStruct, AddTargetParamFlag,
             IsImplicit, /*GenerateAllInfoForClauses*/ false, Mapper,
             /*ForDeviceAddr=*/false, VD, VarRef,
-            /*OverlappedElements*/ {}, HasMapBasePtr && HasMapArraySec);
-      IsFirstComponentList = false;
+            /*OverlappedElements*/ {}, AreBothBasePtrAndPteeMapped);
+      AddTargetParamFlag = false;
     }
   }
 
@@ -9467,7 +9516,6 @@ static void genMapInfoForCaptures(
                                             CE = CS.capture_end();
        CI != CE; ++CI, ++RI, ++CV) {
     MappableExprsHandler::MapCombinedInfoTy CurInfo;
-    MappableExprsHandler::StructRangeInfoTy PartialStruct;
 
     // VLA sizes are passed to the outlined region by copy and do not have map
     // information associated.
@@ -9488,13 +9536,18 @@ static void genMapInfoForCaptures(
     } else {
       // If we have any information in the map clause, we use it, otherwise we
       // just do a default mapping.
-      MEHandler.generateInfoForCapture(CI, *CV, CurInfo, PartialStruct);
+      MEHandler.generateInfoForCaptureFromClauseInfo(
+          CI, *CV, CurInfo, OMPBuilder,
+          /*OffsetForMemberOfFlag=*/CombinedInfo.BasePointers.size());
+
       if (!CI->capturesThis())
         MappedVarSet.insert(CI->getCapturedVar());
       else
         MappedVarSet.insert(nullptr);
-      if (CurInfo.BasePointers.empty() && !PartialStruct.Base.isValid())
+
+      if (CurInfo.BasePointers.empty())
         MEHandler.generateDefaultMapInfo(*CI, **RI, *CV, CurInfo);
+
       // Generate correct mapping for variables captured by reference in
       // lambdas.
       if (CI->capturesVariable())
@@ -9502,7 +9555,7 @@ static void genMapInfoForCaptures(
                                                 CurInfo, LambdaPointers);
     }
     // We expect to have at least an element of information for this capture.
-    assert((!CurInfo.BasePointers.empty() || PartialStruct.Base.isValid()) &&
+    assert(!CurInfo.BasePointers.empty() &&
            "Non-existing map pointer for capture!");
     assert(CurInfo.BasePointers.size() == CurInfo.Pointers.size() &&
            CurInfo.BasePointers.size() == CurInfo.Sizes.size() &&
@@ -9510,15 +9563,6 @@ static void genMapInfoForCaptures(
            CurInfo.BasePointers.size() == CurInfo.Mappers.size() &&
            "Inconsistent map information sizes!");
 
-    // If there is an entry in PartialStruct it means we have a struct with
-    // individual members mapped. Emit an extra combined entry.
-    if (PartialStruct.Base.isValid()) {
-      CombinedInfo.append(PartialStruct.PreliminaryMapData);
-      MEHandler.emitCombinedEntry(CombinedInfo, CurInfo.Types, PartialStruct,
-                                  CI->capturesThis(), OMPBuilder, nullptr,
-                                  /*NotTargetParams*/ false);
-    }
-
     // We need to append the results of this capture to what we already have.
     CombinedInfo.append(CurInfo);
   }


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to