================
@@ -2366,6 +2373,164 @@ bool 
AMDGPUCodeGenPrepareImpl::visitMbcntHi(IntrinsicInst &I) const {
   return tryReplaceWithWorkitemId(I, Wave);
 }
 
+/// Helper to match the dot4 pattern: mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>) Returns true if pattern matches, sets A, B to the <4 x i8> sources and
+/// IsSigned based on whether sext was used.
+static bool matchDot4Pattern(Value *MulOp, Value *&A, Value *&B,
+                             bool &IsSigned) {
+  auto *Mul = dyn_cast<BinaryOperator>(MulOp);
+  if (!Mul || Mul->getOpcode() != Instruction::Mul)
+    return false;
+
+  // Check that result type is <4 x i32>
+  auto *MulTy = dyn_cast<FixedVectorType>(Mul->getType());
+  if (!MulTy || MulTy->getNumElements() != 4 ||
+      !MulTy->getElementType()->isIntegerTy(32))
+    return false;
+
+  Value *Src0 = Mul->getOperand(0);
+  Value *Src1 = Mul->getOperand(1);
+
+  // Match zext <4 x i8> or sext <4 x i8>
+  auto matchExtend = [](Value *V, Value *&Src, bool &Signed) -> bool {
+    if (auto *ZExt = dyn_cast<ZExtInst>(V)) {
+      auto *SrcTy = dyn_cast<FixedVectorType>(ZExt->getSrcTy());
+      if (SrcTy && SrcTy->getNumElements() == 4 &&
+          SrcTy->getElementType()->isIntegerTy(8)) {
+        Src = ZExt->getOperand(0);
+        Signed = false;
+        return true;
+      }
+    } else if (auto *SExt = dyn_cast<SExtInst>(V)) {
+      auto *SrcTy = dyn_cast<FixedVectorType>(SExt->getSrcTy());
+      if (SrcTy && SrcTy->getNumElements() == 4 &&
+          SrcTy->getElementType()->isIntegerTy(8)) {
+        Src = SExt->getOperand(0);
+        Signed = true;
+        return true;
+      }
+    }
+    return false;
+  };
+
+  bool Signed0 = false, Signed1 = false;
+  if (!matchExtend(Src0, A, Signed0) || !matchExtend(Src1, B, Signed1))
+    return false;
+
+  // Both operands must have the same signedness
+  if (Signed0 != Signed1)
+    return false;
+
+  IsSigned = Signed0;
+  return true;
+}
+
+/// Try to convert vector.reduce.add(mul(zext/sext <4 x i8>, zext/sext <4 x
+/// i8>)) to a dot4 intrinsic call (non-saturating case only).
+bool AMDGPUCodeGenPrepareImpl::visitVectorReduceAdd(IntrinsicInst &I) {
+  // Check if we have dot4 instructions available
+  if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+    return false;
+
+  Value *A = nullptr, *B = nullptr;
+  bool IsSigned = false;
+
+  if (!matchDot4Pattern(I.getArgOperand(0), A, B, IsSigned))
+    return false;
+
+  LLVMContext &Ctx = I.getContext();
+  Type *I32Ty = Type::getInt32Ty(Ctx);
+  IRBuilder<> Builder(&I);
+
+  // Bitcast <4 x i8> to i32
+  Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+  Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+  // Non-saturating case: accumulator is 0, clamp is false
+  Value *Acc = ConstantInt::get(I32Ty, 0);
+  Value *Clamp = ConstantInt::getFalse(Ctx);
+
+  Intrinsic::ID DotIID =
+      IsSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4;
+
+  Value *Dot = Builder.CreateIntrinsic(DotIID, {}, {ASrc, BSrc, Acc, Clamp},
+                                       nullptr, I.getName());
+
+  I.replaceAllUsesWith(Dot);
+  DeadVals.push_back(&I);
+
+  return true;
+}
+
+/// Try to convert uadd.sat/sadd.sat(vector.reduce.add(mul(...)), c) to a
+/// saturating dot4 intrinsic. This combine starts at the root (saturating add)
+/// and looks at its operands.
+bool AMDGPUCodeGenPrepareImpl::visitSaturatingAdd(IntrinsicInst &I) {
+  // Check if we have dot4 instructions available
+  if (!ST.hasDot7Insts() || (!ST.hasDot1Insts() && !ST.hasDot8Insts()))
+    return false;
+
+  Intrinsic::ID IID = I.getIntrinsicID();
+  bool IsSigned = (IID == Intrinsic::sadd_sat);
+
+  // Look for vector.reduce.add as one of the operands
+  Value *ReduceOp = nullptr;
+  Value *Accum = nullptr;
+
+  for (int Swap = 0; Swap < 2; ++Swap) {
+    Value *Op0 = I.getArgOperand(Swap);
+    Value *Op1 = I.getArgOperand(1 - Swap);
+
+    if (auto *ReduceInst = dyn_cast<IntrinsicInst>(Op0)) {
+      if (ReduceInst->getIntrinsicID() == Intrinsic::vector_reduce_add) {
+        ReduceOp = Op0;
+        Accum = Op1;
+        break;
+      }
+    }
+  }
+
+  if (!ReduceOp)
+    return false;
+
+  auto *ReduceInst = cast<IntrinsicInst>(ReduceOp);
+
+  Value *A = nullptr, *B = nullptr;
+  bool PatternSigned = false;
+
+  if (!matchDot4Pattern(ReduceInst->getArgOperand(0), A, B, PatternSigned))
+    return false;
+
+  // Signedness of the pattern must match the saturating add type
+  if (PatternSigned != IsSigned)
+    return false;
+
+  LLVMContext &Ctx = I.getContext();
+  Type *I32Ty = Type::getInt32Ty(Ctx);
+  IRBuilder<> Builder(&I);
+
+  // Bitcast <4 x i8> to i32
+  Value *ASrc = Builder.CreateBitCast(A, I32Ty);
+  Value *BSrc = Builder.CreateBitCast(B, I32Ty);
+
+  // Saturating case: use the accumulator and set clamp to true
+  Value *Clamp = ConstantInt::getTrue(Ctx);
+
+  Intrinsic::ID DotIID =
+      IsSigned ? Intrinsic::amdgcn_sdot4 : Intrinsic::amdgcn_udot4;
+
+  Value *Dot = Builder.CreateIntrinsic(DotIID, {}, {ASrc, BSrc, Accum, Clamp},
+                                       nullptr, I.getName());
----------------
arsenm wrote:

takeName 

https://github.com/llvm/llvm-project/pull/187945
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to