================ @@ -16,88 +16,90 @@ namespace clang { SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {} +/// Checks if the first `NumArgsToCheck` arguments of a function call are of +/// vector type. If any of the arguments is not a vector type, it emits a +/// diagnostic error and returns `true`. Otherwise, it returns `false`. +/// +/// \param TheCall The function call expression to check. +/// \param NumArgsToCheck The number of arguments to check for vector type. +/// \return `true` if any of the arguments is not a vector type, `false` +/// otherwise. + +bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) { + for (unsigned i = 0; i < NumArgsToCheck; ++i) { + ExprResult Arg = TheCall->getArg(i); + QualType ArgTy = Arg.get()->getType(); + auto *VTy = ArgTy->getAs<VectorType>(); + if (VTy == nullptr) { + SemaRef.Diag(Arg.get()->getBeginLoc(), + diag::err_typecheck_convert_incompatible) + << ArgTy + << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1 + << 0 << 0; + return true; + } + } + return false; +} + bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { switch (BuiltinID) { case SPIRV::BI__builtin_spirv_distance: { if (SemaRef.checkArgCount(TheCall, 2)) return true; - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTyA = ArgTyA->getAs<VectorType>(); - if (VTyA == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - - ExprResult B = TheCall->getArg(1); - QualType ArgTyB = B.get()->getType(); - auto *VTyB = ArgTyB->getAs<VectorType>(); - if (VTyB == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyB - << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 - << 0 << 0; + // Use the helper function to check both arguments + if (CheckVectorArgs(TheCall, 2)) return true; - } - QualType RetTy = VTyA->getElementType(); + QualType RetTy = + TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType(); TheCall->setType(RetTy); break; } case SPIRV::BI__builtin_spirv_length: { if (SemaRef.checkArgCount(TheCall, 1)) return true; - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTy = ArgTyA->getAs<VectorType>(); - if (VTy == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; + + // Use the helper function to check the argument + if (CheckVectorArgs(TheCall, 1)) return true; - } - QualType RetTy = VTy->getElementType(); + + QualType RetTy = + TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType(); TheCall->setType(RetTy); break; } - case SPIRV::BI__builtin_spirv_reflect: { - if (SemaRef.checkArgCount(TheCall, 2)) + case SPIRV::BI__builtin_spirv_refract: { + if (SemaRef.checkArgCount(TheCall, 3)) return true; - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTyA = ArgTyA->getAs<VectorType>(); - if (VTyA == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; + // Use the helper function to check the first two arguments + if (CheckVectorArgs(TheCall, 2)) return true; - } - ExprResult B = TheCall->getArg(1); - QualType ArgTyB = B.get()->getType(); - auto *VTyB = ArgTyB->getAs<VectorType>(); - if (VTyB == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyB - << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 - << 0 << 0; + ExprResult C = TheCall->getArg(2); + QualType ArgTyC = C.get()->getType(); + if (!ArgTyC->isFloatingType()) { + SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type) + << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC; return true; } - QualType RetTy = ArgTyA; + QualType RetTy = TheCall->getArg(0)->getType(); + TheCall->setType(RetTy); + break; + } + case SPIRV::BI__builtin_spirv_reflect: { + if (SemaRef.checkArgCount(TheCall, 2)) + return true; + + // Use the helper function to check both arguments ---------------- spall wrote:
same question here about if you need to check for float. and same comment about semahlsl style. https://github.com/llvm/llvm-project/pull/136026 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits