llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Simon Pilgrim (RKSimon)

<details>
<summary>Changes</summary>

Create a EvaluateBinOpExpr helper that each related group of elementwise binop 
builtins can use with their own custom callback, to help reduce the amount of 
duplication and avoid too much code bloat as more builtins are added.

This also handles builtins which have a elementwise LHS operand and a scalar 
RHS operand.

Similar to #<!-- -->155891 which did the same thing for the new ByteCode eval.

---
Full diff: https://github.com/llvm/llvm-project/pull/157137.diff


1 Files Affected:

- (modified) clang/lib/AST/ExprConstant.cpp (+86-195) 


``````````diff
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 100e944f9b48c..0060e80e8e309 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11623,6 +11623,38 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
   if (!IsConstantEvaluatedBuiltinCall(E))
     return ExprEvaluatorBaseTy::VisitCallExpr(E);
 
+  auto EvaluateBinOpExpr =
+      [&](llvm::function_ref<APInt(const APSInt &, const APSInt &)> Fn) {
+        APValue SourceLHS, SourceRHS;
+        if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
+            !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
+          return false;
+
+        auto *DestTy = E->getType()->castAs<VectorType>();
+        QualType DestEltTy = DestTy->getElementType();
+        bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
+        unsigned SourceLen = SourceLHS.getVectorLength();
+        SmallVector<APValue, 4> ResultElements;
+        ResultElements.reserve(SourceLen);
+
+        if (SourceRHS.isInt()) {
+          const APSInt &RHS = SourceRHS.getInt();
+          for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+            const APSInt &LHS = SourceLHS.getVectorElt(EltNum).getInt();
+            ResultElements.push_back(
+                APValue(APSInt(Fn(LHS, RHS), DestUnsigned)));
+          }
+        } else {
+          for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
+            const APSInt &LHS = SourceLHS.getVectorElt(EltNum).getInt();
+            const APSInt &RHS = SourceRHS.getVectorElt(EltNum).getInt();
+            ResultElements.push_back(
+                APValue(APSInt(Fn(LHS, RHS), DestUnsigned)));
+          }
+        }
+        return Success(APValue(ResultElements.data(), SourceLen), E);
+      };
+
   switch (E->getBuiltinCallee()) {
   default:
     return false;
@@ -11679,27 +11711,30 @@ bool VectorExprEvaluator::VisitCallExpr(const 
CallExpr *E) {
   }
 
   case Builtin::BI__builtin_elementwise_add_sat:
+    return EvaluateBinOpExpr([](const APSInt &LHS, const APSInt &RHS) {
+      return LHS.isSigned() ? LHS.sadd_sat(RHS) : LHS.uadd_sat(RHS);
+    });
+
   case Builtin::BI__builtin_elementwise_sub_sat:
+    return EvaluateBinOpExpr([](const APSInt &LHS, const APSInt &RHS) {
+      return LHS.isSigned() ? LHS.ssub_sat(RHS) : LHS.usub_sat(RHS);
+    });
+
   case clang::X86::BI__builtin_ia32_pmulhuw128:
   case clang::X86::BI__builtin_ia32_pmulhuw256:
   case clang::X86::BI__builtin_ia32_pmulhuw512:
+    return EvaluateBinOpExpr(llvm::APIntOps::mulhu);
+
   case clang::X86::BI__builtin_ia32_pmulhw128:
   case clang::X86::BI__builtin_ia32_pmulhw256:
   case clang::X86::BI__builtin_ia32_pmulhw512:
+    return EvaluateBinOpExpr(llvm::APIntOps::mulhs);
+
   case clang::X86::BI__builtin_ia32_psllv2di:
   case clang::X86::BI__builtin_ia32_psllv4di:
   case clang::X86::BI__builtin_ia32_psllv4si:
   case clang::X86::BI__builtin_ia32_psllv8si:
   case clang::X86::BI__builtin_ia32_psllv16si:
-  case clang::X86::BI__builtin_ia32_psrav4si:
-  case clang::X86::BI__builtin_ia32_psrav8si:
-  case clang::X86::BI__builtin_ia32_psrav16si:
-  case clang::X86::BI__builtin_ia32_psrlv2di:
-  case clang::X86::BI__builtin_ia32_psrlv4di:
-  case clang::X86::BI__builtin_ia32_psrlv4si:
-  case clang::X86::BI__builtin_ia32_psrlv8si:
-  case clang::X86::BI__builtin_ia32_psrlv16si:
-
   case clang::X86::BI__builtin_ia32_psllwi128:
   case clang::X86::BI__builtin_ia32_pslldi128:
   case clang::X86::BI__builtin_ia32_psllqi128:
@@ -11709,17 +11744,16 @@ bool VectorExprEvaluator::VisitCallExpr(const 
CallExpr *E) {
   case clang::X86::BI__builtin_ia32_psllwi512:
   case clang::X86::BI__builtin_ia32_pslldi512:
   case clang::X86::BI__builtin_ia32_psllqi512:
+    return EvaluateBinOpExpr([](const APSInt &LHS, const APSInt &RHS) {
+      if (RHS.uge(LHS.getBitWidth())) {
+        return APInt::getZero(LHS.getBitWidth());
+      }
+      return LHS.shl(RHS.getZExtValue());
+    });
 
-  case clang::X86::BI__builtin_ia32_psrlwi128:
-  case clang::X86::BI__builtin_ia32_psrldi128:
-  case clang::X86::BI__builtin_ia32_psrlqi128:
-  case clang::X86::BI__builtin_ia32_psrlwi256:
-  case clang::X86::BI__builtin_ia32_psrldi256:
-  case clang::X86::BI__builtin_ia32_psrlqi256:
-  case clang::X86::BI__builtin_ia32_psrlwi512:
-  case clang::X86::BI__builtin_ia32_psrldi512:
-  case clang::X86::BI__builtin_ia32_psrlqi512:
-
+  case clang::X86::BI__builtin_ia32_psrav4si:
+  case clang::X86::BI__builtin_ia32_psrav8si:
+  case clang::X86::BI__builtin_ia32_psrav16si:
   case clang::X86::BI__builtin_ia32_psrawi128:
   case clang::X86::BI__builtin_ia32_psradi128:
   case clang::X86::BI__builtin_ia32_psraqi128:
@@ -11728,145 +11762,35 @@ bool VectorExprEvaluator::VisitCallExpr(const 
CallExpr *E) {
   case clang::X86::BI__builtin_ia32_psraqi256:
   case clang::X86::BI__builtin_ia32_psrawi512:
   case clang::X86::BI__builtin_ia32_psradi512:
-  case clang::X86::BI__builtin_ia32_psraqi512: {
-
-    APValue SourceLHS, SourceRHS;
-    if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
-        !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
-      return false;
-
-    QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
-    bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
-    unsigned SourceLen = SourceLHS.getVectorLength();
-    SmallVector<APValue, 4> ResultElements;
-    ResultElements.reserve(SourceLen);
-
-    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
-      APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt();
-
-      if (SourceRHS.isInt()) {
-        const unsigned LaneBitWidth = LHS.getBitWidth();
-        const unsigned ShiftAmount = SourceRHS.getInt().getZExtValue();
-
-        switch (E->getBuiltinCallee()) {
-        case clang::X86::BI__builtin_ia32_psllwi128:
-        case clang::X86::BI__builtin_ia32_psllwi256:
-        case clang::X86::BI__builtin_ia32_psllwi512:
-        case clang::X86::BI__builtin_ia32_pslldi128:
-        case clang::X86::BI__builtin_ia32_pslldi256:
-        case clang::X86::BI__builtin_ia32_pslldi512:
-        case clang::X86::BI__builtin_ia32_psllqi128:
-        case clang::X86::BI__builtin_ia32_psllqi256:
-        case clang::X86::BI__builtin_ia32_psllqi512:
-          if (ShiftAmount >= LaneBitWidth) {
-            ResultElements.push_back(
-                APValue(APSInt(APInt::getZero(LaneBitWidth), DestUnsigned)));
-          } else {
-            ResultElements.push_back(
-                APValue(APSInt(LHS.shl(ShiftAmount), DestUnsigned)));
-          }
-          break;
-        case clang::X86::BI__builtin_ia32_psrlwi128:
-        case clang::X86::BI__builtin_ia32_psrlwi256:
-        case clang::X86::BI__builtin_ia32_psrlwi512:
-        case clang::X86::BI__builtin_ia32_psrldi128:
-        case clang::X86::BI__builtin_ia32_psrldi256:
-        case clang::X86::BI__builtin_ia32_psrldi512:
-        case clang::X86::BI__builtin_ia32_psrlqi128:
-        case clang::X86::BI__builtin_ia32_psrlqi256:
-        case clang::X86::BI__builtin_ia32_psrlqi512:
-          if (ShiftAmount >= LaneBitWidth) {
-            ResultElements.push_back(
-                APValue(APSInt(APInt::getZero(LaneBitWidth), DestUnsigned)));
-          } else {
-            ResultElements.push_back(
-                APValue(APSInt(LHS.lshr(ShiftAmount), DestUnsigned)));
-          }
-          break;
-        case clang::X86::BI__builtin_ia32_psrawi128:
-        case clang::X86::BI__builtin_ia32_psrawi256:
-        case clang::X86::BI__builtin_ia32_psrawi512:
-        case clang::X86::BI__builtin_ia32_psradi128:
-        case clang::X86::BI__builtin_ia32_psradi256:
-        case clang::X86::BI__builtin_ia32_psradi512:
-        case clang::X86::BI__builtin_ia32_psraqi128:
-        case clang::X86::BI__builtin_ia32_psraqi256:
-        case clang::X86::BI__builtin_ia32_psraqi512:
-          ResultElements.push_back(
-              APValue(APSInt(LHS.ashr(std::min(ShiftAmount, LaneBitWidth - 1)),
-                             DestUnsigned)));
-          break;
-        default:
-          llvm_unreachable("Unexpected builtin callee");
-        }
-        continue;
+  case clang::X86::BI__builtin_ia32_psraqi512:
+    return EvaluateBinOpExpr([](const APSInt &LHS, const APSInt &RHS) {
+      if (RHS.uge(LHS.getBitWidth())) {
+        return LHS.ashr(LHS.getBitWidth() - 1);
       }
-      APSInt RHS = SourceRHS.getVectorElt(EltNum).getInt();
-      switch (E->getBuiltinCallee()) {
-      case Builtin::BI__builtin_elementwise_add_sat:
-        ResultElements.push_back(APValue(
-            APSInt(LHS.isSigned() ? LHS.sadd_sat(RHS) : LHS.uadd_sat(RHS),
-                   DestUnsigned)));
-        break;
-      case Builtin::BI__builtin_elementwise_sub_sat:
-        ResultElements.push_back(APValue(
-            APSInt(LHS.isSigned() ? LHS.ssub_sat(RHS) : LHS.usub_sat(RHS),
-                   DestUnsigned)));
-        break;
-      case clang::X86::BI__builtin_ia32_pmulhuw128:
-      case clang::X86::BI__builtin_ia32_pmulhuw256:
-      case clang::X86::BI__builtin_ia32_pmulhuw512:
-        ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhu(LHS, 
RHS),
-                                                /*isUnsigned=*/true)));
-        break;
-      case clang::X86::BI__builtin_ia32_pmulhw128:
-      case clang::X86::BI__builtin_ia32_pmulhw256:
-      case clang::X86::BI__builtin_ia32_pmulhw512:
-        ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhs(LHS, 
RHS),
-                                                /*isUnsigned=*/false)));
-        break;
-      case clang::X86::BI__builtin_ia32_psllv2di:
-      case clang::X86::BI__builtin_ia32_psllv4di:
-      case clang::X86::BI__builtin_ia32_psllv4si:
-      case clang::X86::BI__builtin_ia32_psllv8si:
-      case clang::X86::BI__builtin_ia32_psllv16si:
-        if (RHS.uge(RHS.getBitWidth())) {
-          ResultElements.push_back(
-              APValue(APSInt(APInt::getZero(RHS.getBitWidth()), 
DestUnsigned)));
-          break;
-        }
-        ResultElements.push_back(
-            APValue(APSInt(LHS.shl(RHS.getZExtValue()), DestUnsigned)));
-        break;
-      case clang::X86::BI__builtin_ia32_psrav4si:
-      case clang::X86::BI__builtin_ia32_psrav8si:
-      case clang::X86::BI__builtin_ia32_psrav16si:
-        if (RHS.uge(RHS.getBitWidth())) {
-          ResultElements.push_back(
-              APValue(APSInt(LHS.ashr(RHS.getBitWidth() - 1), DestUnsigned)));
-          break;
-        }
-        ResultElements.push_back(
-            APValue(APSInt(LHS.ashr(RHS.getZExtValue()), DestUnsigned)));
-        break;
-      case clang::X86::BI__builtin_ia32_psrlv2di:
-      case clang::X86::BI__builtin_ia32_psrlv4di:
-      case clang::X86::BI__builtin_ia32_psrlv4si:
-      case clang::X86::BI__builtin_ia32_psrlv8si:
-      case clang::X86::BI__builtin_ia32_psrlv16si:
-        if (RHS.uge(RHS.getBitWidth())) {
-          ResultElements.push_back(
-              APValue(APSInt(APInt::getZero(RHS.getBitWidth()), 
DestUnsigned)));
-          break;
-        }
-        ResultElements.push_back(
-            APValue(APSInt(LHS.lshr(RHS.getZExtValue()), DestUnsigned)));
-        break;
+      return LHS.ashr(RHS.getZExtValue());
+    });
+
+  case clang::X86::BI__builtin_ia32_psrlv2di:
+  case clang::X86::BI__builtin_ia32_psrlv4di:
+  case clang::X86::BI__builtin_ia32_psrlv4si:
+  case clang::X86::BI__builtin_ia32_psrlv8si:
+  case clang::X86::BI__builtin_ia32_psrlv16si:
+  case clang::X86::BI__builtin_ia32_psrlwi128:
+  case clang::X86::BI__builtin_ia32_psrldi128:
+  case clang::X86::BI__builtin_ia32_psrlqi128:
+  case clang::X86::BI__builtin_ia32_psrlwi256:
+  case clang::X86::BI__builtin_ia32_psrldi256:
+  case clang::X86::BI__builtin_ia32_psrlqi256:
+  case clang::X86::BI__builtin_ia32_psrlwi512:
+  case clang::X86::BI__builtin_ia32_psrldi512:
+  case clang::X86::BI__builtin_ia32_psrlqi512:
+    return EvaluateBinOpExpr([](const APSInt &LHS, const APSInt &RHS) {
+      if (RHS.uge(LHS.getBitWidth())) {
+        return APInt::getZero(LHS.getBitWidth());
       }
-    }
+      return LHS.lshr(RHS.getZExtValue());
+    });
 
-    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
-  }
   case clang::X86::BI__builtin_ia32_pmuldq128:
   case clang::X86::BI__builtin_ia32_pmuldq256:
   case clang::X86::BI__builtin_ia32_pmuldq512:
@@ -11904,6 +11828,7 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr 
*E) {
 
     return Success(APValue(ResultElements.data(), ResultElements.size()), E);
   }
+
   case clang::X86::BI__builtin_ia32_vprotbi:
   case clang::X86::BI__builtin_ia32_vprotdi:
   case clang::X86::BI__builtin_ia32_vprotqi:
@@ -11913,53 +11838,19 @@ bool VectorExprEvaluator::VisitCallExpr(const 
CallExpr *E) {
   case clang::X86::BI__builtin_ia32_prold512:
   case clang::X86::BI__builtin_ia32_prolq128:
   case clang::X86::BI__builtin_ia32_prolq256:
-  case clang::X86::BI__builtin_ia32_prolq512: {
-    APValue SourceLHS, SourceRHS;
-    if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
-        !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
-      return false;
-
-    QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
-    bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
-    unsigned SourceLen = SourceLHS.getVectorLength();
-    SmallVector<APValue, 4> ResultElements;
-    ResultElements.reserve(SourceLen);
-
-    APSInt RHS = SourceRHS.getInt();
-
-    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
-      const APSInt &LHS = SourceLHS.getVectorElt(EltNum).getInt();
-      ResultElements.push_back(APValue(APSInt(LHS.rotl(RHS), DestUnsigned)));
-    }
+  case clang::X86::BI__builtin_ia32_prolq512:
+    return EvaluateBinOpExpr(
+        [](const APSInt &LHS, const APSInt &RHS) { return LHS.rotl(RHS); });
 
-    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
-  }
   case clang::X86::BI__builtin_ia32_prord128:
   case clang::X86::BI__builtin_ia32_prord256:
   case clang::X86::BI__builtin_ia32_prord512:
   case clang::X86::BI__builtin_ia32_prorq128:
   case clang::X86::BI__builtin_ia32_prorq256:
-  case clang::X86::BI__builtin_ia32_prorq512: {
-    APValue SourceLHS, SourceRHS;
-    if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
-        !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
-      return false;
+  case clang::X86::BI__builtin_ia32_prorq512:
+    return EvaluateBinOpExpr(
+        [](const APSInt &LHS, const APSInt &RHS) { return LHS.rotr(RHS); });
 
-    QualType DestEltTy = E->getType()->castAs<VectorType>()->getElementType();
-    bool DestUnsigned = DestEltTy->isUnsignedIntegerOrEnumerationType();
-    unsigned SourceLen = SourceLHS.getVectorLength();
-    SmallVector<APValue, 4> ResultElements;
-    ResultElements.reserve(SourceLen);
-
-    APSInt RHS = SourceRHS.getInt();
-
-    for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
-      const APSInt &LHS = SourceLHS.getVectorElt(EltNum).getInt();
-      ResultElements.push_back(APValue(APSInt(LHS.rotr(RHS), DestUnsigned)));
-    }
-
-    return Success(APValue(ResultElements.data(), ResultElements.size()), E);
-  }
   case Builtin::BI__builtin_elementwise_max:
   case Builtin::BI__builtin_elementwise_min: {
     APValue SourceLHS, SourceRHS;

``````````

</details>


https://github.com/llvm/llvm-project/pull/157137
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to