================ @@ -1512,6 +1512,83 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, TheCall->setType(ReturnType); } +static bool CheckScalarOrVector(Sema *S, CallExpr *TheCall, QualType Scalar, + unsigned ArgIndex) { + assert(TheCall->getNumArgs() >= ArgIndex); + QualType ArgType = TheCall->getArg(ArgIndex)->getType(); + auto *VTy = ArgType->getAs<VectorType>(); + // not the scalar or vector<scalar> + if (!(S->Context.hasSameUnqualifiedType(ArgType, Scalar) || + (VTy && S->Context.hasSameUnqualifiedType(VTy->getElementType(), + Scalar)))) { + S->Diag(TheCall->getArg(0)->getBeginLoc(), + diag::err_typecheck_expect_scalar_or_vector) + << ArgType << Scalar; + return true; + } + return false; +} + +static bool CheckBoolSelect(Sema *S, CallExpr *TheCall) { + assert(TheCall->getNumArgs() == 3); + Expr *Arg1 = TheCall->getArg(1); + Expr *Arg2 = TheCall->getArg(2); + if(!S->Context.hasSameUnqualifiedType(Arg1->getType(), + Arg2->getType())) { + S->Diag(TheCall->getBeginLoc(), + diag::err_typecheck_call_different_arg_types) + << Arg1->getType() << Arg2->getType() + << Arg1->getSourceRange() << Arg2->getSourceRange(); + return true; + } + + TheCall->setType(Arg1->getType()); + return false; +} + +static bool CheckVectorSelect(Sema *S, CallExpr *TheCall) { + assert(TheCall->getNumArgs() == 3); + Expr *Arg1 = TheCall->getArg(1); + Expr *Arg2 = TheCall->getArg(2); + if (!Arg1->getType()->isVectorType()) { + S->Diag(Arg1->getBeginLoc(), + diag::err_builtin_non_vector_type) + << "Second" << "__builtin_hlsl_select" << Arg1->getType() + << Arg1->getSourceRange(); + return true; + } + + if (!Arg2->getType()->isVectorType()) { + S->Diag(Arg2->getBeginLoc(), + diag::err_builtin_non_vector_type) + << "Third" << "__builtin_hlsl_select" << Arg2->getType() + << Arg2->getSourceRange(); + return true; + } + + if (!S->Context.hasSameUnqualifiedType(Arg1->getType(), + Arg2->getType())) { + S->Diag(TheCall->getBeginLoc(), + diag::err_typecheck_call_different_arg_types) + << Arg1->getType() << Arg2->getType() + << Arg1->getSourceRange() << Arg2->getSourceRange(); + return true; + } + + // caller has checked that Arg0 is a vector. + // check all three args have the same length. + if (TheCall->getArg(0)->getType()->getAs<VectorType>()->getNumElements() != ---------------- farzonl wrote:
this block feels like something `CheckVectorElementCallArgs` already does. https://github.com/llvm/llvm-project/pull/107129 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits