================
@@ -0,0 +1,349 @@
+//===--- AddPureVirtualOverride.cpp ------------------------------*- 
C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "refactor/Tweak.h"
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/Type.h"
+#include "clang/AST/TypeLoc.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <algorithm>
+#include <functional>
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+namespace clang {
+namespace clangd {
+namespace {
+
+class OverridePureVirtuals : public Tweak {
+public:
+  const char *id() const final; // defined by REGISTER_TWEAK.
+  bool prepare(const Selection &Sel) override;
+  Expected<Effect> apply(const Selection &Sel) override;
+  std::string title() const override { return "Override pure virtual methods"; 
}
+  llvm::StringLiteral kind() const override {
+    return CodeAction::REFACTOR_KIND;
+  }
+
+private:
+  // Stores the CXXRecordDecl of the class being modified.
+  const CXXRecordDecl *CurrentDecl = nullptr;
+  // Stores pure virtual methods that need overriding, grouped by their 
original
+  // access specifier.
+  std::map<AccessSpecifier, std::vector<const CXXMethodDecl *>>
+      MissingMethodsByAccess;
+  // Stores the source locations of existing access specifiers in CurrentDecl.
+  std::map<AccessSpecifier, SourceLocation> AccessSpecifierLocations;
+
+  // Helper function to gather information before applying the tweak.
+  void collectMissingPureVirtuals(const Selection &Sel);
+};
+
+REGISTER_TWEAK(OverridePureVirtuals)
+
+// Collects all unique, canonical pure virtual methods from a class and its
+// entire inheritance hierarchy. This function aims to find methods that 
*could*
+// make a derived class abstract if not implemented.
+std::vector<const CXXMethodDecl *>
+getAllUniquePureVirtualsFromHierarchy(const CXXRecordDecl *Decl) {
+  std::vector<const CXXMethodDecl *> Result;
+  llvm::DenseSet<const CXXMethodDecl *> VisitedCanonicalMethods;
+  // We declare it as a std::function because we are going to call it
+  // recursively.
+  std::function<void(const CXXRecordDecl *)> Collect;
+
+  Collect = [&](const CXXRecordDecl *CurrentClass) {
+    if (!CurrentClass) {
+      return;
+    }
+    const CXXRecordDecl *Def = CurrentClass->getDefinition();
+    if (!Def) {
+      return;
+    }
+
+    for (const CXXMethodDecl *M : Def->methods()) {
+      // Add if its canonical declaration hasn't been processed yet.
+      // This ensures each distinct pure virtual method signature is collected
+      // once.
+      if (M->isPureVirtual() &&
+          VisitedCanonicalMethods.insert(M->getCanonicalDecl()).second) {
+        Result.emplace_back(M); // Store the specific declaration encountered.
+      }
+    }
+
+    for (const auto &BaseSpec : Def->bases()) {
+      if (const CXXRecordDecl *BaseDef =
+              BaseSpec.getType()->getAsCXXRecordDecl()) {
+        Collect(BaseDef); // Recursively collect from base classes.
+      }
+    }
+  };
+
+  Collect(Decl);
+  return Result;
+}
+
+// Gets canonical declarations of methods already overridden or implemented in
+// class D.
+std::set<const CXXMethodDecl *>
+getImplementedOrOverriddenCanonicals(const CXXRecordDecl *D) {
+  std::set<const CXXMethodDecl *> ImplementedSet;
+  for (const CXXMethodDecl *M : D->methods()) {
+    // If M provides an implementation for any virtual method it overrides.
+    // A method is an "implementation" if it's virtual and not pure.
+    // Or if it directly overrides a base method.
+    for (const CXXMethodDecl *OverriddenM : M->overridden_methods()) {
+      ImplementedSet.insert(OverriddenM->getCanonicalDecl());
+    }
+  }
+  return ImplementedSet;
+}
+
+// Get the location of every colon of the `AccessSpecifier`.
+std::map<AccessSpecifier, SourceLocation>
+getSpecifierLocations(const CXXRecordDecl *D) {
+  std::map<AccessSpecifier, SourceLocation> Locs;
+  for (auto *DeclNode : D->decls()) { // Changed to DeclNode to avoid ambiguity
+    if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(DeclNode)) {
+      Locs[ASD->getAccess()] = ASD->getColonLoc();
+    }
+  }
+  return Locs;
+}
+
+// Check if the current class has any pure virtual method to be implemented.
+bool OverridePureVirtuals::prepare(const Selection &Sel) {
+  const SelectionTree::Node *Node = Sel.ASTSelection.commonAncestor();
+  if (!Node) {
+    return false;
+  }
+
+  // Make sure we have a definition.
+  CurrentDecl = Node->ASTNode.get<CXXRecordDecl>();
+  if (!CurrentDecl || !CurrentDecl->getDefinition()) {
+    return false;
+  }
+
+  // A class needs overrides if it's abstract itself, or derives from abstract
+  // bases and hasn't implemented all necessary methods. A simpler check: if it
+  // has any base that is abstract.
+  bool HasAbstractBase = false;
+  for (const auto &Base : CurrentDecl->bases()) {
+    if (const CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl()) {
+      if (BaseDecl->getDefinition() &&
+          BaseDecl->getDefinition()->isAbstract()) {
+        HasAbstractBase = true;
+        break;
+      }
+    }
+  }
+
+  // Only offer for polymorphic classes with abstract bases.
+  return CurrentDecl->isPolymorphic() && HasAbstractBase;
+}
+
+// Collects all pure virtual methods that are missing an override in
+// CurrentDecl, grouped by their original access specifier.
+void OverridePureVirtuals::collectMissingPureVirtuals(const Selection &Sel) {
+  if (!CurrentDecl)
+    return;
+  CurrentDecl = CurrentDecl->getDefinition(); // Work with the definition
+  if (!CurrentDecl)
+    return;
+
+  AccessSpecifierLocations = getSpecifierLocations(CurrentDecl);
+  MissingMethodsByAccess.clear();
+
+  // Get all unique pure virtual methods from the entire base class hierarchy.
+  std::vector<const CXXMethodDecl *> AllPureVirtualsInHierarchy;
+  llvm::DenseSet<const CXXMethodDecl *> CanonicalPureVirtualsSeen;
+
+  for (const auto &BaseSpec : CurrentDecl->bases()) {
+    if (const CXXRecordDecl *BaseRD =
+            BaseSpec.getType()->getAsCXXRecordDecl()) {
+      const CXXRecordDecl *BaseDef = BaseRD->getDefinition();
+      if (!BaseDef)
+        continue;
+
+      std::vector<const CXXMethodDecl *> PuresFromBasePath =
+          getAllUniquePureVirtualsFromHierarchy(BaseDef);
+      for (const CXXMethodDecl *M : PuresFromBasePath) {
+        if (CanonicalPureVirtualsSeen.insert(M->getCanonicalDecl()).second) {
+          AllPureVirtualsInHierarchy.emplace_back(M);
+        }
+      }
+    }
+  }
+
+  // Get methods already implemented or overridden in CurrentDecl.
+  const auto ImplementedOrOverriddenSet =
+      getImplementedOrOverriddenCanonicals(CurrentDecl);
+
+  // Filter AllPureVirtualsInHierarchy to find those not in
+  // ImplementedOrOverriddenSet.
+  for (const CXXMethodDecl *BaseMethod : AllPureVirtualsInHierarchy) {
+    bool AlreadyHandled =
+        ImplementedOrOverriddenSet.count(BaseMethod->getCanonicalDecl()) > 0;
+
+    if (!AlreadyHandled) {
+      // This method needs an override.
+      // Group it by its access specifier in its defining class.
+      MissingMethodsByAccess[BaseMethod->getAccess()].emplace_back(BaseMethod);
+    }
+  }
+}
+
+// Free function to generate the string for a group of method overrides.
+std::string
+generateOverridesStringForGroup(std::vector<const CXXMethodDecl *> Methods,
+                                const LangOptions &LangOpts) {
+  const auto GetParamString = [&LangOpts](const ParmVarDecl *P) {
+    std::string TypeStr = P->getType().getAsString(LangOpts);
+    if (P->getNameAsString().empty()) {
+      // Unnamed parameter.
+      return TypeStr;
+    }
+    return llvm::formatv("{0} {1}", TypeStr, P->getNameAsString()).str();
----------------
marcogmaia wrote:

Added.

https://github.com/llvm/llvm-project/pull/139348
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to