llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Shilei Tian (shiltian) <details> <summary>Changes</summary> --- Full diff: https://github.com/llvm/llvm-project/pull/99732.diff 7 Files Affected: - (modified) clang/include/clang/AST/OpenMPClause.h (+48-31) - (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1-1) - (modified) clang/include/clang/Sema/SemaOpenMP.h (+2-1) - (modified) clang/lib/AST/OpenMPClause.cpp (+23-3) - (modified) clang/lib/AST/StmtProfile.cpp (+1-2) - (modified) clang/lib/Parse/ParseOpenMP.cpp (+1-1) - (modified) clang/lib/Sema/SemaOpenMP.cpp (+23-18) ``````````diff diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index 325a1baa44614..2e82ccac28dc8 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -6131,43 +6131,54 @@ class OMPMapClause final : public OMPMappableExprListClause<OMPMapClause>, /// \endcode /// In this example directive '#pragma omp teams' has clause 'num_teams' /// with single expression 'n'. -class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit { - friend class OMPClauseReader; +/// +/// When 'ompx_bare' clause exists on a 'target' directive, 'num_teams' clause +/// can accept up to three expressions. +/// +/// \code +/// #pragma omp target teams ompx_bare num_teams(x, y, z) +/// \endcode +class OMPNumTeamsClause final + : public OMPVarListClause<OMPNumTeamsClause>, + public OMPClauseWithPreInit, + private llvm::TrailingObjects<OMPNumTeamsClause, Expr *> { + friend OMPVarListClause; + friend TrailingObjects; /// Location of '('. SourceLocation LParenLoc; - /// NumTeams number. - Stmt *NumTeams = nullptr; + OMPNumTeamsClause(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc, unsigned N) + : OMPVarListClause(llvm::omp::OMPC_num_teams, StartLoc, LParenLoc, EndLoc, + N), + OMPClauseWithPreInit(this) {} - /// Set the NumTeams number. - /// - /// \param E NumTeams number. - void setNumTeams(Expr *E) { NumTeams = E; } + /// Build an empty clause. + OMPNumTeamsClause(unsigned N) + : OMPVarListClause(llvm::omp::OMPC_num_teams, SourceLocation(), + SourceLocation(), SourceLocation(), N), + OMPClauseWithPreInit(this) {} public: - /// Build 'num_teams' clause. + /// Creates clause with a list of variables \a VL. /// - /// \param E Expression associated with this clause. - /// \param HelperE Helper Expression associated with this clause. - /// \param CaptureRegion Innermost OpenMP region where expressions in this - /// clause must be captured. + /// \param C AST context. /// \param StartLoc Starting location of the clause. /// \param LParenLoc Location of '('. /// \param EndLoc Ending location of the clause. - OMPNumTeamsClause(Expr *E, Stmt *HelperE, OpenMPDirectiveKind CaptureRegion, - SourceLocation StartLoc, SourceLocation LParenLoc, - SourceLocation EndLoc) - : OMPClause(llvm::omp::OMPC_num_teams, StartLoc, EndLoc), - OMPClauseWithPreInit(this), LParenLoc(LParenLoc), NumTeams(E) { - setPreInitStmt(HelperE, CaptureRegion); - } + /// \param VL List of references to the variables. + /// \param PreInit + static OMPNumTeamsClause *Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, + SourceLocation EndLoc, ArrayRef<Expr *> VL, + Stmt *PreInit); - /// Build an empty clause. - OMPNumTeamsClause() - : OMPClause(llvm::omp::OMPC_num_teams, SourceLocation(), - SourceLocation()), - OMPClauseWithPreInit(this) {} + /// Creates an empty clause with \a N variables. + /// + /// \param C AST context. + /// \param N The number of variables. + static OMPNumTeamsClause *CreateEmpty(const ASTContext &C, unsigned N); /// Sets the location of '('. void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; } @@ -6175,16 +6186,22 @@ class OMPNumTeamsClause : public OMPClause, public OMPClauseWithPreInit { /// Returns the location of '('. SourceLocation getLParenLoc() const { return LParenLoc; } - /// Return NumTeams number. - Expr *getNumTeams() { return cast<Expr>(NumTeams); } + /// Return NumTeams number. By default, we return the first expression. + Expr *getNumTeams() { return getVarRefs().front(); } - /// Return NumTeams number. - Expr *getNumTeams() const { return cast<Expr>(NumTeams); } + /// Return NumTeams number. By default, we return the first expression. + Expr *getNumTeams() const { + return const_cast<OMPNumTeamsClause *>(this)->getNumTeams(); + } - child_range children() { return child_range(&NumTeams, &NumTeams + 1); } + child_range children() { + return child_range(reinterpret_cast<Stmt **>(varlist_begin()), + reinterpret_cast<Stmt **>(varlist_end())); + } const_child_range children() const { - return const_child_range(&NumTeams, &NumTeams + 1); + auto Children = const_cast<OMPNumTeamsClause *>(this)->children(); + return const_child_range(Children.begin(), Children.end()); } child_range used_children() { diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index e3c0cb46799f7..beb7b3597c2a8 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -3793,8 +3793,8 @@ bool RecursiveASTVisitor<Derived>::VisitOMPMapClause(OMPMapClause *C) { template <typename Derived> bool RecursiveASTVisitor<Derived>::VisitOMPNumTeamsClause( OMPNumTeamsClause *C) { + TRY_TO(VisitOMPClauseList(C)); TRY_TO(VisitOMPClauseWithPreInit(C)); - TRY_TO(TraverseStmt(C->getNumTeams())); return true; } diff --git a/clang/include/clang/Sema/SemaOpenMP.h b/clang/include/clang/Sema/SemaOpenMP.h index 54d81f91ffebc..bf5fbc670b05c 100644 --- a/clang/include/clang/Sema/SemaOpenMP.h +++ b/clang/include/clang/Sema/SemaOpenMP.h @@ -1227,7 +1227,8 @@ class SemaOpenMP : public SemaBase { const OMPVarListLocTy &Locs, bool NoDiagnose = false, ArrayRef<Expr *> UnresolvedMappers = std::nullopt); /// Called on well-formed 'num_teams' clause. - OMPClause *ActOnOpenMPNumTeamsClause(Expr *NumTeams, SourceLocation StartLoc, + OMPClause *ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList, + SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc); /// Called on well-formed 'thread_limit' clause. diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp index 042a5df5906ca..ee9e9a0d39a92 100644 --- a/clang/lib/AST/OpenMPClause.cpp +++ b/clang/lib/AST/OpenMPClause.cpp @@ -1720,6 +1720,24 @@ const Expr *OMPDoacrossClause::getLoopData(unsigned NumLoop) const { return *It; } +OMPNumTeamsClause * +OMPNumTeamsClause::Create(const ASTContext &C, SourceLocation StartLoc, + SourceLocation LParenLoc, SourceLocation EndLoc, + ArrayRef<Expr *> VL, Stmt *PreInit) { + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(VL.size())); + OMPNumTeamsClause *Clause = + new (Mem) OMPNumTeamsClause(C, StartLoc, LParenLoc, EndLoc, VL.size()); + Clause->setVarRefs(VL); + Clause->setPreInitStmt(PreInit); + return Clause; +} + +OMPNumTeamsClause *OMPNumTeamsClause::CreateEmpty(const ASTContext &C, + unsigned N) { + void *Mem = C.Allocate(totalSizeToAlloc<Expr *>(N)); + return new (Mem) OMPNumTeamsClause(N); +} + //===----------------------------------------------------------------------===// // OpenMP clauses printing methods //===----------------------------------------------------------------------===// @@ -1977,9 +1995,11 @@ void OMPClausePrinter::VisitOMPDeviceClause(OMPDeviceClause *Node) { } void OMPClausePrinter::VisitOMPNumTeamsClause(OMPNumTeamsClause *Node) { - OS << "num_teams("; - Node->getNumTeams()->printPretty(OS, nullptr, Policy, 0); - OS << ")"; + if (!Node->varlist_empty()) { + OS << "num_teams"; + VisitOMPClauseList(Node, '('); + OS << ")"; + } } void OMPClausePrinter::VisitOMPThreadLimitClause(OMPThreadLimitClause *Node) { diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 89d2a422509d8..b782a4ab8367e 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -843,9 +843,8 @@ void OMPClauseProfiler::VisitOMPAllocateClause(const OMPAllocateClause *C) { VisitOMPClauseList(C); } void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) { + VisitOMPClauseList(C); VistOMPClauseWithPreInit(C); - if (C->getNumTeams()) - Profiler->VisitStmt(C->getNumTeams()); } void OMPClauseProfiler::VisitOMPThreadLimitClause( const OMPThreadLimitClause *C) { diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp index f5b44d210680c..e851bb4ac7fef 100644 --- a/clang/lib/Parse/ParseOpenMP.cpp +++ b/clang/lib/Parse/ParseOpenMP.cpp @@ -3098,7 +3098,6 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, case OMPC_simdlen: case OMPC_collapse: case OMPC_ordered: - case OMPC_num_teams: case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: @@ -3279,6 +3278,7 @@ OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind, case OMPC_affinity: case OMPC_doacross: case OMPC_enter: + case OMPC_num_teams: if (getLangOpts().OpenMP >= 52 && DKind == OMPD_ordered && CKind == OMPC_depend) Diag(Tok, diag::warn_omp_depend_in_ordered_deprecated); diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp index 3bd981cb442aa..a4e0ce730ae05 100644 --- a/clang/lib/Sema/SemaOpenMP.cpp +++ b/clang/lib/Sema/SemaOpenMP.cpp @@ -15041,9 +15041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_ordered: Res = ActOnOpenMPOrderedClause(StartLoc, EndLoc, LParenLoc, Expr); break; - case OMPC_num_teams: - Res = ActOnOpenMPNumTeamsClause(Expr, StartLoc, LParenLoc, EndLoc); - break; case OMPC_thread_limit: Res = ActOnOpenMPThreadLimitClause(Expr, StartLoc, LParenLoc, EndLoc); break; @@ -15147,6 +15144,7 @@ OMPClause *SemaOpenMP::ActOnOpenMPSingleExprClause(OpenMPClauseKind Kind, case OMPC_affinity: case OMPC_when: case OMPC_bind: + case OMPC_num_teams: default: llvm_unreachable("Clause is not allowed."); } @@ -17010,6 +17008,9 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, static_cast<OpenMPDoacrossClauseModifier>(ExtraModifier), ExtraModifierLoc, ColonLoc, VarList, StartLoc, LParenLoc, EndLoc); break; + case OMPC_num_teams: + Res = ActOnOpenMPNumTeamsClause(VarList, StartLoc, LParenLoc, EndLoc); + break; case OMPC_if: case OMPC_depobj: case OMPC_final: @@ -17040,7 +17041,6 @@ OMPClause *SemaOpenMP::ActOnOpenMPVarListClause(OpenMPClauseKind Kind, case OMPC_device: case OMPC_threads: case OMPC_simd: - case OMPC_num_teams: case OMPC_thread_limit: case OMPC_priority: case OMPC_grainsize: @@ -21703,32 +21703,37 @@ const ValueDecl *SemaOpenMP::getOpenMPDeclareMapperVarName() const { return cast<DeclRefExpr>(DSAStack->getDeclareMapperVarRef())->getDecl(); } -OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(Expr *NumTeams, +OMPClause *SemaOpenMP::ActOnOpenMPNumTeamsClause(ArrayRef<Expr *> VarList, SourceLocation StartLoc, SourceLocation LParenLoc, SourceLocation EndLoc) { - Expr *ValExpr = NumTeams; - Stmt *HelperValStmt = nullptr; - - // OpenMP [teams Constrcut, Restrictions] - // The num_teams expression must evaluate to a positive integer value. - if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams, - /*StrictlyPositive=*/true)) + if (VarList.empty()) return nullptr; OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective(); OpenMPDirectiveKind CaptureRegion = getOpenMPCaptureRegionForClause( DKind, OMPC_num_teams, getLangOpts().OpenMP); - if (CaptureRegion != OMPD_unknown && - !SemaRef.CurContext->isDependentContext()) { + + if (CaptureRegion == OMPD_unknown || SemaRef.CurContext->isDependentContext()) + return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, + EndLoc, VarList, /*PreInit=*/nullptr); + + llvm::MapVector<const Expr *, DeclRefExpr *> Captures; + SmallVector<Expr *, 3> Vars; + for (Expr *ValExpr : VarList) { + // OpenMP [teams Constrcut, Restrictions] + // The num_teams expression must evaluate to a positive integer value. + if (!isNonNegativeIntegerValue(ValExpr, SemaRef, OMPC_num_teams, + /*StrictlyPositive=*/true)) + return nullptr; ValExpr = SemaRef.MakeFullExpr(ValExpr).get(); - llvm::MapVector<const Expr *, DeclRefExpr *> Captures; ValExpr = tryBuildCapture(SemaRef, ValExpr, Captures).get(); - HelperValStmt = buildPreInits(getASTContext(), Captures); + Vars.push_back(ValExpr); } - return new (getASTContext()) OMPNumTeamsClause( - ValExpr, HelperValStmt, CaptureRegion, StartLoc, LParenLoc, EndLoc); + Stmt *PreInit = buildPreInits(getASTContext(), Captures); + return OMPNumTeamsClause::Create(getASTContext(), StartLoc, LParenLoc, EndLoc, + Vars, PreInit); } OMPClause *SemaOpenMP::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit, `````````` </details> https://github.com/llvm/llvm-project/pull/99732 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits