https://github.com/rofirrim updated 
https://github.com/llvm/llvm-project/pull/155849

From 365c3c5d18b0c466fee85075d6f6bc3c63267fef Mon Sep 17 00:00:00 2001
From: Roger Ferrer Ibanez <roger.fer...@bsc.es>
Date: Wed, 27 Aug 2025 08:18:02 +0000
Subject: [PATCH] [Clang][OpenMP] Add an additional class to hold data that
 will be shared between all loop transformations

This class is not a statement by itself and its goal is to avoid having
to replicate information between classes.

Also simplify the way we handle the "generated loops" information as we
currently only need to know if it is zero or non-zero.
---
 clang/include/clang/AST/StmtOpenMP.h      | 55 +++++++++++++----------
 clang/lib/AST/StmtOpenMP.cpp              | 21 ++++-----
 clang/lib/Sema/SemaOpenMP.cpp             | 12 ++---
 clang/lib/Serialization/ASTReaderStmt.cpp |  2 +-
 clang/lib/Serialization/ASTWriterStmt.cpp |  2 +-
 5 files changed, 52 insertions(+), 40 deletions(-)

diff --git a/clang/include/clang/AST/StmtOpenMP.h 
b/clang/include/clang/AST/StmtOpenMP.h
index a436676113921..d9f87f1e49b40 100644
--- a/clang/include/clang/AST/StmtOpenMP.h
+++ b/clang/include/clang/AST/StmtOpenMP.h
@@ -956,30 +956,46 @@ class OMPLoopBasedDirective : public 
OMPExecutableDirective {
   }
 };
 
+/// Common class of data shared between
+/// OMPCanonicalLoopNestTransformationDirective and transformations over
+/// canonical loop sequences.
+class OMPLoopTransformationDirective {
+  /// Number of (top-level) generated loops.
+  /// This value is 1 for most transformations as they only map one loop nest
+  /// into another.
+  /// Some loop transformations (like a non-partial 'unroll') may not generate
+  /// a loop nest, so this would be 0.
+  /// Some loop transformations (like 'fuse' with looprange and 'split') may
+  /// generate more than one loop nest, so the value would be >= 1.
+  unsigned NumGeneratedTopLevelLoops = 1;
+
+protected:
+  void setNumGeneratedTopLevelLoops(unsigned N) {
+    NumGeneratedTopLevelLoops = N;
+  }
+
+public:
+  unsigned getNumGeneratedTopLevelLoops() const {
+    return NumGeneratedTopLevelLoops;
+  }
+};
+
 /// The base class for all transformation directives of canonical loop nests.
 class OMPCanonicalLoopNestTransformationDirective
-    : public OMPLoopBasedDirective {
+    : public OMPLoopBasedDirective,
+      public OMPLoopTransformationDirective {
   friend class ASTStmtReader;
 
-  /// Number of loops generated by this loop transformation.
-  unsigned NumGeneratedLoops = 0;
-
 protected:
   explicit OMPCanonicalLoopNestTransformationDirective(
       StmtClass SC, OpenMPDirectiveKind Kind, SourceLocation StartLoc,
       SourceLocation EndLoc, unsigned NumAssociatedLoops)
       : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) 
{}
 
-  /// Set the number of loops generated by this loop transformation.
-  void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; }
-
 public:
   /// Return the number of associated (consumed) loops.
   unsigned getNumAssociatedLoops() const { return getLoopsNumber(); }
 
-  /// Return the number of loops generated by this loop transformation.
-  unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; }
-
   /// Get the de-sugared statements after the loop transformation.
   ///
   /// Might be nullptr if either the directive generates no loops and is 
handled
@@ -5560,9 +5576,7 @@ class OMPTileDirective final
                             unsigned NumLoops)
       : OMPCanonicalLoopNestTransformationDirective(
             OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc,
-            NumLoops) {
-    setNumGeneratedLoops(2 * NumLoops);
-  }
+            NumLoops) {}
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5638,9 +5652,7 @@ class OMPStripeDirective final
                               unsigned NumLoops)
       : OMPCanonicalLoopNestTransformationDirective(
             OMPStripeDirectiveClass, llvm::omp::OMPD_stripe, StartLoc, EndLoc,
-            NumLoops) {
-    setNumGeneratedLoops(2 * NumLoops);
-  }
+            NumLoops) {}
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5744,7 +5756,8 @@ class OMPUnrollDirective final
   static OMPUnrollDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
-         unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits);
+         unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt,
+         Stmt *PreInits);
 
   /// Build an empty '#pragma omp unroll' AST node for deserialization.
   ///
@@ -5794,9 +5807,7 @@ class OMPReverseDirective final
                                unsigned NumLoops)
       : OMPCanonicalLoopNestTransformationDirective(
             OMPReverseDirectiveClass, llvm::omp::OMPD_reverse, StartLoc, 
EndLoc,
-            NumLoops) {
-    setNumGeneratedLoops(NumLoops);
-  }
+            NumLoops) {}
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
@@ -5867,9 +5878,7 @@ class OMPInterchangeDirective final
                                    SourceLocation EndLoc, unsigned NumLoops)
       : OMPCanonicalLoopNestTransformationDirective(
             OMPInterchangeDirectiveClass, llvm::omp::OMPD_interchange, 
StartLoc,
-            EndLoc, NumLoops) {
-    setNumGeneratedLoops(NumLoops);
-  }
+            EndLoc, NumLoops) {}
 
   void setPreInits(Stmt *PreInits) {
     Data->getChildren()[PreInitsOffset] = PreInits;
diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp
index 36ecaf6489ef0..1f6586f95a9f8 100644
--- a/clang/lib/AST/StmtOpenMP.cpp
+++ b/clang/lib/AST/StmtOpenMP.cpp
@@ -139,13 +139,14 @@ bool OMPLoopBasedDirective::doForAllLoops(
 
       Stmt *TransformedStmt = Dir->getTransformedStmt();
       if (!TransformedStmt) {
-        unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops();
-        if (NumGeneratedLoops == 0) {
+        unsigned NumGeneratedTopLevelLoops =
+            Dir->getNumGeneratedTopLevelLoops();
+        if (NumGeneratedTopLevelLoops == 0) {
           // May happen if the loop transformation does not result in a
           // generated loop (such as full unrolling).
           break;
         }
-        if (NumGeneratedLoops > 0) {
+        if (NumGeneratedTopLevelLoops > 0) {
           // The loop transformation construct has generated loops, but these
           // may not have been generated yet due to being in a dependent
           // context.
@@ -447,16 +448,16 @@ OMPStripeDirective *OMPStripeDirective::CreateEmpty(const 
ASTContext &C,
       SourceLocation(), SourceLocation(), NumLoops);
 }
 
-OMPUnrollDirective *
-OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc,
-                           SourceLocation EndLoc, ArrayRef<OMPClause *> 
Clauses,
-                           Stmt *AssociatedStmt, unsigned NumGeneratedLoops,
-                           Stmt *TransformedStmt, Stmt *PreInits) {
-  assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop");
+OMPUnrollDirective *OMPUnrollDirective::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt,
+    unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt, Stmt *PreInits) 
{
+  assert(NumGeneratedTopLevelLoops <= 1 &&
+         "Unrolling generates at most one loop");
 
   auto *Dir = createDirective<OMPUnrollDirective>(
       C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc);
-  Dir->setNumGeneratedLoops(NumGeneratedLoops);
+  Dir->setNumGeneratedTopLevelLoops(NumGeneratedTopLevelLoops);
   Dir->setTransformedStmt(TransformedStmt);
   Dir->setPreInits(PreInits);
   return Dir;
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 63a56a6583efc..60f0317020c59 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -14919,12 +14919,13 @@ StmtResult 
SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
                                   Body, OriginalInits))
     return StmtError();
 
-  unsigned NumGeneratedLoops = PartialClause ? 1 : 0;
+  unsigned NumGeneratedTopLevelLoops = PartialClause ? 1 : 0;
 
   // Delay unrolling to when template is completely instantiated.
   if (SemaRef.CurContext->isDependentContext())
     return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, 
AStmt,
-                                      NumGeneratedLoops, nullptr, nullptr);
+                                      NumGeneratedTopLevelLoops, nullptr,
+                                      nullptr);
 
   assert(LoopHelpers.size() == NumLoops &&
          "Expecting a single-dimensional loop iteration space");
@@ -14947,9 +14948,10 @@ StmtResult 
SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
   // The generated loop may only be passed to other loop-associated directive
   // when a partial clause is specified. Without the requirement it is
   // sufficient to generate loop unroll metadata at code-generation.
-  if (NumGeneratedLoops == 0)
+  if (NumGeneratedTopLevelLoops == 0)
     return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, 
AStmt,
-                                      NumGeneratedLoops, nullptr, nullptr);
+                                      NumGeneratedTopLevelLoops, nullptr,
+                                      nullptr);
 
   // Otherwise, we need to provide a de-sugared/transformed AST that can be
   // associated with another loop directive.
@@ -15164,7 +15166,7 @@ StmtResult 
SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses,
               LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc());
 
   return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    NumGeneratedLoops, OuterFor,
+                                    NumGeneratedTopLevelLoops, OuterFor,
                                     buildPreInits(Context, PreInits));
 }
 
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp 
b/clang/lib/Serialization/ASTReaderStmt.cpp
index 7ec8e450fbaca..213c2c2148f64 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2450,7 +2450,7 @@ void 
ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) {
 void ASTStmtReader::VisitOMPCanonicalLoopNestTransformationDirective(
     OMPCanonicalLoopNestTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
-  D->setNumGeneratedLoops(Record.readUInt32());
+  D->setNumGeneratedTopLevelLoops(Record.readUInt32());
 }
 
 void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) {
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp 
b/clang/lib/Serialization/ASTWriterStmt.cpp
index 07a5cde47a9a8..21c04ddbc2c7a 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2459,7 +2459,7 @@ void 
ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) {
 void ASTStmtWriter::VisitOMPCanonicalLoopNestTransformationDirective(
     OMPCanonicalLoopNestTransformationDirective *D) {
   VisitOMPLoopBasedDirective(D);
-  Record.writeUInt32(D->getNumGeneratedLoops());
+  Record.writeUInt32(D->getNumGeneratedTopLevelLoops());
 }
 
 void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) {

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to