yaxunl created this revision. yaxunl added reviewers: tra, rjmccall. If a kernel template has a function as its template parameter, a device function should be allowed as template argument since a kernel can call a device function. However, currently if the kernel template is instantiated in a host function, clang will emit an error message saying the device function is an invalid candidate for the template parameter.
This happens because clang checks the reference to the device function during parsing the template arguments. At this point, the template is not instantiated yet. Clang incorrectly assumes the device function is called by the host function and emits the error message. This patch fixes the issue by disabling checking of device function during parsing template arguments and deferring the check to the instantion of the template. At that point, the template decl is already available, therefore the check can be done against the instantiated function template decl. https://reviews.llvm.org/D56411 Files: include/clang/Sema/Sema.h lib/Parse/ParseTemplate.cpp lib/Sema/SemaCUDA.cpp lib/Sema/SemaExpr.cpp lib/Sema/SemaTemplate.cpp test/SemaCUDA/kernel-template-with-device-func-arg.cu
Index: test/SemaCUDA/kernel-template-with-device-func-arg.cu =================================================================== --- /dev/null +++ test/SemaCUDA/kernel-template-with-device-func-arg.cu @@ -0,0 +1,49 @@ +// RUN: %clang_cc1 -fsyntax-only -verify %s + +#include "Inputs/cuda.h" + +struct C { + __device__ void devfun() {} + void hostfun() {} + template<class T> __device__ void devtempfun() {} +}; + +__device__ void devfun() {} +__host__ void hostfun() {} +template<class T> __device__ void devtempfun() {} + +template <void (*devF)()> __global__ void kernel() { devF();} +template <typename T, void(T::*devF)()> __global__ void kernel2(T *p) { (p->*devF)(); } + +template<> __global__ void kernel<devfun>(); +template<> __global__ void kernel<hostfun>(); // expected-error {{no function template matches function template specialization 'kernel'}} + // expected-note@-5 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} +template<> __global__ void kernel<devtempfun<int> >(); + +template<> __global__ void kernel<&devfun>(); +template<> __global__ void kernel<&hostfun>(); // expected-error {{no function template matches function template specialization 'kernel'}} + // expected-note@-10 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} +template<> __global__ void kernel<&devtempfun<int> >(); + +template<> __global__ void kernel2<C, &C::devfun>(C *p); +template<> __global__ void kernel2<C, &C::hostfun>(C *p); // expected-error {{no function template matches function template specialization 'kernel2'}} + // expected-note@-14 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} +template<> __global__ void kernel2<C, &C::devtempfun<int> >(C *p); + +void fun() { + kernel<&devfun><<<1,1>>>(); + kernel<&hostfun><<<1,1>>>(); // expected-error {{no matching function for call to 'kernel'}} + // expected-note@-21 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} + kernel<&devtempfun<int> ><<<1,1>>>(); + + kernel<devfun><<<1,1>>>(); + kernel<hostfun><<<1,1>>>(); // expected-error {{no matching function for call to 'kernel'}} + // expected-note@-26 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} + kernel<devtempfun<int> ><<<1,1>>>(); + + C a; + kernel2<C, &C::devfun><<<1,1>>>(&a); + kernel2<C, &C::hostfun><<<1,1>>>(&a); // expected-error {{no matching function for call to 'kernel2'}} + // expected-note@-31 {{candidate template ignored: invalid explicitly-specified argument for template parameter 'devF'}} + kernel2<C, &C::devtempfun<int> ><<<1,1>>>(&a); +} Index: lib/Sema/SemaTemplate.cpp =================================================================== --- lib/Sema/SemaTemplate.cpp +++ lib/Sema/SemaTemplate.cpp @@ -4753,8 +4753,8 @@ TemplateArgument Result; unsigned CurSFINAEErrors = NumSFINAEErrors; ExprResult Res = - CheckTemplateArgument(NTTP, NTTPType, Arg.getArgument().getAsExpr(), - Result, CTAK); + CheckTemplateArgument(NTTP, NTTPType, Arg.getArgument().getAsExpr(), + Result, CTAK, dyn_cast<TemplateDecl>(Template)); if (Res.isInvalid()) return true; // If the current template argument causes an error, give up now. @@ -6123,6 +6123,27 @@ return true; } +namespace { +bool CheckCUDATemplateArgument(Sema &S, Expr *Arg, TemplateDecl *Template) { + if (Template) { + Expr *E = Arg; + if (UnaryOperator *UO = dyn_cast<UnaryOperator>(E)) { + E = UO ? UO->getSubExpr() : nullptr; + } + if (DeclRefExpr *DRE = dyn_cast_or_null<DeclRefExpr>(E)) { + ValueDecl *Entity = DRE ? DRE->getDecl() : nullptr; + if (Entity) { + if (auto Callee = dyn_cast<FunctionDecl>(Entity)) + if (auto Caller = + dyn_cast<FunctionDecl>(Template->getTemplatedDecl())) + if (!S.CheckCUDACall(Arg->getBeginLoc(), Callee, Caller)) + return false; + } + } + } + return true; +} +} // namespace /// Check a template argument against its corresponding /// non-type template parameter. /// @@ -6133,7 +6154,8 @@ ExprResult Sema::CheckTemplateArgument(NonTypeTemplateParmDecl *Param, QualType ParamType, Expr *Arg, TemplateArgument &Converted, - CheckTemplateArgumentKind CTAK) { + CheckTemplateArgumentKind CTAK, + TemplateDecl *Template) { SourceLocation StartLoc = Arg->getBeginLoc(); // If the parameter type somehow involves auto, deduce the type now. @@ -6530,7 +6552,11 @@ if (FunctionDecl *Fn = ResolveAddressOfOverloadedFunction(Arg, ParamType, true, FoundResult)) { - if (DiagnoseUseOfDecl(Fn, Arg->getBeginLoc())) + if (DiagnoseUseOfDecl(Fn, Arg->getBeginLoc(), + /*UnknownObjCClass=*/nullptr, + /*ObjCPropertyAccess=*/false, + /*AvoidPartialAvailabilityChecks=*/false, + /*ClassReciever=*/nullptr, Template)) return ExprError(); Arg = FixOverloadedFunctionReference(Arg, FoundResult, Fn); @@ -6539,6 +6565,9 @@ return ExprError(); } + if (!CheckCUDATemplateArgument(*this, Arg, Template)) + return ExprError(); + if (!ParamType->isMemberPointerType()) { if (CheckTemplateArgumentAddressOfObjectOrFunction(*this, Param, ParamType, Index: lib/Sema/SemaExpr.cpp =================================================================== --- lib/Sema/SemaExpr.cpp +++ lib/Sema/SemaExpr.cpp @@ -207,7 +207,8 @@ const ObjCInterfaceDecl *UnknownObjCClass, bool ObjCPropertyAccess, bool AvoidPartialAvailabilityChecks, - ObjCInterfaceDecl *ClassReceiver) { + ObjCInterfaceDecl *ClassReceiver, + TemplateDecl *Template) { SourceLocation Loc = Locs.front(); if (getLangOpts().CPlusPlus && isa<FunctionDecl>(D)) { // If there were any diagnostics suppressed by template argument deduction, @@ -262,7 +263,11 @@ DeduceReturnType(FD, Loc)) return true; - if (getLangOpts().CUDA && !CheckCUDACall(Loc, FD)) + if (getLangOpts().CUDA && + !CheckCUDACall( + Loc, FD, + Template ? dyn_cast<FunctionDecl>(Template->getTemplatedDecl()) + : nullptr)) return true; } Index: lib/Sema/SemaCUDA.cpp =================================================================== --- lib/Sema/SemaCUDA.cpp +++ lib/Sema/SemaCUDA.cpp @@ -833,13 +833,15 @@ } } -bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee) { +bool Sema::CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee, + FunctionDecl *Caller) { assert(getLangOpts().CUDA && "Should only be called during CUDA compilation"); assert(Callee && "Callee may not be null."); // FIXME: Is bailing out early correct here? Should we instead assume that // the caller is a global initializer? - FunctionDecl *Caller = dyn_cast<FunctionDecl>(CurContext); if (!Caller) + Caller = dyn_cast<FunctionDecl>(CurContext); + if (!Caller || IsParsingTemplateArgumentList) return true; // If the caller is known-emitted, mark the callee as known-emitted. Index: lib/Parse/ParseTemplate.cpp =================================================================== --- lib/Parse/ParseTemplate.cpp +++ lib/Parse/ParseTemplate.cpp @@ -1296,6 +1296,20 @@ return Tok.isOneOf(tok::greater, tok::comma); } +namespace { +class ParseTemplateArgumentListRAII { + Sema &S; + +public: + ParseTemplateArgumentListRAII(Sema &_S) : S(_S) { + S.setIsParsingTemplateArgumentList(); + } + ~ParseTemplateArgumentListRAII() { + S.setIsParsingTemplateArgumentList(false); + } +}; + +} // namespace /// ParseTemplateArgumentList - Parse a C++ template-argument-list /// (C++ [temp.names]). Returns true if there was an error. /// @@ -1306,6 +1320,7 @@ Parser::ParseTemplateArgumentList(TemplateArgList &TemplateArgs) { ColonProtectionRAIIObject ColonProtection(*this, false); + ParseTemplateArgumentListRAII PTAL(Actions); do { ParsedTemplateArgument Arg = ParseTemplateArgument(); Index: include/clang/Sema/Sema.h =================================================================== --- include/clang/Sema/Sema.h +++ include/clang/Sema/Sema.h @@ -3993,11 +3993,15 @@ // Expression Parsing Callbacks: SemaExpr.cpp. bool CanUseDecl(NamedDecl *D, bool TreatUnavailableAsInvalid); + + // \param TemplateReceiver is the receiving instantiated template declaration + // if it is not a null pointer. bool DiagnoseUseOfDecl(NamedDecl *D, ArrayRef<SourceLocation> Locs, const ObjCInterfaceDecl *UnknownObjCClass = nullptr, bool ObjCPropertyAccess = false, bool AvoidPartialAvailabilityChecks = false, - ObjCInterfaceDecl *ClassReciever = nullptr); + ObjCInterfaceDecl *ClassReciever = nullptr, + TemplateDecl *TemplateReceiver = nullptr); void NoteDeletedFunction(FunctionDecl *FD); void NoteDeletedInheritingConstructor(CXXConstructorDecl *CD); std::string getDeletedOrUnavailableSuffix(const FunctionDecl *FD); @@ -6429,10 +6433,12 @@ bool CheckTemplateArgument(TemplateTypeParmDecl *Param, TypeSourceInfo *Arg); - ExprResult CheckTemplateArgument(NonTypeTemplateParmDecl *Param, - QualType InstantiatedParamType, Expr *Arg, - TemplateArgument &Converted, - CheckTemplateArgumentKind CTAK = CTAK_Specified); + ExprResult + CheckTemplateArgument(NonTypeTemplateParmDecl *Param, + QualType InstantiatedParamType, Expr *Arg, + TemplateArgument &Converted, + CheckTemplateArgumentKind CTAK = CTAK_Specified, + TemplateDecl *Template = nullptr); bool CheckTemplateTemplateArgument(TemplateParameterList *Params, TemplateArgumentLoc &Arg); @@ -9975,8 +9981,15 @@ private: unsigned ForceCUDAHostDeviceDepth = 0; + unsigned IsParsingTemplateArgumentList = 0; public: + void setIsParsingTemplateArgumentList(bool Enable = true) { + if (Enable) + ++IsParsingTemplateArgumentList; + else + --IsParsingTemplateArgumentList; + } /// Increments our count of the number of times we've seen a pragma forcing /// functions to be __host__ __device__. So long as this count is greater /// than zero, all functions encountered will be __host__ __device__. @@ -10196,7 +10209,8 @@ /// deferred errors. /// /// - Otherwise, returns true without emitting any diagnostics. - bool CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee); + bool CheckCUDACall(SourceLocation Loc, FunctionDecl *Callee, + FunctionDecl *Caller = nullptr); /// Set __device__ or __host__ __device__ attributes on the given lambda /// operator() method.
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits