================
@@ -0,0 +1,296 @@
+//===--- OverridePureVirtuals.cpp --------------------------------*- 
C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "refactor/Tweak.h"
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/Type.h"
+#include "clang/AST/TypeLoc.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <string>
+
+namespace clang {
+namespace clangd {
+namespace {
+
+class OverridePureVirtuals : public Tweak {
+public:
+  const char *id() const final; // defined by REGISTER_TWEAK.
+  bool prepare(const Selection &Sel) override;
+  Expected<Effect> apply(const Selection &Sel) override;
+  std::string title() const override { return "Override pure virtual methods"; 
}
+  llvm::StringLiteral kind() const override {
+    return CodeAction::QUICKFIX_KIND;
+  }
+
+private:
+  // Stores the CXXRecordDecl of the class being modified.
+  const CXXRecordDecl *CurrentDeclDef = nullptr;
+  // Stores pure virtual methods that need overriding, grouped by their 
original
+  // access specifier.
+  llvm::MapVector<AccessSpecifier, llvm::SmallVector<const CXXMethodDecl *>>
+      MissingMethodsByAccess;
+  // Stores the source locations of existing access specifiers in CurrentDecl.
+  llvm::MapVector<AccessSpecifier, SourceLocation> AccessSpecifierLocations;
+
+  // Helper function to gather information before applying the tweak.
+  void collectMissingPureVirtuals(const Selection &Sel);
+};
+
+REGISTER_TWEAK(OverridePureVirtuals)
+
+// Function to get all unique pure virtual methods from the entire
+// base class hierarchy of CurrentDeclDef.
+llvm::SmallVector<const clang::CXXMethodDecl *>
+getAllUniquePureVirtualsFromBaseHierarchy(
+    const clang::CXXRecordDecl *CurrentDeclDef) {
+  llvm::SmallVector<const clang::CXXMethodDecl *> AllPureVirtualsInHierarchy;
+  llvm::DenseSet<const clang::CXXMethodDecl *> CanonicalPureVirtualsSeen;
+
+  if (!CurrentDeclDef || !CurrentDeclDef->getDefinition())
+    return AllPureVirtualsInHierarchy;
+
+  const clang::CXXRecordDecl *Def = CurrentDeclDef->getDefinition();
+
+  Def->forallBases([&](const clang::CXXRecordDecl *BaseDefinition) {
+    for (const clang::CXXMethodDecl *Method : BaseDefinition->methods()) {
+      if (Method->isPureVirtual() &&
+          CanonicalPureVirtualsSeen.insert(Method->getCanonicalDecl()).second)
+        AllPureVirtualsInHierarchy.emplace_back(Method);
+    }
+    // Continue iterating through all bases.
+    return true;
+  });
+
+  return AllPureVirtualsInHierarchy;
+}
+
+// Gets canonical declarations of methods already overridden or implemented in
+// class D.
+llvm::SetVector<const CXXMethodDecl *>
+getImplementedOrOverriddenCanonicals(const CXXRecordDecl *D) {
+  llvm::SetVector<const CXXMethodDecl *> ImplementedSet;
+  for (const CXXMethodDecl *M : D->methods()) {
+    // If M provides an implementation for any virtual method it overrides.
+    // A method is an "implementation" if it's virtual and not pure.
+    // Or if it directly overrides a base method.
+    for (const CXXMethodDecl *OverriddenM : M->overridden_methods())
+      ImplementedSet.insert(OverriddenM->getCanonicalDecl());
+  }
+  return ImplementedSet;
+}
+
+// Get the location of every colon of the `AccessSpecifier`.
+llvm::MapVector<AccessSpecifier, SourceLocation>
+getSpecifierLocations(const CXXRecordDecl *D) {
+  llvm::MapVector<AccessSpecifier, SourceLocation> Locs;
+  for (auto *DeclNode : D->decls()) {
+    if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(DeclNode))
+      Locs[ASD->getAccess()] = ASD->getColonLoc();
+  }
+  return Locs;
+}
+
+bool hasAbstractBaseAncestor(const clang::CXXRecordDecl *CurrentDecl) {
+  if (!CurrentDecl || !CurrentDecl->getDefinition())
+    return false;
+
+  return llvm::any_of(
+      CurrentDecl->getDefinition()->bases(), [](CXXBaseSpecifier BaseSpec) {
+        const auto *D = BaseSpec.getType()->getAsCXXRecordDecl();
+        const auto *Def = D ? D->getDefinition() : nullptr;
+        return Def && Def->isAbstract();
+      });
+}
+
+// Check if the current class has any pure virtual method to be implemented.
+bool OverridePureVirtuals::prepare(const Selection &Sel) {
+  const SelectionTree::Node *Node = Sel.ASTSelection.commonAncestor();
+  if (!Node)
+    return false;
+
+  // Make sure we have a definition.
+  CurrentDeclDef = Node->ASTNode.get<CXXRecordDecl>();
+  if (!CurrentDeclDef || !CurrentDeclDef->getDefinition())
+    return false;
+
+  // From now on, we should work with the definition.
+  CurrentDeclDef = CurrentDeclDef->getDefinition();
+
+  // Only offer for abstract classes with abstract bases.
+  return CurrentDeclDef->isAbstract() &&
+         hasAbstractBaseAncestor(CurrentDeclDef);
+}
+
+// Collects all pure virtual methods that are missing an override in
+// CurrentDecl, grouped by their original access specifier.
+void OverridePureVirtuals::collectMissingPureVirtuals(const Selection &Sel) {
+  if (!CurrentDeclDef)
+    return;
+
+  AccessSpecifierLocations = getSpecifierLocations(CurrentDeclDef);
+  MissingMethodsByAccess.clear();
+
+  // Get all unique pure virtual methods from the entire base class hierarchy.
+  llvm::SmallVector<const CXXMethodDecl *> AllPureVirtualsInHierarchy =
+      getAllUniquePureVirtualsFromBaseHierarchy(CurrentDeclDef);
+
+  // Get methods already implemented or overridden in CurrentDecl.
+  const auto ImplementedOrOverriddenSet =
+      getImplementedOrOverriddenCanonicals(CurrentDeclDef);
+
+  // Filter AllPureVirtualsInHierarchy to find those not in
+  // ImplementedOrOverriddenSet, which needs to be overriden.
+  for (const CXXMethodDecl *BaseMethod : AllPureVirtualsInHierarchy) {
+    bool AlreadyHandled = ImplementedOrOverriddenSet.contains(BaseMethod);
+    if (!AlreadyHandled)
+      MissingMethodsByAccess[BaseMethod->getAccess()].emplace_back(BaseMethod);
+  }
+}
+
+// Free function to generate the string for a group of method overrides.
+std::string generateOverridesStringForGroup(
+    llvm::SmallVector<const CXXMethodDecl *> Methods,
+    const LangOptions &LangOpts) {
+  const auto GetParamString = [&LangOpts](const ParmVarDecl *P) {
+    std::string TypeStr = P->getType().getAsString(LangOpts);
+    // Unnamed parameter.
+    if (P->getNameAsString().empty())
+      return TypeStr;
+
+    return llvm::formatv("{0} {1}", std::move(TypeStr), P->getNameAsString())
+        .str();
+  };
+
+  std::string MethodsString;
+  for (const auto *Method : Methods) {
+    llvm::SmallVector<std::string> ParamsAsString;
+    ParamsAsString.reserve(Method->parameters().size());
+    llvm::transform(Method->parameters(), std::back_inserter(ParamsAsString),
+                    GetParamString);
+    auto Params = llvm::join(ParamsAsString, ", ");
+
+    MethodsString +=
+        llvm::formatv(
+            "  {0} {1}({2}) {3}override {{\n"
----------------
zwuis wrote:

We need to handle the following things as well:
1. `constexpr`/`consteval` specifier
2. ref-qualifier
3. trailing return type (You can refer to the implementation of clang-tidy 
check modernize-use-trailing-return-type, or leave a FIXME comment if it's 
beyond your capabilities)

https://github.com/llvm/llvm-project/pull/139348
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to