llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang-modules @llvm/pr-subscribers-clang Author: Erich Keane (erichkeane) <details> <summary>Changes</summary> Like with the 'default' clause, this is being applied to only Compute Constructs for now. The 'if' clause takes a condition expression which is used as a runtime value. This is not a particularly complex semantic implementation, as there isn't much to this clause, other than its interactions with 'self', which will be managed in the patch to implement that. --- Patch is 35.28 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88411.diff 17 Files Affected: - (modified) clang/include/clang/AST/ASTNodeTraverser.h (+2-1) - (modified) clang/include/clang/AST/OpenACCClause.h (+79-4) - (added) clang/include/clang/Basic/OpenACCClauses.def (+21) - (modified) clang/include/clang/Parse/Parser.h (+5-1) - (modified) clang/include/clang/Sema/SemaOpenACC.h (+27-1) - (modified) clang/lib/AST/OpenACCClause.cpp (+41) - (modified) clang/lib/AST/StmtProfile.cpp (+16-3) - (modified) clang/lib/AST/TextNodeDumper.cpp (+5) - (modified) clang/lib/Parse/ParseOpenACC.cpp (+22-10) - (modified) clang/lib/Sema/SemaOpenACC.cpp (+63-10) - (modified) clang/lib/Sema/TreeTransform.h (+13) - (modified) clang/lib/Serialization/ASTReader.cpp (+6-1) - (modified) clang/lib/Serialization/ASTWriter.cpp (+6-1) - (modified) clang/test/ParserOpenACC/parse-clauses.c (-2) - (modified) clang/test/SemaOpenACC/compute-construct-clause-ast.cpp (+116-4) - (added) clang/test/SemaOpenACC/compute-construct-if-clause.c (+62) - (added) clang/test/SemaOpenACC/compute-construct-if-clause.cpp (+33) ``````````diff diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h index 94e7dd817809dd..37fe030fb8e5a3 100644 --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -243,7 +243,8 @@ class ASTNodeTraverser void Visit(const OpenACCClause *C) { getNodeDelegate().AddChild([=] { getNodeDelegate().Visit(C); - // TODO OpenACC: Switch on clauses that have children, and add them. + for (const auto *S : C->children()) + Visit(S); }); } diff --git a/clang/include/clang/AST/OpenACCClause.h b/clang/include/clang/AST/OpenACCClause.h index 27e4e1a12c9837..6e3c00614168e7 100644 --- a/clang/include/clang/AST/OpenACCClause.h +++ b/clang/include/clang/AST/OpenACCClause.h @@ -15,6 +15,7 @@ #define LLVM_CLANG_AST_OPENACCCLAUSE_H #include "clang/AST/ASTContext.h" #include "clang/Basic/OpenACCKinds.h" +#include "clang/AST/StmtIterator.h" namespace clang { /// This is the base type for all OpenACC Clauses. @@ -34,6 +35,17 @@ class OpenACCClause { static bool classof(const OpenACCClause *) { return true; } + using child_iterator = StmtIterator; + using const_child_iterator = ConstStmtIterator; + using child_range = llvm::iterator_range<child_iterator>; + using const_child_range = llvm::iterator_range<const_child_iterator>; + + child_range children(); + const_child_range children() const { + auto Children = const_cast<OpenACCClause *>(this)->children(); + return const_child_range(Children.begin(), Children.end()); + } + virtual ~OpenACCClause() = default; }; @@ -49,6 +61,14 @@ class OpenACCClauseWithParams : public OpenACCClause { public: SourceLocation getLParenLoc() const { return LParenLoc; } + + child_range children() { + return child_range(child_iterator(), child_iterator()); + } + const_child_range children() const { + return const_child_range(const_child_iterator(), const_child_iterator()); + } + }; /// A 'default' clause, has the optional 'none' or 'present' argument. @@ -81,6 +101,52 @@ class OpenACCDefaultClause : public OpenACCClauseWithParams { SourceLocation EndLoc); }; +/// Represents one of the handful of classes that has an optional/required +/// 'condition' expression as an argument. +class OpenACCClauseWithCondition : public OpenACCClauseWithParams { + Expr *ConditionExpr; + + protected: + OpenACCClauseWithCondition(OpenACCClauseKind K, SourceLocation BeginLoc, + SourceLocation LParenLoc, + Expr *ConditionExpr, SourceLocation EndLoc) + : OpenACCClauseWithParams(K, BeginLoc, LParenLoc, EndLoc), + ConditionExpr(ConditionExpr) {} + + public: + bool hasConditionExpr() const { return ConditionExpr; } + const Expr *getConditionExpr() const { return ConditionExpr; } + Expr *getConditionExpr() { return ConditionExpr; } + + child_range children() { + if (ConditionExpr) + return child_range(reinterpret_cast<Stmt **>(&ConditionExpr), + reinterpret_cast<Stmt **>(&ConditionExpr + 1)); + return child_range(child_iterator(), child_iterator()); + } + + const_child_range children() const { + if (ConditionExpr) + return const_child_range( + reinterpret_cast<Stmt *const *>(&ConditionExpr), + reinterpret_cast<Stmt *const *>(&ConditionExpr + 1)); + return const_child_range(const_child_iterator(), const_child_iterator()); + } +}; + +/// An 'if' clause, which has a required condition expression. +class OpenACCIfClause : public OpenACCClauseWithCondition { +protected: + OpenACCIfClause(SourceLocation BeginLoc, SourceLocation LParenLoc, + Expr *ConditionExpr, SourceLocation EndLoc); + +public: + static OpenACCIfClause *Create(const ASTContext &C, SourceLocation BeginLoc, + SourceLocation LParenLoc, + Expr *ConditionExpr, + SourceLocation EndLoc); +}; + template <class Impl> class OpenACCClauseVisitor { Impl &getDerived() { return static_cast<Impl &>(*this); } @@ -98,6 +164,9 @@ template <class Impl> class OpenACCClauseVisitor { case OpenACCClauseKind::Default: VisitOpenACCDefaultClause(*cast<OpenACCDefaultClause>(C)); return; + case OpenACCClauseKind::If: + VisitOpenACCIfClause(*cast<OpenACCIfClause>(C)); + return; case OpenACCClauseKind::Finalize: case OpenACCClauseKind::IfPresent: case OpenACCClauseKind::Seq: @@ -106,7 +175,6 @@ template <class Impl> class OpenACCClauseVisitor { case OpenACCClauseKind::Worker: case OpenACCClauseKind::Vector: case OpenACCClauseKind::NoHost: - case OpenACCClauseKind::If: case OpenACCClauseKind::Self: case OpenACCClauseKind::Copy: case OpenACCClauseKind::UseDevice: @@ -145,9 +213,13 @@ template <class Impl> class OpenACCClauseVisitor { llvm_unreachable("Invalid Clause kind"); } - void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause) { - return getDerived().VisitOpenACCDefaultClause(Clause); +#define VISIT_CLAUSE(CLAUSE_NAME) \ + void VisitOpenACC##CLAUSE_NAME##Clause( \ + const OpenACC##CLAUSE_NAME##Clause &Clause) {\ + return getDerived().VisitOpenACC##CLAUSE_NAME##Clause(Clause); \ } + +#include "clang/Basic/OpenACCClauses.def" }; class OpenACCClausePrinter final @@ -165,7 +237,10 @@ class OpenACCClausePrinter final } OpenACCClausePrinter(raw_ostream &OS) : OS(OS) {} - void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause); +#define VISIT_CLAUSE(CLAUSE_NAME) \ + void VisitOpenACC##CLAUSE_NAME##Clause( \ + const OpenACC##CLAUSE_NAME##Clause &Clause); +#include "clang/Basic/OpenACCClauses.def" }; } // namespace clang diff --git a/clang/include/clang/Basic/OpenACCClauses.def b/clang/include/clang/Basic/OpenACCClauses.def new file mode 100644 index 00000000000000..7fd2720e02ce22 --- /dev/null +++ b/clang/include/clang/Basic/OpenACCClauses.def @@ -0,0 +1,21 @@ +//===-- OpenACCClauses.def - List of implemented OpenACC Clauses -- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a list of currently implemented OpenACC Clauses (and +// eventually, the entire list) in a way that makes generating 'visitor' and +// other lists easier. +// +// The primary macro is a single-argument version taking the name of the Clause +// as used in Clang source (so `Default` instead of `default`). +// +// VISIT_CLAUSE(CLAUSE_NAME) + +VISIT_CLAUSE(Default) +VISIT_CLAUSE(If) + +#undef VISIT_CLAUSE diff --git a/clang/include/clang/Parse/Parser.h b/clang/include/clang/Parse/Parser.h index 3a055c10ffb387..9d83a52929789e 100644 --- a/clang/include/clang/Parse/Parser.h +++ b/clang/include/clang/Parse/Parser.h @@ -3611,6 +3611,9 @@ class Parser : public CodeCompletionHandler { OpenACCClauseParseResult OpenACCCannotContinue(); OpenACCClauseParseResult OpenACCSuccess(OpenACCClause *Clause); + using OpenACCConditionExprParseResult = + std::pair<ExprResult, OpenACCParseCanContinue>; + /// Parses the OpenACC directive (the entire pragma) including the clause /// list, but does not produce the main AST node. OpenACCDirectiveParseInfo ParseOpenACCDirective(); @@ -3657,7 +3660,8 @@ class Parser : public CodeCompletionHandler { bool ParseOpenACCGangArgList(); /// Parses a 'gang-arg', used for the 'gang' clause. bool ParseOpenACCGangArg(); - + /// Parses a 'condition' expr, ensuring it results in a + ExprResult ParseOpenACCConditionExpr(); private: //===--------------------------------------------------------------------===// // C++ 14: Templates [temp] diff --git a/clang/include/clang/Sema/SemaOpenACC.h b/clang/include/clang/Sema/SemaOpenACC.h index 27aaee164a2880..c1fe0f5b9c0f6b 100644 --- a/clang/include/clang/Sema/SemaOpenACC.h +++ b/clang/include/clang/Sema/SemaOpenACC.h @@ -40,7 +40,11 @@ class SemaOpenACC : public SemaBase { OpenACCDefaultClauseKind DefaultClauseKind; }; - std::variant<DefaultDetails> Details; + struct ConditionDetails { + Expr *ConditionExpr; + }; + + std::variant<DefaultDetails, ConditionDetails> Details; public: OpenACCParsedClause(OpenACCDirectiveKind DirKind, @@ -63,6 +67,16 @@ class SemaOpenACC : public SemaBase { return std::get<DefaultDetails>(Details).DefaultClauseKind; } + const Expr *getConditionExpr() const { + return const_cast<OpenACCParsedClause *>(this)->getConditionExpr(); + } + + Expr *getConditionExpr() { + assert(ClauseKind == OpenACCClauseKind::If && + "Parsed clause kind does not have a condition expr"); + return std::get<ConditionDetails>(Details).ConditionExpr; + } + void setLParenLoc(SourceLocation EndLoc) { LParenLoc = EndLoc; } void setEndLoc(SourceLocation EndLoc) { ClauseRange.setEnd(EndLoc); } @@ -71,6 +85,18 @@ class SemaOpenACC : public SemaBase { "Parsed clause is not a default clause"); Details = DefaultDetails{DefKind}; } + + void setConditionDetails(Expr *ConditionExpr) { + assert(ClauseKind == OpenACCClauseKind::If && + "Parsed clause kind does not have a condition expr"); + // In C++ we can count on this being a 'bool', but in C this gets left as + // some sort of scalar that codegen will have to take care of converting. + assert((!ConditionExpr || ConditionExpr->isInstantiationDependent() || + ConditionExpr->getType()->isScalarType()) && + "Condition expression type not scalar/dependent"); + + Details = ConditionDetails{ConditionExpr}; + } }; SemaOpenACC(Sema &S); diff --git a/clang/lib/AST/OpenACCClause.cpp b/clang/lib/AST/OpenACCClause.cpp index c83128b60e3acc..0a512d48253a8c 100644 --- a/clang/lib/AST/OpenACCClause.cpp +++ b/clang/lib/AST/OpenACCClause.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "clang/AST/OpenACCClause.h" +#include "clang/AST/Expr.h" #include "clang/AST/ASTContext.h" using namespace clang; @@ -27,6 +28,41 @@ OpenACCDefaultClause *OpenACCDefaultClause::Create(const ASTContext &C, return new (Mem) OpenACCDefaultClause(K, BeginLoc, LParenLoc, EndLoc); } +OpenACCIfClause *OpenACCIfClause::Create(const ASTContext &C, + SourceLocation BeginLoc, + SourceLocation LParenLoc, + Expr *ConditionExpr, + SourceLocation EndLoc) { + void *Mem = C.Allocate(sizeof(OpenACCIfClause), alignof(OpenACCIfClause)); + return new (Mem) OpenACCIfClause(BeginLoc, LParenLoc, ConditionExpr, EndLoc); +} + +OpenACCIfClause::OpenACCIfClause(SourceLocation BeginLoc, + SourceLocation LParenLoc, + Expr *ConditionExpr, + SourceLocation EndLoc) + : OpenACCClauseWithCondition(OpenACCClauseKind::If, BeginLoc, LParenLoc, + ConditionExpr, EndLoc) { + assert(ConditionExpr && "if clause requires condition expr"); + assert((ConditionExpr->isInstantiationDependent() || + ConditionExpr->getType()->isScalarType()) && + "Condition expression type not scalar/dependent"); +} + +OpenACCClause::child_range OpenACCClause::children() { + switch (getClauseKind()) { + default: + assert(false && "Clause children function not implemented"); + break; +#define VISIT_CLAUSE(CLAUSE_NAME) \ + case OpenACCClauseKind::CLAUSE_NAME: \ + return cast<OpenACC##CLAUSE_NAME##Clause>(this)->children(); + +#include "clang/Basic/OpenACCClauses.def" + } + return child_range(child_iterator(), child_iterator()); +} + //===----------------------------------------------------------------------===// // OpenACC clauses printing methods //===----------------------------------------------------------------------===// @@ -34,3 +70,8 @@ void OpenACCClausePrinter::VisitOpenACCDefaultClause( const OpenACCDefaultClause &C) { OS << "default(" << C.getDefaultClauseKind() << ")"; } + +void OpenACCClausePrinter::VisitOpenACCIfClause( + const OpenACCIfClause &C) { + OS << "if(" << C.getConditionExpr() << ")"; +} diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp index 01e1d1cc8289bf..24593fd2f4d405 100644 --- a/clang/lib/AST/StmtProfile.cpp +++ b/clang/lib/AST/StmtProfile.cpp @@ -2445,9 +2445,10 @@ void StmtProfiler::VisitTemplateArgument(const TemplateArgument &Arg) { namespace { class OpenACCClauseProfiler : public OpenACCClauseVisitor<OpenACCClauseProfiler> { + StmtProfiler &Profiler; public: - OpenACCClauseProfiler() = default; + OpenACCClauseProfiler(StmtProfiler &P) :Profiler(P) {} void VisitOpenACCClauseList(ArrayRef<const OpenACCClause *> Clauses) { for (const OpenACCClause *Clause : Clauses) { @@ -2456,12 +2457,24 @@ class OpenACCClauseProfiler Visit(Clause); } } - void VisitOpenACCDefaultClause(const OpenACCDefaultClause &Clause); + +#define VISIT_CLAUSE(CLAUSE_NAME) \ + void VisitOpenACC##CLAUSE_NAME##Clause( \ + const OpenACC##CLAUSE_NAME##Clause &Clause); + +#include "clang/Basic/OpenACCClauses.def" }; /// Nothing to do here, there are no sub-statements. void OpenACCClauseProfiler::VisitOpenACCDefaultClause( const OpenACCDefaultClause &Clause) {} + +void OpenACCClauseProfiler::VisitOpenACCIfClause( + const OpenACCIfClause &Clause) { + assert(Clause.hasConditionExpr() && + "if clause requires a valid condition expr"); + Profiler.VisitStmt(Clause.getConditionExpr()); + } } // namespace void StmtProfiler::VisitOpenACCComputeConstruct( @@ -2469,7 +2482,7 @@ void StmtProfiler::VisitOpenACCComputeConstruct( // VisitStmt handles children, so the AssociatedStmt is handled. VisitStmt(S); - OpenACCClauseProfiler P; + OpenACCClauseProfiler P{*this}; P.VisitOpenACCClauseList(S->clauses()); } diff --git a/clang/lib/AST/TextNodeDumper.cpp b/clang/lib/AST/TextNodeDumper.cpp index 085a7f51ce99ad..56650f99134d45 100644 --- a/clang/lib/AST/TextNodeDumper.cpp +++ b/clang/lib/AST/TextNodeDumper.cpp @@ -397,6 +397,11 @@ void TextNodeDumper::Visit(const OpenACCClause *C) { case OpenACCClauseKind::Default: OS << '(' << cast<OpenACCDefaultClause>(C)->getDefaultClauseKind() << ')'; break; + case OpenACCClauseKind::If: + // The condition expression will be printed as a part of the 'children', + // but print 'clause' here so it is clear what is happening from the dump. + OS << " clause"; + break; default: // Nothing to do here. break; diff --git a/clang/lib/Parse/ParseOpenACC.cpp b/clang/lib/Parse/ParseOpenACC.cpp index b487a1968d1ec8..6192afa8541cad 100644 --- a/clang/lib/Parse/ParseOpenACC.cpp +++ b/clang/lib/Parse/ParseOpenACC.cpp @@ -535,14 +535,6 @@ bool ClauseHasRequiredParens(OpenACCDirectiveKind DirKind, return getClauseParensKind(DirKind, Kind) == ClauseParensKind::Required; } -ExprResult ParseOpenACCConditionalExpr(Parser &P) { - // FIXME: It isn't clear if the spec saying 'condition' means the same as - // it does in an if/while/etc (See ParseCXXCondition), however as it was - // written with Fortran/C in mind, we're going to assume it just means an - // 'expression evaluating to boolean'. - return P.getActions().CorrectDelayedTyposInExpr(P.ParseExpression()); -} - // Skip until we see the end of pragma token, but don't consume it. This is us // just giving up on the rest of the pragma so we can continue executing. We // have to do this because 'SkipUntil' considers paren balancing, which isn't @@ -595,6 +587,23 @@ Parser::OpenACCClauseParseResult Parser::OpenACCSuccess(OpenACCClause *Clause) { return {Clause, OpenACCParseCanContinue::Can}; } +ExprResult Parser::ParseOpenACCConditionExpr() { + // FIXME: It isn't clear if the spec saying 'condition' means the same as + // it does in an if/while/etc (See ParseCXXCondition), however as it was + // written with Fortran/C in mind, we're going to assume it just means an + // 'expression evaluating to boolean'. + ExprResult ER = getActions().CorrectDelayedTyposInExpr(ParseExpression()); + + if (!ER.isUsable()) + return ER; + + Sema::ConditionResult R = + getActions().ActOnCondition(getCurScope(), ER.get()->getExprLoc(), + ER.get(), Sema::ConditionKind::Boolean); + + return R.isInvalid() ? ExprError () : R.get().second; +} + // OpenACC 3.3, section 1.7: // To simplify the specification and convey appropriate constraint information, // a pqr-list is a comma-separated list of pdr items. The one exception is a @@ -842,12 +851,15 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams( break; } case OpenACCClauseKind::If: { - ExprResult CondExpr = ParseOpenACCConditionalExpr(*this); + ExprResult CondExpr = ParseOpenACCConditionExpr(); + ParsedClause.setConditionDetails( + CondExpr.isUsable() ? CondExpr.get() : nullptr); if (CondExpr.isInvalid()) { Parens.skipToEnd(); return OpenACCCanContinue(); } + break; } case OpenACCClauseKind::CopyIn: @@ -964,7 +976,7 @@ Parser::OpenACCClauseParseResult Parser::ParseOpenACCClauseParams( switch (ClauseKind) { case OpenACCClauseKind::Self: { assert(DirKind != OpenACCDirectiveKind::Update); - ExprResult CondExpr = ParseOpenACCConditionalExpr(*this); + ExprResult CondExpr = ParseOpenACCConditionExpr(); if (CondExpr.isInvalid()) { Parens.skipToEnd(); diff --git a/clang/lib/Sema/SemaOpenACC.cpp b/clang/lib/Sema/SemaOpenACC.cpp index a6f4453e525d01..8e98f3ae913325 100644 --- a/clang/lib/Sema/SemaOpenACC.cpp +++ b/clang/lib/Sema/SemaOpenACC.cpp @@ -55,12 +55,49 @@ bool doesClauseApplyToDirective(OpenACCDirectiveKind DirectiveKind, default: return false; } + case OpenACCClauseKind::If: + switch (DirectiveKind) { + case OpenACCDirectiveKind::Parallel: + case OpenACCDirectiveKind::Serial: + case OpenACCDirectiveKind::Kernels: + case OpenACCDirectiveKind::Data: + case OpenACCDirectiveKind::EnterData: + case OpenACCDirectiveKind::ExitData: + case OpenACCDirectiveKind::HostData: + case OpenACCDirectiveKind::Init: + case OpenACCDirectiveKind::Shutdown: + case OpenACCDirectiveKind::Set: + case OpenACCDirectiveKind::Update: + case OpenACCDirectiveKind::Wait: + case OpenACCDirectiveKind::ParallelLoop: + case OpenACCDirectiveKind::SerialLoop: + case OpenACCDirectiveKind::KernelsLoop: + return true; + default: + return false; + } default: // Do nothing so we can go to the 'unimplemented' diagnostic instead. return true; } llvm_unreachable("Invalid clause kind"); } + +bool checkAlreadyHasClauseOfKind( + SemaOpenACC &S, ArrayRef<const OpenACCClause *> ExistingClauses, + SemaOpenACC::OpenACCParsedClause &Clause) { + auto Itr = llvm::find_if(ExistingClauses, [&](const OpenACCClause *C) { + return C->getClauseKind() == Clause.getClauseKind(); + }); + if (Itr != ExistingClauses.end()) { + S.Diag(Clause.getBeginLoc(), diag::err_acc_duplicate_clause_disallowed) + << Clause.getDirectiveKind() << Clause.getClauseKind(); + S.Diag((*Itr)->getBeginLoc(), diag::note_acc_previous_clause_here); + return true; + } + return false; +} + } // namespace SemaOpenACC::SemaOpenACC(Sema &S) : SemaBase(S) {} @@ -97,22 +134,38 @@ SemaOpenACC::ActOnClause(ArrayRef<const OpenACCClause *> ExistingClause... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/88411 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits