================ @@ -68,28 +69,65 @@ static Value *expandAbs(CallInst *Orig) { "dx.max"); } -static Value *expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) { +// Create DXIL dot intrinsics for floating point dot operations +static Value *expandFloatDotIntrinsic(CallInst *Orig) { + Value *A = Orig->getOperand(0); + Value *B = Orig->getOperand(1); + Type *ATy = A->getType(); + [[maybe_unused]] Type *BTy = B->getType(); + assert(ATy->isVectorTy() && BTy->isVectorTy()); + + IRBuilder<> Builder(Orig); + + auto *AVec = dyn_cast<FixedVectorType>(ATy); + + assert(ATy->getScalarType()->isFloatingPointTy()); + + Intrinsic::ID DotIntrinsic = Intrinsic::dx_dot4; + switch (AVec->getNumElements()) { + case 2: + DotIntrinsic = Intrinsic::dx_dot2; + break; + case 3: + DotIntrinsic = Intrinsic::dx_dot3; + break; + case 4: + DotIntrinsic = Intrinsic::dx_dot4; + break; + default: + llvm_unreachable("dot product with vector outside 2-4 range"); + } + return Builder.CreateIntrinsic(ATy->getScalarType(), DotIntrinsic, + ArrayRef<Value *>{A, B}, nullptr, "dot"); +} + +// Expand integer dot product to multiply and add ops +static Value *expandIntegerDotIntrinsic(CallInst *Orig, + Intrinsic::ID DotIntrinsic) { assert(DotIntrinsic == Intrinsic::dx_sdot || DotIntrinsic == Intrinsic::dx_udot); - Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot - ? Intrinsic::dx_imad - : Intrinsic::dx_umad; Value *A = Orig->getOperand(0); Value *B = Orig->getOperand(1); - [[maybe_unused]] Type *ATy = A->getType(); + Type *ATy = A->getType(); [[maybe_unused]] Type *BTy = B->getType(); assert(ATy->isVectorTy() && BTy->isVectorTy()); - IRBuilder<> Builder(Orig->getParent()); - Builder.SetInsertPoint(Orig); + IRBuilder<> Builder(Orig); + + auto *AVec = dyn_cast<FixedVectorType>(ATy); - auto *AVec = dyn_cast<FixedVectorType>(A->getType()); + assert(ATy->getScalarType()->isIntegerTy()); + + Value *Result; + Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot + ? Intrinsic::dx_imad + : Intrinsic::dx_umad; Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0); Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0); - Value *Result = Builder.CreateMul(Elt0, Elt1); - for (unsigned I = 1; I < AVec->getNumElements(); I++) { - Elt0 = Builder.CreateExtractElement(A, I); - Elt1 = Builder.CreateExtractElement(B, I); + Result = Builder.CreateMul(Elt0, Elt1); + for (unsigned i = 1; i < AVec->getNumElements(); i++) { ---------------- pow2clk wrote:
I try so hard to be agnostic about style guides, but stuff like this challenge my lack of creed. 😣 https://github.com/llvm/llvm-project/pull/104656 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits