================
@@ -0,0 +1,349 @@
+//===--- AddPureVirtualOverride.cpp ------------------------------*- 
C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "refactor/Tweak.h"
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/Type.h"
+#include "clang/AST/TypeLoc.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <algorithm>
+#include <functional>
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+namespace clang {
+namespace clangd {
+namespace {
+
+class OverridePureVirtuals : public Tweak {
+public:
+  const char *id() const final; // defined by REGISTER_TWEAK.
+  bool prepare(const Selection &Sel) override;
+  Expected<Effect> apply(const Selection &Sel) override;
+  std::string title() const override { return "Override pure virtual methods"; 
}
+  llvm::StringLiteral kind() const override {
+    return CodeAction::REFACTOR_KIND;
+  }
+
+private:
+  // Stores the CXXRecordDecl of the class being modified.
+  const CXXRecordDecl *CurrentDecl = nullptr;
+  // Stores pure virtual methods that need overriding, grouped by their 
original
+  // access specifier.
+  std::map<AccessSpecifier, std::vector<const CXXMethodDecl *>>
+      MissingMethodsByAccess;
+  // Stores the source locations of existing access specifiers in CurrentDecl.
+  std::map<AccessSpecifier, SourceLocation> AccessSpecifierLocations;
+
+  // Helper function to gather information before applying the tweak.
+  void collectMissingPureVirtuals(const Selection &Sel);
+};
+
+REGISTER_TWEAK(OverridePureVirtuals)
+
+// Collects all unique, canonical pure virtual methods from a class and its
+// entire inheritance hierarchy. This function aims to find methods that 
*could*
+// make a derived class abstract if not implemented.
+std::vector<const CXXMethodDecl *>
+getAllUniquePureVirtualsFromHierarchy(const CXXRecordDecl *Decl) {
+  std::vector<const CXXMethodDecl *> Result;
+  llvm::DenseSet<const CXXMethodDecl *> VisitedCanonicalMethods;
+  // We declare it as a std::function because we are going to call it
+  // recursively.
+  std::function<void(const CXXRecordDecl *)> Collect;
+
+  Collect = [&](const CXXRecordDecl *CurrentClass) {
+    if (!CurrentClass) {
+      return;
+    }
+    const CXXRecordDecl *Def = CurrentClass->getDefinition();
+    if (!Def) {
+      return;
+    }
+
+    for (const CXXMethodDecl *M : Def->methods()) {
+      // Add if its canonical declaration hasn't been processed yet.
+      // This ensures each distinct pure virtual method signature is collected
+      // once.
+      if (M->isPureVirtual() &&
+          VisitedCanonicalMethods.insert(M->getCanonicalDecl()).second) {
+        Result.emplace_back(M); // Store the specific declaration encountered.
+      }
+    }
+
+    for (const auto &BaseSpec : Def->bases()) {
+      if (const CXXRecordDecl *BaseDef =
+              BaseSpec.getType()->getAsCXXRecordDecl()) {
+        Collect(BaseDef); // Recursively collect from base classes.
+      }
+    }
+  };
+
+  Collect(Decl);
+  return Result;
+}
+
+// Gets canonical declarations of methods already overridden or implemented in
+// class D.
+std::set<const CXXMethodDecl *>
+getImplementedOrOverriddenCanonicals(const CXXRecordDecl *D) {
+  std::set<const CXXMethodDecl *> ImplementedSet;
+  for (const CXXMethodDecl *M : D->methods()) {
+    // If M provides an implementation for any virtual method it overrides.
+    // A method is an "implementation" if it's virtual and not pure.
+    // Or if it directly overrides a base method.
+    for (const CXXMethodDecl *OverriddenM : M->overridden_methods()) {
+      ImplementedSet.insert(OverriddenM->getCanonicalDecl());
+    }
+  }
+  return ImplementedSet;
+}
+
+// Get the location of every colon of the `AccessSpecifier`.
+std::map<AccessSpecifier, SourceLocation>
+getSpecifierLocations(const CXXRecordDecl *D) {
+  std::map<AccessSpecifier, SourceLocation> Locs;
+  for (auto *DeclNode : D->decls()) { // Changed to DeclNode to avoid ambiguity
+    if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(DeclNode)) {
+      Locs[ASD->getAccess()] = ASD->getColonLoc();
+    }
+  }
+  return Locs;
+}
+
+// Check if the current class has any pure virtual method to be implemented.
+bool OverridePureVirtuals::prepare(const Selection &Sel) {
+  const SelectionTree::Node *Node = Sel.ASTSelection.commonAncestor();
+  if (!Node) {
+    return false;
+  }
+
+  // Make sure we have a definition.
+  CurrentDecl = Node->ASTNode.get<CXXRecordDecl>();
+  if (!CurrentDecl || !CurrentDecl->getDefinition()) {
+    return false;
+  }
+
+  // A class needs overrides if it's abstract itself, or derives from abstract
+  // bases and hasn't implemented all necessary methods. A simpler check: if it
+  // has any base that is abstract.
+  bool HasAbstractBase = false;
+  for (const auto &Base : CurrentDecl->bases()) {
+    if (const CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl()) {
+      if (BaseDecl->getDefinition() &&
+          BaseDecl->getDefinition()->isAbstract()) {
+        HasAbstractBase = true;
+        break;
+      }
+    }
+  }
----------------
marcogmaia wrote:

Adopted, I've refactored it to its own function.

https://github.com/llvm/llvm-project/pull/139348
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to