abidmalikwaterloo updated this revision to Diff 442661.
abidmalikwaterloo marked 14 inline comments as done.
abidmalikwaterloo added a comment.
Updated and clean the code based on the comments from
the reviewer!
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D122255/new/
https://reviews.llvm.org/D122255
Files:
clang/include/clang/AST/OpenMPClause.h
clang/include/clang/AST/RecursiveASTVisitor.h
clang/include/clang/AST/StmtOpenMP.h
clang/include/clang/Basic/DiagnosticSemaKinds.td
clang/include/clang/Parse/Parser.h
clang/include/clang/Sema/Sema.h
clang/lib/AST/OpenMPClause.cpp
clang/lib/AST/StmtPrinter.cpp
clang/lib/Parse/ParseOpenMP.cpp
clang/lib/Sema/SemaOpenMP.cpp
clang/lib/Sema/SemaStmt.cpp
clang/test/OpenMP/metadirective_ast_print_new_1.cpp
clang/test/OpenMP/metadirective_ast_print_new_2.cpp
clang/test/OpenMP/metadirective_ast_print_new_3.cpp
llvm/include/llvm/Frontend/OpenMP/OMPContext.h
llvm/lib/Frontend/OpenMP/OMPContext.cpp
Index: llvm/lib/Frontend/OpenMP/OMPContext.cpp
===================================================================
--- llvm/lib/Frontend/OpenMP/OMPContext.cpp
+++ llvm/lib/Frontend/OpenMP/OMPContext.cpp
@@ -20,6 +20,7 @@
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
+#include <map>
#define DEBUG_TYPE "openmp-ir-builder"
using namespace llvm;
@@ -339,6 +340,45 @@
return Score;
}
+/// Takes \p VMI and \p Ctx and sort the
+/// scores using \p VectorOfClauses
+void llvm::omp::getArrayVariantMatchForContext(const SmallVectorImpl<VariantMatchInfo> &VMIs,
+ const OMPContext &Ctx, SmallVector<std::pair<unsigned, APInt>> &VectorOfClauses){
+
+ //APInt BestScore(64, 0);
+ APInt Score (64, 0);
+ //The MapOfClauses will contain the index of the cluase and its context socre
+ llvm::DenseMap<unsigned, llvm::APInt> MapOfCluases;
+
+ for (unsigned u = 0, e = VMIs.size(); u < e; ++u) {
+ const VariantMatchInfo &VMI = VMIs[u];
+
+ SmallVector<unsigned, 8> ConstructMatches;
+ // If the variant is not applicable its not the best.
+ if (!isVariantApplicableInContextHelper(VMI, Ctx, &ConstructMatches,
+ /* DeviceSetOnly */ false)){
+ Score = 0;
+ // adding index and its corresdoning score
+ MapOfCluases.insert({u, Score});
+ continue;
+ }
+ // Else get the score
+ Score = getVariantMatchScore(VMI, Ctx, ConstructMatches);
+ MapOfCluases.insert({u, Score});
+ }
+
+ for (auto& it : MapOfCluases)
+ VectorOfClauses.push_back(it);
+
+ // The following Lamda will sort the VectorOfClauses based on the score
+ std::sort(VectorOfClauses.begin(), VectorOfClauses.end(), [] (std::pair<unsigned, APInt>&a,
+ std::pair<unsigned, APInt>&b){
+ return a.second.ugt(b.second);
+ });
+}
+
+
+
int llvm::omp::getBestVariantMatchForContext(
const SmallVectorImpl<VariantMatchInfo> &VMIs, const OMPContext &Ctx) {
Index: llvm/include/llvm/Frontend/OpenMP/OMPContext.h
===================================================================
--- llvm/include/llvm/Frontend/OpenMP/OMPContext.h
+++ llvm/include/llvm/Frontend/OpenMP/OMPContext.h
@@ -189,6 +189,15 @@
int getBestVariantMatchForContext(const SmallVectorImpl<VariantMatchInfo> &VMIs,
const OMPContext &Ctx);
+/// Sort array \p A of clause index with score
+/// This will be used to produce AST clauses
+/// in a sorted order with the clause with the highiest order
+/// on the top and default clause at the bottom
+void getArrayVariantMatchForContext(
+ const SmallVectorImpl<VariantMatchInfo> &VMIs, const OMPContext &Ctx,
+ SmallVector<std::pair<unsigned, APInt>> &A);
+
+// new--
} // namespace omp
template <> struct DenseMapInfo<omp::TraitProperty> {
Index: clang/test/OpenMP/metadirective_ast_print_new_3.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/metadirective_ast_print_new_3.cpp
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s
+// expected-no-diagnostics
+
+int main() {
+ int N = 15;
+#pragma omp metadirective when(user = {condition(N > 10)} : parallel for)\
+ default(target teams)
+ for (int i = 0; i < N; i++)
+ ;
+
+
+#pragma omp metadirective when(device = {arch("nvptx64")}, user = {condition(N >= 100)} : parallel for)\
+ default(target parallel for)
+ for (int i = 0; i < N; i++)
+ ;
+ return 0;
+}
+
+
+
+// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: parallel for) default(target teams)
+// CHECK: #pragma omp metadirective when(device={arch(nvptx64)}, user={condition(N >= 100)}: parallel for) default(target parallel for)
Index: clang/test/OpenMP/metadirective_ast_print_new_2.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/metadirective_ast_print_new_2.cpp
@@ -0,0 +1,29 @@
+// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s
+// expected-no-diagnostics
+
+void bar(){
+ int i=0;
+}
+
+void myfoo(void){
+
+ int N = 13;
+ int b,n;
+ int a[100];
+
+
+ #pragma omp metadirective when (user = {condition(N>10)}: target teams distribute parallel for ) \
+ when (user = {condition(N==10)}: parallel for )\
+ when (user = {condition(N==13)}: parallel for simd) \
+ when ( device={arch("arm")}: target teams num_teams(512) thread_limit(32))\
+ when ( device={arch("nvptx")}: target teams num_teams(512) thread_limit(32))\
+ default ( parallel for)\
+
+ { for (int i = 0; i < N; i++)
+ bar();
+ }
+}
+
+// CHECK: bar()
+// CHECK: myfoo
+// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams distribute parallel for) when(user={condition(N == 13)}: parallel for simd) when(device={arch(nvptx)}: target teams)
Index: clang/test/OpenMP/metadirective_ast_print_new_1.cpp
===================================================================
--- /dev/null
+++ clang/test/OpenMP/metadirective_ast_print_new_1.cpp
@@ -0,0 +1,20 @@
+// RUN: %clang_cc1 -verify -fopenmp -ast-print %s -o - | FileCheck %s
+// expected-no-diagnostics
+void bar(){
+ int i=0;
+}
+
+void myfoo(void){
+
+ int N = 13;
+ int b,n;
+ int a[100];
+
+ #pragma omp metadirective when(user={condition(N>10)}: target teams ) default(parallel for)
+ for (int i = 0; i < N; i++)
+ bar();
+
+}
+
+// CHECK: void bar()
+// CHECK: #pragma omp metadirective when(user={condition(N > 10)}: target teams) default(parallel for)
Index: clang/lib/Sema/SemaStmt.cpp
===================================================================
--- clang/lib/Sema/SemaStmt.cpp
+++ clang/lib/Sema/SemaStmt.cpp
@@ -4791,8 +4791,8 @@
CapturedStmt *Res = CapturedStmt::Create(
getASTContext(), S, static_cast<CapturedRegionKind>(RSI->CapRegionKind),
Captures, CaptureInits, CD, RD);
-
- CD->setBody(Res->getCapturedStmt());
+
+ CD->setBody(Res->getCapturedStmt());
RD->completeDefinition();
return Res;
Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -37,6 +37,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPContext.h"
#include <set>
using namespace clang;
@@ -3930,6 +3931,7 @@
void Sema::ActOnOpenMPRegionStart(OpenMPDirectiveKind DKind, Scope *CurScope) {
switch (DKind) {
+ case OMPD_metadirective:
case OMPD_parallel:
case OMPD_parallel_for:
case OMPD_parallel_for_simd:
@@ -4339,7 +4341,6 @@
case OMPD_declare_variant:
case OMPD_begin_declare_variant:
case OMPD_end_declare_variant:
- case OMPD_metadirective:
llvm_unreachable("OpenMP Directive is not allowed");
case OMPD_unknown:
default:
@@ -4521,7 +4522,7 @@
}
StmtResult Sema::ActOnOpenMPRegionEnd(StmtResult S,
- ArrayRef<OMPClause *> Clauses) {
+ ArrayRef<OMPClause *> Clauses){
handleDeclareVariantConstructTrait(DSAStack, DSAStack->getCurrentDirective(),
/* ScopeEntry */ false);
if (DSAStack->getCurrentDirective() == OMPD_atomic ||
@@ -4611,6 +4612,7 @@
<< SourceRange(OC->getBeginLoc(), OC->getEndLoc());
ErrorFound = true;
}
+
// OpenMP 5.0, 2.9.2 Worksharing-Loop Construct, Restrictions.
// If an order(concurrent) clause is present, an ordered clause may not appear
// on the same directive.
@@ -4623,6 +4625,7 @@
}
ErrorFound = true;
}
+
if (isOpenMPWorksharingDirective(DSAStack->getCurrentDirective()) &&
isOpenMPSimdDirective(DSAStack->getCurrentDirective()) && OC &&
OC->getNumForLoops()) {
@@ -4635,7 +4638,9 @@
}
StmtResult SR = S;
unsigned CompletedRegions = 0;
+
for (OpenMPDirectiveKind ThisCaptureRegion : llvm::reverse(CaptureRegions)) {
+
// Mark all variables in private list clauses as used in inner region.
// Required for proper codegen of combined directives.
// TODO: add processing for other clauses.
@@ -4656,6 +4661,7 @@
}
}
}
+
if (ThisCaptureRegion == OMPD_target) {
// Capture allocator traits in the target region. They are used implicitly
// and, thus, are not captured by default.
@@ -4671,6 +4677,7 @@
}
}
}
+
if (ThisCaptureRegion == OMPD_parallel) {
// Capture temp arrays for inscan reductions and locals in aligned
// clauses.
@@ -4687,10 +4694,14 @@
}
}
}
+
if (++CompletedRegions == CaptureRegions.size())
DSAStack->setBodyComplete();
+
SR = ActOnCapturedRegionEnd(SR.get());
+
}
+
return SR;
}
@@ -5963,6 +5974,12 @@
llvm::SmallVector<OpenMPDirectiveKind, 4> AllowedNameModifiers;
switch (Kind) {
+
+ case OMPD_metadirective:
+ Res = ActOnOpenMPMetaDirective(ClausesWithImplicit, AStmt, StartLoc,
+ EndLoc);
+ AllowedNameModifiers.push_back(OMPD_metadirective);
+ break;
case OMPD_parallel:
Res = ActOnOpenMPParallelDirective(ClausesWithImplicit, AStmt, StartLoc,
EndLoc);
@@ -7342,10 +7359,118 @@
FD->addAttr(NewAttr);
}
+StmtResult Sema::ActOnOpenMPMetaDirective(ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt,
+ SourceLocation StartLoc,
+ SourceLocation EndLoc){
+ if (!AStmt)
+ return StmtError();
+
+ auto *CS = cast<CapturedStmt>(AStmt);
+
+ CS->getCapturedDecl()->setNothrow();
+
+ StmtResult IfStmt = StmtError();
+ Stmt *ElseStmt = nullptr;
+
+ for (auto i = Clauses.rbegin(); i < Clauses.rend(); i++) {
+ OMPWhenClause *WhenClause = dyn_cast<OMPWhenClause>(*i);
+ Expr *WhenCondExpr = nullptr;
+ Stmt *ThenStmt = nullptr;
+ OpenMPDirectiveKind DKind = WhenClause->getDKind();
+
+ if (DKind != OMPD_unknown)
+ ThenStmt = CompoundStmt::Create(Context, {WhenClause->getDirective()},
+ SourceLocation(), SourceLocation());
+
+ for (const OMPTraitSet &Set : WhenClause->getTI().Sets){
+ for (const OMPTraitSelector &Selector : Set.Selectors){
+ switch (Selector.Kind){
+ case TraitSelector::device_arch:{
+ bool archMatch = false;
+ for (const OMPTraitProperty &Property : Selector.Properties){
+ for (auto &T : getLangOpts().OMPTargetTriples){
+ if (T.getArchName() == Property.RawString){
+ archMatch = true;
+ break;
+ }
+ }
+ if (archMatch)
+ break;
+ }
+ // Create a true/false boolean expression and assign to WhenCondExpr
+ auto *C = new (Context)
+ CXXBoolLiteralExpr(archMatch, Context.BoolTy, StartLoc);
+ WhenCondExpr = dyn_cast<Expr>(C);
+ break;
+ }
+ case TraitSelector::user_condition:{
+ assert(Selector.ScoreOrCondition &&
+ "Ill-formed user condition, expected condition expression!");
+
+ WhenCondExpr = Selector.ScoreOrCondition;
+ break;
+ }
+ case TraitSelector::implementation_vendor:{
+ bool vendorMatch = false;
+ for (const OMPTraitProperty &Property : Selector.Properties){
+ for (auto &T : getLangOpts().OMPTargetTriples){
+ if (T.getVendorName() == Property.RawString){
+ vendorMatch = true;
+ break;
+ }
+ }
+ if (vendorMatch)
+ break;
+ }
+ // Create a true/false boolean expression and assign to WhenCondExpr
+ auto *WhenCondition = new (Context)
+ CXXBoolLiteralExpr(vendorMatch, Context.BoolTy, StartLoc);
+ WhenCondExpr = dyn_cast<Expr>(WhenCondition);
+ break;
+ }
+ case TraitSelector::device_isa:
+ case TraitSelector::device_kind:
+ case TraitSelector::implementation_extension:
+ default:
+ break;
+ }
+ }
+ }
+
+ if (WhenCondExpr == nullptr) {
+ if (ElseStmt != nullptr) {
+ Diag(WhenClause->getBeginLoc(), diag::err_omp_misplaced_default_clause);
+ return StmtError();
+ }
+ if (DKind == OMPD_unknown)
+ ElseStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()},
+ SourceLocation(), SourceLocation());
+ else
+ ElseStmt = ThenStmt;
+ continue;
+ }
+
+ if (ThenStmt == NULL)
+ ThenStmt = CompoundStmt::Create(Context, {CS->getCapturedStmt()},
+ SourceLocation(), SourceLocation());
+
+ IfStmt =
+ ActOnIfStmt(SourceLocation(), /*false*/ IfStatementKind::Ordinary, SourceLocation(), nullptr,
+ ActOnCondition(getCurScope(), SourceLocation(),
+ WhenCondExpr, Sema::ConditionKind::Boolean),
+ SourceLocation(), ThenStmt, SourceLocation(), ElseStmt);
+ ElseStmt = IfStmt.get();
+ }
+
+ return OMPMetaDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
+ IfStmt.get());
+}
+
StmtResult Sema::ActOnOpenMPParallelDirective(ArrayRef<OMPClause *> Clauses,
Stmt *AStmt,
SourceLocation StartLoc,
- SourceLocation EndLoc) {
+ SourceLocation EndLoc){
if (!AStmt)
return StmtError();
@@ -14837,6 +14962,17 @@
return std::string(Out.str());
}
+OMPClause *
+Sema::ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind,
+ StmtResult Directive, SourceLocation StartLoc,
+ SourceLocation LParenLoc, SourceLocation EndLoc) {
+ return new (Context)
+ OMPWhenClause(TI, DKind, Directive.get(), StartLoc, LParenLoc, EndLoc);
+}
+
+
+
+
OMPClause *Sema::ActOnOpenMPDefaultClause(DefaultKind Kind,
SourceLocation KindKwLoc,
SourceLocation StartLoc,
Index: clang/lib/Parse/ParseOpenMP.cpp
===================================================================
--- clang/lib/Parse/ParseOpenMP.cpp
+++ clang/lib/Parse/ParseOpenMP.cpp
@@ -2458,7 +2458,7 @@
bool HasAssociatedStatement = true;
switch (DKind) {
- case OMPD_metadirective: {
+ case OMPD_metadirective:{
ConsumeToken();
SmallVector<VariantMatchInfo, 4> VMIs;
@@ -2470,10 +2470,12 @@
BalancedDelimiterTracker T(*this, tok::l_paren,
tok::annot_pragma_openmp_end);
- while (Tok.isNot(tok::annot_pragma_openmp_end)) {
+
+ while (Tok.isNot(tok::annot_pragma_openmp_end)){
OpenMPClauseKind CKind = Tok.isAnnotation()
? OMPC_unknown
: getOpenMPClauseKind(PP.getSpelling(Tok));
+
SourceLocation Loc = ConsumeToken();
// Parse '('.
@@ -2491,7 +2493,7 @@
return Directive;
}
- // Parse ':'
+ // Parse ':' // You have parsed the OpenMP Context in the meta directive
if (Tok.is(tok::colon))
ConsumeAnyToken();
else {
@@ -2500,6 +2502,7 @@
return Directive;
}
}
+
// Skip Directive for now. We will parse directive in the second iteration
int paren = 0;
while (Tok.isNot(tok::r_paren) || paren != 0) {
@@ -2513,86 +2516,112 @@
TPA.Commit();
return Directive;
}
- ConsumeAnyToken();
- }
+ ConsumeAnyToken();
+ }
+
// Parse ')'
if (Tok.is(tok::r_paren))
T.consumeClose();
-
+
VariantMatchInfo VMI;
TI.getAsVariantMatchInfo(ASTContext, VMI);
-
- VMIs.push_back(VMI);
- }
-
+
+ if (CKind == OMPC_when )
+ VMIs.push_back(VMI);
+ }
+
+ // This is the end of the first iteration
+ // The pointer is moved back
TPA.Revert();
// End of the first iteration. Parser is reset to the start of metadirective
-
+
TargetOMPContext OMPCtx(ASTContext, /* DiagUnknownTrait */ nullptr,
/* CurrentFunctionDecl */ nullptr,
ArrayRef<llvm::omp::TraitProperty>());
-
- // A single match is returned for OpenMP 5.0
- int BestIdx = getBestVariantMatchForContext(VMIs, OMPCtx);
-
- int Idx = 0;
- // In OpenMP 5.0 metadirective is either replaced by another directive or
- // ignored.
- // TODO: In OpenMP 5.1 generate multiple directives based upon the matches
- // found by getBestWhenMatchForContext.
- while (Tok.isNot(tok::annot_pragma_openmp_end)) {
- // OpenMP 5.0 implementation - Skip to the best index found.
- if (Idx++ != BestIdx) {
- ConsumeToken(); // Consume clause name
- T.consumeOpen(); // Consume '('
- int paren = 0;
- // Skip everything inside the clause
- while (Tok.isNot(tok::r_paren) || paren != 0) {
- if (Tok.is(tok::l_paren))
- paren++;
- if (Tok.is(tok::r_paren))
- paren--;
- ConsumeAnyToken();
- }
- // Parse ')'
- if (Tok.is(tok::r_paren))
- T.consumeClose();
- continue;
- }
-
+
+ // Array SortedClauses will be used for sorting clauses
+ // based on the context selector score
+ SmallVector<std::pair<unsigned, llvm::APInt>> SortedCluases;
+
+ // The function will get the score for each clause and sort it
+ // based on the score number
+
+ getArrayVariantMatchForContext(VMIs, OMPCtx, SortedCluases) ;
+
+ ParseScope OMPDirectiveScope(this, ScopeFlags);
+ Actions.StartOpenMPDSABlock(DKind, DirName, Actions.getCurScope(), Loc);
+
+ while(Tok.isNot(tok::annot_pragma_openmp_end)){
+
OpenMPClauseKind CKind = Tok.isAnnotation()
? OMPC_unknown
: getOpenMPClauseKind(PP.getSpelling(Tok));
- SourceLocation Loc = ConsumeToken();
-
- // Parse '('.
- T.consumeOpen();
-
- // Skip ContextSelectors for when clause
- if (CKind == OMPC_when) {
- OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
- // parse and skip the ContextSelectors
- parseOMPContextSelectors(Loc, TI);
-
- // Parse ':'
- ConsumeAnyToken();
- }
-
- // If no directive is passed, skip in OpenMP 5.0.
- // TODO: Generate nothing directive from OpenMP 5.1.
- if (Tok.is(tok::r_paren)) {
- SkipUntil(tok::annot_pragma_openmp_end);
- break;
- }
-
- // Parse Directive
- ReadDirectiveWithinMetadirective = true;
- Directive = ParseOpenMPDeclarativeOrExecutableDirective(StmtCtx);
- ReadDirectiveWithinMetadirective = false;
- break;
+
+ Actions.StartOpenMPClause(CKind);
+ OMPClause *Clause = ParseOpenMPMetaDirectiveClause( DKind, CKind);
+
+ FirstClauses[(unsigned) CKind].setInt(true);
+ if (Clause) {
+ FirstClauses[(unsigned) CKind].setPointer(Clause);
+ Clauses.push_back(Clause);
+ }// end of if statement
+
+ if (Tok.is(tok::comma))
+ ConsumeToken();
+
+ Actions.EndOpenMPClause();
+
+ if (Tok.is(tok::r_paren))
+ ConsumeAnyToken();
+
}
- break;
- }
+
+ // End location of the directive
+ EndLoc = Tok.getLocation();
+
+ //Consume final annot_pragma_openmp_end
+ ConsumeAnnotationToken();
+
+ SmallVector<OMPClause *, 5> Clauses_new;
+ unsigned count = 0;
+
+ // SortedClauses has index and score, and are sorted with respect to the
+ // the context score. The first iteration will take each element. The
+ // first element will have the highiest score. The element will have the
+ // index of the cluase for the best score. The second iteration tries to
+ // find that specific clause by checking the count numder with the
+ // index (Iteration1.first)
+ for ( auto &Iteration1 : SortedCluases){
+ count = 0;
+ for ( auto &Iteration2 : Clauses){
+ if ( count == Iteration1.first ){
+ Clauses_new.push_back(Iteration2);
+ break;
+ } else count++;
+ }
+ }
+ // Adding the default clasue at the end
+ Clauses_new.push_back(Clauses.back());
+
+ // Parsing the OpenMP region which will take the
+ // metadirective
+
+ Actions.ActOnOpenMPRegionStart(DKind, getCurScope());
+ ParsingOpenMPDirectiveRAII NormalScope(*this, /*value=*/ false);
+ // This is parsing the region
+ StmtResult AStmt = ParseStatement();
+
+ StmtResult AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt);
+ // Ending of the parallel region
+ AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses_new);
+ Directive = Actions.ActOnOpenMPExecutableDirective(
+ DKind, DirName, CancelRegion, Clauses_new, AssociatedStmt.get(), Loc,
+ EndLoc);
+ // Exit scope
+ Actions.EndOpenMPDSABlock(Directive.get());
+ OMPDirectiveScope.Exit();
+ break;
+ } // end of case OMPD_metadirective:
case OMPD_threadprivate: {
// FIXME: Should this be permitted in C++?
if ((StmtCtx & ParsedStmtContext::AllowDeclarationsInC) ==
@@ -3050,6 +3079,164 @@
return Actions.ActOnOpenMPUsesAllocatorClause(Loc, T.getOpenLocation(),
T.getCloseLocation(), Data);
}
+/// Parsing of OpenMP MetaDirective Clauses
+
+OMPClause *Parser::ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind,
+ OpenMPClauseKind CKind) {
+ OMPClause *Clause = nullptr;
+ bool ErrorFound = false;
+ bool WrongDirective = false;
+ SmallVector<llvm::PointerIntPair<OMPClause *, 1, bool>,
+ llvm::omp::Clause_enumSize + 1>
+ FirstClauses(llvm::omp::Clause_enumSize + 1);
+
+ // Check if it is called from metadirective.
+ if (DKind != OMPD_metadirective) {
+ Diag(Tok, diag::err_omp_unexpected_clause)
+ << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+ ErrorFound = true;
+ }
+
+ // Check if clause is allowed for the given directive.
+ if (CKind != OMPC_unknown &&
+ !isAllowedClauseForDirective(DKind, CKind, getLangOpts().OpenMP)) {
+ Diag(Tok, diag::err_omp_unexpected_clause)
+ << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+ ErrorFound = true;
+ WrongDirective = true;
+ }
+
+ // Check if clause is not allowed
+ if (CKind == OMPC_unknown) {
+ Diag(Tok, diag::err_omp_unexpected_clause)
+ << getOpenMPClauseName(CKind) << "Unknown clause: Not allowed";
+ ErrorFound = true;
+ WrongDirective = true;
+ }
+
+ if (CKind == OMPC_default || CKind == OMPC_when) {
+ SourceLocation Loc = ConsumeToken();
+ SourceLocation DelimLoc;
+ // Parse '('.
+ BalancedDelimiterTracker T(*this, tok::l_paren,
+ tok::annot_pragma_openmp_end);
+ if (T.expectAndConsume(diag::err_expected_lparen_after,
+ getOpenMPClauseName(CKind).data()))
+ return nullptr;
+
+ OMPTraitInfo &TI = Actions.getASTContext().getNewOMPTraitInfo();
+ if (CKind == OMPC_when) {
+ // parse and get condition expression to pass to the When clause
+ parseOMPContextSelectors(Loc, TI);
+
+ // Parse ':'
+ if (Tok.is(tok::colon))
+ ConsumeAnyToken();
+ else {
+ Diag(Tok, diag::warn_pragma_expected_colon) << "when clause";
+ return nullptr;
+ }
+ }
+
+ // Parse Directive
+ OpenMPDirectiveKind DirKind = OMPD_unknown;
+ SmallVector<OMPClause *, 5> Clauses;
+ StmtResult AssociatedStmt;
+ StmtResult Directive = StmtError();
+
+ if (Tok.isNot(tok::r_paren)) {
+ ParsingOpenMPDirectiveRAII DirScope(*this);
+ ParenBraceBracketBalancer BalancerRAIIObj(*this);
+ DeclarationNameInfo DirName;
+ unsigned ScopeFlags = Scope::FnScope | Scope::DeclScope |
+ Scope::CompoundStmtScope |
+ Scope::OpenMPDirectiveScope;
+
+ DirKind = parseOpenMPDirectiveKind(*this);
+ ConsumeToken();
+ ParseScope OMPDirectiveScope(this, ScopeFlags);
+ Actions.StartOpenMPDSABlock(DirKind, DirName, Actions.getCurScope(), Loc);
+
+ int paren = 0;
+
+ while (Tok.isNot(tok::r_paren) || paren != 0) {
+ if (Tok.is(tok::l_paren))
+ paren++;
+ if (Tok.is(tok::r_paren))
+ paren--;
+
+ OpenMPClauseKind CKind = Tok.isAnnotation()
+ ? OMPC_unknown
+ : getOpenMPClauseKind(PP.getSpelling(Tok));
+
+ if (CKind == OMPC_unknown &&
+ !isAllowedClauseForDirective(DirKind, CKind, getLangOpts().OpenMP)) {
+ Diag(Tok, diag::err_omp_unexpected_clause)
+ << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+ ErrorFound = true;
+ WrongDirective = true;
+ }
+
+ Actions.StartOpenMPClause(CKind);
+ OMPClause *Clause = ParseOpenMPClause(
+ DirKind, CKind, !FirstClauses[(unsigned)CKind].getInt());
+ FirstClauses[(unsigned)CKind].setInt(true);
+ if (Clause) {
+ FirstClauses[(unsigned)CKind].setPointer(Clause);
+ Clauses.push_back(Clause);
+ }
+
+ // Skip ',' if any.
+ if (Tok.is(tok::comma))
+ ConsumeToken();
+ Actions.EndOpenMPClause();
+ }
+
+ Actions.ActOnOpenMPRegionStart(DirKind, getCurScope());
+ ParsingOpenMPDirectiveRAII NormalScope(*this, /*Value=*/false);
+
+ /* Get Stmt and revert back */
+ TentativeParsingAction TPA(*this);
+ while (Tok.isNot(tok::annot_pragma_openmp_end)) {
+ ConsumeAnyToken();
+ }
+
+ ConsumeAnnotationToken();
+ ParseScope InnerStmtScope(this, Scope::DeclScope,
+ getLangOpts().C99 || getLangOpts().CPlusPlus,
+ Tok.is(tok::l_brace));
+
+ StmtResult AStmt = ParseStatement();
+ InnerStmtScope.Exit();
+ TPA.Revert();
+ /* End Get Stmt */
+
+ AssociatedStmt = (Sema::CompoundScopeRAII(Actions), AStmt);
+ AssociatedStmt = Actions.ActOnOpenMPRegionEnd(AssociatedStmt, Clauses);
+
+ Directive = Actions.ActOnOpenMPExecutableDirective(
+ DirKind, DirName, OMPD_unknown, llvm::makeArrayRef(Clauses),
+ AssociatedStmt.get(), Loc, Tok.getLocation());
+
+ Actions.EndOpenMPDSABlock(Directive.get());
+ OMPDirectiveScope.Exit();
+ }
+ // Parse ')'
+ T.consumeClose();
+
+ if (WrongDirective)
+ return nullptr;
+
+ Clause = Actions.ActOnOpenMPWhenClause(TI, DirKind, Directive, Loc,
+ DelimLoc, Tok.getLocation());
+ } else {
+ ErrorFound = false;
+ Diag(Tok, diag::err_omp_unexpected_clause)
+ << getOpenMPClauseName(CKind) << getOpenMPDirectiveName(DKind);
+ }
+
+ return ErrorFound ? nullptr : Clause;
+}
/// Parsing of OpenMP clauses.
///
Index: clang/lib/AST/StmtPrinter.cpp
===================================================================
--- clang/lib/AST/StmtPrinter.cpp
+++ clang/lib/AST/StmtPrinter.cpp
@@ -657,11 +657,23 @@
bool ForceNoStmt) {
OMPClausePrinter Printer(OS, Policy);
ArrayRef<OMPClause *> Clauses = S->clauses();
- for (auto *Clause : Clauses)
- if (Clause && !Clause->isImplicit()) {
+
+ for (auto *Clause : Clauses){
+ if (Clause && !Clause->isImplicit()){
OS << ' ';
Printer.Visit(Clause);
- }
+ if (isa<OMPMetaDirective>(S)){
+ OMPWhenClause *WhenClause = dyn_cast<OMPWhenClause>(Clause);
+ if (WhenClause!=nullptr){
+ if (WhenClause->getDKind() != llvm::omp::OMPD_unknown){
+ Printer.VisitOMPWhenClause(WhenClause);
+ OS << llvm::omp::getOpenMPDirectiveName(WhenClause->getDKind());
+ }
+ OS << ")";
+ }
+ }
+ }
+ }
OS << NL;
if (!ForceNoStmt && S->hasAssociatedStmt())
PrintStmt(S->getRawStmt());
Index: clang/lib/AST/OpenMPClause.cpp
===================================================================
--- clang/lib/AST/OpenMPClause.cpp
+++ clang/lib/AST/OpenMPClause.cpp
@@ -1609,6 +1609,75 @@
// OpenMP clauses printing methods
//===----------------------------------------------------------------------===//
+void OMPClausePrinter::VisitOMPWhenClause(OMPWhenClause *Node) {
+
+ if (Node->getTI().Sets.size() == 0) {
+ OS << "default(";
+ return;
+ }
+ OS << "when(";
+ int count = 0;
+ for (const OMPTraitSet &Set : Node->getTI().Sets) {
+ if (count == 0)
+ count++;
+ else
+ OS << ", ";
+ for (const OMPTraitSelector &Selector : Set.Selectors) {
+ switch (Selector.Kind) {
+ case TraitSelector::device_kind: {
+ OS << "device={kind(";
+ for (const OMPTraitProperty &Property : Selector.Properties) {
+ OS << Property.RawString;
+ }
+ OS << ")}";
+ break;
+ }
+ case TraitSelector::device_arch: {
+ OS << "device={arch(";
+ for (const OMPTraitProperty &Property : Selector.Properties) {
+ OS << Property.RawString;
+ }
+ OS << ")}";
+ break;
+ }
+ case TraitSelector::device_isa: {
+ OS << "device={isa(";
+ for (const OMPTraitProperty &Property : Selector.Properties) {
+ OS << Property.RawString;
+ }
+ OS << ")}";
+ break;
+ }
+ case TraitSelector::implementation_vendor: {
+ OS << "implementation={vendor(";
+ for (const OMPTraitProperty &Property : Selector.Properties) {
+ OS << Property.RawString;
+ }
+ OS << ")}";
+ break;
+ }
+ case TraitSelector::implementation_extension: {
+ OS << "implementation={extension(";
+ for (const OMPTraitProperty &Property : Selector.Properties) {
+ OS << Property.RawString;
+ }
+ OS << ")}";
+ break;
+ }
+ case TraitSelector::user_condition: {
+ OS << "user={condition(";
+ Selector.ScoreOrCondition->printPretty(OS, nullptr, Policy, 0);
+ OS << ")}";
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ }
+ OS << ": ";
+}
+
void OMPClausePrinter::VisitOMPIfClause(OMPIfClause *Node) {
OS << "if(";
if (Node->getNameModifier() != OMPD_unknown)
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -66,6 +66,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TinyPtrVector.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Frontend/OpenMP/OMPContext.h"
#include <deque>
#include <memory>
#include <string>
@@ -10678,10 +10679,18 @@
///
/// \returns Statement for finished OpenMP region.
StmtResult ActOnOpenMPRegionEnd(StmtResult S, ArrayRef<OMPClause *> Clauses);
+
+ /// Called on well-formed
StmtResult ActOnOpenMPExecutableDirective(
OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc);
+
+ /// Called on meta directive
+ StmtResult ActOnOpenMPExecutableMetaDirective(
+ OpenMPDirectiveKind Kind, const DeclarationNameInfo &DirName,
+ OpenMPDirectiveKind CancelRegion, ArrayRef<OMPClause *> Clauses,
+ Stmt *AStmt, SourceLocation StartLoc, SourceLocation EndLoc);
/// Called on well-formed '\#pragma omp parallel' after parsing
/// of the associated statement.
StmtResult ActOnOpenMPParallelDirective(ArrayRef<OMPClause *> Clauses,
@@ -11127,7 +11136,9 @@
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'when' clause.
- OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, SourceLocation StartLoc,
+ OMPClause *ActOnOpenMPWhenClause(OMPTraitInfo &TI, OpenMPDirectiveKind DKind,
+ StmtResult Directive,
+ SourceLocation StartLoc,
SourceLocation LParenLoc,
SourceLocation EndLoc);
/// Called on well-formed 'default' clause.
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -3291,6 +3291,13 @@
/// \param StmtCtx The context in which we're parsing the directive.
StmtResult
ParseOpenMPDeclarativeOrExecutableDirective(ParsedStmtContext StmtCtx);
+ /// Parse clause for metadirective
+ ///
+ /// \param Dkind Kind of current directive
+ /// \param CKind Kind of current clause
+ ///
+ OMPClause *ParseOpenMPMetaDirectiveClause(OpenMPDirectiveKind DKind,
+ OpenMPClauseKind CKind);
/// Parses clause of kind \a CKind for directive of a kind \a Kind.
///
/// \param DKind Kind of current directive.
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10848,6 +10848,8 @@
"'%0' clause requires 'dispatch' context selector">;
def err_omp_append_args_with_varargs : Error<
"'append_args' is not allowed with varargs functions">;
+def err_omp_misplaced_default_clause : Error<
+ "Only one default clause is allowed in">;
} // end of OpenMP category
let CategoryName = "Related Result Type Issue" in {
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -5475,7 +5475,7 @@
ArrayRef<OMPClause *> Clauses,
Stmt *AssociatedStmt, Stmt *IfStmt);
static OMPMetaDirective *CreateEmpty(const ASTContext &C, unsigned NumClauses,
- EmptyShell);
+ EmptyShell);
Stmt *getIfStmt() const { return IfStmt; }
static bool classof(const Stmt *T) {
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -501,7 +501,8 @@
/// Process clauses with pre-initis.
bool VisitOMPClauseWithPreInit(OMPClauseWithPreInit *Node);
bool VisitOMPClauseWithPostUpdate(OMPClauseWithPostUpdate *Node);
-
+ bool VisitOMPWhenClause(OMPWhenClause *C);
+
bool PostVisitStmt(Stmt *S);
};
@@ -3136,6 +3137,18 @@
return true;
}
+template <typename Derived>
+bool RecursiveASTVisitor<Derived>::VisitOMPWhenClause(OMPWhenClause *C) {
+ for (const OMPTraitSet &Set : C->getTI().Sets) {
+ for (const OMPTraitSelector &Selector : Set.Selectors) {
+ if (Selector.Kind == llvm::omp::TraitSelector::user_condition &&
+ Selector.ScoreOrCondition)
+ TRY_TO(TraverseStmt(Selector.ScoreOrCondition));
+ }
+ }
+ return true;
+}
+
template <typename Derived>
bool RecursiveASTVisitor<Derived>::VisitOMPDefaultClause(OMPDefaultClause *) {
return true;
Index: clang/include/clang/AST/OpenMPClause.h
===================================================================
--- clang/include/clang/AST/OpenMPClause.h
+++ clang/include/clang/AST/OpenMPClause.h
@@ -8611,25 +8611,7 @@
template<class ImplClass, typename RetTy = void>
class ConstOMPClauseVisitor :
public OMPClauseVisitorBase <ImplClass, const_ptr, RetTy> {};
-
-class OMPClausePrinter final : public OMPClauseVisitor<OMPClausePrinter> {
- raw_ostream &OS;
- const PrintingPolicy &Policy;
-
- /// Process clauses with list of variables.
- template <typename T> void VisitOMPClauseList(T *Node, char StartSym);
- /// Process motion clauses.
- template <typename T> void VisitOMPMotionClause(T *Node);
-
-public:
- OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy)
- : OS(OS), Policy(Policy) {}
-
-#define GEN_CLANG_CLAUSE_CLASS
-#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S);
-#include "llvm/Frontend/OpenMP/OMP.inc"
-};
-
+
struct OMPTraitProperty {
llvm::omp::TraitProperty Kind = llvm::omp::TraitProperty::invalid;
@@ -8872,6 +8854,98 @@
}
};
+/// This captures 'when' clause in the '#pragma omp metadirective'
+/// \code
+/// #pragma omp metadirective when(user={condition(N<100)}:parallel for)
+/// \endcode
+/// In the above example, the metadirective clause has a condition which when
+/// satisfied will use the parallel for directive with the code enclosed by the
+/// directive.
+class OMPWhenClause final : public OMPClause {
+ friend class OMPClauseReader;
+
+ OMPTraitInfo *TI;
+ OpenMPDirectiveKind DKind;
+ Stmt *Directive;
+
+ /// Location of '('.
+ SourceLocation LParenLoc;
+
+ /// Sets the location of '('.
+ ///
+ /// \param Loc Location of '('.
+ void setLParenLoc(SourceLocation Loc) { LParenLoc = Loc; }
+
+public:
+ /// Build 'when' clause with argument \a A ('none' or 'shared').
+ ///
+ /// \param T TraitInfor containing information about the context selector
+ /// \param DKind The directive associated with the when clause
+ /// \param D The statement associated with the when clause
+ /// \param StartLoc Starting location of the clause.
+ /// \param LParenLoc Location of '('.
+ /// \param EndLoc Ending location of the clause.
+ OMPWhenClause(OMPTraitInfo &T, OpenMPDirectiveKind dKind, Stmt *D,
+ SourceLocation StartLoc, SourceLocation LParenLoc,
+ SourceLocation EndLoc)
+ : OMPClause(llvm::omp::OMPC_when, StartLoc, EndLoc), TI(&T), DKind(dKind),
+ Directive(D), LParenLoc(LParenLoc) {}
+
+ /// Build an empty clause.
+ OMPWhenClause()
+ : OMPClause(llvm::omp::OMPC_when, SourceLocation(), SourceLocation()) {}
+
+
+
+ /// Returns the location of '('.
+ SourceLocation getLParenLoc() const { return LParenLoc; }
+
+ /// Returns the directive variant kind
+ OpenMPDirectiveKind getDKind() const { return DKind; }
+
+ Stmt *getDirective() const { return Directive; }
+
+ /// Returns the OMPTraitInfo
+ OMPTraitInfo &getTI() const { return *TI; }
+
+ child_range children() {
+ return child_range(child_iterator(), child_iterator());
+ }
+
+ const_child_range children() const {
+ return const_child_range(const_child_iterator(), const_child_iterator());
+ }
+ child_range used_children() {
+ return child_range(child_iterator(), child_iterator());
+ }
+ const_child_range used_children() const {
+ return const_child_range(const_child_iterator(), const_child_iterator());
+ }
+
+ static bool classof(const OMPClause *T) {
+ return T->getClauseKind() == llvm::omp::OMPC_when;
+ }
+};
+
+class OMPClausePrinter final : public OMPClauseVisitor<OMPClausePrinter> {
+ raw_ostream &OS;
+ const PrintingPolicy &Policy;
+
+ /// Process clauses with list of variables.
+ template <typename T> void VisitOMPClauseList(T *Node, char StartSym);
+ /// Process motion clauses.
+ template <typename T> void VisitOMPMotionClause(T *Node);
+
+public:
+ OMPClausePrinter(raw_ostream &OS, const PrintingPolicy &Policy)
+ : OS(OS), Policy(Policy) {}
+
+ void VisitOMPWhenClause(OMPWhenClause *Node);
+
+#define GEN_CLANG_CLAUSE_CLASS
+#define CLAUSE_CLASS(Enum, Str, Class) void Visit##Class(Class *S);
+#include "llvm/Frontend/OpenMP/OMP.inc"
+};
} // namespace clang
#endif // LLVM_CLANG_AST_OPENMPCLAUSE_H
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits