njames93 updated this revision to Diff 303502.
njames93 added a comment.

Addressed comments.
Now using a MapVector to collect all uncovered cases in prepare, then just loop 
over that in apply.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D90555/new/

https://reviews.llvm.org/D90555

Files:
  clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp
  clang-tools-extra/clangd/unittests/TweakTests.cpp
  llvm/include/llvm/ADT/DenseMapInfo.h

Index: llvm/include/llvm/ADT/DenseMapInfo.h
===================================================================
--- llvm/include/llvm/ADT/DenseMapInfo.h
+++ llvm/include/llvm/ADT/DenseMapInfo.h
@@ -14,6 +14,7 @@
 #define LLVM_ADT_DENSEMAPINFO_H
 
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/APSInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/StringRef.h"
@@ -371,6 +372,26 @@
   }
 };
 
+/// Provide DenseMapInfo for APSInt, using the DenseMapInfo for APInt.
+template <> struct DenseMapInfo<APSInt> {
+  static inline APSInt getEmptyKey() {
+    return APSInt(DenseMapInfo<APInt>::getEmptyKey());
+  }
+
+  static inline APSInt getTombstoneKey() {
+    return APSInt(DenseMapInfo<APInt>::getTombstoneKey());
+  }
+
+  static unsigned getHashValue(const APSInt &Key) {
+    return static_cast<unsigned>(hash_value(Key));
+  }
+
+  static bool isEqual(const APSInt &LHS, const APSInt &RHS) {
+    return LHS.getBitWidth() == RHS.getBitWidth() &&
+           LHS.isUnsigned() == RHS.isUnsigned() && LHS == RHS;
+  }
+};
+
 } // end namespace llvm
 
 #endif // LLVM_ADT_DENSEMAPINFO_H
Index: clang-tools-extra/clangd/unittests/TweakTests.cpp
===================================================================
--- clang-tools-extra/clangd/unittests/TweakTests.cpp
+++ clang-tools-extra/clangd/unittests/TweakTests.cpp
@@ -2980,6 +2980,18 @@
             void function() { switch (ns::A) {case ns::A:break;} }
           )"",
       },
+      {
+          // Duplicated constant names
+          Function,
+          R""(enum Enum {A,B,b=B}; ^switch (A) {})"",
+          R""(enum Enum {A,B,b=B}; switch (A) {case A:case B:break;})"",
+      },
+      {
+          // Duplicated constant names all in switch
+          Function,
+          R""(enum Enum {A,B,b=B}; ^switch (A) {case A:case B:break;})"",
+          "unavailable",
+      },
   };
 
   for (const auto &Case : Cases) {
Index: clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp
===================================================================
--- clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp
+++ clang-tools-extra/clangd/refactor/tweaks/PopulateSwitch.cpp
@@ -40,6 +40,8 @@
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
 #include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/MapVector.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
 #include <cassert>
 #include <string>
@@ -62,6 +64,10 @@
   const CompoundStmt *Body = nullptr;
   const EnumType *EnumT = nullptr;
   const EnumDecl *EnumD = nullptr;
+  // Maps the Enum values to the EnumConstantDecl and a bool signifying if its
+  // covered in the switch.
+  llvm::MapVector<llvm::APSInt, std::pair<const EnumConstantDecl *, bool>>
+      EnumConstants;
 };
 
 REGISTER_TWEAK(PopulateSwitch)
@@ -112,21 +118,33 @@
   if (!EnumD)
     return false;
 
-  // We trigger if there are fewer cases than enum values (and no case covers
-  // multiple values). This guarantees we'll have at least one case to insert.
-  // We don't yet determine what the cases are, as that means evaluating
-  // expressions.
-  auto I = EnumD->enumerator_begin();
-  auto E = EnumD->enumerator_end();
+  // We trigger if there are any values in the enum that aren't covered by the
+  // switch.
 
-  for (const SwitchCase *CaseList = Switch->getSwitchCaseList();
-       CaseList && I != E; CaseList = CaseList->getNextSwitchCase(), I++) {
+  ASTContext &Ctx = Sel.AST->getASTContext();
+
+  unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0));
+  bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType();
+
+  EnumConstants.clear();
+
+  for (auto *EnumConstant : EnumD->enumerators()) {
+    llvm::APSInt Val = EnumConstant->getInitVal();
+    Val = Val.extOrTrunc(EnumIntWidth);
+    Val.setIsSigned(EnumIsSigned);
+    EnumConstants.insert(
+        std::make_pair(Val, std::make_pair(EnumConstant, false)));
+  }
+
+  for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList;
+       CaseList = CaseList->getNextSwitchCase()) {
     // Default likely intends to cover cases we'd insert.
     if (isa<DefaultStmt>(CaseList))
       return false;
 
     const CaseStmt *CS = cast<CaseStmt>(CaseList);
-    // Case statement covers multiple values, so just counting doesn't work.
+
+    // GNU range cases are rare, we don't support them.
     if (CS->caseStmtIsGNURange())
       return false;
 
@@ -135,48 +153,45 @@
     const ConstantExpr *CE = dyn_cast<ConstantExpr>(CS->getLHS());
     if (!CE || CE->isValueDependent())
       return false;
+
+    // Unsure if this case could ever come up, but prevents an unreachable
+    // executing in getResultAsAPSInt.
+    if (CE->getResultStorageKind() == ConstantExpr::RSK_None)
+      return false;
+    llvm::APSInt Val = CE->getResultAsAPSInt();
+    Val = Val.extOrTrunc(EnumIntWidth);
+    Val.setIsSigned(EnumIsSigned);
+    auto Iter = EnumConstants.find(Val);
+    if (Iter == EnumConstants.end())
+      return false;
+    bool &IsCovered = Iter->second.second;
+    // A case covered twice in a switch is a compile error, so just bail out if
+    // we encounter it.
+    if (IsCovered)
+      return false;
+    IsCovered = true;
   }
 
-  // Only suggest tweak if we have more enumerators than cases.
-  return I != E;
+  return !llvm::all_of(EnumConstants,
+                       [](auto &Pair) { return Pair.second.second; });
 }
 
 Expected<Tweak::Effect> PopulateSwitch::apply(const Selection &Sel) {
   ASTContext &Ctx = Sel.AST->getASTContext();
 
-  // Get the enum's integer width and signedness, for adjusting case literals.
-  unsigned EnumIntWidth = Ctx.getIntWidth(QualType(EnumT, 0));
-  bool EnumIsSigned = EnumT->isSignedIntegerOrEnumerationType();
-
-  llvm::SmallSet<llvm::APSInt, 32> ExistingEnumerators;
-  for (const SwitchCase *CaseList = Switch->getSwitchCaseList(); CaseList;
-       CaseList = CaseList->getNextSwitchCase()) {
-    const CaseStmt *CS = cast<CaseStmt>(CaseList);
-    assert(!CS->caseStmtIsGNURange());
-    const ConstantExpr *CE = cast<ConstantExpr>(CS->getLHS());
-    assert(!CE->isValueDependent());
-    llvm::APSInt Val = CE->getResultAsAPSInt();
-    Val = Val.extOrTrunc(EnumIntWidth);
-    Val.setIsSigned(EnumIsSigned);
-    ExistingEnumerators.insert(Val);
-  }
-
   SourceLocation Loc = Body->getRBracLoc();
   ASTContext &DeclASTCtx = DeclCtx->getParentASTContext();
 
-  std::string Text;
-  for (EnumConstantDecl *Enumerator : EnumD->enumerators()) {
-    if (ExistingEnumerators.contains(Enumerator->getInitVal()))
+  llvm::SmallString<256> Text;
+  for (auto &EnumConstant : EnumConstants) {
+    // Skip any enum constants already covered
+    if (EnumConstant.second.second)
       continue;
 
-    Text += "case ";
-    Text += getQualification(DeclASTCtx, DeclCtx, Loc, EnumD);
-    if (EnumD->isScoped()) {
-      Text += EnumD->getName();
-      Text += "::";
-    }
-    Text += Enumerator->getName();
-    Text += ":";
+    Text.append({"case ", getQualification(DeclASTCtx, DeclCtx, Loc, EnumD)});
+    if (EnumD->isScoped())
+      Text.append({EnumD->getName(), "::"});
+    Text.append({EnumConstant.second.first->getName(), ":"});
   }
 
   assert(!Text.empty() && "No enumerators to insert!");
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to