llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-llvm-analysis Author: Matt Arsenault (arsenm) <details> <summary>Changes</summary> Verify the format is valid and the type is one of the expected i32 vectors. Verify the used vector types at least cover the requirements of the corresponding format operand. --- Patch is 22.02 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/117048.diff 3 Files Affected: - (modified) llvm/include/llvm/IR/IntrinsicsAMDGPU.td (+4-6) - (modified) llvm/lib/IR/Verifier.cpp (+49) - (added) llvm/test/Verifier/AMDGPU/mfma-scale.ll (+230) ``````````diff diff --git a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td index 3a5fc86183ca0e..ee7ea8343eacbf 100644 --- a/llvm/include/llvm/IR/IntrinsicsAMDGPU.td +++ b/llvm/include/llvm/IR/IntrinsicsAMDGPU.td @@ -2973,12 +2973,10 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> : // blgp. // // These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats, -// and <4 x i32> for f4 formats. If the format control bits imply a -// smaller type than used, the high elements will be truncated. -// -// If the format control bits imply a larger type than used, the high -// elements are padded with undef. - +// and <4 x i32> for f4 formats. It is invalid to use a format that +// requires more registers than the corresponding vector type (e.g. it +// is illegal to use <6 x i32> in operand 0 if cbsz specifies an f8 +// format that requires 8 registers). class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> : DefaultAttrsIntrinsic<[DestTy], [llvm_anyvector_ty, llvm_anyvector_ty, DestTy, diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp index 5c0ccf734cccbc..32b50e199ae8fc 100644 --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -6383,6 +6383,55 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) { "llvm.amdgcn.s.prefetch.data only supports global or constant memory"); break; } + case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4: + case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: { + Value *Src0 = Call.getArgOperand(0); + Value *Src1 = Call.getArgOperand(1); + + uint64_t CBSZ = cast<ConstantInt>(Call.getArgOperand(3))->getZExtValue(); + uint64_t BLGP = cast<ConstantInt>(Call.getArgOperand(4))->getZExtValue(); + Check(CBSZ <= 4, "invalid value for cbsz format", Call, + Call.getArgOperand(3)); + Check(BLGP <= 4, "invalid value for blgp format", Call, + Call.getArgOperand(4)); + + // AMDGPU::MFMAScaleFormats values + auto getFormatNumRegs = [](unsigned FormatVal) { + switch (FormatVal) { + case 0: + case 1: + return 8u; + case 2: + case 3: + return 6u; + case 4: + return 4u; + default: + llvm_unreachable("invalid format value"); + } + }; + + auto isValidSrcASrcBVector = [](FixedVectorType *Ty) { + if (!Ty || !Ty->getElementType()->isIntegerTy(32)) + return false; + unsigned NumElts = Ty->getNumElements(); + return NumElts == 4 || NumElts == 6 || NumElts == 8; + }; + + auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType()); + auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType()); + Check(isValidSrcASrcBVector(Src0Ty), + "operand 0 must be 4, 6 or 8 element i32 vector", &Call, Src0); + Check(isValidSrcASrcBVector(Src1Ty), + "operand 1 must be 4, 6 or 8 element i32 vector", &Call, Src1); + + // Permit excess registers for the format. + Check(Src0Ty->getNumElements() >= getFormatNumRegs(CBSZ), + "invalid vector type for format", &Call, Src0, Call.getArgOperand(3)); + Check(Src1Ty->getNumElements() >= getFormatNumRegs(BLGP), + "invalid vector type for format", &Call, Src1, Call.getArgOperand(5)); + break; + } case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32: case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: { Value *V = Call.getArgOperand(0); diff --git a/llvm/test/Verifier/AMDGPU/mfma-scale.ll b/llvm/test/Verifier/AMDGPU/mfma-scale.ll new file mode 100644 index 00000000000000..1e3e8856df3d10 --- /dev/null +++ b/llvm/test/Verifier/AMDGPU/mfma-scale.ll @@ -0,0 +1,230 @@ +; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s + +; -------------------------------------------------------------------- +; Wrong mangled types +; -------------------------------------------------------------------- + +; CHECK: operand 0 must be 4, 6 or 8 element i32 vector +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <4 x i64> %arg0 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v4i64_fp8__v8i32_fp8(<4 x i64> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <4 x float> %arg2, + i32 0, ; cbsz + i32 2, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: operand 1 must be 4, 6 or 8 element i32 vector +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <4 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <4 x i64> %arg1 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8v4i64_fp8(<8 x i32> %arg0, <4 x i64> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <4 x float> %arg2, + i32 0, ; cbsz + i32 2, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: operand 0 must be 4, 6 or 8 element i32 vector +; CHECK: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK: <4 x i64> %arg0 +define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v4i64_fp8__v8i32_fp8(<4 x i64> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i64.v8i32(<4 x i64> %arg0, <8 x i32> %arg1, <16 x float> %arg2, + i32 0, ; cbsz + i32 2, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <16 x float> %result +} + +; CHECK: operand 1 must be 4, 6 or 8 element i32 vector +; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <16 x float> %arg2, i32 0, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <4 x i64> %arg1 +define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_fp8v4i64_fp8(<8 x i32> %arg0, <4 x i64> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i64(<8 x i32> %arg0, <4 x i64> %arg1, <16 x float> %arg2, + i32 0, ; cbsz + i32 2, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <16 x float> %result +} + +; -------------------------------------------------------------------- +; Impossible vector types +; -------------------------------------------------------------------- + +; CHECK: operand 0 must be 4, 6 or 8 element i32 vector +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v5i32.v8i32(<5 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <5 x i32> %arg0 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v5i32_fp4__v8i32_fp4(<5 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<5 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, + i32 4, ; cbsz + i32 4, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: operand 1 must be 4, 6 or 8 element i32 vector +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v5i32(<8 x i32> %arg0, <5 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <5 x i32> %arg1 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v5i32_fp4(<8 x i32> %arg0, <5 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v5i32(<8 x i32> %arg0, <5 x i32> %arg1, <4 x float> %arg2, + i32 4, ; cbsz + i32 4, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: operand 0 must be 4, 6 or 8 element i32 vector +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v7i32.v8i32(<7 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <7 x i32> %arg0 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v7i32_fp4__v8i32_fp4(<7 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i64.v8i32(<7 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, + i32 4, ; cbsz + i32 4, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: operand 1 must be 4, 6 or 8 element i32 vector +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v7i32(<8 x i32> %arg0, <7 x i32> %arg1, <4 x float> %arg2, i32 4, i32 4, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <7 x i32> %arg1 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp4__v7i32_fp4(<8 x i32> %arg0, <7 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v7i32(<8 x i32> %arg0, <7 x i32> %arg1, <4 x float> %arg2, + i32 4, ; cbsz + i32 4, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; -------------------------------------------------------------------- +; Out of bounds format +; -------------------------------------------------------------------- + +; CHECK: invalid value for cbsz format +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 9999, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: i32 9999 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_invalid0__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, + i32 9999, ; cbsz + i32 2, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: invalid value for blgp format +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 0, i32 9999, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: i32 9999 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i32_invalid0(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, + i32 0, ; cbsz + i32 9999, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: invalid value for cbsz format +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 5, i32 2, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: i32 5 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_invalid1__v8i32_fp6(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, + i32 5, ; cbsz + i32 2, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: invalid value for blgp format +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 0, i32 5, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: i32 5 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v8i32_fp8__v8i321_invalid(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <4 x float> %arg2, + i32 0, ; cbsz + i32 5, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: invalid value for cbsz format +; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 5, i32 5, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: i32 5 +define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v8i32_invalid__v8i32_invalid(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v8i32(<8 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, + i32 5, ; cbsz + i32 5, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <16 x float> %result +} + +; -------------------------------------------------------------------- +; Incorrect signature for format cases (IR vector too small) +; -------------------------------------------------------------------- + +; CHECK: invalid vector type for format +; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <4 x i32> %arg0 +; CHECK-NEXT: i32 0 +define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4___v4i32_fp8__v8i32_fp8(<4 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v4i32.v8i32(<4 x i32> %arg0, <8 x i32> %arg1, <16 x float> %arg2, + i32 0, ; cbsz + i32 0, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <16 x float> %result +} + +; CHECK: invalid vector type for format +; CHECK-NEXT: %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <4 x i32> %arg1 +; CHECK-NEXT: i32 0 +define <16 x float> @test_mfma_scale_f32_32x32x64_f8f6f4__v8i32_fp8___v4i32_fp8(<8 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <16 x float> @llvm.amdgcn.mfma.scale.f32.32x32x64.f8f6f4.v8i32.v4i32(<8 x i32> %arg0, <4 x i32> %arg1, <16 x float> %arg2, + i32 0, ; cbsz + i32 0, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <16 x float> %result +} + +; CHECK: invalid vector type for format +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> %arg0, <4 x i32> %arg1, <4 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <4 x i32> %arg0 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v4i32_fp8__v4i32_fp8(<4 x i32> %arg0, <4 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v4i32.v4i32(<4 x i32> %arg0, <4 x i32> %arg1, <4 x float> %arg2, + i32 0, ; cbsz + i32 0, ; blgp + i32 0, i32 %scale0, i32 0, i32 %scale1) + ret <4 x float> %result +} + +; CHECK: invalid vector type for format +; CHECK-NEXT: %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2, i32 0, i32 0, i32 0, i32 %scale0, i32 0, i32 %scale1) +; CHECK-NEXT: <6 x i32> %arg0 +define <4 x float> @test_mfma_scale_f32_16x16x128_f8f6f4___v6i32_fp8__v6i32_fp8(<6 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2, i32 %scale0, i32 %scale1) { + %result = call <4 x float> @llvm.amdgcn.mfma.scale.f32.16x16x128.f8f6f4.v6i32.v6i32(<6 x i32> %arg0, <6 x i32> %arg1, <4 x float> %arg2, + i32 0, ; cbsz + i32 0, ; blgp + ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/117048 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits