https://github.com/farzonl updated https://github.com/llvm/llvm-project/pull/86071
>From db83effd2e9498fd7fd507b748a423390acefd5e Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <farzonlo...@microsoft.com> Date: Tue, 19 Mar 2024 17:29:55 -0400 Subject: [PATCH 1/2] Add Float `Dot` Intrinsic Lowering Completes #83626 - `CGBuiltin.cpp` - modify `getDotProductIntrinsic` to be able to emit `dot2`, `dot3`, and `dot4` intrinsics based on element count - `IntrinsicsDirectX.td` - for floating point add `dot2`,`dot3`, and `dot4` inntrinsics -`DXIL.td` add dxilop intrinsic lowering for `dot2`,`dot3`, & `dot4`. -`DXILOpLowering.cpp` - add vector arg flattening for dot product. -`DXILOpBuilder.h` - modify `createDXILOpCall` to take a smallVector instead of an iterator - `DXILOpBuilder.cpp` - modify createDXILOpCall by moving the small vector up to the callee function in `DXILOpLowering.cpp`. Moving one function up gives us access to the callInst and Function Which were needed to distinguish the dot product intrinsics and get the operands without using the iterator. --- clang/lib/CodeGen/CGBuiltin.cpp | 25 +++--- clang/test/CodeGenHLSL/builtins/dot.hlsl | 28 +++---- llvm/include/llvm/IR/IntrinsicsDirectX.td | 10 ++- llvm/lib/Target/DirectX/DXIL.td | 9 +++ llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 8 +- llvm/lib/Target/DirectX/DXILOpBuilder.h | 5 +- llvm/lib/Target/DirectX/DXILOpLowering.cpp | 55 ++++++++++++- llvm/test/CodeGen/DirectX/dot2_error.ll | 10 +++ llvm/test/CodeGen/DirectX/dot3_error.ll | 10 +++ llvm/test/CodeGen/DirectX/dot4_error.ll | 10 +++ llvm/test/CodeGen/DirectX/fdot.ll | 94 ++++++++++++++++++++++ 11 files changed, 230 insertions(+), 34 deletions(-) create mode 100644 llvm/test/CodeGen/DirectX/dot2_error.ll create mode 100644 llvm/test/CodeGen/DirectX/dot3_error.ll create mode 100644 llvm/test/CodeGen/DirectX/dot4_error.ll create mode 100644 llvm/test/CodeGen/DirectX/fdot.ll diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 2eaceeba617700..8f4817258e3b18 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -18066,15 +18066,22 @@ llvm::Value *CodeGenFunction::EmitScalarOrConstFoldImmArg(unsigned ICEArguments, return Arg; } -Intrinsic::ID getDotProductIntrinsic(QualType QT) { +Intrinsic::ID getDotProductIntrinsic(QualType QT, int elementCount) { + if (QT->hasFloatingRepresentation()) { + switch (elementCount) { + case 2: + return Intrinsic::dx_dot2; + case 3: + return Intrinsic::dx_dot3; + case 4: + return Intrinsic::dx_dot4; + } + } if (QT->hasSignedIntegerRepresentation()) return Intrinsic::dx_sdot; - if (QT->hasUnsignedIntegerRepresentation()) - return Intrinsic::dx_udot; - assert(QT->hasFloatingRepresentation()); - return Intrinsic::dx_dot; - ; + assert(QT->hasUnsignedIntegerRepresentation()); + return Intrinsic::dx_udot; } Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, @@ -18128,8 +18135,7 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, assert(T0->getScalarType() == T1->getScalarType() && "Dot product of vectors need the same element types."); - [[maybe_unused]] auto *VecTy0 = - E->getArg(0)->getType()->getAs<VectorType>(); + auto *VecTy0 = E->getArg(0)->getType()->getAs<VectorType>(); [[maybe_unused]] auto *VecTy1 = E->getArg(1)->getType()->getAs<VectorType>(); // A HLSLVectorTruncation should have happend @@ -18138,7 +18144,8 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, return Builder.CreateIntrinsic( /*ReturnType=*/T0->getScalarType(), - getDotProductIntrinsic(E->getArg(0)->getType()), + getDotProductIntrinsic(E->getArg(0)->getType(), + VecTy0->getNumElements()), ArrayRef<Value *>{Op0, Op1}, nullptr, "dx.dot"); } break; case Builtin::BI__builtin_hlsl_lerp: { diff --git a/clang/test/CodeGenHLSL/builtins/dot.hlsl b/clang/test/CodeGenHLSL/builtins/dot.hlsl index 0f993193c00cce..307d71cce3cb6d 100644 --- a/clang/test/CodeGenHLSL/builtins/dot.hlsl +++ b/clang/test/CodeGenHLSL/builtins/dot.hlsl @@ -110,21 +110,21 @@ uint64_t test_dot_ulong4(uint64_t4 p0, uint64_t4 p1) { return dot(p0, p1); } // NO_HALF: ret float %dx.dot half test_dot_half(half p0, half p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v2f16(<2 x half> %0, <2 x half> %1) +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %0, <2 x half> %1) // NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// NO_HALF: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1) // NO_HALF: ret float %dx.dot half test_dot_half2(half2 p0, half2 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v3f16(<3 x half> %0, <3 x half> %1) +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %0, <3 x half> %1) // NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// NO_HALF: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1) // NO_HALF: ret float %dx.dot half test_dot_half3(half3 p0, half3 p1) { return dot(p0, p1); } -// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot.v4f16(<4 x half> %0, <4 x half> %1) +// NATIVE_HALF: %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %0, <4 x half> %1) // NATIVE_HALF: ret half %dx.dot -// NO_HALF: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// NO_HALF: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1) // NO_HALF: ret float %dx.dot half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); } @@ -132,34 +132,34 @@ half test_dot_half4(half4 p0, half4 p1) { return dot(p0, p1); } // CHECK: ret float %dx.dot float test_dot_float(float p0, float p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %1) // CHECK: ret float %dx.dot float test_dot_float2(float2 p0, float2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %1) // CHECK: ret float %dx.dot float test_dot_float3(float3 p0, float3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %0, <4 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %0, <4 x float> %1) // CHECK: ret float %dx.dot float test_dot_float4(float4 p0, float4 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %splat.splat, <2 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %splat.splat, <2 x float> %1) // CHECK: ret float %dx.dot float test_dot_float2_splat(float p0, float2 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %splat.splat, <3 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %splat.splat, <3 x float> %1) // CHECK: ret float %dx.dot float test_dot_float3_splat(float p0, float3 p1) { return dot(p0, p1); } -// CHECK: %dx.dot = call float @llvm.dx.dot.v4f32(<4 x float> %splat.splat, <4 x float> %1) +// CHECK: %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %splat.splat, <4 x float> %1) // CHECK: ret float %dx.dot float test_dot_float4_splat(float p0, float4 p1) { return dot(p0, p1); } // CHECK: %conv = sitofp i32 %1 to float // CHECK: %splat.splatinsert = insertelement <2 x float> poison, float %conv, i64 0 // CHECK: %splat.splat = shufflevector <2 x float> %splat.splatinsert, <2 x float> poison, <2 x i32> zeroinitializer -// CHECK: %dx.dot = call float @llvm.dx.dot.v2f32(<2 x float> %0, <2 x float> %splat.splat) +// CHECK: %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %0, <2 x float> %splat.splat) // CHECK: ret float %dx.dot float test_builtin_dot_float2_int_splat(float2 p0, int p1) { return dot(p0, p1); @@ -168,7 +168,7 @@ float test_builtin_dot_float2_int_splat(float2 p0, int p1) { // CHECK: %conv = sitofp i32 %1 to float // CHECK: %splat.splatinsert = insertelement <3 x float> poison, float %conv, i64 0 // CHECK: %splat.splat = shufflevector <3 x float> %splat.splatinsert, <3 x float> poison, <3 x i32> zeroinitializer -// CHECK: %dx.dot = call float @llvm.dx.dot.v3f32(<3 x float> %0, <3 x float> %splat.splat) +// CHECK: %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %0, <3 x float> %splat.splat) // CHECK: ret float %dx.dot float test_builtin_dot_float3_int_splat(float3 p0, int p1) { return dot(p0, p1); diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 1164b241ba7b0d..a871fac46b9fd5 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -24,7 +24,15 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; -def int_dx_dot : +def int_dx_dot2 : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; +def int_dx_dot3 : + Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], + [IntrNoMem, IntrWillReturn, Commutative] >; +def int_dx_dot4 : Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index f7e69ebae15b6c..f95cf22861360c 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -303,6 +303,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad, "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">; def UMad : DXILOpMapping<49, tertiary, int_dx_umad, "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">; +def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1", + [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>; +def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2", + [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>; +def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3", + [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>; def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id, "Reads the thread ID">; def GroupId : DXILOpMapping<94, groupId, int_dx_group_id, diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 0841ae95423c7b..0b3982ea0f438a 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -254,7 +254,7 @@ namespace dxil { CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, Type *OverloadTy, - llvm::iterator_range<Use *> Args) { + SmallVector<Value *> Args) { const OpCodeProperty *Prop = getOpCodeProperty(OpCode); OverloadKind Kind = getOverloadKind(OverloadTy); @@ -272,10 +272,8 @@ CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy); DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT); } - SmallVector<Value *> FullArgs; - FullArgs.emplace_back(B.getInt32((int32_t)OpCode)); - FullArgs.append(Args.begin(), Args.end()); - return B.CreateCall(DXILFn, FullArgs); + + return B.CreateCall(DXILFn, Args); } Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) { diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index f3abcc6e02a4e3..5babeae470178b 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -13,7 +13,7 @@ #define LLVM_LIB_TARGET_DIRECTX_DXILOPBUILDER_H #include "DXILConstants.h" -#include "llvm/ADT/iterator_range.h" +#include "llvm/ADT/SmallVector.h" namespace llvm { class Module; @@ -35,8 +35,7 @@ class DXILOpBuilder { /// \param OverloadTy Overload type of the DXIL Op call constructed /// \return DXIL Op call constructed CallInst *createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy, - Type *OverloadTy, - llvm::iterator_range<Use *> Args); + Type *OverloadTy, SmallVector<Value *> Args); Type *getOverloadTy(dxil::OpCode OpCode, FunctionType *FT); static const char *getOpCodeName(dxil::OpCode DXILOp); diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 3e334b0ec298d3..f09e322f88e1fd 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -30,6 +30,48 @@ using namespace llvm; using namespace llvm::dxil; +static bool isVectorArgExpansion(Function &F) { + switch (F.getIntrinsicID()) { + case Intrinsic::dx_dot2: + case Intrinsic::dx_dot3: + case Intrinsic::dx_dot4: + return true; + } + return false; +} + +static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { + SmallVector<Value *, 4> ExtractedElements; + auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); + for (unsigned I = 0; I < VecArg->getNumElements(); ++I) { + Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I); + Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index); + ExtractedElements.push_back(ExtractedElement); + } + return ExtractedElements; +} + +static SmallVector<Value *> argVectorFlatten(CallInst *Orig, + IRBuilder<> &Builder) { + // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. + unsigned NumOperands = Orig->getNumOperands() - 1; + assert(NumOperands > 0); + Value *Arg0 = Orig->getOperand(0); + [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); + assert(VecArg0); + SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder); + for (unsigned I = 1; I < NumOperands; ++I) { + Value *Arg = Orig->getOperand(I); + [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); + assert(VecArg); + assert(VecArg0->getElementType() == VecArg->getElementType()); + assert(VecArg0->getNumElements() == VecArg->getNumElements()); + auto NextOperandList = populateOperands(Arg, Builder); + NewOperands.append(NextOperandList.begin(), NextOperandList.end()); + } + return NewOperands; +} + static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { IRBuilder<> B(M.getContext()); DXILOpBuilder DXILB(M, B); @@ -39,9 +81,18 @@ static void lowerIntrinsic(dxil::OpCode DXILOp, Function &F, Module &M) { if (!CI) continue; + SmallVector<Value *> Args; + Value *DXILOpArg = B.getInt32(static_cast<unsigned>(DXILOp)); + Args.emplace_back(DXILOpArg); B.SetInsertPoint(CI); - CallInst *DXILCI = DXILB.createDXILOpCall(DXILOp, F.getReturnType(), - OverloadTy, CI->args()); + if (isVectorArgExpansion(F)) { + SmallVector<Value *> NewArgs = argVectorFlatten(CI, B); + Args.append(NewArgs.begin(), NewArgs.end()); + } else + Args.append(CI->arg_begin(), CI->arg_end()); + + CallInst *DXILCI = + DXILB.createDXILOpCall(DXILOp, F.getReturnType(), OverloadTy, Args); CI->replaceAllUsesWith(DXILCI); CI->eraseFromParent(); diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll b/llvm/test/CodeGen/DirectX/dot2_error.ll new file mode 100644 index 00000000000000..a27bfaedacd573 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot2_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation dot2 does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload + +define noundef double @dot_double2(<2 x double> noundef %a, <2 x double> noundef %b) { +entry: + %dx.dot = call double @llvm.dx.dot2.v2f64(<2 x double> %a, <2 x double> %b) + ret double %dx.dot +} diff --git a/llvm/test/CodeGen/DirectX/dot3_error.ll b/llvm/test/CodeGen/DirectX/dot3_error.ll new file mode 100644 index 00000000000000..eb69fb145038aa --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot3_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation dot3 does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload + +define noundef double @dot_double3(<3 x double> noundef %a, <3 x double> noundef %b) { +entry: + %dx.dot = call double @llvm.dx.dot3.v3f64(<3 x double> %a, <3 x double> %b) + ret double %dx.dot +} diff --git a/llvm/test/CodeGen/DirectX/dot4_error.ll b/llvm/test/CodeGen/DirectX/dot4_error.ll new file mode 100644 index 00000000000000..5cd632684c0c01 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot4_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower %s 2>&1 | FileCheck %s + +; DXIL operation dot4 does not support double overload type +; CHECK: LLVM ERROR: Invalid Overload + +define noundef double @dot_double4(<4 x double> noundef %a, <4 x double> noundef %b) { +entry: + %dx.dot = call double @llvm.dx.dot4.v4f64(<4 x double> %a, <4 x double> %b) + ret double %dx.dot +} diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll new file mode 100644 index 00000000000000..3e13b2ad2650c8 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/fdot.ll @@ -0,0 +1,94 @@ +; RUN: opt -S -dxil-op-lower < %s | FileCheck %s + +; Make sure dxil operation function calls for dot are generated for int/uint vectors. + +; CHECK-LABEL: dot_half2 +define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) { +entry: +; CHECK: extractelement <2 x half> %a, i32 0 +; CHECK: extractelement <2 x half> %a, i32 1 +; CHECK: extractelement <2 x half> %b, i32 0 +; CHECK: extractelement <2 x half> %b, i32 1 +; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) + %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b) + ret half %dx.dot +} + +; CHECK-LABEL: dot_half3 +define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) { +entry: +; CHECK: extractelement <3 x half> %a, i32 0 +; CHECK: extractelement <3 x half> %a, i32 1 +; CHECK: extractelement <3 x half> %a, i32 2 +; CHECK: extractelement <3 x half> %b, i32 0 +; CHECK: extractelement <3 x half> %b, i32 1 +; CHECK: extractelement <3 x half> %b, i32 2 +; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) + %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b) + ret half %dx.dot +} + +; CHECK-LABEL: dot_half4 +define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) { +entry: +; CHECK: extractelement <4 x half> %a, i32 0 +; CHECK: extractelement <4 x half> %a, i32 1 +; CHECK: extractelement <4 x half> %a, i32 2 +; CHECK: extractelement <4 x half> %a, i32 3 +; CHECK: extractelement <4 x half> %b, i32 0 +; CHECK: extractelement <4 x half> %b, i32 1 +; CHECK: extractelement <4 x half> %b, i32 2 +; CHECK: extractelement <4 x half> %b, i32 3 +; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) + %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b) + ret half %dx.dot +} + +; CHECK-LABEL: dot_float2 +define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) { +entry: +; CHECK: extractelement <2 x float> %a, i32 0 +; CHECK: extractelement <2 x float> %a, i32 1 +; CHECK: extractelement <2 x float> %b, i32 0 +; CHECK: extractelement <2 x float> %b, i32 1 +; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) + %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b) + ret float %dx.dot +} + +; CHECK-LABEL: dot_float3 +define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) { +entry: +; CHECK: extractelement <3 x float> %a, i32 0 +; CHECK: extractelement <3 x float> %a, i32 1 +; CHECK: extractelement <3 x float> %a, i32 2 +; CHECK: extractelement <3 x float> %b, i32 0 +; CHECK: extractelement <3 x float> %b, i32 1 +; CHECK: extractelement <3 x float> %b, i32 2 +; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) + %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b) + ret float %dx.dot +} + +; CHECK-LABEL: dot_float4 +define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) { +entry: +; CHECK: extractelement <4 x float> %a, i32 0 +; CHECK: extractelement <4 x float> %a, i32 1 +; CHECK: extractelement <4 x float> %a, i32 2 +; CHECK: extractelement <4 x float> %a, i32 3 +; CHECK: extractelement <4 x float> %b, i32 0 +; CHECK: extractelement <4 x float> %b, i32 1 +; CHECK: extractelement <4 x float> %b, i32 2 +; CHECK: extractelement <4 x float> %b, i32 3 +; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) + %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b) + ret float %dx.dot +} + +declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> ) +declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> ) +declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> ) +declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) +declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>) +declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>) >From 98fa81636b6f6e408763ecec53b3f8869a2bc096 Mon Sep 17 00:00:00 2001 From: Farzon Lotfi <farzonlo...@microsoft.com> Date: Mon, 25 Mar 2024 16:39:33 -0400 Subject: [PATCH 2/2] address pr comments --- llvm/lib/Target/DirectX/DXIL.td | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index f95cf22861360c..2e6d58e14fd325 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -303,15 +303,15 @@ def IMad : DXILOpMapping<48, tertiary, int_dx_imad, "Signed integer arithmetic multiply/add operation. imad(m,a,b) = m * a + b.">; def UMad : DXILOpMapping<49, tertiary, int_dx_umad, "Unsigned integer arithmetic multiply/add operation. umad(m,a,b) = m * a + b.">; -def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2, - "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1", - [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>; -def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3, - "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2", - [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>; -def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4, - "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3", - [llvm_halforfloat_ty,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>,LLVMMatchType<0>]>; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 4)) in + def Dot2 : DXILOpMapping<54, dot2, int_dx_dot2, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 1">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 6)) in + def Dot3 : DXILOpMapping<55, dot3, int_dx_dot3, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 2">; +let OpTypes = !listconcat([llvm_halforfloat_ty], !listsplat(llvm_halforfloat_ty, 8)) in + def Dot4 : DXILOpMapping<56, dot4, int_dx_dot4, + "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + a[n]*b[n] where n is between 0 and 3">; def ThreadId : DXILOpMapping<93, threadId, int_dx_thread_id, "Reads the thread ID">; def GroupId : DXILOpMapping<94, groupId, int_dx_group_id, _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits