================ @@ -403,142 +369,183 @@ enum ArmSMEState : unsigned { ArmZT0Mask = 0b11 << 2 }; +bool SemaARM::CheckImmediateArg(CallExpr *TheCall, unsigned CheckTy, + unsigned ArgIdx, unsigned EltBitWidth, + unsigned VecBitWidth) { + + typedef bool (*OptionSetCheckFnTy)(int64_t Value); + + // Function that checks whether the operand (ArgIdx) is an immediate + // that is one of the predefined values. + auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm, + int ErrDiag) -> bool { + // We can't check the value of a dependent argument. + Expr *Arg = TheCall->getArg(ArgIdx); + if (Arg->isTypeDependent() || Arg->isValueDependent()) + return false; + + // Check constant-ness first. + llvm::APSInt Imm; + if (SemaRef.BuiltinConstantArg(TheCall, ArgIdx, Imm)) + return true; + + if (!CheckImm(Imm.getSExtValue())) + return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange(); + return false; + }; + + switch ((ImmCheckType)CheckTy) { + case ImmCheckType::ImmCheck0_31: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 31)) + return true; + break; + case ImmCheckType::ImmCheck0_13: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 13)) + return true; + break; + case ImmCheckType::ImmCheck0_63: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 63)) + return true; + break; + case ImmCheckType::ImmCheck1_16: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 16)) + return true; + break; + case ImmCheckType::ImmCheck0_7: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 7)) + return true; + break; + case ImmCheckType::ImmCheck1_1: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 1)) + return true; + break; + case ImmCheckType::ImmCheck1_3: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 3)) + return true; + break; + case ImmCheckType::ImmCheck1_7: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 7)) + return true; + break; + case ImmCheckType::ImmCheckExtract: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, + (2048 / EltBitWidth) - 1)) + return true; + break; + case ImmCheckType::ImmCheckCvt: + case ImmCheckType::ImmCheckShiftRight: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, EltBitWidth)) + return true; + break; + case ImmCheckType::ImmCheckShiftRightNarrow: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, EltBitWidth / 2)) + return true; + break; + case ImmCheckType::ImmCheckShiftLeft: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, EltBitWidth - 1)) + return true; + break; + case ImmCheckType::ImmCheckLaneIndex: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, + (VecBitWidth / EltBitWidth) - 1)) + return true; + break; + case ImmCheckType::ImmCheckLaneIndexCompRotate: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, + (VecBitWidth / (2 * EltBitWidth)) - 1)) + return true; + break; + case ImmCheckType::ImmCheckLaneIndexDot: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, + (VecBitWidth / (4 * EltBitWidth)) - 1)) + return true; + break; + case ImmCheckType::ImmCheckComplexRot90_270: + if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; }, + diag::err_rotation_argument_to_cadd)) + return true; + break; + case ImmCheckType::ImmCheckComplexRotAll90: + if (CheckImmediateInSet( + [](int64_t V) { return V == 0 || V == 90 || V == 180 || V == 270; }, + diag::err_rotation_argument_to_cmla)) + return true; + break; + case ImmCheckType::ImmCheck0_1: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 1)) + return true; + break; + case ImmCheckType::ImmCheck0_2: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 2)) + return true; + break; + case ImmCheckType::ImmCheck0_3: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 3)) + return true; + break; + case ImmCheckType::ImmCheck0_0: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 0)) + return true; + break; + case ImmCheckType::ImmCheck0_15: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 15)) + return true; + break; + case ImmCheckType::ImmCheck0_255: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 0, 255)) + return true; + break; + case ImmCheckType::ImmCheck1_32: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 32)) + return true; + break; + case ImmCheckType::ImmCheck1_64: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 1, 64)) + return true; + break; + case ImmCheckType::ImmCheck2_4_Mul2: + if (SemaRef.BuiltinConstantArgRange(TheCall, ArgIdx, 2, 4) || + SemaRef.BuiltinConstantArgMultiple(TheCall, ArgIdx, 2)) + return true; + break; + default: + llvm_unreachable("Invalid immediate range typeflag!"); + break; + } + return false; +} + +bool SemaARM::ParseNeonImmChecks( + CallExpr *TheCall, + SmallVector<std::tuple<int, int, int, int>, 2> &ImmChecks, + int OverloadType = -1) { + unsigned CheckTy; + unsigned ArgIdx, ElementSizeInBits, VecSizeInBits; + bool HasError = false; + + for (const auto &I : ImmChecks) { + std::tie(ArgIdx, CheckTy, ElementSizeInBits, VecSizeInBits) = I; + + if (OverloadType >= 0) + ElementSizeInBits = NeonTypeFlags(OverloadType).getEltSizeInBits(); + + HasError |= CheckImmediateArg(TheCall, CheckTy, ArgIdx, ElementSizeInBits, + VecSizeInBits); + } + + return HasError; +} + bool SemaARM::ParseSVEImmChecks( CallExpr *TheCall, SmallVector<std::tuple<int, int, int>, 3> &ImmChecks) { - // Perform all the immediate checks for this builtin call. - bool HasError = false; - for (auto &I : ImmChecks) { - int ArgNum, CheckTy, ElementSizeInBits; - std::tie(ArgNum, CheckTy, ElementSizeInBits) = I; - - typedef bool (*OptionSetCheckFnTy)(int64_t Value); - - // Function that checks whether the operand (ArgNum) is an immediate - // that is one of the predefined values. - auto CheckImmediateInSet = [&](OptionSetCheckFnTy CheckImm, - int ErrDiag) -> bool { - // We can't check the value of a dependent argument. - Expr *Arg = TheCall->getArg(ArgNum); - if (Arg->isTypeDependent() || Arg->isValueDependent()) - return false; - - // Check constant-ness first. - llvm::APSInt Imm; - if (SemaRef.BuiltinConstantArg(TheCall, ArgNum, Imm)) - return true; - if (!CheckImm(Imm.getSExtValue())) - return Diag(TheCall->getBeginLoc(), ErrDiag) << Arg->getSourceRange(); - return false; - }; + bool HasError = false; + unsigned CheckTy, ArgIdx, ElementSizeInBits; - switch ((SVETypeFlags::ImmCheckType)CheckTy) { - case SVETypeFlags::ImmCheck0_31: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 31)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_13: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 13)) - HasError = true; - break; - case SVETypeFlags::ImmCheck1_16: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 16)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_7: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 7)) - HasError = true; - break; - case SVETypeFlags::ImmCheck1_1: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 1)) - HasError = true; - break; - case SVETypeFlags::ImmCheck1_3: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 3)) - HasError = true; - break; - case SVETypeFlags::ImmCheck1_7: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, 7)) - HasError = true; - break; - case SVETypeFlags::ImmCheckExtract: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, - (2048 / ElementSizeInBits) - 1)) - HasError = true; - break; - case SVETypeFlags::ImmCheckShiftRight: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, - ElementSizeInBits)) - HasError = true; - break; - case SVETypeFlags::ImmCheckShiftRightNarrow: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 1, - ElementSizeInBits / 2)) - HasError = true; - break; - case SVETypeFlags::ImmCheckShiftLeft: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, - ElementSizeInBits - 1)) - HasError = true; - break; - case SVETypeFlags::ImmCheckLaneIndex: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, - (128 / (1 * ElementSizeInBits)) - 1)) - HasError = true; - break; - case SVETypeFlags::ImmCheckLaneIndexCompRotate: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, - (128 / (2 * ElementSizeInBits)) - 1)) - HasError = true; - break; - case SVETypeFlags::ImmCheckLaneIndexDot: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, - (128 / (4 * ElementSizeInBits)) - 1)) - HasError = true; - break; - case SVETypeFlags::ImmCheckComplexRot90_270: - if (CheckImmediateInSet([](int64_t V) { return V == 90 || V == 270; }, - diag::err_rotation_argument_to_cadd)) - HasError = true; - break; - case SVETypeFlags::ImmCheckComplexRotAll90: - if (CheckImmediateInSet( - [](int64_t V) { - return V == 0 || V == 90 || V == 180 || V == 270; - }, - diag::err_rotation_argument_to_cmla)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_1: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 1)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_2: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 2)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_3: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 3)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_0: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 0)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_15: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 15)) - HasError = true; - break; - case SVETypeFlags::ImmCheck0_255: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 0, 255)) - HasError = true; - break; - case SVETypeFlags::ImmCheck2_4_Mul2: - if (SemaRef.BuiltinConstantArgRange(TheCall, ArgNum, 2, 4) || - SemaRef.BuiltinConstantArgMultiple(TheCall, ArgNum, 2)) - HasError = true; - break; - } + for (const auto &I : ImmChecks) { + std::tie(ArgIdx, CheckTy, ElementSizeInBits) = I; + HasError |= ---------------- SpencerAbson wrote:
Done. https://github.com/llvm/llvm-project/pull/100278 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits