llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-hlsl @llvm/pr-subscribers-clang Author: Muhammad Bassiouni (bassiounix) <details> <summary>Changes</summary> Move helper functions out of `clang/lib/Sema/SemaHLSL.cpp` into a common location for `clang/lib/Sema/SemaSPIRV.cpp` to use. Moved functions are `CheckArgTypeIsCorrect` and `CheckAllArgTypesAreCorrect`. This is a contribution to the issue #<!-- -->123831. --- Full diff: https://github.com/llvm/llvm-project/pull/125045.diff 5 Files Affected: - (added) clang/include/clang/Sema/Common.h (+22) - (modified) clang/lib/Sema/CMakeLists.txt (+1) - (added) clang/lib/Sema/Common.cpp (+65) - (modified) clang/lib/Sema/SemaHLSL.cpp (+1-27) - (modified) clang/lib/Sema/SemaSPIRV.cpp (+6-46) ``````````diff diff --git a/clang/include/clang/Sema/Common.h b/clang/include/clang/Sema/Common.h new file mode 100644 index 00000000000000..3f775df8bddb64 --- /dev/null +++ b/clang/include/clang/Sema/Common.h @@ -0,0 +1,22 @@ +#ifndef LLVM_CLANG_SEMA_COMMON_H +#define LLVM_CLANG_SEMA_COMMON_H + +#include "clang/Sema/Sema.h" + +namespace clang { + +using LLVMFnRef = llvm::function_ref<bool(clang::QualType PassedType)>; +using PairParam = std::pair<unsigned int, unsigned int>; +using CheckParam = std::variant<PairParam, LLVMFnRef>; + +bool CheckArgTypeIsCorrect( + Sema *S, Expr *Arg, QualType ExpectedType, + llvm::function_ref<bool(clang::QualType PassedType)> Check); + +bool CheckAllArgTypesAreCorrect( + Sema *SemaPtr, CallExpr *TheCall, + std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check); + +} // namespace clang + +#endif diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt index 19cf3a2db00fdc..ddc340a51a3b2d 100644 --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -17,6 +17,7 @@ add_clang_library(clangSema AnalysisBasedWarnings.cpp CheckExprLifetime.cpp CodeCompleteConsumer.cpp + Common.cpp DeclSpec.cpp DelayedDiagnostic.cpp HeuristicResolver.cpp diff --git a/clang/lib/Sema/Common.cpp b/clang/lib/Sema/Common.cpp new file mode 100644 index 00000000000000..72a9e4a2c99ae1 --- /dev/null +++ b/clang/lib/Sema/Common.cpp @@ -0,0 +1,65 @@ +#include "clang/Sema/Common.h" + +namespace clang { + +bool CheckArgTypeIsCorrect( + Sema *S, Expr *Arg, QualType ExpectedType, + llvm::function_ref<bool(clang::QualType PassedType)> Check) { + QualType PassedType = Arg->getType(); + if (Check(PassedType)) { + if (auto *VecTyA = PassedType->getAs<VectorType>()) + ExpectedType = S->Context.getVectorType( + ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind()); + S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) + << PassedType << ExpectedType << 1 << 0 << 0; + return true; + } + return false; +} + +bool CheckAllArgTypesAreCorrect( + Sema *SemaPtr, CallExpr *TheCall, + std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check) { + unsigned int NumElts; + unsigned int expected; + if (auto *n = std::get_if<PairParam>(&Check)) { + if (SemaPtr->checkArgCount(TheCall, n->first)) { + return true; + } + NumElts = n->first; + expected = n->second; + } else { + NumElts = TheCall->getNumArgs(); + } + + for (unsigned i = 0; i < NumElts; i++) { + Expr *localArg = TheCall->getArg(i); + if (auto *val = std::get_if<QualType>(&ExpectedType)) { + if (auto *fn = std::get_if<LLVMFnRef>(&Check)) { + return CheckArgTypeIsCorrect(SemaPtr, localArg, *val, *fn); + } + } + + QualType PassedType = localArg->getType(); + if (PassedType->getAs<VectorType>() == nullptr) { + SemaPtr->Diag(localArg->getBeginLoc(), + diag::err_typecheck_convert_incompatible) + << PassedType + << SemaPtr->Context.getVectorType(PassedType, expected, + VectorKind::Generic) + << 1 << 0 << 0; + return true; + } + } + + if (std::get_if<PairParam>(&Check)) { + if (auto *localArgVecTy = + TheCall->getArg(0)->getType()->getAs<VectorType>()) { + TheCall->setType(localArgVecTy->getElementType()); + } + } + + return false; +} + +} // namespace clang diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index d748c10455289b..0cc71e4122666c 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -27,6 +27,7 @@ #include "clang/Basic/SourceLocation.h" #include "clang/Basic/Specifiers.h" #include "clang/Basic/TargetInfo.h" +#include "clang/Sema/Common.h" #include "clang/Sema/Initialization.h" #include "clang/Sema/ParsedAttr.h" #include "clang/Sema/Sema.h" @@ -1996,33 +1997,6 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) { return false; } -static bool CheckArgTypeIsCorrect( - Sema *S, Expr *Arg, QualType ExpectedType, - llvm::function_ref<bool(clang::QualType PassedType)> Check) { - QualType PassedType = Arg->getType(); - if (Check(PassedType)) { - if (auto *VecTyA = PassedType->getAs<VectorType>()) - ExpectedType = S->Context.getVectorType( - ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind()); - S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) - << PassedType << ExpectedType << 1 << 0 << 0; - return true; - } - return false; -} - -static bool CheckAllArgTypesAreCorrect( - Sema *S, CallExpr *TheCall, QualType ExpectedType, - llvm::function_ref<bool(clang::QualType PassedType)> Check) { - for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { - Expr *Arg = TheCall->getArg(i); - if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) { - return true; - } - } - return false; -} - static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) { auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool { return !PassedType->hasFloatingRepresentation(); diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp index dc49fc79073572..df6a3d61056f5e 100644 --- a/clang/lib/Sema/SemaSPIRV.cpp +++ b/clang/lib/Sema/SemaSPIRV.cpp @@ -10,7 +10,9 @@ #include "clang/Sema/SemaSPIRV.h" #include "clang/Basic/TargetBuiltins.h" +#include "clang/Sema/Common.h" #include "clang/Sema/Sema.h" +#include <utility> namespace clang { @@ -20,54 +22,12 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { switch (BuiltinID) { case SPIRV::BI__builtin_spirv_distance: { - if (SemaRef.checkArgCount(TheCall, 2)) - return true; - - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTyA = ArgTyA->getAs<VectorType>(); - if (VTyA == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - - ExprResult B = TheCall->getArg(1); - QualType ArgTyB = B.get()->getType(); - auto *VTyB = ArgTyB->getAs<VectorType>(); - if (VTyB == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyB - << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - - QualType RetTy = VTyA->getElementType(); - TheCall->setType(RetTy); - break; + return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt, + std::make_pair(2, 2)); } case SPIRV::BI__builtin_spirv_length: { - if (SemaRef.checkArgCount(TheCall, 1)) - return true; - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTy = ArgTyA->getAs<VectorType>(); - if (VTy == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - QualType RetTy = VTy->getElementType(); - TheCall->setType(RetTy); - break; + return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt, + std::make_pair(1, 2)); } } return false; `````````` </details> https://github.com/llvm/llvm-project/pull/125045 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits