================ @@ -0,0 +1,296 @@ +//===--- OverridePureVirtuals.cpp --------------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "refactor/Tweak.h" + +#include "clang/AST/ASTContext.h" +#include "clang/AST/DeclCXX.h" +#include "clang/AST/Type.h" +#include "clang/AST/TypeLoc.h" +#include "clang/Basic/LLVM.h" +#include "clang/Basic/SourceLocation.h" +#include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Support/FormatVariadic.h" +#include <string> + +namespace clang { +namespace clangd { +namespace { + +class OverridePureVirtuals : public Tweak { +public: + const char *id() const final; // defined by REGISTER_TWEAK. + bool prepare(const Selection &Sel) override; + Expected<Effect> apply(const Selection &Sel) override; + std::string title() const override { return "Override pure virtual methods"; } + llvm::StringLiteral kind() const override { + return CodeAction::QUICKFIX_KIND; + } + +private: + // Stores the CXXRecordDecl of the class being modified. + const CXXRecordDecl *CurrentDeclDef = nullptr; + // Stores pure virtual methods that need overriding, grouped by their original + // access specifier. + llvm::MapVector<AccessSpecifier, llvm::SmallVector<const CXXMethodDecl *>> + MissingMethodsByAccess; + // Stores the source locations of existing access specifiers in CurrentDecl. + llvm::MapVector<AccessSpecifier, SourceLocation> AccessSpecifierLocations; + + // Helper function to gather information before applying the tweak. + void collectMissingPureVirtuals(const Selection &Sel); +}; + +REGISTER_TWEAK(OverridePureVirtuals) + +// Function to get all unique pure virtual methods from the entire +// base class hierarchy of CurrentDeclDef. +llvm::SmallVector<const clang::CXXMethodDecl *> +getAllUniquePureVirtualsFromBaseHierarchy( + const clang::CXXRecordDecl *CurrentDeclDef) { + llvm::SmallVector<const clang::CXXMethodDecl *> AllPureVirtualsInHierarchy; + llvm::DenseSet<const clang::CXXMethodDecl *> CanonicalPureVirtualsSeen; + + if (!CurrentDeclDef || !CurrentDeclDef->getDefinition()) + return AllPureVirtualsInHierarchy; + + const clang::CXXRecordDecl *Def = CurrentDeclDef->getDefinition(); + + Def->forallBases([&](const clang::CXXRecordDecl *BaseDefinition) { + for (const clang::CXXMethodDecl *Method : BaseDefinition->methods()) { + if (Method->isPureVirtual() && + CanonicalPureVirtualsSeen.insert(Method->getCanonicalDecl()).second) + AllPureVirtualsInHierarchy.emplace_back(Method); + } + // Continue iterating through all bases. + return true; + }); + + return AllPureVirtualsInHierarchy; +} + +// Gets canonical declarations of methods already overridden or implemented in +// class D. +llvm::SetVector<const CXXMethodDecl *> +getImplementedOrOverriddenCanonicals(const CXXRecordDecl *D) { + llvm::SetVector<const CXXMethodDecl *> ImplementedSet; + for (const CXXMethodDecl *M : D->methods()) { + // If M provides an implementation for any virtual method it overrides. + // A method is an "implementation" if it's virtual and not pure. + // Or if it directly overrides a base method. + for (const CXXMethodDecl *OverriddenM : M->overridden_methods()) + ImplementedSet.insert(OverriddenM->getCanonicalDecl()); + } + return ImplementedSet; +} + +// Get the location of every colon of the `AccessSpecifier`. +llvm::MapVector<AccessSpecifier, SourceLocation> +getSpecifierLocations(const CXXRecordDecl *D) { + llvm::MapVector<AccessSpecifier, SourceLocation> Locs; + for (auto *DeclNode : D->decls()) { + if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(DeclNode)) + Locs[ASD->getAccess()] = ASD->getColonLoc(); + } + return Locs; +} + +bool hasAbstractBaseAncestor(const clang::CXXRecordDecl *CurrentDecl) { + if (!CurrentDecl || !CurrentDecl->getDefinition()) + return false; + + return llvm::any_of( + CurrentDecl->getDefinition()->bases(), [](CXXBaseSpecifier BaseSpec) { + const auto *D = BaseSpec.getType()->getAsCXXRecordDecl(); + const auto *Def = D ? D->getDefinition() : nullptr; + return Def && Def->isAbstract(); + }); +} + +// Check if the current class has any pure virtual method to be implemented. +bool OverridePureVirtuals::prepare(const Selection &Sel) { + const SelectionTree::Node *Node = Sel.ASTSelection.commonAncestor(); + if (!Node) + return false; + + // Make sure we have a definition. + CurrentDeclDef = Node->ASTNode.get<CXXRecordDecl>(); + if (!CurrentDeclDef || !CurrentDeclDef->getDefinition()) + return false; + + // From now on, we should work with the definition. + CurrentDeclDef = CurrentDeclDef->getDefinition(); + + // Only offer for abstract classes with abstract bases. + return CurrentDeclDef->isAbstract() && + hasAbstractBaseAncestor(CurrentDeclDef); +} + +// Collects all pure virtual methods that are missing an override in +// CurrentDecl, grouped by their original access specifier. +void OverridePureVirtuals::collectMissingPureVirtuals(const Selection &Sel) { + if (!CurrentDeclDef) + return; + + AccessSpecifierLocations = getSpecifierLocations(CurrentDeclDef); + MissingMethodsByAccess.clear(); + + // Get all unique pure virtual methods from the entire base class hierarchy. + llvm::SmallVector<const CXXMethodDecl *> AllPureVirtualsInHierarchy = + getAllUniquePureVirtualsFromBaseHierarchy(CurrentDeclDef); + + // Get methods already implemented or overridden in CurrentDecl. + const auto ImplementedOrOverriddenSet = + getImplementedOrOverriddenCanonicals(CurrentDeclDef); + + // Filter AllPureVirtualsInHierarchy to find those not in + // ImplementedOrOverriddenSet, which needs to be overriden. + for (const CXXMethodDecl *BaseMethod : AllPureVirtualsInHierarchy) { + bool AlreadyHandled = ImplementedOrOverriddenSet.contains(BaseMethod); + if (!AlreadyHandled) + MissingMethodsByAccess[BaseMethod->getAccess()].emplace_back(BaseMethod); + } +} + +// Free function to generate the string for a group of method overrides. +std::string generateOverridesStringForGroup( + llvm::SmallVector<const CXXMethodDecl *> Methods, + const LangOptions &LangOpts) { + const auto GetParamString = [&LangOpts](const ParmVarDecl *P) { + std::string TypeStr = P->getType().getAsString(LangOpts); + // Unnamed parameter. + if (P->getNameAsString().empty()) + return TypeStr; + + return llvm::formatv("{0} {1}", std::move(TypeStr), P->getNameAsString()) + .str(); + }; + + std::string MethodsString; + for (const auto *Method : Methods) { + llvm::SmallVector<std::string> ParamsAsString; + ParamsAsString.reserve(Method->parameters().size()); + llvm::transform(Method->parameters(), std::back_inserter(ParamsAsString), + GetParamString); + auto Params = llvm::join(ParamsAsString, ", "); + + MethodsString += + llvm::formatv( + " {0} {1}({2}) {3}override {{\n" ---------------- zwuis wrote:
We need to handle the following things as well: 1. `constexpr`/`consteval` specifier 2. ref-qualifier 3. trailing return type (You can refer to the implementation of clang-tidy check modernize-use-trailing-return-type, or leave a FIXME comment if it's beyond your capabilities) https://github.com/llvm/llvm-project/pull/139348 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits