llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Justin Bogner (bogner)

<details>
<summary>Changes</summary>

The `dx.dot2`, `dot3`, and `dot4` intrinsics exist purely to lower `dx.fdot`, 
and they map exactly to the DXIL ops of the same name. Using vectors for their 
arguments adds unnecessary complexity and causes us to have vector operations 
that are not trivial to lower post-scalarizer.

Similarly, the `dx.dot2add` intrinsic is overly generic for something that only 
needs to lower to a single `dot2AddHalf` DXIL op. Update its signature to match 
the operation it lowers to.

Fixes #<!-- -->134569.

---

Patch is 34.40 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/134570.diff


12 Files Affected:

- (modified) clang/lib/CodeGen/CGHLSLBuiltins.cpp (+8-3) 
- (modified) clang/test/CodeGenHLSL/builtins/dot2add.hlsl (+50-10) 
- (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+24-15) 
- (modified) llvm/lib/Target/DirectX/DXIL.td (+1-2) 
- (modified) llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp (+10-3) 
- (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (-58) 
- (modified) llvm/test/CodeGen/DirectX/dot2_error.ll (+3-2) 
- (modified) llvm/test/CodeGen/DirectX/dot2add.ll (+8-3) 
- (modified) llvm/test/CodeGen/DirectX/dot3_error.ll (+4-2) 
- (modified) llvm/test/CodeGen/DirectX/dot4_error.ll (+5-2) 
- (modified) llvm/test/CodeGen/DirectX/fdot.ll (+43-43) 
- (modified) llvm/test/CodeGen/DirectX/normalize.ll (+6-6) 


``````````diff
diff --git a/clang/lib/CodeGen/CGHLSLBuiltins.cpp 
b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
index 27d1c69439944..4e92be0664f71 100644
--- a/clang/lib/CodeGen/CGHLSLBuiltins.cpp
+++ b/clang/lib/CodeGen/CGHLSLBuiltins.cpp
@@ -385,12 +385,17 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned 
BuiltinID,
            "Intrinsic dot2add is only allowed for dxil architecture");
     Value *A = EmitScalarExpr(E->getArg(0));
     Value *B = EmitScalarExpr(E->getArg(1));
-    Value *C = EmitScalarExpr(E->getArg(2));
+    Value *Acc = EmitScalarExpr(E->getArg(2));
+
+    Value *AX = Builder.CreateExtractElement(A, Builder.getSize(0));
+    Value *AY = Builder.CreateExtractElement(A, Builder.getSize(1));
+    Value *BX = Builder.CreateExtractElement(B, Builder.getSize(0));
+    Value *BY = Builder.CreateExtractElement(B, Builder.getSize(1));
 
     Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add;
     return Builder.CreateIntrinsic(
-        /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr,
-        "dx.dot2add");
+        /*ReturnType=*/Acc->getType(), ID,
+        ArrayRef<Value *>{Acc, AX, AY, BX, BY}, nullptr, "dx.dot2add");
   }
   case Builtin::BI__builtin_hlsl_dot4add_i8packed: {
     Value *A = EmitScalarExpr(E->getArg(0));
diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl 
b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
index 2464607dd636c..c345e17476e08 100644
--- a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl
@@ -13,7 +13,11 @@ float test_default_parameter_type(half2 p1, half2 p2, float 
p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -25,7 +29,11 @@ float test_float_arg2_type(half2 p1, float2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -37,7 +45,11 @@ float test_float_arg1_type(float2 p1, half2 p2, float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -49,7 +61,11 @@ float test_double_arg3_type(half2 p1, half2 p2, double p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -62,7 +78,11 @@ float test_float_arg1_arg2_type(float2 p1, float2 p2, float 
p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -75,7 +95,11 @@ float test_double_arg1_arg2_type(double2 p1, double2 p2, 
float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -88,7 +112,11 @@ float test_int16_arg1_arg2_type(int16_t2 p1, int16_t2 p2, 
float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -101,7 +129,11 @@ float test_int32_arg1_arg2_type(int32_t2 p1, int32_t2 p2, 
float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -114,7 +146,11 @@ float test_int64_arg1_arg2_type(int64_t2 p1, int64_t2 p2, 
float p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
@@ -129,7 +165,11 @@ float test_bool_arg1_arg2_type(bool2 p1, bool2 p2, float 
p3) {
   // CHECK-SPIRV:  %[[CONV:.*]] = fpext reassoc nnan ninf nsz arcp afn half 
%[[MUL]] to float
   // CHECK-SPIRV:  %[[C:.*]] = load float, ptr %c.addr.i, align 4
   // CHECK-SPIRV:  %[[RES:.*]] = fadd reassoc nnan ninf nsz arcp afn float 
%[[CONV]], %[[C]]
-  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f16(<2 x 
half> %{{.*}}, <2 x half> %{{.*}}, float %{{.*}})
+  // CHECK-DXIL:  %[[AX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[AY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[BX:.*]] = extractelement <2 x half> %{{.*}}, i32 0
+  // CHECK-DXIL:  %[[BY:.*]] = extractelement <2 x half> %{{.*}}, i32 1
+  // CHECK-DXIL:  %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add(float 
%{{.*}}, half %[[AX]], half %[[AY]], half %[[BX]], half %[[BY]])
   // CHECK:  ret float %[[RES]]
   return dot2add(p1, p2, p3);
 }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td 
b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 775d325feeb14..b1a27311e2a9c 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -76,18 +76,27 @@ def int_dx_nclamp : 
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>,
 def int_dx_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], 
[LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>;
 def int_dx_saturate : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], 
[LLVMMatchType<0>], [IntrNoMem]>;
 
-def int_dx_dot2 :
-    DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
-    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, 
LLVMVectorElementType<0>>],
-    [IntrNoMem, Commutative] >;
-def int_dx_dot3 :
-    DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
-    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, 
LLVMVectorElementType<0>>],
-    [IntrNoMem, Commutative] >;
-def int_dx_dot4 :
-    DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
-    [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, 
LLVMVectorElementType<0>>],
-    [IntrNoMem, Commutative] >;
+def int_dx_dot2 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                        [
+                                          llvm_anyfloat_ty, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>
+                                        ],
+                                        [IntrNoMem, Commutative]>;
+def int_dx_dot3 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                        [
+                                          llvm_anyfloat_ty, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>
+                                        ],
+                                        [IntrNoMem, Commutative]>;
+def int_dx_dot4 : DefaultAttrsIntrinsic<[LLVMMatchType<0>],
+                                        [
+                                          llvm_anyfloat_ty, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>,
+                                          LLVMMatchType<0>, LLVMMatchType<0>
+                                        ],
+                                        [IntrNoMem, Commutative]>;
 def int_dx_fdot :
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, 
LLVMVectorElementType<0>>],
@@ -100,9 +109,9 @@ def int_dx_udot :
     DefaultAttrsIntrinsic<[LLVMVectorElementType<0>],
     [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
     [IntrNoMem, Commutative] >;
-def int_dx_dot2add : 
-    DefaultAttrsIntrinsic<[llvm_float_ty], 
-    [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty], 
+def int_dx_dot2add :
+    DefaultAttrsIntrinsic<[llvm_float_ty],
+    [llvm_float_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty, llvm_half_ty],
     [IntrNoMem, Commutative]>;
 def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
 def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], 
[llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>;
diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td
index b1e7406ead675..645105ade72b6 100644
--- a/llvm/lib/Target/DirectX/DXIL.td
+++ b/llvm/lib/Target/DirectX/DXIL.td
@@ -1078,8 +1078,7 @@ def RawBufferStore : DXILOp<140, rawBufferStore> {
 }
 
 def Dot2AddHalf : DXILOp<162, dot2AddHalf> {
-  let Doc = "dot product of 2 vectors of half having size = 2, returns "
-            "float";
+  let Doc = "2D half dot product with accumulate to float";
   let intrinsics = [IntrinSelect<int_dx_dot2add>];
   let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy];
   let result = FloatTy;
diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp 
b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
index e44d3b70eb657..53ffcc3ebbdbe 100644
--- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
+++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
@@ -169,7 +169,8 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, Value 
*A, Value *B) {
   assert(ATy->getScalarType()->isFloatingPointTy());
 
   Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4;
-  switch (AVec->getNumElements()) {
+  int NumElts = AVec->getNumElements();
+  switch (NumElts) {
   case 2:
     DotIntrinsic = Intrinsic::dx_dot2;
     break;
@@ -185,8 +186,14 @@ static Value *expandFloatDotIntrinsic(CallInst *Orig, 
Value *A, Value *B) {
         /* gen_crash_diag=*/false);
     return nullptr;
   }
-  return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic,
-                                 ArrayRef<Value *>{A, B}, nullptr, "dot");
+
+  SmallVector<Value *> Args;
+  for (int I = 0; I < NumElts; ++I)
+    Args.push_back(Builder.CreateExtractElement(A, Builder.getInt32(I)));
+  for (int I = 0; I < NumElts; ++I)
+    Args.push_back(Builder.CreateExtractElement(B, Builder.getInt32(I)));
+  return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, Args,
+                                 nullptr, "dot");
 }
 
 // Create the appropriate DXIL float dot intrinsic for the operands of Orig
diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp 
b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
index 41a9426998826..4574e5f7bbd96 100644
--- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp
+++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp
@@ -33,52 +33,6 @@
 using namespace llvm;
 using namespace llvm::dxil;
 
-static bool isVectorArgExpansion(Function &F) {
-  switch (F.getIntrinsicID()) {
-  case Intrinsic::dx_dot2:
-  case Intrinsic::dx_dot3:
-  case Intrinsic::dx_dot4:
-    return true;
-  }
-  return false;
-}
-
-static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) 
{
-  SmallVector<Value *> ExtractedElements;
-  auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
-  for (unsigned I = 0; I < VecArg->getNumElements(); ++I) {
-    Value *Index = ConstantInt::get(Type::getInt32Ty(Arg->getContext()), I);
-    Value *ExtractedElement = Builder.CreateExtractElement(Arg, Index);
-    ExtractedElements.push_back(ExtractedElement);
-  }
-  return ExtractedElements;
-}
-
-static SmallVector<Value *>
-argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder, unsigned NumOperands) {
-  assert(NumOperands > 0);
-  Value *Arg0 = Orig->getOperand(0);
-  [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType());
-  assert(VecArg0);
-  SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder);
-  for (unsigned I = 1; I < NumOperands; ++I) {
-    Value *Arg = Orig->getOperand(I);
-    [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType());
-    assert(VecArg);
-    assert(VecArg0->getElementType() == VecArg->getElementType());
-    assert(VecArg0->getNumElements() == VecArg->getNumElements());
-    auto NextOperandList = populateOperands(Arg, Builder);
-    NewOperands.append(NextOperandList.begin(), NextOperandList.end());
-  }
-  return NewOperands;
-}
-
-static SmallVector<Value *> argVectorFlatten(CallInst *Orig,
-                                             IRBuilder<> &Builder) {
-  // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening.
-  return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1);
-}
-
 namespace {
 class OpLowerer {
   Module &M;
@@ -150,9 +104,6 @@ class OpLowerer {
   [[nodiscard]] bool
   replaceFunctionWithOp(Function &F, dxil::OpCode DXILOp,
                         ArrayRef<IntrinArgSelect> ArgSelects) {
-    bool IsVectorArgExpansion = isVectorArgExpansion(F);
-    assert(!(IsVectorArgExpansion && ArgSelects.size()) &&
-           "Cann't do vector arg expansion when using arg selects.");
     return replaceFunction(F, [&](CallInst *CI) -> Error {
       OpBuilder.getIRB().SetInsertPoint(CI);
       SmallVector<Value *> Args;
@@ -170,15 +121,6 @@ class OpLowerer {
             break;
           }
         }
-      } else if (IsVectorArgExpansion) {
-        Args = argVectorFlatten(CI, OpBuilder.getIRB());
-      } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) {
-        // arg[NumOperands-1] is a pointer and is not needed by our flattening.
-        // arg[NumOperands-2] also does not need to be flattened because it is 
a
-        // scalar.
-        unsigned NumOperands = CI->getNumOperands() - 2;
-        Args.push_back(CI->getArgOperand(NumOperands));
-        Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands));
       } else {
         Args.append(CI->arg_begin(), CI->arg_end());
       }
diff --git a/llvm/test/CodeGen/DirectX/dot2_error.ll 
b/llvm/test/CodeGen/DirectX/dot2_error.ll
index 97b025d36f018..f2167aa516057 100644
--- a/llvm/test/CodeGen/DirectX/dot2_error.ll
+++ b/llvm/test/CodeGen/DirectX/dot2_error.ll
@@ -4,8 +4,9 @@
 ; CHECK: in function dot_double2
 ; CHECK-SAME: Cannot create Dot2 operation: Invalid overload type
 
...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/134570
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to