https://github.com/rofirrim updated https://github.com/llvm/llvm-project/pull/155849
From 365c3c5d18b0c466fee85075d6f6bc3c63267fef Mon Sep 17 00:00:00 2001 From: Roger Ferrer Ibanez <roger.fer...@bsc.es> Date: Wed, 27 Aug 2025 08:18:02 +0000 Subject: [PATCH] [Clang][OpenMP] Add an additional class to hold data that will be shared between all loop transformations This class is not a statement by itself and its goal is to avoid having to replicate information between classes. Also simplify the way we handle the "generated loops" information as we currently only need to know if it is zero or non-zero. --- clang/include/clang/AST/StmtOpenMP.h | 55 +++++++++++++---------- clang/lib/AST/StmtOpenMP.cpp | 21 ++++----- clang/lib/Sema/SemaOpenMP.cpp | 12 ++--- clang/lib/Serialization/ASTReaderStmt.cpp | 2 +- clang/lib/Serialization/ASTWriterStmt.cpp | 2 +- 5 files changed, 52 insertions(+), 40 deletions(-) diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index a436676113921..d9f87f1e49b40 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -956,30 +956,46 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { } }; +/// Common class of data shared between +/// OMPCanonicalLoopNestTransformationDirective and transformations over +/// canonical loop sequences. +class OMPLoopTransformationDirective { + /// Number of (top-level) generated loops. + /// This value is 1 for most transformations as they only map one loop nest + /// into another. + /// Some loop transformations (like a non-partial 'unroll') may not generate + /// a loop nest, so this would be 0. + /// Some loop transformations (like 'fuse' with looprange and 'split') may + /// generate more than one loop nest, so the value would be >= 1. + unsigned NumGeneratedTopLevelLoops = 1; + +protected: + void setNumGeneratedTopLevelLoops(unsigned N) { + NumGeneratedTopLevelLoops = N; + } + +public: + unsigned getNumGeneratedTopLevelLoops() const { + return NumGeneratedTopLevelLoops; + } +}; + /// The base class for all transformation directives of canonical loop nests. class OMPCanonicalLoopNestTransformationDirective - : public OMPLoopBasedDirective { + : public OMPLoopBasedDirective, + public OMPLoopTransformationDirective { friend class ASTStmtReader; - /// Number of loops generated by this loop transformation. - unsigned NumGeneratedLoops = 0; - protected: explicit OMPCanonicalLoopNestTransformationDirective( StmtClass SC, OpenMPDirectiveKind Kind, SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumAssociatedLoops) : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} - /// Set the number of loops generated by this loop transformation. - void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; } - public: /// Return the number of associated (consumed) loops. unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } - /// Return the number of loops generated by this loop transformation. - unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; } - /// Get the de-sugared statements after the loop transformation. /// /// Might be nullptr if either the directive generates no loops and is handled @@ -5560,9 +5576,7 @@ class OMPTileDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5638,9 +5652,7 @@ class OMPStripeDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPStripeDirectiveClass, llvm::omp::OMPD_stripe, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5744,7 +5756,8 @@ class OMPUnrollDirective final static OMPUnrollDirective * Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, - unsigned NumGeneratedLoops, Stmt *TransformedStmt, Stmt *PreInits); + unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt, + Stmt *PreInits); /// Build an empty '#pragma omp unroll' AST node for deserialization. /// @@ -5794,9 +5807,7 @@ class OMPReverseDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPReverseDirectiveClass, llvm::omp::OMPD_reverse, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5867,9 +5878,7 @@ class OMPInterchangeDirective final SourceLocation EndLoc, unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPInterchangeDirectiveClass, llvm::omp::OMPD_interchange, StartLoc, - EndLoc, NumLoops) { - setNumGeneratedLoops(NumLoops); - } + EndLoc, NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index 36ecaf6489ef0..1f6586f95a9f8 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -139,13 +139,14 @@ bool OMPLoopBasedDirective::doForAllLoops( Stmt *TransformedStmt = Dir->getTransformedStmt(); if (!TransformedStmt) { - unsigned NumGeneratedLoops = Dir->getNumGeneratedLoops(); - if (NumGeneratedLoops == 0) { + unsigned NumGeneratedTopLevelLoops = + Dir->getNumGeneratedTopLevelLoops(); + if (NumGeneratedTopLevelLoops == 0) { // May happen if the loop transformation does not result in a // generated loop (such as full unrolling). break; } - if (NumGeneratedLoops > 0) { + if (NumGeneratedTopLevelLoops > 0) { // The loop transformation construct has generated loops, but these // may not have been generated yet due to being in a dependent // context. @@ -447,16 +448,16 @@ OMPStripeDirective *OMPStripeDirective::CreateEmpty(const ASTContext &C, SourceLocation(), SourceLocation(), NumLoops); } -OMPUnrollDirective * -OMPUnrollDirective::Create(const ASTContext &C, SourceLocation StartLoc, - SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses, - Stmt *AssociatedStmt, unsigned NumGeneratedLoops, - Stmt *TransformedStmt, Stmt *PreInits) { - assert(NumGeneratedLoops <= 1 && "Unrolling generates at most one loop"); +OMPUnrollDirective *OMPUnrollDirective::Create( + const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc, + ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, + unsigned NumGeneratedTopLevelLoops, Stmt *TransformedStmt, Stmt *PreInits) { + assert(NumGeneratedTopLevelLoops <= 1 && + "Unrolling generates at most one loop"); auto *Dir = createDirective<OMPUnrollDirective>( C, Clauses, AssociatedStmt, TransformedStmtOffset + 1, StartLoc, EndLoc); - Dir->setNumGeneratedLoops(NumGeneratedLoops); + Dir->setNumGeneratedTopLevelLoops(NumGeneratedTopLevelLoops); Dir->setTransformedStmt(TransformedStmt); Dir->setPreInits(PreInits); return Dir; diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 63a56a6583efc..60f0317020c59 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -14919,12 +14919,13 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, Body, OriginalInits)) return StmtError(); - unsigned NumGeneratedLoops = PartialClause ? 1 : 0; + unsigned NumGeneratedTopLevelLoops = PartialClause ? 1 : 0; // Delay unrolling to when template is completely instantiated. if (SemaRef.CurContext->isDependentContext()) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - NumGeneratedLoops, nullptr, nullptr); + NumGeneratedTopLevelLoops, nullptr, + nullptr); assert(LoopHelpers.size() == NumLoops && "Expecting a single-dimensional loop iteration space"); @@ -14947,9 +14948,10 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, // The generated loop may only be passed to other loop-associated directive // when a partial clause is specified. Without the requirement it is // sufficient to generate loop unroll metadata at code-generation. - if (NumGeneratedLoops == 0) + if (NumGeneratedTopLevelLoops == 0) return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - NumGeneratedLoops, nullptr, nullptr); + NumGeneratedTopLevelLoops, nullptr, + nullptr); // Otherwise, we need to provide a de-sugared/transformed AST that can be // associated with another loop directive. @@ -15164,7 +15166,7 @@ StmtResult SemaOpenMP::ActOnOpenMPUnrollDirective(ArrayRef<OMPClause *> Clauses, LoopHelper.Init->getBeginLoc(), LoopHelper.Inc->getEndLoc()); return OMPUnrollDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt, - NumGeneratedLoops, OuterFor, + NumGeneratedTopLevelLoops, OuterFor, buildPreInits(Context, PreInits)); } diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 7ec8e450fbaca..213c2c2148f64 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2450,7 +2450,7 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) { void ASTStmtReader::VisitOMPCanonicalLoopNestTransformationDirective( OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); - D->setNumGeneratedLoops(Record.readUInt32()); + D->setNumGeneratedTopLevelLoops(Record.readUInt32()); } void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index 07a5cde47a9a8..21c04ddbc2c7a 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2459,7 +2459,7 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) { void ASTStmtWriter::VisitOMPCanonicalLoopNestTransformationDirective( OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); - Record.writeUInt32(D->getNumGeneratedLoops()); + Record.writeUInt32(D->getNumGeneratedTopLevelLoops()); } void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits