https://github.com/JDevlieghere updated https://github.com/llvm/llvm-project/pull/143628
>From ecc6e9e88e2f5786e7ebd0a454567cd933e60b7f Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere <jo...@devlieghere.com> Date: Thu, 12 Jun 2025 16:33:55 -0700 Subject: [PATCH] [lldb] Add MCP support to LLDB https://discourse.llvm.org/t/rfc-adding-mcp-support-to-lldb/86798 --- lldb/include/lldb/Core/Debugger.h | 6 + lldb/include/lldb/Core/PluginManager.h | 9 + lldb/include/lldb/Core/ProtocolServer.h | 37 +++ .../Interpreter/CommandOptionArgumentTable.h | 1 + lldb/include/lldb/lldb-enumerations.h | 1 + lldb/include/lldb/lldb-forward.h | 3 +- lldb/include/lldb/lldb-private-interfaces.h | 2 + lldb/source/Commands/CMakeLists.txt | 1 + .../Commands/CommandObjectProtocolServer.cpp | 145 +++++++++ .../Commands/CommandObjectProtocolServer.h | 25 ++ lldb/source/Core/CMakeLists.txt | 1 + lldb/source/Core/Debugger.cpp | 24 ++ lldb/source/Core/PluginManager.cpp | 27 ++ lldb/source/Core/ProtocolServer.cpp | 21 ++ .../source/Interpreter/CommandInterpreter.cpp | 2 + lldb/source/Plugins/CMakeLists.txt | 1 + lldb/source/Plugins/Protocol/CMakeLists.txt | 1 + .../Plugins/Protocol/MCP/CMakeLists.txt | 13 + lldb/source/Plugins/Protocol/MCP/MCPError.cpp | 31 ++ lldb/source/Plugins/Protocol/MCP/MCPError.h | 33 +++ lldb/source/Plugins/Protocol/MCP/Protocol.cpp | 180 +++++++++++ lldb/source/Plugins/Protocol/MCP/Protocol.h | 131 ++++++++ .../Protocol/MCP/ProtocolServerMCP.cpp | 280 ++++++++++++++++++ .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 77 +++++ lldb/source/Plugins/Protocol/MCP/Tool.cpp | 72 +++++ lldb/source/Plugins/Protocol/MCP/Tool.h | 61 ++++ 26 files changed, 1184 insertions(+), 1 deletion(-) create mode 100644 lldb/include/lldb/Core/ProtocolServer.h create mode 100644 lldb/source/Commands/CommandObjectProtocolServer.cpp create mode 100644 lldb/source/Commands/CommandObjectProtocolServer.h create mode 100644 lldb/source/Core/ProtocolServer.cpp create mode 100644 lldb/source/Plugins/Protocol/CMakeLists.txt create mode 100644 lldb/source/Plugins/Protocol/MCP/CMakeLists.txt create mode 100644 lldb/source/Plugins/Protocol/MCP/MCPError.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/MCPError.h create mode 100644 lldb/source/Plugins/Protocol/MCP/Protocol.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/Protocol.h create mode 100644 lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h create mode 100644 lldb/source/Plugins/Protocol/MCP/Tool.cpp create mode 100644 lldb/source/Plugins/Protocol/MCP/Tool.h diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h index d73aba1e3ce58..0f6659d1a0bf7 100644 --- a/lldb/include/lldb/Core/Debugger.h +++ b/lldb/include/lldb/Core/Debugger.h @@ -598,6 +598,10 @@ class Debugger : public std::enable_shared_from_this<Debugger>, void FlushProcessOutput(Process &process, bool flush_stdout, bool flush_stderr); + void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp); + void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp); + lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const; + SourceManager::SourceFileCache &GetSourceFileCache() { return m_source_file_cache; } @@ -768,6 +772,8 @@ class Debugger : public std::enable_shared_from_this<Debugger>, mutable std::mutex m_progress_reports_mutex; /// @} + llvm::SmallVector<lldb::ProtocolServerSP> m_protocol_servers; + std::mutex m_destroy_callback_mutex; lldb::callback_token_t m_destroy_callback_next_token = 0; struct DestroyCallbackInfo { diff --git a/lldb/include/lldb/Core/PluginManager.h b/lldb/include/lldb/Core/PluginManager.h index e7b1691031111..967af598c40ff 100644 --- a/lldb/include/lldb/Core/PluginManager.h +++ b/lldb/include/lldb/Core/PluginManager.h @@ -327,6 +327,15 @@ class PluginManager { static void AutoCompleteProcessName(llvm::StringRef partial_name, CompletionRequest &request); + // Protocol + static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback); + + static bool UnregisterPlugin(ProtocolServerCreateInstance create_callback); + + static ProtocolServerCreateInstance + GetProtocolCreateCallbackForPluginName(llvm::StringRef name); + // Register Type Provider static bool RegisterPlugin(llvm::StringRef name, llvm::StringRef description, RegisterTypeBuilderCreateInstance create_callback); diff --git a/lldb/include/lldb/Core/ProtocolServer.h b/lldb/include/lldb/Core/ProtocolServer.h new file mode 100644 index 0000000000000..ca0210a0bbe72 --- /dev/null +++ b/lldb/include/lldb/Core/ProtocolServer.h @@ -0,0 +1,37 @@ +//===-- ProtocolServer.h --------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_CORE_PROTOCOLSERVER_H +#define LLDB_CORE_PROTOCOLSERVER_H + +#include "lldb/Core/PluginInterface.h" +#include "lldb/Host/Socket.h" +#include "lldb/lldb-private-interfaces.h" + +namespace lldb_private { + +class ProtocolServer : public PluginInterface { +public: + ProtocolServer() = default; + virtual ~ProtocolServer() = default; + + static lldb::ProtocolServerSP Create(llvm::StringRef name, + Debugger &debugger); + + struct Connection { + Socket::SocketProtocol protocol; + std::string name; + }; + + virtual llvm::Error Start(Connection connection); + virtual llvm::Error Stop(); +}; + +} // namespace lldb_private + +#endif diff --git a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h index 8535dfcf46da5..4face717531b1 100644 --- a/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h +++ b/lldb/include/lldb/Interpreter/CommandOptionArgumentTable.h @@ -315,6 +315,7 @@ static constexpr CommandObject::ArgumentTableEntry g_argument_table[] = { { lldb::eArgTypeCPUName, "cpu-name", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of a CPU." }, { lldb::eArgTypeCPUFeatures, "cpu-features", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The CPU feature string." }, { lldb::eArgTypeManagedPlugin, "managed-plugin", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "Plugins managed by the PluginManager" }, + { lldb::eArgTypeProtocol, "protocol", lldb::CompletionType::eNoCompletion, {}, { nullptr, false }, "The name of the protocol." }, // clang-format on }; diff --git a/lldb/include/lldb/lldb-enumerations.h b/lldb/include/lldb/lldb-enumerations.h index eeb7299a354e1..69e8671b6e21b 100644 --- a/lldb/include/lldb/lldb-enumerations.h +++ b/lldb/include/lldb/lldb-enumerations.h @@ -664,6 +664,7 @@ enum CommandArgumentType { eArgTypeCPUName, eArgTypeCPUFeatures, eArgTypeManagedPlugin, + eArgTypeProtocol, eArgTypeLastArg // Always keep this entry as the last entry in this // enumeration!! }; diff --git a/lldb/include/lldb/lldb-forward.h b/lldb/include/lldb/lldb-forward.h index c664d1398f74d..558818e8e2309 100644 --- a/lldb/include/lldb/lldb-forward.h +++ b/lldb/include/lldb/lldb-forward.h @@ -164,13 +164,13 @@ class PersistentExpressionState; class Platform; class Process; class ProcessAttachInfo; -class ProcessLaunchInfo; class ProcessInfo; class ProcessInstanceInfo; class ProcessInstanceInfoMatch; class ProcessLaunchInfo; class ProcessModID; class Property; +class ProtocolServer; class Queue; class QueueImpl; class QueueItem; @@ -391,6 +391,7 @@ typedef std::shared_ptr<lldb_private::Platform> PlatformSP; typedef std::shared_ptr<lldb_private::Process> ProcessSP; typedef std::shared_ptr<lldb_private::ProcessAttachInfo> ProcessAttachInfoSP; typedef std::shared_ptr<lldb_private::ProcessLaunchInfo> ProcessLaunchInfoSP; +typedef std::shared_ptr<lldb_private::ProtocolServer> ProtocolServerSP; typedef std::weak_ptr<lldb_private::Process> ProcessWP; typedef std::shared_ptr<lldb_private::RegisterCheckpoint> RegisterCheckpointSP; typedef std::shared_ptr<lldb_private::RegisterContext> RegisterContextSP; diff --git a/lldb/include/lldb/lldb-private-interfaces.h b/lldb/include/lldb/lldb-private-interfaces.h index d366dbd1d7832..34eaaa8e581e9 100644 --- a/lldb/include/lldb/lldb-private-interfaces.h +++ b/lldb/include/lldb/lldb-private-interfaces.h @@ -81,6 +81,8 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force, typedef lldb::ProcessSP (*ProcessCreateInstance)( lldb::TargetSP target_sp, lldb::ListenerSP listener_sp, const FileSpec *crash_file_path, bool can_connect); +typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)( + Debugger &debugger); typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)( Target &target); typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)( diff --git a/lldb/source/Commands/CMakeLists.txt b/lldb/source/Commands/CMakeLists.txt index 1ea51acec5f15..69e4c45f0b8e5 100644 --- a/lldb/source/Commands/CMakeLists.txt +++ b/lldb/source/Commands/CMakeLists.txt @@ -23,6 +23,7 @@ add_lldb_library(lldbCommands NO_PLUGIN_DEPENDENCIES CommandObjectPlatform.cpp CommandObjectPlugin.cpp CommandObjectProcess.cpp + CommandObjectProtocolServer.cpp CommandObjectQuit.cpp CommandObjectRegexCommand.cpp CommandObjectRegister.cpp diff --git a/lldb/source/Commands/CommandObjectProtocolServer.cpp b/lldb/source/Commands/CommandObjectProtocolServer.cpp new file mode 100644 index 0000000000000..bdb237cf010f4 --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.cpp @@ -0,0 +1,145 @@ +//===-- CommandObjectProtocolServer.cpp +//----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CommandObjectProtocolServer.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/Socket.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/Utility/UriParser.h" +#include "llvm/Support/FormatAdapters.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +#define LLDB_OPTIONS_mcp +#include "CommandOptions.inc" + +static llvm::Expected<std::pair<Socket::SocketProtocol, std::string>> +validateConnection(llvm::StringRef conn) { + auto uri = lldb_private::URI::Parse(conn); + + if (uri && (uri->scheme == "tcp" || uri->scheme == "connect" || + !uri->hostname.empty() || uri->port)) { + return std::make_pair( + Socket::ProtocolTcp, + formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname, + uri->port.value_or(0))); + } + + if (uri && (uri->scheme == "unix" || uri->scheme == "unix-connect" || + uri->path != "/")) { + return std::make_pair(Socket::ProtocolUnixDomain, uri->path.str()); + } + + return llvm::createStringError( + "Unsupported connection specifier, expected 'unix-connect:///path' or " + "'connect://[host]:port', got '%s'.", + conn.str().c_str()); +} + +class CommandObjectProtocolServerStart : public CommandObjectParsed { +public: + CommandObjectProtocolServerStart(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "mcp start", "start MCP server", + "mcp start <connection>") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + AddSimpleArgumentList(lldb::eArgTypeConnectURL, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStart() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + if (args.GetArgumentCount() < 2) { + result.AppendError("no connection specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + + ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol); + if (!server_sp) + server_sp = ProtocolServer::Create(protocol, GetDebugger()); + + llvm::StringRef connection_uri = args.GetArgumentAtIndex(1); + + auto maybeProtoclAndName = validateConnection(connection_uri); + if (auto error = maybeProtoclAndName.takeError()) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + ProtocolServer::Connection connection; + std::tie(connection.protocol, connection.name) = *maybeProtoclAndName; + + if (llvm::Error error = server_sp->Start(connection)) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + GetDebugger().AddProtocolServer(server_sp); + } +}; + +class CommandObjectProtocolServerStop : public CommandObjectParsed { +public: + CommandObjectProtocolServerStop(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "protocol-server stop", + "stop protocol server", "protocol-server stop") { + AddSimpleArgumentList(lldb::eArgTypeProtocol, eArgRepeatPlain); + } + + ~CommandObjectProtocolServerStop() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() < 1) { + result.AppendError("no protocol specified"); + return; + } + + llvm::StringRef protocol = args.GetArgumentAtIndex(0); + + Debugger &debugger = GetDebugger(); + + ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol); + if (!server_sp) { + result.AppendError( + llvm::formatv("no {0} protocol server running", protocol).str()); + return; + } + + if (llvm::Error error = server_sp->Stop()) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + debugger.RemoveProtocolServer(server_sp); + } +}; + +CommandObjectProtocolServer::CommandObjectProtocolServer( + CommandInterpreter &interpreter) + : CommandObjectMultiword(interpreter, "protocol-server", + "Start and stop a protocol server.", + "protocol-server") { + LoadSubCommand("start", CommandObjectSP(new CommandObjectProtocolServerStart( + interpreter))); + LoadSubCommand("stop", CommandObjectSP( + new CommandObjectProtocolServerStop(interpreter))); +} + +CommandObjectProtocolServer::~CommandObjectProtocolServer() = default; diff --git a/lldb/source/Commands/CommandObjectProtocolServer.h b/lldb/source/Commands/CommandObjectProtocolServer.h new file mode 100644 index 0000000000000..3591216b014cb --- /dev/null +++ b/lldb/source/Commands/CommandObjectProtocolServer.h @@ -0,0 +1,25 @@ +//===-- CommandObjectProtocolServer.h +//------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H +#define LLDB_SOURCE_COMMANDS_COMMANDOBJECTPROTOCOLSERVER_H + +#include "lldb/Interpreter/CommandObjectMultiword.h" + +namespace lldb_private { + +class CommandObjectProtocolServer : public CommandObjectMultiword { +public: + CommandObjectProtocolServer(CommandInterpreter &interpreter); + ~CommandObjectProtocolServer() override; +}; + +} // namespace lldb_private + +#endif // LLDB_SOURCE_COMMANDS_COMMANDOBJECTMCP_H diff --git a/lldb/source/Core/CMakeLists.txt b/lldb/source/Core/CMakeLists.txt index d6b75bca7f2d6..df35bd5c025f3 100644 --- a/lldb/source/Core/CMakeLists.txt +++ b/lldb/source/Core/CMakeLists.txt @@ -46,6 +46,7 @@ add_lldb_library(lldbCore NO_PLUGIN_DEPENDENCIES Opcode.cpp PluginManager.cpp Progress.cpp + ProtocolServer.cpp Statusline.cpp RichManglingContext.cpp SearchFilter.cpp diff --git a/lldb/source/Core/Debugger.cpp b/lldb/source/Core/Debugger.cpp index 81037d3def811..2bc9c7ead79d3 100644 --- a/lldb/source/Core/Debugger.cpp +++ b/lldb/source/Core/Debugger.cpp @@ -16,6 +16,7 @@ #include "lldb/Core/ModuleSpec.h" #include "lldb/Core/PluginManager.h" #include "lldb/Core/Progress.h" +#include "lldb/Core/ProtocolServer.h" #include "lldb/Core/StreamAsynchronousIO.h" #include "lldb/Core/Telemetry.h" #include "lldb/DataFormatters/DataVisualization.h" @@ -2363,3 +2364,26 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() { "Debugger::GetThreadPool called before Debugger::Initialize"); return *g_thread_pool; } + +void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { + assert(protocol_server_sp && + GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr); + m_protocol_servers.push_back(protocol_server_sp); +} + +void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { + auto it = llvm::find(m_protocol_servers, protocol_server_sp); + if (it != m_protocol_servers.end()) + m_protocol_servers.erase(it); +} + +lldb::ProtocolServerSP +Debugger::GetProtocolServer(llvm::StringRef protocol) const { + for (ProtocolServerSP protocol_server_sp : m_protocol_servers) { + if (!protocol_server_sp) + continue; + if (protocol_server_sp->GetPluginName() == protocol) + return protocol_server_sp; + } + return nullptr; +} diff --git a/lldb/source/Core/PluginManager.cpp b/lldb/source/Core/PluginManager.cpp index 5d44434033c55..11085a8463803 100644 --- a/lldb/source/Core/PluginManager.cpp +++ b/lldb/source/Core/PluginManager.cpp @@ -1006,6 +1006,33 @@ void PluginManager::AutoCompleteProcessName(llvm::StringRef name, } } +#pragma mark ProtocolServer + +typedef PluginInstance<ProtocolServerCreateInstance> ProtocolServerInstance; +typedef PluginInstances<ProtocolServerInstance> ProtocolServerInstances; + +static ProtocolServerInstances &GetProtocolServerInstances() { + static ProtocolServerInstances g_instances; + return g_instances; +} + +bool PluginManager::RegisterPlugin( + llvm::StringRef name, llvm::StringRef description, + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().RegisterPlugin(name, description, + create_callback); +} + +bool PluginManager::UnregisterPlugin( + ProtocolServerCreateInstance create_callback) { + return GetProtocolServerInstances().UnregisterPlugin(create_callback); +} + +ProtocolServerCreateInstance +PluginManager::GetProtocolCreateCallbackForPluginName(llvm::StringRef name) { + return GetProtocolServerInstances().GetCallbackForName(name); +} + #pragma mark RegisterTypeBuilder struct RegisterTypeBuilderInstance diff --git a/lldb/source/Core/ProtocolServer.cpp b/lldb/source/Core/ProtocolServer.cpp new file mode 100644 index 0000000000000..d57a047afa7b2 --- /dev/null +++ b/lldb/source/Core/ProtocolServer.cpp @@ -0,0 +1,21 @@ +//===-- ProtocolServer.cpp ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Core/PluginManager.h" + +using namespace lldb_private; +using namespace lldb; + +ProtocolServerSP ProtocolServer::Create(llvm::StringRef name, + Debugger &debugger) { + if (ProtocolServerCreateInstance create_callback = + PluginManager::GetProtocolCreateCallbackForPluginName(name)) + return create_callback(debugger); + return nullptr; +} diff --git a/lldb/source/Interpreter/CommandInterpreter.cpp b/lldb/source/Interpreter/CommandInterpreter.cpp index 4f9ae104dedea..00c3472444d2e 100644 --- a/lldb/source/Interpreter/CommandInterpreter.cpp +++ b/lldb/source/Interpreter/CommandInterpreter.cpp @@ -30,6 +30,7 @@ #include "Commands/CommandObjectPlatform.h" #include "Commands/CommandObjectPlugin.h" #include "Commands/CommandObjectProcess.h" +#include "Commands/CommandObjectProtocolServer.h" #include "Commands/CommandObjectQuit.h" #include "Commands/CommandObjectRegexCommand.h" #include "Commands/CommandObjectRegister.h" @@ -574,6 +575,7 @@ void CommandInterpreter::LoadCommandDictionary() { REGISTER_COMMAND_OBJECT("platform", CommandObjectPlatform); REGISTER_COMMAND_OBJECT("plugin", CommandObjectPlugin); REGISTER_COMMAND_OBJECT("process", CommandObjectMultiwordProcess); + REGISTER_COMMAND_OBJECT("protocol-server", CommandObjectProtocolServer); REGISTER_COMMAND_OBJECT("quit", CommandObjectQuit); REGISTER_COMMAND_OBJECT("register", CommandObjectRegister); REGISTER_COMMAND_OBJECT("scripting", CommandObjectMultiwordScripting); diff --git a/lldb/source/Plugins/CMakeLists.txt b/lldb/source/Plugins/CMakeLists.txt index 854f589f45ae0..cf2da73e38931 100644 --- a/lldb/source/Plugins/CMakeLists.txt +++ b/lldb/source/Plugins/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(ObjectFile) add_subdirectory(OperatingSystem) add_subdirectory(Platform) add_subdirectory(Process) +add_subdirectory(Protocol) add_subdirectory(REPL) add_subdirectory(RegisterTypeBuilder) add_subdirectory(ScriptInterpreter) diff --git a/lldb/source/Plugins/Protocol/CMakeLists.txt b/lldb/source/Plugins/Protocol/CMakeLists.txt new file mode 100644 index 0000000000000..93b347d4cc9d8 --- /dev/null +++ b/lldb/source/Plugins/Protocol/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(MCP) diff --git a/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt new file mode 100644 index 0000000000000..db31a7a69cb33 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/CMakeLists.txt @@ -0,0 +1,13 @@ +add_lldb_library(lldbPluginProtocolServerMCP PLUGIN + MCPError.cpp + Protocol.cpp + ProtocolServerMCP.cpp + Tool.cpp + + LINK_COMPONENTS + Support + + LINK_LIBS + lldbHost + lldbUtility +) diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.cpp b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp new file mode 100644 index 0000000000000..69e1b5371af6f --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.cpp @@ -0,0 +1,31 @@ +//===-- MCPError.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "MCPError.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include <system_error> + +namespace lldb_private::mcp { + +char MCPError::ID; + +MCPError::MCPError(std::string message, int64_t error_code) + : m_message(message), m_error_code(error_code) {} + +void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } + +std::error_code MCPError::convertToErrorCode() const { + return llvm::inconvertibleErrorCode(); +} + +protocol::Error MCPError::toProtcolError() const { + return protocol::Error{m_error_code, m_message, std::nullopt}; +} + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/MCPError.h b/lldb/source/Plugins/Protocol/MCP/MCPError.h new file mode 100644 index 0000000000000..2a76a7b087e20 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/MCPError.h @@ -0,0 +1,33 @@ +//===-- MCPError.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/Error.h" +#include <string> + +namespace lldb_private::mcp { + +class MCPError : public llvm::ErrorInfo<MCPError> { +public: + static char ID; + + MCPError(std::string message, int64_t error_code); + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + + const std::string &getMessage() const { return m_message; } + + protocol::Error toProtcolError() const; + +private: + std::string m_message; + int64_t m_error_code; +}; + +} // namespace lldb_private::mcp diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.cpp b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp new file mode 100644 index 0000000000000..eb2e856fe5c66 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.cpp @@ -0,0 +1,180 @@ +//===- Protocol.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/JSON.h" + +using namespace llvm; + +namespace lldb_private::mcp::protocol { + +static bool mapRaw(const json::Value &Params, StringLiteral Prop, + std::optional<json::Value> &V, json::Path P) { + const auto *O = Params.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + const json::Value *E = O->get(Prop); + if (E) + V = std::move(*E); + return true; +} + +llvm::json::Value toJSON(const Request &R) { + json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + if (R.params) + Result.insert({"params", R.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("method", R.method)) + return false; + return mapRaw(V, "params", R.params, P); +} + +llvm::json::Value toJSON(const Error &E) { + return llvm::json::Object{ + {"code", E.code}, {"message", E.message}, {"data", E.data}}; +} + +bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("code", E.code) && O.map("message", E.message) && + O.map("data", E.data); +} + +llvm::json::Value toJSON(const ProtocolError &PE) { + return llvm::json::Object{{"id", PE.id}, {"error", toJSON(PE.error)}}; +} + +bool fromJSON(const llvm::json::Value &V, ProtocolError &PE, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("id", PE.id) && O.map("error", PE.error); +} + +llvm::json::Value toJSON(const Response &R) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; + if (R.result) + Result.insert({"result", R.result}); + if (R.error) + Result.insert({"error", R.error}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { + return true; +} + +llvm::json::Value toJSON(const Notification &N) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"method", N.method}}; + if (N.params) + Result.insert({"params", N.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("method", N.method)) + return false; + auto *Obj = V.getAsObject(); + if (!Obj) + return false; + if (auto *Params = Obj->get("params")) + N.params = *Params; + return true; +} + +llvm::json::Value toJSON(const ToolCapability &TC) { + return llvm::json::Object{{"listChanged", TC.listChanged}}; +} + +bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("listChanged", TC.listChanged); +} + +llvm::json::Value toJSON(const Capabilities &C) { + return llvm::json::Object{{"tools", C.tools}}; +} + +bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("tool", C.tools); +} + +llvm::json::Value toJSON(const TextContent &TC) { + return llvm::json::Object{{"type", "text"}, {"text", TC.text}}; +} + +bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("text", TC.text); +} + +llvm::json::Value toJSON(const TextResult &TR) { + return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; +} + +bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("content", TR.content) && O.map("isError", TR.isError); +} + +llvm::json::Value toJSON(const ToolAnnotations &TA) { + llvm::json::Object Result; + if (TA.title) + Result.insert({"title", TA.title}); + if (TA.readOnlyHint) + Result.insert({"readOnlyHint", TA.readOnlyHint}); + if (TA.destructiveHint) + Result.insert({"destructiveHint", TA.destructiveHint}); + if (TA.idempotentHint) + Result.insert({"idempotentHint", TA.idempotentHint}); + if (TA.openWorldHint) + Result.insert({"openWorldHint", TA.openWorldHint}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ToolAnnotations &TA, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.mapOptional("title", TA.title) && + O.mapOptional("readOnlyHint", TA.readOnlyHint) && + O.mapOptional("destructiveHint", TA.destructiveHint) && + O.mapOptional("idempotentHint", TA.idempotentHint) && + O.mapOptional("openWorldHint", TA.openWorldHint); +} + +llvm::json::Value toJSON(const ToolDefinition &TD) { + llvm::json::Object Result{{"name", TD.name}}; + if (TD.description) + Result.insert({"description", TD.description}); + if (TD.inputSchema) + Result.insert({"inputSchema", TD.inputSchema}); + if (TD.annotations) + Result.insert({"annotations", TD.annotations}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ToolDefinition &TD, + llvm::json::Path P) { + + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("name", TD.name) || + !O.mapOptional("description", TD.description) || + !O.mapOptional("annotations", TD.annotations)) + return false; + return mapRaw(V, "inputSchema", TD.inputSchema, P); +} + +} // namespace lldb_private::mcp::protocol diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h new file mode 100644 index 0000000000000..9e52a47957b85 --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -0,0 +1,131 @@ +//===- Protocol.h ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOL_H + +#include "llvm/Support/JSON.h" +#include <optional> +#include <string> + +namespace lldb_private::mcp::protocol { + +static llvm::StringLiteral kProtocolVersion = "2025-03-26"; + +struct Request { + uint64_t id = 0; + std::string method; + std::optional<llvm::json::Value> params; +}; + +llvm::json::Value toJSON(const Request &); +bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); + +struct Error { + int64_t code = 0; + std::string message; + std::optional<std::string> data; +}; + +llvm::json::Value toJSON(const Error &); +bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); + +struct ProtocolError { + uint64_t id = 0; + Error error; +}; + +llvm::json::Value toJSON(const ProtocolError &); +bool fromJSON(const llvm::json::Value &, ProtocolError &, llvm::json::Path); + +struct Response { + uint64_t id = 0; + std::optional<llvm::json::Value> result; + std::optional<Error> error; +}; + +llvm::json::Value toJSON(const Response &); +bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); + +struct Notification { + std::string method; + std::optional<llvm::json::Value> params; +}; + +llvm::json::Value toJSON(const Notification &); +bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); + +struct ToolCapability { + bool listChanged = false; +}; + +llvm::json::Value toJSON(const ToolCapability &); +bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); + +struct Capabilities { + ToolCapability tools; +}; + +llvm::json::Value toJSON(const Capabilities &); +bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); + +struct TextContent { + std::string text; +}; + +llvm::json::Value toJSON(const TextContent &); +bool fromJSON(const llvm::json::Value &, TextContent &, llvm::json::Path); + +struct TextResult { + std::vector<TextContent> content; + bool isError = false; +}; + +llvm::json::Value toJSON(const TextResult &); +bool fromJSON(const llvm::json::Value &, TextResult &, llvm::json::Path); + +struct ToolAnnotations { + // Human-readable title for the tool. + std::optional<std::string> title; + + /// If true, the tool does not modify its environment. + std::optional<bool> readOnlyHint; + + /// If true, the tool may perform destructive updates. + std::optional<bool> destructiveHint; + + /// If true, repeated calls with same args have no additional effect. + std::optional<bool> idempotentHint; + + /// If true, tool interacts with external entities. + std::optional<bool> openWorldHint; +}; + +llvm::json::Value toJSON(const ToolAnnotations &); +bool fromJSON(const llvm::json::Value &, ToolAnnotations &, llvm::json::Path); + +struct ToolDefinition { + /// Unique identifier for the tool. + std::string name; + + /// Human-readable description. + std::optional<std::string> description; + + // JSON Schema for the tool's parameters. + std::optional<llvm::json::Value> inputSchema; + + // Optional hints about tool behavior. + std::optional<ToolAnnotations> annotations; +}; + +llvm::json::Value toJSON(const ToolDefinition &); +bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); + +} // namespace lldb_private::mcp::protocol + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp new file mode 100644 index 0000000000000..042c7ac8b76fa --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -0,0 +1,280 @@ +//===- ProtocolServerMCP.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "ProtocolServerMCP.h" +#include "MCPError.h" +#include "lldb/Core/PluginManager.h" +#include "lldb/Utility/LLDBLog.h" +#include "lldb/Utility/Log.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Threading.h" +#include <thread> + +using namespace lldb_private; +using namespace lldb_private::mcp; +using namespace llvm; + +LLDB_PLUGIN_DEFINE(ProtocolServerMCP) + +ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) + : ProtocolServer(), m_debugger(debugger) { + AddHandler("initialize", std::bind(&ProtocolServerMCP::InitializeHandler, + this, std::placeholders::_1)); + AddHandler("tools/list", std::bind(&ProtocolServerMCP::ToolsListHandler, this, + std::placeholders::_1)); + AddHandler("tools/call", std::bind(&ProtocolServerMCP::ToolsCallHandler, this, + std::placeholders::_1)); + AddTool(std::make_unique<LLDBCommandTool>( + "lldb_command", "Run an lldb command.", m_debugger)); +} + +ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } + +void ProtocolServerMCP::Initialize() { + PluginManager::RegisterPlugin(GetPluginNameStatic(), + GetPluginDescriptionStatic(), CreateInstance); +} + +void ProtocolServerMCP::Terminate() { + PluginManager::UnregisterPlugin(CreateInstance); +} + +lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) { + return std::make_shared<ProtocolServerMCP>(debugger); +} + +llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { + return "MCP Server."; +} + +llvm::Expected<protocol::Response> +ProtocolServerMCP::Handle(protocol::Request request) { + auto it = m_handlers.find(request.method); + if (it != m_handlers.end()) + return it->second(request); + + return make_error<MCPError>( + llvm::formatv("no handler for request: {0}", request.method).str(), 1); +} + +llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + if (m_running) + return llvm::createStringError("server already running"); + + Status status; + m_listener = Socket::Create(connection.protocol, status); + if (status.Fail()) + return status.takeError(); + + status = m_listener->Listen(connection.name, /*backlog=*/5); + if (status.Fail()) + return status.takeError(); + + std::string address = + llvm::join(m_listener->GetListeningConnectionURI(), ", "); + Log *log = GetLog(LLDBLog::Host); + LLDB_LOG(log, "MCP server started with connection listeners: {0}", address); + + auto handles = m_listener->Accept(m_loop, [=](std::unique_ptr<Socket> sock) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + const std::string client_name = + llvm::formatv("client-{0}", m_clients.size() + 1).str(); + LLDB_LOG(log, "client {0} connected", client_name); + + lldb::IOObjectSP io(std::move(sock)); + + m_clients.emplace_back(io, [=]() { + llvm::set_thread_name(client_name + "-runloop"); + if (auto Err = Run(std::make_unique<JSONRPCTransport>(io, io))) + LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(Err), "MCP Error: {0}"); + }); + }); + if (llvm::Error error = handles.takeError()) + return error; + + m_read_handles = std::move(*handles); + m_loop_thread = std::thread([=] { + llvm::set_thread_name("mcp-runloop"); + m_loop.Run(); + }); + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::Stop() { + { + std::lock_guard<std::mutex> guard(m_server_mutex); + m_running = false; + } + + // Stop accepting new connections. + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + + // Wait for the main loop to exit. + if (m_loop_thread.joinable()) + m_loop_thread.join(); + + // Wait for all our clients to exit. + for (auto &client : m_clients) { + client.first->Close(); + if (client.second.joinable()) + client.second.join(); + } + + { + std::lock_guard<std::mutex> guard(m_server_mutex); + m_listener.reset(); + m_read_handles.clear(); + m_clients.clear(); + } + + return llvm::Error::success(); +} + +llvm::Error ProtocolServerMCP::Run(std::unique_ptr<JSONTransport> transport) { + Log *log = GetLog(LLDBLog::Host); + + while (true) { + llvm::Expected<protocol::Request> request = + transport->Read<protocol::Request>(std::chrono::seconds(1)); + if (request.errorIsA<TransportEOFError>() || + request.errorIsA<TransportInvalidError>()) { + consumeError(request.takeError()); + break; + } + + if (request.errorIsA<TransportTimeoutError>()) { + consumeError(request.takeError()); + continue; + } + + if (llvm::Error err = request.takeError()) { + LLDB_LOG_ERROR(log, std::move(err), "{0}"); + continue; + } + + protocol::Response response; + llvm::Expected<protocol::Response> maybe_response = Handle(*request); + if (!maybe_response) { + llvm::handleAllErrors( + maybe_response.takeError(), + [&](const MCPError &err) { + response.error.emplace(err.toProtcolError()); + }, + [&](const llvm::ErrorInfoBase &err) { + response.error.emplace(); + response.error->code = -1; + response.error->message = err.message(); + }); + } else { + response = *maybe_response; + } + + response.id = request->id; + + if (llvm::Error err = transport->Write(response)) + return err; + } + + return llvm::Error::success(); +} + +protocol::Capabilities ProtocolServerMCP::GetCapabilities() { + protocol::Capabilities capabilities; + capabilities.tools.listChanged = true; + return capabilities; +} + +void ProtocolServerMCP::AddTool(std::unique_ptr<Tool> tool) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + if (!tool) + return; + m_tools[tool->GetName()] = std::move(tool); +} + +void ProtocolServerMCP::AddHandler( + llvm::StringRef method, + std::function<protocol::Response(const protocol::Request &)> handler) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + m_handlers[method] = std::move(handler); +} + +protocol::Response +ProtocolServerMCP::InitializeHandler(const protocol::Request &request) { + protocol::Response response; + + std::string protocol_version = protocol::kProtocolVersion.str(); + if (request.params) { + if (const json::Object *param_obj = request.params->getAsObject()) { + if (const json::Value *val = param_obj->get("protocolVersion")) { + if (auto protocol_version_str = val->getAsString()) { + protocol_version = *protocol_version_str; + } + } + } + } + + response.result.emplace(llvm::json::Object{ + {"protocolVersion", protocol_version}, + {"capabilities", GetCapabilities()}, + {"serverInfo", + llvm::json::Object{{"name", kName}, {"version", kVersion}}}}); + return response; +} + +protocol::Response +ProtocolServerMCP::ToolsListHandler(const protocol::Request &request) { + protocol::Response response; + + llvm::json::Array tools; + for (const auto &tool : m_tools) + tools.emplace_back(toJSON(tool.second->GetDefinition())); + + response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + + return response; +} + +protocol::Response +ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { + protocol::Response response; + + if (!request.params) + return response; + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return response; + + const json::Value *name = param_obj->get("name"); + if (!name) + return response; + + llvm::StringRef tool_name = name->getAsString().value_or(""); + if (tool_name.empty()) + return response; + + auto it = m_tools.find(tool_name); + if (it == m_tools.end()) + return response; + + const json::Value *args = param_obj->get("arguments"); + if (!args) + return response; + + protocol::TextResult text_result = it->second->Call(*args); + response.result.emplace(toJSON(text_result)); + + return response; +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h new file mode 100644 index 0000000000000..c194019940d2c --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -0,0 +1,77 @@ +//===- ProtocolServerMCP.h ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H +#define LLDB_PLUGINS_PROTOCOL_MCP_PROTOCOLSERVERMCP_H + +#include "Protocol.h" +#include "Tool.h" +#include "lldb/Core/ProtocolServer.h" +#include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/Socket.h" +#include "llvm/ADT/StringMap.h" +#include <thread> + +namespace lldb_private::mcp { + +class ProtocolServerMCP : public ProtocolServer { +public: + ProtocolServerMCP(Debugger &debugger); + virtual ~ProtocolServerMCP() override; + + virtual llvm::Error Start(ProtocolServer::Connection connection) override; + virtual llvm::Error Stop() override; + + void AddTool(std::unique_ptr<Tool> tool); + void AddHandler( + llvm::StringRef method, + std::function<protocol::Response(const protocol::Request &)> handler); + + static void Initialize(); + static void Terminate(); + + static llvm::StringRef GetPluginNameStatic() { return "MCP"; } + static llvm::StringRef GetPluginDescriptionStatic(); + + static lldb::ProtocolServerSP CreateInstance(Debugger &debugger); + + llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); } + +private: + llvm::Error Run(std::unique_ptr<JSONTransport> transport); + llvm::Expected<protocol::Response> Handle(protocol::Request request); + + protocol::Response InitializeHandler(const protocol::Request &); + protocol::Response ToolsListHandler(const protocol::Request &); + protocol::Response ToolsCallHandler(const protocol::Request &); + + protocol::Capabilities GetCapabilities(); + + llvm::StringLiteral kName = "lldb-mcp"; + llvm::StringLiteral kVersion = "0.1.0"; + + Debugger &m_debugger; + + bool m_running = false; + + MainLoop m_loop; + std::thread m_loop_thread; + + std::unique_ptr<Socket> m_listener; + std::vector<MainLoopBase::ReadHandleUP> m_read_handles; + std::vector<std::pair<lldb::IOObjectSP, std::thread>> m_clients; + + std::mutex m_server_mutex; + llvm::StringMap<std::unique_ptr<Tool>> m_tools; + llvm::StringMap<std::function<protocol::Response(const protocol::Request &)>> + m_handlers; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp new file mode 100644 index 0000000000000..8a83ab907409e --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -0,0 +1,72 @@ +//===- Tool.cpp -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Tool.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" + +using namespace lldb_private::mcp; +using namespace llvm; + +Tool::Tool(std::string name, std::string description) + : m_name(std::move(name)), m_description(std::move(description)) {} + +protocol::ToolDefinition Tool::GetDefinition() const { + protocol::ToolDefinition definition; + definition.name = m_name; + definition.description.emplace(m_description); + + if (std::optional<llvm::json::Value> input_schema = GetSchema()) + definition.inputSchema = *input_schema; + + if (m_annotations) + definition.annotations = m_annotations; + + return definition; +} + +LLDBCommandTool::LLDBCommandTool(std::string name, std::string description, + Debugger &debugger) + : Tool(std::move(name), std::move(description)), m_debugger(debugger) {} + +protocol::TextResult LLDBCommandTool::Call(const llvm::json::Value &args) { + std::string arguments; + if (const json::Object *args_obj = args.getAsObject()) { + if (const json::Value *s = args_obj->get("arguments")) { + arguments = s->getAsString().value_or(""); + } + } + + CommandReturnObject result(/*colors=*/false); + m_debugger.GetCommandInterpreter().HandleCommand(arguments.c_str(), + eLazyBoolYes, result); + + std::string output; + llvm::StringRef output_str = result.GetOutputString(); + if (!output_str.empty()) + output += output_str.str(); + + std::string err_str = result.GetErrorString(); + if (!err_str.empty()) { + if (!output.empty()) + output += '\n'; + output += err_str; + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); + return text_result; +} + +std::optional<llvm::json::Value> LLDBCommandTool::GetSchema() const { + llvm::json::Object str_type{{"type", "string"}}; + llvm::json::Object properties{{"arguments", std::move(str_type)}}; + llvm::json::Object schema{{"type", "object"}, + {"properties", std::move(properties)}}; + return schema; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h new file mode 100644 index 0000000000000..1b233d592397b --- /dev/null +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -0,0 +1,61 @@ +//===- Tool.h -------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H +#define LLDB_PLUGINS_PROTOCOL_MCP_TOOL_H + +#include "Protocol.h" +#include "lldb/Core/Debugger.h" +#include "llvm/Support/JSON.h" +#include <string> + +namespace lldb_private::mcp { + +class Tool { +public: + Tool(std::string name, std::string description); + virtual ~Tool() = default; + + virtual protocol::TextResult Call(const llvm::json::Value &args) = 0; + + virtual std::optional<llvm::json::Value> GetSchema() const { + return std::nullopt; + } + + protocol::ToolDefinition GetDefinition() const; + + const std::string &GetName() { return m_name; } + +protected: + void SetAnnotations(protocol::ToolAnnotations annotations) { + m_annotations.emplace(std::move(annotations)); + } + +private: + std::string m_name; + std::string m_description; + std::optional<protocol::ToolAnnotations> m_annotations; +}; + +class LLDBCommandTool : public mcp::Tool { +public: + LLDBCommandTool(std::string name, std::string description, + Debugger &debugger); + ~LLDBCommandTool() = default; + + virtual mcp::protocol::TextResult + Call(const llvm::json::Value &args) override; + + virtual std::optional<llvm::json::Value> GetSchema() const override; + +private: + Debugger &m_debugger; +}; +} // namespace lldb_private::mcp + +#endif _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits