https://github.com/rofirrim created https://github.com/llvm/llvm-project/pull/155849
This is preparatory work for the implementation of `#pragma omp fuse` in https://github.com/llvm/llvm-project/pull/139293 **Note**: this change builds on top of https://github.com/llvm/llvm-project/pull/155848 This change adds an additional class to hold data that will be shared between all loop transformations: those that apply to canonical loop nests (the majority) and those that apply to canonical loop sequences (`fuse` in OpenMP 6.0). This class is not a statement by itself and its goal is to avoid having to replicate information between classes. Also simplfiy the way we handle the "generated loops" information as we currently only need to know if it is zero or non-zero. From c0a9364dde8eb90d10a8971d2f4598b96c05ac76 Mon Sep 17 00:00:00 2001 From: Roger Ferrer Ibanez <roger.fer...@bsc.es> Date: Tue, 26 Aug 2025 13:18:19 +0000 Subject: [PATCH 1/2] [Clang][NFC] Rename OMPLoopTransformationDirective to OMPCanonicalLoopNestTransformationDirective Not all loop transformations makes sense to make them OMPLoopBasedDirective, in particular in OpenMP 6.0 'fuse' (to be implemented later) is a transformation of a canonical loop sequence. --- clang/include/clang/AST/StmtOpenMP.h | 77 ++++++++++++----------- clang/include/clang/Basic/OpenMPKinds.h | 7 +++ clang/include/clang/Basic/StmtNodes.td | 14 +++-- clang/lib/AST/StmtOpenMP.cpp | 13 ++-- clang/lib/AST/StmtProfile.cpp | 14 ++--- clang/lib/Basic/OpenMPKinds.cpp | 8 ++- clang/lib/CodeGen/CGStmtOpenMP.cpp | 3 +- clang/lib/Sema/SemaOpenMP.cpp | 6 +- clang/lib/Serialization/ASTReaderStmt.cpp | 14 ++--- clang/lib/Serialization/ASTWriterStmt.cpp | 14 ++--- clang/tools/libclang/CIndex.cpp | 18 +++--- 11 files changed, 106 insertions(+), 82 deletions(-) diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index 2fb33d3036bca..a436676113921 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -889,23 +889,24 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { /// Calls the specified callback function for all the loops in \p CurStmt, /// from the outermost to the innermost. - static bool - doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, - unsigned NumLoops, - llvm::function_ref<bool(unsigned, Stmt *)> Callback, - llvm::function_ref<void(OMPLoopTransformationDirective *)> - OnTransformationCallback); + static bool doForAllLoops( + Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, + llvm::function_ref<bool(unsigned, Stmt *)> Callback, + llvm::function_ref<void(OMPCanonicalLoopNestTransformationDirective *)> + OnTransformationCallback); static bool doForAllLoops(const Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref<bool(unsigned, const Stmt *)> Callback, - llvm::function_ref<void(const OMPLoopTransformationDirective *)> + llvm::function_ref< + void(const OMPCanonicalLoopNestTransformationDirective *)> OnTransformationCallback) { auto &&NewCallback = [Callback](unsigned Cnt, Stmt *CurStmt) { return Callback(Cnt, CurStmt); }; auto &&NewTransformCb = - [OnTransformationCallback](OMPLoopTransformationDirective *A) { + [OnTransformationCallback]( + OMPCanonicalLoopNestTransformationDirective *A) { OnTransformationCallback(A); }; return doForAllLoops(const_cast<Stmt *>(CurStmt), TryImperfectlyNestedLoops, @@ -918,7 +919,7 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { doForAllLoops(Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref<bool(unsigned, Stmt *)> Callback) { - auto &&TransformCb = [](OMPLoopTransformationDirective *) {}; + auto &&TransformCb = [](OMPCanonicalLoopNestTransformationDirective *) {}; return doForAllLoops(CurStmt, TryImperfectlyNestedLoops, NumLoops, Callback, TransformCb); } @@ -955,19 +956,18 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { } }; -/// The base class for all loop transformation directives. -class OMPLoopTransformationDirective : public OMPLoopBasedDirective { +/// The base class for all transformation directives of canonical loop nests. +class OMPCanonicalLoopNestTransformationDirective + : public OMPLoopBasedDirective { friend class ASTStmtReader; /// Number of loops generated by this loop transformation. unsigned NumGeneratedLoops = 0; protected: - explicit OMPLoopTransformationDirective(StmtClass SC, - OpenMPDirectiveKind Kind, - SourceLocation StartLoc, - SourceLocation EndLoc, - unsigned NumAssociatedLoops) + explicit OMPCanonicalLoopNestTransformationDirective( + StmtClass SC, OpenMPDirectiveKind Kind, SourceLocation StartLoc, + SourceLocation EndLoc, unsigned NumAssociatedLoops) : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} /// Set the number of loops generated by this loop transformation. @@ -5545,7 +5545,8 @@ class OMPTargetTeamsDistributeSimdDirective final : public OMPLoopDirective { }; /// This represents the '#pragma omp tile' loop transformation directive. -class OMPTileDirective final : public OMPLoopTransformationDirective { +class OMPTileDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5557,9 +5558,9 @@ class OMPTileDirective final : public OMPLoopTransformationDirective { explicit OMPTileDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPTileDirectiveClass, - llvm::omp::OMPD_tile, StartLoc, EndLoc, - NumLoops) { + : OMPCanonicalLoopNestTransformationDirective( + OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc, + NumLoops) { setNumGeneratedLoops(2 * NumLoops); } @@ -5622,7 +5623,8 @@ class OMPTileDirective final : public OMPLoopTransformationDirective { }; /// This represents the '#pragma omp stripe' loop transformation directive. -class OMPStripeDirective final : public OMPLoopTransformationDirective { +class OMPStripeDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5634,9 +5636,9 @@ class OMPStripeDirective final : public OMPLoopTransformationDirective { explicit OMPStripeDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPStripeDirectiveClass, - llvm::omp::OMPD_stripe, StartLoc, EndLoc, - NumLoops) { + : OMPCanonicalLoopNestTransformationDirective( + OMPStripeDirectiveClass, llvm::omp::OMPD_stripe, StartLoc, EndLoc, + NumLoops) { setNumGeneratedLoops(2 * NumLoops); } @@ -5702,7 +5704,8 @@ class OMPStripeDirective final : public OMPLoopTransformationDirective { /// #pragma omp unroll /// for (int i = 0; i < 64; ++i) /// \endcode -class OMPUnrollDirective final : public OMPLoopTransformationDirective { +class OMPUnrollDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5713,9 +5716,9 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective { }; explicit OMPUnrollDirective(SourceLocation StartLoc, SourceLocation EndLoc) - : OMPLoopTransformationDirective(OMPUnrollDirectiveClass, - llvm::omp::OMPD_unroll, StartLoc, EndLoc, - 1) {} + : OMPCanonicalLoopNestTransformationDirective(OMPUnrollDirectiveClass, + llvm::omp::OMPD_unroll, + StartLoc, EndLoc, 1) {} /// Set the pre-init statements. void setPreInits(Stmt *PreInits) { @@ -5776,7 +5779,8 @@ class OMPUnrollDirective final : public OMPLoopTransformationDirective { /// for (int i = 0; i < n; ++i) /// ... /// \endcode -class OMPReverseDirective final : public OMPLoopTransformationDirective { +class OMPReverseDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5788,9 +5792,9 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective { explicit OMPReverseDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPReverseDirectiveClass, - llvm::omp::OMPD_reverse, StartLoc, - EndLoc, NumLoops) { + : OMPCanonicalLoopNestTransformationDirective( + OMPReverseDirectiveClass, llvm::omp::OMPD_reverse, StartLoc, EndLoc, + NumLoops) { setNumGeneratedLoops(NumLoops); } @@ -5848,7 +5852,8 @@ class OMPReverseDirective final : public OMPLoopTransformationDirective { /// for (int j = 0; j < n; ++j) /// .. /// \endcode -class OMPInterchangeDirective final : public OMPLoopTransformationDirective { +class OMPInterchangeDirective final + : public OMPCanonicalLoopNestTransformationDirective { friend class ASTStmtReader; friend class OMPExecutableDirective; @@ -5860,9 +5865,9 @@ class OMPInterchangeDirective final : public OMPLoopTransformationDirective { explicit OMPInterchangeDirective(SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumLoops) - : OMPLoopTransformationDirective(OMPInterchangeDirectiveClass, - llvm::omp::OMPD_interchange, StartLoc, - EndLoc, NumLoops) { + : OMPCanonicalLoopNestTransformationDirective( + OMPInterchangeDirectiveClass, llvm::omp::OMPD_interchange, StartLoc, + EndLoc, NumLoops) { setNumGeneratedLoops(NumLoops); } diff --git a/clang/include/clang/Basic/OpenMPKinds.h b/clang/include/clang/Basic/OpenMPKinds.h index f40db4c13c55a..d3285cd9c6a14 100644 --- a/clang/include/clang/Basic/OpenMPKinds.h +++ b/clang/include/clang/Basic/OpenMPKinds.h @@ -365,6 +365,13 @@ bool isOpenMPTaskingDirective(OpenMPDirectiveKind Kind); /// functions bool isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind); +/// Checks if the specified directive is a loop transformation directive that +/// applies to a canonical loop nest. +/// \param DKind Specified directive. +/// \return True iff the directive is a loop transformation. +bool isOpenMPCanonicalLoopNestTransformationDirective( + OpenMPDirectiveKind DKind); + /// Checks if the specified directive is a loop transformation directive. /// \param DKind Specified directive. /// \return True iff the directive is a loop transformation. diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td index c9c173f5c7469..781577549573d 100644 --- a/clang/include/clang/Basic/StmtNodes.td +++ b/clang/include/clang/Basic/StmtNodes.td @@ -227,12 +227,14 @@ def OMPLoopBasedDirective : StmtNode<OMPExecutableDirective, 1>; def OMPLoopDirective : StmtNode<OMPLoopBasedDirective, 1>; def OMPParallelDirective : StmtNode<OMPExecutableDirective>; def OMPSimdDirective : StmtNode<OMPLoopDirective>; -def OMPLoopTransformationDirective : StmtNode<OMPLoopBasedDirective, 1>; -def OMPTileDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPStripeDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPUnrollDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPReverseDirective : StmtNode<OMPLoopTransformationDirective>; -def OMPInterchangeDirective : StmtNode<OMPLoopTransformationDirective>; +def OMPCanonicalLoopNestTransformationDirective + : StmtNode<OMPLoopBasedDirective, 1>; +def OMPTileDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPStripeDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPUnrollDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPReverseDirective : StmtNode<OMPCanonicalLoopNestTransformationDirective>; +def OMPInterchangeDirective + : StmtNode<OMPCanonicalLoopNestTransformationDirective>; def OMPForDirective : StmtNode<OMPLoopDirective>; def OMPForSimdDirective : StmtNode<OMPLoopDirective>; def OMPSectionsDirective : StmtNode<OMPExecutableDirective>; diff --git a/clang/lib/AST/StmtOpenMP.cpp b/clang/lib/AST/StmtOpenMP.cpp index 2eeb5e45ab511..36ecaf6489ef0 100644 --- a/clang/lib/AST/StmtOpenMP.cpp +++ b/clang/lib/AST/StmtOpenMP.cpp @@ -125,12 +125,13 @@ OMPLoopBasedDirective::tryToFindNextInnerLoop(Stmt *CurStmt, bool OMPLoopBasedDirective::doForAllLoops( Stmt *CurStmt, bool TryImperfectlyNestedLoops, unsigned NumLoops, llvm::function_ref<bool(unsigned, Stmt *)> Callback, - llvm::function_ref<void(OMPLoopTransformationDirective *)> + llvm::function_ref<void(OMPCanonicalLoopNestTransformationDirective *)> OnTransformationCallback) { CurStmt = CurStmt->IgnoreContainers(); for (unsigned Cnt = 0; Cnt < NumLoops; ++Cnt) { while (true) { - auto *Dir = dyn_cast<OMPLoopTransformationDirective>(CurStmt); + auto *Dir = + dyn_cast<OMPCanonicalLoopNestTransformationDirective>(CurStmt); if (!Dir) break; @@ -369,11 +370,11 @@ OMPForDirective *OMPForDirective::Create( return Dir; } -Stmt *OMPLoopTransformationDirective::getTransformedStmt() const { +Stmt *OMPCanonicalLoopNestTransformationDirective::getTransformedStmt() const { switch (getStmtClass()) { #define STMT(CLASS, PARENT) #define ABSTRACT_STMT(CLASS) -#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ +#define OMPCANONICALLOOPNESTTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ case Stmt::CLASS##Class: \ return static_cast<const CLASS *>(this)->getTransformedStmt(); #include "clang/AST/StmtNodes.inc" @@ -382,11 +383,11 @@ Stmt *OMPLoopTransformationDirective::getTransformedStmt() const { } } -Stmt *OMPLoopTransformationDirective::getPreInits() const { +Stmt *OMPCanonicalLoopNestTransformationDirective::getPreInits() const { switch (getStmtClass()) { #define STMT(CLASS, PARENT) #define ABSTRACT_STMT(CLASS) -#define OMPLOOPTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ +#define OMPCANONICALLOOPNESTTRANSFORMATIONDIRECTIVE(CLASS, PARENT) \ case Stmt::CLASS##Class: \ return static_cast<const CLASS *>(this)->getPreInits(); #include "clang/AST/StmtNodes.inc" diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 2035fa7635f2a..7a9b7fb431099 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -999,30 +999,30 @@ void StmtProfiler::VisitOMPSimdDirective(const OMPSimdDirective *S) { VisitOMPLoopDirective(S); } -void StmtProfiler::VisitOMPLoopTransformationDirective( - const OMPLoopTransformationDirective *S) { +void StmtProfiler::VisitOMPCanonicalLoopNestTransformationDirective( + const OMPCanonicalLoopNestTransformationDirective *S) { VisitOMPLoopBasedDirective(S); } void StmtProfiler::VisitOMPTileDirective(const OMPTileDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPStripeDirective(const OMPStripeDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPUnrollDirective(const OMPUnrollDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPReverseDirective(const OMPReverseDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPInterchangeDirective( const OMPInterchangeDirective *S) { - VisitOMPLoopTransformationDirective(S); + VisitOMPCanonicalLoopNestTransformationDirective(S); } void StmtProfiler::VisitOMPForDirective(const OMPForDirective *S) { diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp index 220b31b0f19bc..3f8f64df8702e 100644 --- a/clang/lib/Basic/OpenMPKinds.cpp +++ b/clang/lib/Basic/OpenMPKinds.cpp @@ -717,11 +717,17 @@ bool clang::isOpenMPLoopBoundSharingDirective(OpenMPDirectiveKind Kind) { Kind == OMPD_teams_loop || Kind == OMPD_target_teams_loop; } -bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) { +bool clang::isOpenMPCanonicalLoopNestTransformationDirective( + OpenMPDirectiveKind DKind) { return DKind == OMPD_tile || DKind == OMPD_unroll || DKind == OMPD_reverse || DKind == OMPD_interchange || DKind == OMPD_stripe; } +bool clang::isOpenMPLoopTransformationDirective(OpenMPDirectiveKind DKind) { + // FIXME: There will be more cases when we implement 'fuse'. + return isOpenMPCanonicalLoopNestTransformationDirective(DKind); +} + bool clang::isOpenMPCombinedParallelADirective(OpenMPDirectiveKind DKind) { return DKind == OMPD_parallel_for || DKind == OMPD_parallel_for_simd || DKind == OMPD_parallel_master || diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp index f6a0ca574a191..6f795b45bc381 100644 --- a/clang/lib/CodeGen/CGStmtOpenMP.cpp +++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp @@ -1927,7 +1927,8 @@ static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop, return; } if (SimplifiedS == NextLoop) { - if (auto *Dir = dyn_cast<OMPLoopTransformationDirective>(SimplifiedS)) + if (auto *Dir = + dyn_cast<OMPCanonicalLoopNestTransformationDirective>(SimplifiedS)) SimplifiedS = Dir->getTransformedStmt(); if (const auto *CanonLoop = dyn_cast<OMPCanonicalLoop>(SimplifiedS)) SimplifiedS = CanonLoop->getLoopStmt(); diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 7d800c446b595..a02850c66b4fe 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -4145,7 +4145,8 @@ class DSAAttrChecker final : public StmtVisitor<DSAAttrChecker, void> { VisitSubCaptures(S); } - void VisitOMPLoopTransformationDirective(OMPLoopTransformationDirective *S) { + void VisitOMPCanonicalLoopNestTransformationDirective( + OMPCanonicalLoopNestTransformationDirective *S) { // Loop transformation directives do not introduce data sharing VisitStmt(S); } @@ -9748,7 +9749,8 @@ checkOpenMPLoop(OpenMPDirectiveKind DKind, Expr *CollapseLoopCountExpr, } return false; }, - [&SemaRef, &Captures](OMPLoopTransformationDirective *Transform) { + [&SemaRef, + &Captures](OMPCanonicalLoopNestTransformationDirective *Transform) { Stmt *DependentPreInits = Transform->getPreInits(); if (!DependentPreInits) return; diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp index 3f37dfbc3dea9..13618b4a03d1e 100644 --- a/clang/lib/Serialization/ASTReaderStmt.cpp +++ b/clang/lib/Serialization/ASTReaderStmt.cpp @@ -2442,30 +2442,30 @@ void ASTStmtReader::VisitOMPSimdDirective(OMPSimdDirective *D) { VisitOMPLoopDirective(D); } -void ASTStmtReader::VisitOMPLoopTransformationDirective( - OMPLoopTransformationDirective *D) { +void ASTStmtReader::VisitOMPCanonicalLoopNestTransformationDirective( + OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); D->setNumGeneratedLoops(Record.readUInt32()); } void ASTStmtReader::VisitOMPTileDirective(OMPTileDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void ASTStmtReader::VisitOMPStripeDirective(OMPStripeDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void ASTStmtReader::VisitOMPUnrollDirective(OMPUnrollDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void ASTStmtReader::VisitOMPReverseDirective(OMPReverseDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void ASTStmtReader::VisitOMPInterchangeDirective(OMPInterchangeDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void ASTStmtReader::VisitOMPForDirective(OMPForDirective *D) { diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp index be9bad9e96cc1..36b022cc9d371 100644 --- a/clang/lib/Serialization/ASTWriterStmt.cpp +++ b/clang/lib/Serialization/ASTWriterStmt.cpp @@ -2445,34 +2445,34 @@ void ASTStmtWriter::VisitOMPSimdDirective(OMPSimdDirective *D) { Code = serialization::STMT_OMP_SIMD_DIRECTIVE; } -void ASTStmtWriter::VisitOMPLoopTransformationDirective( - OMPLoopTransformationDirective *D) { +void ASTStmtWriter::VisitOMPCanonicalLoopNestTransformationDirective( + OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); Record.writeUInt32(D->getNumGeneratedLoops()); } void ASTStmtWriter::VisitOMPTileDirective(OMPTileDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); Code = serialization::STMT_OMP_TILE_DIRECTIVE; } void ASTStmtWriter::VisitOMPStripeDirective(OMPStripeDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); Code = serialization::STMP_OMP_STRIPE_DIRECTIVE; } void ASTStmtWriter::VisitOMPUnrollDirective(OMPUnrollDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); Code = serialization::STMT_OMP_UNROLL_DIRECTIVE; } void ASTStmtWriter::VisitOMPReverseDirective(OMPReverseDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); Code = serialization::STMT_OMP_REVERSE_DIRECTIVE; } void ASTStmtWriter::VisitOMPInterchangeDirective(OMPInterchangeDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); Code = serialization::STMT_OMP_INTERCHANGE_DIRECTIVE; } diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 858423a06576a..b12b1f07c2f70 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -2154,8 +2154,8 @@ class EnqueueVisitor : public ConstStmtVisitor<EnqueueVisitor, void>, void VisitOMPLoopDirective(const OMPLoopDirective *D); void VisitOMPParallelDirective(const OMPParallelDirective *D); void VisitOMPSimdDirective(const OMPSimdDirective *D); - void - VisitOMPLoopTransformationDirective(const OMPLoopTransformationDirective *D); + void VisitOMPCanonicalLoopNestTransformationDirective( + const OMPCanonicalLoopNestTransformationDirective *D); void VisitOMPTileDirective(const OMPTileDirective *D); void VisitOMPStripeDirective(const OMPStripeDirective *D); void VisitOMPUnrollDirective(const OMPUnrollDirective *D); @@ -3301,30 +3301,30 @@ void EnqueueVisitor::VisitOMPSimdDirective(const OMPSimdDirective *D) { VisitOMPLoopDirective(D); } -void EnqueueVisitor::VisitOMPLoopTransformationDirective( - const OMPLoopTransformationDirective *D) { +void EnqueueVisitor::VisitOMPCanonicalLoopNestTransformationDirective( + const OMPCanonicalLoopNestTransformationDirective *D) { VisitOMPLoopBasedDirective(D); } void EnqueueVisitor::VisitOMPTileDirective(const OMPTileDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void EnqueueVisitor::VisitOMPStripeDirective(const OMPStripeDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void EnqueueVisitor::VisitOMPUnrollDirective(const OMPUnrollDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void EnqueueVisitor::VisitOMPReverseDirective(const OMPReverseDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void EnqueueVisitor::VisitOMPInterchangeDirective( const OMPInterchangeDirective *D) { - VisitOMPLoopTransformationDirective(D); + VisitOMPCanonicalLoopNestTransformationDirective(D); } void EnqueueVisitor::VisitOMPForDirective(const OMPForDirective *D) { From 280f7435a15900085ccf2731b1d4f26297455f71 Mon Sep 17 00:00:00 2001 From: Roger Ferrer Ibanez <roger.fer...@bsc.es> Date: Wed, 27 Aug 2025 08:18:02 +0000 Subject: [PATCH 2/2] [Clang][OpenMP] Add an additional class to hold data that will be shared between all loop transformations This class is not a statement by itself and its goal is to avoid having to replicate information between classes. Also simplify the way we handle the "generated loops" information as we currently only need to know if it is zero or non-zero. --- clang/include/clang/AST/StmtOpenMP.h | 48 +++++++++++++++------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/clang/include/clang/AST/StmtOpenMP.h b/clang/include/clang/AST/StmtOpenMP.h index a436676113921..602a516c0d43f 100644 --- a/clang/include/clang/AST/StmtOpenMP.h +++ b/clang/include/clang/AST/StmtOpenMP.h @@ -956,30 +956,42 @@ class OMPLoopBasedDirective : public OMPExecutableDirective { } }; +/// Common class of data shared between +/// OMPCanonicalLoopNestTransformationDirective and transformations over +/// canonical loop sequences. +class OMPLoopTransformationDirective { + /// Number of (top-level) generated loops. + /// This value is 1 for most transformations as they only map one loop nest + /// into another. + /// Some loop transformations (like a non-partial 'unroll') may not generate + /// a loop nest, so this would be 0. + /// Some loop transformations (like 'fuse' with looprange and 'split') may + /// generate more than one loop nest, so the value would be >= 1. + unsigned NumGeneratedLoops = 1; + +protected: + void setNumGeneratedLoops(unsigned N) { NumGeneratedLoops = N; } + +public: + unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; } +}; + /// The base class for all transformation directives of canonical loop nests. class OMPCanonicalLoopNestTransformationDirective - : public OMPLoopBasedDirective { + : public OMPLoopBasedDirective, + public OMPLoopTransformationDirective { friend class ASTStmtReader; - /// Number of loops generated by this loop transformation. - unsigned NumGeneratedLoops = 0; - protected: explicit OMPCanonicalLoopNestTransformationDirective( StmtClass SC, OpenMPDirectiveKind Kind, SourceLocation StartLoc, SourceLocation EndLoc, unsigned NumAssociatedLoops) : OMPLoopBasedDirective(SC, Kind, StartLoc, EndLoc, NumAssociatedLoops) {} - /// Set the number of loops generated by this loop transformation. - void setNumGeneratedLoops(unsigned Num) { NumGeneratedLoops = Num; } - public: /// Return the number of associated (consumed) loops. unsigned getNumAssociatedLoops() const { return getLoopsNumber(); } - /// Return the number of loops generated by this loop transformation. - unsigned getNumGeneratedLoops() const { return NumGeneratedLoops; } - /// Get the de-sugared statements after the loop transformation. /// /// Might be nullptr if either the directive generates no loops and is handled @@ -5560,9 +5572,7 @@ class OMPTileDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPTileDirectiveClass, llvm::omp::OMPD_tile, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5638,9 +5648,7 @@ class OMPStripeDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPStripeDirectiveClass, llvm::omp::OMPD_stripe, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(2 * NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5794,9 +5802,7 @@ class OMPReverseDirective final unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPReverseDirectiveClass, llvm::omp::OMPD_reverse, StartLoc, EndLoc, - NumLoops) { - setNumGeneratedLoops(NumLoops); - } + NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; @@ -5867,9 +5873,7 @@ class OMPInterchangeDirective final SourceLocation EndLoc, unsigned NumLoops) : OMPCanonicalLoopNestTransformationDirective( OMPInterchangeDirectiveClass, llvm::omp::OMPD_interchange, StartLoc, - EndLoc, NumLoops) { - setNumGeneratedLoops(NumLoops); - } + EndLoc, NumLoops) {} void setPreInits(Stmt *PreInits) { Data->getChildren()[PreInitsOffset] = PreInits; _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits