================ @@ -16,88 +16,90 @@ namespace clang { SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {} +/// Checks if the first `NumArgsToCheck` arguments of a function call are of +/// vector type. If any of the arguments is not a vector type, it emits a +/// diagnostic error and returns `true`. Otherwise, it returns `false`. +/// +/// \param TheCall The function call expression to check. +/// \param NumArgsToCheck The number of arguments to check for vector type. +/// \return `true` if any of the arguments is not a vector type, `false` +/// otherwise. + +bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) { + for (unsigned i = 0; i < NumArgsToCheck; ++i) { + ExprResult Arg = TheCall->getArg(i); + QualType ArgTy = Arg.get()->getType(); + auto *VTy = ArgTy->getAs<VectorType>(); + if (VTy == nullptr) { + SemaRef.Diag(Arg.get()->getBeginLoc(), + diag::err_typecheck_convert_incompatible) + << ArgTy + << SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1 + << 0 << 0; + return true; + } + } + return false; +} + bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { switch (BuiltinID) { case SPIRV::BI__builtin_spirv_distance: { if (SemaRef.checkArgCount(TheCall, 2)) return true; - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTyA = ArgTyA->getAs<VectorType>(); - if (VTyA == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - - ExprResult B = TheCall->getArg(1); - QualType ArgTyB = B.get()->getType(); - auto *VTyB = ArgTyB->getAs<VectorType>(); - if (VTyB == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyB - << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 - << 0 << 0; + // Use the helper function to check both arguments + if (CheckVectorArgs(TheCall, 2)) return true; - } - QualType RetTy = VTyA->getElementType(); + QualType RetTy = + TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType(); TheCall->setType(RetTy); break; } case SPIRV::BI__builtin_spirv_length: { if (SemaRef.checkArgCount(TheCall, 1)) return true; - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTy = ArgTyA->getAs<VectorType>(); - if (VTy == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; + + // Use the helper function to check the argument + if (CheckVectorArgs(TheCall, 1)) ---------------- spall wrote:
same question here about if you should be checking if the element type is float. And Same comment about the style from SemaHLSL https://github.com/llvm/llvm-project/pull/136026 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits