https://github.com/philnik777 created 
https://github.com/llvm/llvm-project/pull/133587

Fixes #132672



>From 1c0a267544c43235d0004edb9beb127a124abd7a Mon Sep 17 00:00:00 2001
From: Nikolas Klauser <nikolasklau...@berlin.de>
Date: Sat, 29 Mar 2025 15:21:10 +0100
Subject: [PATCH] [Clang] Make enums trivially equality comparable

---
 clang/lib/Sema/SemaExprCXX.cpp     | 83 +++++++++++++++++-------------
 clang/test/SemaCXX/type-traits.cpp | 12 +++++
 2 files changed, 60 insertions(+), 35 deletions(-)

diff --git a/clang/lib/Sema/SemaExprCXX.cpp b/clang/lib/Sema/SemaExprCXX.cpp
index 46895db4a0756..d4a9900d3fa8a 100644
--- a/clang/lib/Sema/SemaExprCXX.cpp
+++ b/clang/lib/Sema/SemaExprCXX.cpp
@@ -5174,6 +5174,43 @@ static bool HasNoThrowOperator(const RecordType *RT, 
OverloadedOperatorKind Op,
   return false;
 }
 
+static bool EqualityComparisonIsDefaulted(Sema &S, const TypeDecl *Decl,
+                                          SourceLocation KeyLoc) {
+  EnterExpressionEvaluationContext UnevaluatedContext(
+      S, Sema::ExpressionEvaluationContext::Unevaluated);
+  Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/true);
+  Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
+
+  // const ClassT& obj;
+  OpaqueValueExpr Operand(
+      KeyLoc, 
Decl->getTypeForDecl()->getCanonicalTypeUnqualified().withConst(),
+      ExprValueKind::VK_LValue);
+  UnresolvedSet<16> Functions;
+  // obj == obj;
+  S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
+
+  auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
+                                        Functions, &Operand, &Operand);
+  if (Result.isInvalid() || SFINAE.hasErrorOccurred())
+    return false;
+
+  const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
+  if (!CallExpr)
+    return isa<EnumDecl>(Decl);
+  const auto *Callee = CallExpr->getDirectCallee();
+  auto ParamT = Callee->getParamDecl(0)->getType();
+  if (!Callee->isDefaulted())
+    return false;
+  if (!ParamT->isReferenceType()) {
+    if (const CXXRecordDecl * RD = dyn_cast<CXXRecordDecl>(Decl); 
!RD->isTriviallyCopyable())
+      return false;
+  }
+  if (ParamT.getNonReferenceType()->getUnqualifiedDesugaredType() !=
+      Decl->getTypeForDecl())
+    return false;
+  return true;
+}
+
 static bool HasNonDeletedDefaultedEqualityComparison(Sema &S,
                                                      const CXXRecordDecl *Decl,
                                                      SourceLocation KeyLoc) {
@@ -5182,39 +5219,8 @@ static bool 
HasNonDeletedDefaultedEqualityComparison(Sema &S,
   if (Decl->isLambda())
     return Decl->isCapturelessLambda();
 
-  {
-    EnterExpressionEvaluationContext UnevaluatedContext(
-        S, Sema::ExpressionEvaluationContext::Unevaluated);
-    Sema::SFINAETrap SFINAE(S, /*AccessCheckingSFINAE=*/true);
-    Sema::ContextRAII TUContext(S, S.Context.getTranslationUnitDecl());
-
-    // const ClassT& obj;
-    OpaqueValueExpr Operand(
-        KeyLoc,
-        Decl->getTypeForDecl()->getCanonicalTypeUnqualified().withConst(),
-        ExprValueKind::VK_LValue);
-    UnresolvedSet<16> Functions;
-    // obj == obj;
-    S.LookupBinOp(S.TUScope, {}, BinaryOperatorKind::BO_EQ, Functions);
-
-    auto Result = S.CreateOverloadedBinOp(KeyLoc, BinaryOperatorKind::BO_EQ,
-                                          Functions, &Operand, &Operand);
-    if (Result.isInvalid() || SFINAE.hasErrorOccurred())
-      return false;
-
-    const auto *CallExpr = dyn_cast<CXXOperatorCallExpr>(Result.get());
-    if (!CallExpr)
-      return false;
-    const auto *Callee = CallExpr->getDirectCallee();
-    auto ParamT = Callee->getParamDecl(0)->getType();
-    if (!Callee->isDefaulted())
-      return false;
-    if (!ParamT->isReferenceType() && !Decl->isTriviallyCopyable())
-      return false;
-    if (ParamT.getNonReferenceType()->getUnqualifiedDesugaredType() !=
-        Decl->getTypeForDecl())
-      return false;
-  }
+  if (!EqualityComparisonIsDefaulted(S, Decl, KeyLoc))
+    return false;
 
   return llvm::all_of(Decl->bases(),
                       [&](const CXXBaseSpecifier &BS) {
@@ -5229,7 +5235,10 @@ static bool 
HasNonDeletedDefaultedEqualityComparison(Sema &S,
              Type = Type->getBaseElementTypeUnsafe()
                         ->getCanonicalTypeUnqualified();
 
-           if (Type->isReferenceType() || Type->isEnumeralType())
+           if (Type->isReferenceType() ||
+               (Type->isEnumeralType() &&
+                !EqualityComparisonIsDefaulted(
+                    S, cast<EnumDecl>(Type->getAsTagDecl()), KeyLoc)))
              return false;
            if (const auto *RD = Type->getAsCXXRecordDecl())
              return HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc);
@@ -5240,9 +5249,13 @@ static bool 
HasNonDeletedDefaultedEqualityComparison(Sema &S,
 static bool isTriviallyEqualityComparableType(Sema &S, QualType Type, 
SourceLocation KeyLoc) {
   QualType CanonicalType = Type.getCanonicalType();
   if (CanonicalType->isIncompleteType() || CanonicalType->isDependentType() ||
-      CanonicalType->isEnumeralType() || CanonicalType->isArrayType())
+      CanonicalType->isArrayType())
     return false;
 
+  if (CanonicalType->isEnumeralType())
+    return EqualityComparisonIsDefaulted(
+        S, cast<EnumDecl>(CanonicalType->getAsTagDecl()), KeyLoc);
+
   if (const auto *RD = CanonicalType->getAsCXXRecordDecl()) {
     if (!HasNonDeletedDefaultedEqualityComparison(S, RD, KeyLoc))
       return false;
diff --git a/clang/test/SemaCXX/type-traits.cpp 
b/clang/test/SemaCXX/type-traits.cpp
index b130024503101..657d5bcf07343 100644
--- a/clang/test/SemaCXX/type-traits.cpp
+++ b/clang/test/SemaCXX/type-traits.cpp
@@ -3873,6 +3873,11 @@ 
static_assert(!__is_trivially_equality_comparable(NonTriviallyEqualityComparable
 
 #if __cplusplus >= 202002L
 
+enum TriviallyEqualityComparableEnum {
+  x, y
+};
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableEnum));
+
 struct TriviallyEqualityComparable {
   int i;
   int j;
@@ -3891,6 +3896,13 @@ struct TriviallyEqualityComparableContainsArray {
 };
 
static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsArray));
 
+struct TriviallyEqualityComparableContainsEnum {
+  TriviallyEqualityComparableEnum e;
+
+  bool operator==(const TriviallyEqualityComparableContainsEnum&) const = 
default;
+};
+static_assert(__is_trivially_equality_comparable(TriviallyEqualityComparableContainsEnum));
+
 struct TriviallyEqualityComparableContainsMultiDimensionArray {
   int a[4][4];
 

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to