Author: Sumit Agarwal Date: 2025-04-03T16:23:09-06:00 New Revision: 996cf5dc6731d00da6ce5dc7a25b399d0ed29d54
URL: https://github.com/llvm/llvm-project/commit/996cf5dc6731d00da6ce5dc7a25b399d0ed29d54 DIFF: https://github.com/llvm/llvm-project/commit/996cf5dc6731d00da6ce5dc7a25b399d0ed29d54.diff LOG: [HLSL] Implement dot2add intrinsic (#131237) Resolves #99221 Key points: For SPIRV backend, it decompose into a `dot` followed a `add`. - [x] Implement dot2add clang builtin, - [x] Link dot2add clang builtin with hlsl_intrinsics.h - [x] Add sema checks for dot2add to CheckHLSLBuiltinFunctionCall in SemaHLSL.cpp - [x] Add codegen for dot2add to EmitHLSLBuiltinExpr in CGBuiltin.cpp - [x] Add codegen tests to clang/test/CodeGenHLSL/builtins/dot2add.hlsl - [x] Add sema tests to clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl - [x] Create the int_dx_dot2add intrinsic in IntrinsicsDirectX.td - [x] Create the DXILOpMapping of int_dx_dot2add to 162 in DXIL.td - [x] Create the dot2add.ll and dot2add_errors.ll tests in llvm/test/CodeGen/DirectX/ Added: clang/test/CodeGenHLSL/builtins/dot2add.hlsl clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl llvm/test/CodeGen/DirectX/dot2add.ll Modified: clang/include/clang/Basic/Builtins.td clang/lib/CodeGen/CGHLSLBuiltins.cpp clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h clang/lib/Headers/hlsl/hlsl_intrinsics.h llvm/include/llvm/IR/IntrinsicsDirectX.td llvm/lib/Target/DirectX/DXIL.td llvm/lib/Target/DirectX/DXILOpLowering.cpp Removed: ################################################################################ diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index b2c7ddb43de55..c7ca607e4b3d2 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4891,6 +4891,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } +def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_dot2add"]; + let Attributes = [NoThrow, Const]; + let Prototype = "float(_ExtVector<2, _Float16>, _ExtVector<2, _Float16>, float)"; +} + def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_dot4add_i8packed"]; let Attributes = [NoThrow, Const]; diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp b/clang/lib/CodeGen/CGHLSLBuiltins.cpp index 99c62808c323d..07f6d0953f026 100644 --- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp +++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp @@ -380,6 +380,19 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()), ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot"); } + case Builtin::BI__builtin_hlsl_dot2add: { + llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch(); + assert(Arch == llvm::Triple::dxil && + "Intrinsic dot2add is only allowed for dxil architecture"); + Value *A = EmitScalarExpr(E->getArg(0)); + Value *B = EmitScalarExpr(E->getArg(1)); + Value *C = EmitScalarExpr(E->getArg(2)); + + Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add; + return Builder.CreateIntrinsic( + /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr, + "dx.dot2add"); + } case Builtin::BI__builtin_hlsl_dot4add_i8packed: { Value *A = EmitScalarExpr(E->getArg(0)); Value *B = EmitScalarExpr(E->getArg(1)); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h index 8cdd63d7e07bb..3c15f2b38d80f 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h @@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) { return length_vec_impl(X - Y); } +constexpr float dot2add_impl(half2 a, half2 b, float c) { +#if defined(__DIRECTX__) + return __builtin_hlsl_dot2add(a, b, c); +#else + return dot(a, b) + c; +#endif +} + template <typename T> constexpr T reflect_impl(T I, T N) { return I - 2 * N * I * N; } diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index fd799b8d874ae..1a61fdba4fc19 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -175,6 +175,21 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X, return __detail::distance_vec_impl(X, Y); } +//===----------------------------------------------------------------------===// +// dot2add builtins +//===----------------------------------------------------------------------===// + +/// \fn float dot2add(half2 A, half2 B, float C) +/// \brief Dot product of 2 vector of type half and add a float scalar value. +/// \param A The first input value to dot product. +/// \param B The second input value to dot product. +/// \param C The input value added to the dot product. + +_HLSL_AVAILABILITY(shadermodel, 6.4) +const inline float dot2add(half2 A, half2 B, float C) { + return __detail::dot2add_impl(A, B, C); +} + //===----------------------------------------------------------------------===// // fmod builtins //===----------------------------------------------------------------------===// diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl new file mode 100644 index 0000000000000..2464607dd636c --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl @@ -0,0 +1,135 @@ +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -finclude-default-header -fnative-half-type -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +// CHECK-LABEL: define {{.*}}test_default_parameter_type +float test_default_parameter_type(half2 p1, half2 p2, float p3) { + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_float_arg2_type +float test_float_arg2_type(half2 p1, float2 p2, float p3) { + // CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_float_arg1_type +float test_float_arg1_type(float2 p1, half2 p2, float p3) { + // CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_double_arg3_type +float test_double_arg3_type(half2 p1, half2 p2, double p3) { + // CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn double %{{.*}} to float + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_float_arg1_arg2_type +float test_float_arg1_arg2_type(float2 p1, float2 p2, float p3) { + // CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half> + // CHECK: %conv1 = fptrunc reassoc nnan ninf nsz arcp afn <2 x float> %{{.*}} to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_double_arg1_arg2_type +float test_double_arg1_arg2_type(double2 p1, double2 p2, float p3) { + // CHECK: %conv = fptrunc reassoc nnan ninf nsz arcp afn <2 x double> %{{.*}} to <2 x half> + // CHECK: %conv1 = fptrunc reassoc nnan ninf nsz arcp afn <2 x double> %{{.*}} to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_int16_arg1_arg2_type +float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, float p3) { + // CHECK: %conv = sitofp <2 x i16> %{{.*}} to <2 x half> + // CHECK: %conv1 = sitofp <2 x i16> %{{.*}} to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_int32_arg1_arg2_type +float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, float p3) { + // CHECK: %conv = sitofp <2 x i32> %{{.*}} to <2 x half> + // CHECK: %conv1 = sitofp <2 x i32> %{{.*}} to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_int64_arg1_arg2_type +float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, float p3) { + // CHECK: %conv = sitofp <2 x i64> %{{.*}} to <2 x half> + // CHECK: %conv1 = sitofp <2 x i64> %{{.*}} to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} + +// CHECK-LABEL: define {{.*}}test_bool_arg1_arg2_type +float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float p3) { + // CHECK: %loadedv = trunc <2 x i32> %{{.*}} to <2 x i1> + // CHECK: %conv = uitofp <2 x i1> %loadedv to <2 x half> + // CHECK: %loadedv1 = trunc <2 x i32> %{{.*}} to <2 x i1> + // CHECK: %conv2 = uitofp <2 x i1> %loadedv1 to <2 x half> + // CHECK-SPIRV: %[[MUL:.*]] = call reassoc nnan ninf nsz arcp afn half @llvm.spv.fdot.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}) + // CHECK-SPIRV: %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half %[[MUL]] to float + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr.i, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float %[[CONV]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}}) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} diff --git a/clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl new file mode 100644 index 0000000000000..262ceecbf1d90 --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl @@ -0,0 +1,13 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -fnative-half-type -emit-llvm-only -disable-llvm-passes -verify + +float test_too_few_arg() { + return dot2add(); + // expected-error@-1 {{no matching function for call to 'dot2add'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 0 were provided}} +} + +float test_too_many_arg(half2 p1, half2 p2, float p3) { + return dot2add(p1, p2, p3, p1); + // expected-error@-1 {{no matching function for call to 'dot2add'}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 4 were provided}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index ead7286f4311c..775d325feeb14 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -100,6 +100,10 @@ def int_dx_udot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, Commutative] >; +def int_dx_dot2add : + DefaultAttrsIntrinsic<[llvm_float_ty], + [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty], + [IntrNoMem, Commutative]>; def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 1d8904bdf5514..b1e7406ead675 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1077,6 +1077,17 @@ def RawBufferStore : DXILOp<140, rawBufferStore> { let stages = [Stages<DXIL1_2, [all_stages]>]; } +def Dot2AddHalf : DXILOp<162, dot2AddHalf> { + let Doc = "dot product of 2 vectors of half having size = 2, returns " + "float"; + let intrinsics = [IntrinSelect<int_dx_dot2add>]; + let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy]; + let result = FloatTy; + let overloads = [Overloads<DXIL1_0, []>]; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, [ReadNone]>]; +} + def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> { let Doc = "signed dot product of 4 x i8 vectors packed into i32, with " "accumulate to i32"; diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index dff9f3e03079e..3dcd3d8fd244a 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -54,10 +54,8 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { return ExtractedElements; } -static SmallVector<Value *> argVectorFlatten(CallInst *Orig, - IRBuilder<> &Builder) { - // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. - unsigned NumOperands = Orig->getNumOperands() - 1; +static SmallVector<Value *> +argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) { assert(NumOperands > 0); Value *Arg0 = Orig->getOperand(0); [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); @@ -75,6 +73,12 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig, return NewOperands; } +static SmallVector<Value *> argVectorFlatten(CallInst *Orig, + IRBuilder<> &Builder) { + // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. + return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1); +} + namespace { class OpLowerer { Module &M; @@ -168,6 +172,13 @@ class OpLowerer { } } else if (IsVectorArgExpansion) { Args = argVectorFlatten(CI, OpBuilder.getIRB()); + } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) { + // arg[NumOperands-1] is a pointer and is not needed by our flattening. + // arg[NumOperands-2] also does not need to be flattened because it is a + // scalar. + unsigned NumOperands = CI->getNumOperands() - 2; + Args.push_back(CI->getArgOperand(NumOperands)); + Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands)); } else { Args.append(CI->arg_begin(), CI->arg_end()); } diff --git a/llvm/test/CodeGen/DirectX/dot2add.ll b/llvm/test/CodeGen/DirectX/dot2add.ll new file mode 100644 index 0000000000000..40c6cdafc83da --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot2add.ll @@ -0,0 +1,8 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s + +define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) { +entry: +; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3) + %ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c) + ret float %ret +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits