llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Erich Keane (erichkeane) <details> <summary>Changes</summary> Like the last few patches, branching in/out of a compute construct is not valid. This patch implements checking to ensure that a 'case' or 'default' statement cannot jump into a Compute Construct (in the style of a duff's device!). --- Full diff: https://github.com/llvm/llvm-project/pull/83460.diff 4 Files Affected: - (modified) clang/include/clang/Sema/Scope.h (+19) - (modified) clang/lib/Sema/SemaStmt.cpp (+14) - (modified) clang/test/SemaOpenACC/no-branch-in-out.c (+27) - (modified) clang/test/SemaOpenACC/no-branch-in-out.cpp (+29-1) ``````````diff diff --git a/clang/include/clang/Sema/Scope.h b/clang/include/clang/Sema/Scope.h index b6b5a1f3479a25..1cb2fa83e0bb33 100644 --- a/clang/include/clang/Sema/Scope.h +++ b/clang/include/clang/Sema/Scope.h @@ -534,6 +534,25 @@ class Scope { return false; } + /// Determine if this scope (or its parents) are a compute construct inside of + /// the nearest 'switch' scope. This is needed to check whether we are inside + /// of a 'duffs' device, which is an illegal branch into a compute construct. + bool isInOpenACCComputeConstructBeforeSwitch() const { + for (const Scope *S = this; S; S = S->getParent()) { + if (S->getFlags() & Scope::OpenACCComputeConstructScope) + return true; + if (S->getFlags() & Scope::SwitchScope) + return false; + + if (S->getFlags() & + (Scope::FnScope | Scope::ClassScope | Scope::BlockScope | + Scope::TemplateParamScope | Scope::FunctionPrototypeScope | + Scope::AtCatchScope | Scope::ObjCMethodScope)) + return false; + } + return false; + } + /// Determine whether this scope is a while/do/for statement, which can have /// continue statements embedded into it. bool isContinueScope() const { diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp index ca2d206752744c..4a15a8f6effd31 100644 --- a/clang/lib/Sema/SemaStmt.cpp +++ b/clang/lib/Sema/SemaStmt.cpp @@ -527,6 +527,13 @@ Sema::ActOnCaseStmt(SourceLocation CaseLoc, ExprResult LHSVal, return StmtError(); } + if (LangOpts.OpenACC && + getCurScope()->isInOpenACCComputeConstructBeforeSwitch()) { + Diag(CaseLoc, diag::err_acc_branch_in_out_compute_construct) + << /*branch*/ 0 << /*into*/ 1; + return StmtError(); + } + auto *CS = CaseStmt::Create(Context, LHSVal.get(), RHSVal.get(), CaseLoc, DotDotDotLoc, ColonLoc); getCurFunction()->SwitchStack.back().getPointer()->addSwitchCase(CS); @@ -546,6 +553,13 @@ Sema::ActOnDefaultStmt(SourceLocation DefaultLoc, SourceLocation ColonLoc, return SubStmt; } + if (LangOpts.OpenACC && + getCurScope()->isInOpenACCComputeConstructBeforeSwitch()) { + Diag(DefaultLoc, diag::err_acc_branch_in_out_compute_construct) + << /*branch*/ 0 << /*into*/ 1; + return StmtError(); + } + DefaultStmt *DS = new (Context) DefaultStmt(DefaultLoc, ColonLoc, SubStmt); getCurFunction()->SwitchStack.back().getPointer()->addSwitchCase(DS); return DS; diff --git a/clang/test/SemaOpenACC/no-branch-in-out.c b/clang/test/SemaOpenACC/no-branch-in-out.c index d070247fa65b86..eccc6432450045 100644 --- a/clang/test/SemaOpenACC/no-branch-in-out.c +++ b/clang/test/SemaOpenACC/no-branch-in-out.c @@ -310,3 +310,30 @@ LABEL4:{} ptr=&&LABEL5; } + +void DuffsDevice() { + int j; + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 0: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + default: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 'a' ... 'z': // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } +} diff --git a/clang/test/SemaOpenACC/no-branch-in-out.cpp b/clang/test/SemaOpenACC/no-branch-in-out.cpp index 9affdf733ace8d..e7d5683f9bc78b 100644 --- a/clang/test/SemaOpenACC/no-branch-in-out.cpp +++ b/clang/test/SemaOpenACC/no-branch-in-out.cpp @@ -18,7 +18,6 @@ void ReturnTest() { template<typename T> void BreakContinue() { - #pragma acc parallel for(int i =0; i < 5; ++i) { switch(i) { @@ -109,6 +108,35 @@ void BreakContinue() { } while (j ); } +template<typename T> +void DuffsDevice() { + int j; + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 0: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + default: // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } + + switch (j) { +#pragma acc parallel + for(int i =0; i < 5; ++i) { + case 'a' ... 'z': // expected-error{{invalid branch into OpenACC Compute Construct}} + {} + } + } +} + void Instantiate() { BreakContinue<int>(); + DuffsDevice<int>(); } `````````` </details> https://github.com/llvm/llvm-project/pull/83460 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits