https://github.com/JDevlieghere updated https://github.com/llvm/llvm-project/pull/145616
>From 5ed60a3aa5022694a593e2885ad6e563df6ffa37 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere <jo...@devlieghere.com> Date: Tue, 24 Jun 2025 16:22:46 -0700 Subject: [PATCH 1/3] [lldb] Make MCP server instance global Rather than having one MCP server per debugger, make the MCP server global and pass a debugger id along with tool invocations that require one. This PR also adds a second tool to list the available debuggers with their targets so the model can decide which debugger instance to use. --- lldb/include/lldb/Core/Debugger.h | 6 -- lldb/include/lldb/Core/ProtocolServer.h | 5 +- lldb/include/lldb/lldb-forward.h | 2 +- lldb/include/lldb/lldb-private-interfaces.h | 3 +- .../Commands/CommandObjectProtocolServer.cpp | 51 +++---------- lldb/source/Core/Debugger.cpp | 23 ------ lldb/source/Core/ProtocolServer.cpp | 34 ++++++++- .../Protocol/MCP/ProtocolServerMCP.cpp | 26 +++---- .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 6 +- lldb/source/Plugins/Protocol/MCP/Tool.cpp | 74 +++++++++++++++---- lldb/source/Plugins/Protocol/MCP/Tool.h | 24 +++--- .../Protocol/ProtocolMCPServerTest.cpp | 20 ++--- 12 files changed, 143 insertions(+), 131 deletions(-) diff --git a/lldb/include/lldb/Core/Debugger.h b/lldb/include/lldb/Core/Debugger.h index 9f82466a83417..2087ef2a11562 100644 --- a/lldb/include/lldb/Core/Debugger.h +++ b/lldb/include/lldb/Core/Debugger.h @@ -602,10 +602,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>, void FlushProcessOutput(Process &process, bool flush_stdout, bool flush_stderr); - void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp); - void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp); - lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const; - SourceManager::SourceFileCache &GetSourceFileCache() { return m_source_file_cache; } @@ -776,8 +772,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>, mutable std::mutex m_progress_reports_mutex; /// @} - llvm::SmallVector<lldb::ProtocolServerSP> m_protocol_servers; - std::mutex m_destroy_callback_mutex; lldb::callback_token_t m_destroy_callback_next_token = 0; struct DestroyCallbackInfo { diff --git a/lldb/include/lldb/Core/ProtocolServer.h b/lldb/include/lldb/Core/ProtocolServer.h index fafe460904323..937256c10aec1 100644 --- a/lldb/include/lldb/Core/ProtocolServer.h +++ b/lldb/include/lldb/Core/ProtocolServer.h @@ -20,8 +20,9 @@ class ProtocolServer : public PluginInterface { ProtocolServer() = default; virtual ~ProtocolServer() = default; - static lldb::ProtocolServerSP Create(llvm::StringRef name, - Debugger &debugger); + static ProtocolServer *GetOrCreate(llvm::StringRef name); + + static std::vector<llvm::StringRef> GetSupportedProtocols(); struct Connection { Socket::SocketProtocol protocol; diff --git a/lldb/include/lldb/lldb-forward.h b/lldb/include/lldb/lldb-forward.h index 558818e8e2309..2bc85a2d2afa6 100644 --- a/lldb/include/lldb/lldb-forward.h +++ b/lldb/include/lldb/lldb-forward.h @@ -391,7 +391,7 @@ typedef std::shared_ptr<lldb_private::Platform> PlatformSP; typedef std::shared_ptr<lldb_private::Process> ProcessSP; typedef std::shared_ptr<lldb_private::ProcessAttachInfo> ProcessAttachInfoSP; typedef std::shared_ptr<lldb_private::ProcessLaunchInfo> ProcessLaunchInfoSP; -typedef std::shared_ptr<lldb_private::ProtocolServer> ProtocolServerSP; +typedef std::unique_ptr<lldb_private::ProtocolServer> ProtocolServerUP; typedef std::weak_ptr<lldb_private::Process> ProcessWP; typedef std::shared_ptr<lldb_private::RegisterCheckpoint> RegisterCheckpointSP; typedef std::shared_ptr<lldb_private::RegisterContext> RegisterContextSP; diff --git a/lldb/include/lldb/lldb-private-interfaces.h b/lldb/include/lldb/lldb-private-interfaces.h index 34eaaa8e581e9..249b25c251ac2 100644 --- a/lldb/include/lldb/lldb-private-interfaces.h +++ b/lldb/include/lldb/lldb-private-interfaces.h @@ -81,8 +81,7 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force, typedef lldb::ProcessSP (*ProcessCreateInstance)( lldb::TargetSP target_sp, lldb::ListenerSP listener_sp, const FileSpec *crash_file_path, bool can_connect); -typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)( - Debugger &debugger); +typedef lldb::ProtocolServerUP (*ProtocolServerCreateInstance)(); typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)( Target &target); typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)( diff --git a/lldb/source/Commands/CommandObjectProtocolServer.cpp b/lldb/source/Commands/CommandObjectProtocolServer.cpp index 115754769f3e3..55bd42ed1a533 100644 --- a/lldb/source/Commands/CommandObjectProtocolServer.cpp +++ b/lldb/source/Commands/CommandObjectProtocolServer.cpp @@ -23,20 +23,6 @@ using namespace lldb_private; #define LLDB_OPTIONS_mcp #include "CommandOptions.inc" -static std::vector<llvm::StringRef> GetSupportedProtocols() { - std::vector<llvm::StringRef> supported_protocols; - size_t i = 0; - - for (llvm::StringRef protocol_name = - PluginManager::GetProtocolServerPluginNameAtIndex(i++); - !protocol_name.empty(); - protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) { - supported_protocols.push_back(protocol_name); - } - - return supported_protocols; -} - class CommandObjectProtocolServerStart : public CommandObjectParsed { public: CommandObjectProtocolServerStart(CommandInterpreter &interpreter) @@ -57,12 +43,11 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed { } llvm::StringRef protocol = args.GetArgumentAtIndex(0); - std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols(); - if (llvm::find(supported_protocols, protocol) == - supported_protocols.end()) { + ProtocolServer *server = ProtocolServer::GetOrCreate(protocol); + if (!server) { result.AppendErrorWithFormatv( "unsupported protocol: {0}. Supported protocols are: {1}", protocol, - llvm::join(GetSupportedProtocols(), ", ")); + llvm::join(ProtocolServer::GetSupportedProtocols(), ", ")); return; } @@ -72,10 +57,6 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed { } llvm::StringRef connection_uri = args.GetArgumentAtIndex(1); - ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol); - if (!server_sp) - server_sp = ProtocolServer::Create(protocol, GetDebugger()); - const char *connection_error = "unsupported connection specifier, expected 'accept:///path' or " "'listen://[host]:port', got '{0}'."; @@ -98,14 +79,12 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed { formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname, uri->port.value_or(0)); - if (llvm::Error error = server_sp->Start(connection)) { + if (llvm::Error error = server->Start(connection)) { result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); return; } - GetDebugger().AddProtocolServer(server_sp); - - if (Socket *socket = server_sp->GetSocket()) { + if (Socket *socket = server->GetSocket()) { std::string address = llvm::join(socket->GetListeningConnectionURI(), ", "); result.AppendMessageWithFormatv( @@ -134,30 +113,18 @@ class CommandObjectProtocolServerStop : public CommandObjectParsed { } llvm::StringRef protocol = args.GetArgumentAtIndex(0); - std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols(); - if (llvm::find(supported_protocols, protocol) == - supported_protocols.end()) { + ProtocolServer *server = ProtocolServer::GetOrCreate(protocol); + if (!server) { result.AppendErrorWithFormatv( "unsupported protocol: {0}. Supported protocols are: {1}", protocol, - llvm::join(GetSupportedProtocols(), ", ")); + llvm::join(ProtocolServer::GetSupportedProtocols(), ", ")); return; } - Debugger &debugger = GetDebugger(); - - ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol); - if (!server_sp) { - result.AppendError( - llvm::formatv("no {0} protocol server running", protocol).str()); - return; - } - - if (llvm::Error error = server_sp->Stop()) { + if (llvm::Error error = server->Stop()) { result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); return; } - - debugger.RemoveProtocolServer(server_sp); } }; diff --git a/lldb/source/Core/Debugger.cpp b/lldb/source/Core/Debugger.cpp index 33d1053fd8a65..445baf1f63785 100644 --- a/lldb/source/Core/Debugger.cpp +++ b/lldb/source/Core/Debugger.cpp @@ -2376,26 +2376,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() { "Debugger::GetThreadPool called before Debugger::Initialize"); return *g_thread_pool; } - -void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { - assert(protocol_server_sp && - GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr); - m_protocol_servers.push_back(protocol_server_sp); -} - -void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) { - auto it = llvm::find(m_protocol_servers, protocol_server_sp); - if (it != m_protocol_servers.end()) - m_protocol_servers.erase(it); -} - -lldb::ProtocolServerSP -Debugger::GetProtocolServer(llvm::StringRef protocol) const { - for (ProtocolServerSP protocol_server_sp : m_protocol_servers) { - if (!protocol_server_sp) - continue; - if (protocol_server_sp->GetPluginName() == protocol) - return protocol_server_sp; - } - return nullptr; -} diff --git a/lldb/source/Core/ProtocolServer.cpp b/lldb/source/Core/ProtocolServer.cpp index d57a047afa7b2..41636cdacdecc 100644 --- a/lldb/source/Core/ProtocolServer.cpp +++ b/lldb/source/Core/ProtocolServer.cpp @@ -12,10 +12,36 @@ using namespace lldb_private; using namespace lldb; -ProtocolServerSP ProtocolServer::Create(llvm::StringRef name, - Debugger &debugger) { +ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) { + static std::mutex g_mutex; + static llvm::StringMap<ProtocolServerUP> g_protocol_server_instances; + + std::lock_guard<std::mutex> guard(g_mutex); + + auto it = g_protocol_server_instances.find(name); + if (it != g_protocol_server_instances.end()) + return it->second.get(); + if (ProtocolServerCreateInstance create_callback = - PluginManager::GetProtocolCreateCallbackForPluginName(name)) - return create_callback(debugger); + PluginManager::GetProtocolCreateCallbackForPluginName(name)) { + auto pair = + g_protocol_server_instances.try_emplace(name, create_callback()); + return pair.first->second.get(); + } + return nullptr; } + +std::vector<llvm::StringRef> ProtocolServer::GetSupportedProtocols() { + std::vector<llvm::StringRef> supported_protocols; + size_t i = 0; + + for (llvm::StringRef protocol_name = + PluginManager::GetProtocolServerPluginNameAtIndex(i++); + !protocol_name.empty(); + protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) { + supported_protocols.push_back(protocol_name); + } + + return supported_protocols; +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index c3cd9a88c20bf..fcc1343b150f5 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP) static constexpr size_t kChunkSize = 1024; -ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) - : ProtocolServer(), m_debugger(debugger) { +ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() { AddRequestHandler("initialize", std::bind(&ProtocolServerMCP::InitializeHandler, this, std::placeholders::_1)); @@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger) "notifications/initialized", [](const protocol::Notification &) { LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete"); }); - AddTool(std::make_unique<LLDBCommandTool>( - "lldb_command", "Run an lldb command.", m_debugger)); + AddTool( + std::make_unique<CommandTool>("lldb_command", "Run an lldb command.")); + AddTool(std::make_unique<DebuggerListTool>( + "lldb_debugger_list", "List debugger instances with their debugger_id.")); } ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); } @@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() { PluginManager::UnregisterPlugin(CreateInstance); } -lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) { - return std::make_shared<ProtocolServerMCP>(debugger); +lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() { + return std::make_unique<ProtocolServerMCP>(); } llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { @@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { std::lock_guard<std::mutex> guard(m_server_mutex); if (m_running) - return llvm::createStringError("server already running"); + return llvm::createStringError("the MCP server is already running"); Status status; m_listener = Socket::Create(connection.protocol, status); @@ -162,10 +163,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { if (llvm::Error error = handles.takeError()) return error; + m_running = true; m_listen_handlers = std::move(*handles); m_loop_thread = std::thread([=] { - llvm::set_thread_name( - llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID())); + llvm::set_thread_name("protocol-server.mcp"); m_loop.Run(); }); @@ -175,6 +176,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::Error ProtocolServerMCP::Stop() { { std::lock_guard<std::mutex> guard(m_server_mutex); + if (!m_running) + return createStringError("the MCP sever is not running"); m_running = false; } @@ -312,10 +315,7 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); const json::Value *args = param_obj->get("arguments"); - if (!args) - return llvm::createStringError("no tool arguments"); - - llvm::Expected<protocol::TextResult> text_result = it->second->Call(*args); + llvm::Expected<protocol::TextResult> text_result = it->second->Call(args); if (!text_result) return text_result.takeError(); diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 52bb92a04a802..d55882cc8ab09 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -21,7 +21,7 @@ namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { public: - ProtocolServerMCP(Debugger &debugger); + ProtocolServerMCP(); virtual ~ProtocolServerMCP() override; virtual llvm::Error Start(ProtocolServer::Connection connection) override; @@ -33,7 +33,7 @@ class ProtocolServerMCP : public ProtocolServer { static llvm::StringRef GetPluginNameStatic() { return "MCP"; } static llvm::StringRef GetPluginDescriptionStatic(); - static lldb::ProtocolServerSP CreateInstance(Debugger &debugger); + static lldb::ProtocolServerUP CreateInstance(); llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); } @@ -71,8 +71,6 @@ class ProtocolServerMCP : public ProtocolServer { llvm::StringLiteral kName = "lldb-mcp"; llvm::StringLiteral kVersion = "0.1.0"; - Debugger &m_debugger; - bool m_running = false; MainLoop m_loop; diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index de8fcc8f3cb4c..6903a4a160461 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -7,20 +7,23 @@ //===----------------------------------------------------------------------===// #include "Tool.h" +#include "lldb/Core/Module.h" #include "lldb/Interpreter/CommandInterpreter.h" #include "lldb/Interpreter/CommandReturnObject.h" using namespace lldb_private::mcp; using namespace llvm; -struct LLDBCommandToolArguments { +struct CommandToolArguments { + uint64_t debugger_id; std::string arguments; }; -bool fromJSON(const llvm::json::Value &V, LLDBCommandToolArguments &A, +bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - return O && O.map("arguments", A.arguments); + return O && O.map("debugger_id", A.debugger_id) && + O.mapOptional("arguments", A.arguments); } Tool::Tool(std::string name, std::string description) @@ -37,22 +40,27 @@ protocol::ToolDefinition Tool::GetDefinition() const { return definition; } -LLDBCommandTool::LLDBCommandTool(std::string name, std::string description, - Debugger &debugger) - : Tool(std::move(name), std::move(description)), m_debugger(debugger) {} - llvm::Expected<protocol::TextResult> -LLDBCommandTool::Call(const llvm::json::Value &args) { +CommandTool::Call(const llvm::json::Value *args) { + if (!args) + return createStringError("no tool arguments"); + llvm::json::Path::Root root; - LLDBCommandToolArguments arguments; - if (!fromJSON(args, arguments, root)) + CommandToolArguments arguments; + if (!fromJSON(*args, arguments, root)) return root.getError(); + lldb::DebuggerSP debugger_sp = + Debugger::GetDebuggerAtIndex(arguments.debugger_id); + if (!debugger_sp) + return createStringError( + llvm::formatv("no debugger with id {0}", arguments.debugger_id)); + // FIXME: Disallow certain commands and their aliases. CommandReturnObject result(/*colors=*/false); - m_debugger.GetCommandInterpreter().HandleCommand(arguments.arguments.c_str(), - eLazyBoolYes, result); + debugger_sp->GetCommandInterpreter().HandleCommand( + arguments.arguments.c_str(), eLazyBoolYes, result); std::string output; llvm::StringRef output_str = result.GetOutputString(); @@ -72,10 +80,46 @@ LLDBCommandTool::Call(const llvm::json::Value &args) { return text_result; } -std::optional<llvm::json::Value> LLDBCommandTool::GetSchema() const { +std::optional<llvm::json::Value> CommandTool::GetSchema() const { + llvm::json::Object id_type{{"type", "number"}}; llvm::json::Object str_type{{"type", "string"}}; - llvm::json::Object properties{{"arguments", std::move(str_type)}}; + llvm::json::Object properties{{"debugger_id", std::move(id_type)}, + {"arguments", std::move(str_type)}}; + llvm::json::Array required{"debugger_id"}; llvm::json::Object schema{{"type", "object"}, - {"properties", std::move(properties)}}; + {"properties", std::move(properties)}, + {"required", std::move(required)}}; return schema; } + +llvm::Expected<protocol::TextResult> +DebuggerListTool::Call(const llvm::json::Value *args) { + llvm::json::Path::Root root; + + std::string output; + llvm::raw_string_ostream os(output); + + const size_t num_debuggers = Debugger::GetNumDebuggers(); + for (size_t i = 0; i < num_debuggers; ++i) { + lldb::DebuggerSP debugger_sp = Debugger::GetDebuggerAtIndex(i); + if (!debugger_sp) + continue; + + os << "- debugger " << i << '\n'; + + const TargetList &target_list = debugger_sp->GetTargetList(); + const size_t num_targets = target_list.GetNumTargets(); + for (size_t j = 0; j < num_targets; ++j) { + lldb::TargetSP target_sp = target_list.GetTargetAtIndex(i); + if (!target_sp) + continue; + os << " - target " << j; + if (Module *exe_module = target_sp->GetExecutableModulePointer()) + os << " " << exe_module->GetFileSpec().GetPath(); + } + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); + return text_result; +} diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index 57a5125813b76..6ca987db30cac 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -22,10 +22,10 @@ class Tool { virtual ~Tool() = default; virtual llvm::Expected<protocol::TextResult> - Call(const llvm::json::Value &args) = 0; + Call(const llvm::json::Value *args) = 0; virtual std::optional<llvm::json::Value> GetSchema() const { - return std::nullopt; + return llvm::json::Object{{"type", "object"}}; } protocol::ToolDefinition GetDefinition() const; @@ -37,20 +37,26 @@ class Tool { std::string m_description; }; -class LLDBCommandTool : public mcp::Tool { +class CommandTool : public mcp::Tool { public: - LLDBCommandTool(std::string name, std::string description, - Debugger &debugger); - ~LLDBCommandTool() = default; + using mcp::Tool::Tool; + ~CommandTool() = default; virtual llvm::Expected<protocol::TextResult> - Call(const llvm::json::Value &args) override; + Call(const llvm::json::Value *args) override; virtual std::optional<llvm::json::Value> GetSchema() const override; +}; -private: - Debugger &m_debugger; +class DebuggerListTool : public mcp::Tool { +public: + using mcp::Tool::Tool; + ~DebuggerListTool() = default; + + virtual llvm::Expected<protocol::TextResult> + Call(const llvm::json::Value *args) override; }; + } // namespace lldb_private::mcp #endif diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index 72b8c7b1fd825..5634718c67cbe 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -46,9 +46,9 @@ class TestTool : public mcp::Tool { using mcp::Tool::Tool; virtual llvm::Expected<mcp::protocol::TextResult> - Call(const llvm::json::Value &args) override { + Call(const llvm::json::Value *args) override { std::string argument; - if (const json::Object *args_obj = args.getAsObject()) { + if (const json::Object *args_obj = args->getAsObject()) { if (const json::Value *s = args_obj->get("arguments")) { argument = s->getAsString().value_or(""); } @@ -66,7 +66,7 @@ class ErrorTool : public mcp::Tool { using mcp::Tool::Tool; virtual llvm::Expected<mcp::protocol::TextResult> - Call(const llvm::json::Value &args) override { + Call(const llvm::json::Value *args) override { return llvm::createStringError("error"); } }; @@ -77,7 +77,7 @@ class FailTool : public mcp::Tool { using mcp::Tool::Tool; virtual llvm::Expected<mcp::protocol::TextResult> - Call(const llvm::json::Value &args) override { + Call(const llvm::json::Value *args) override { mcp::protocol::TextResult text_result; text_result.content.emplace_back(mcp::protocol::TextContent{{"failed"}}); text_result.isError = true; @@ -115,7 +115,7 @@ class ProtocolServerMCPTest : public ::testing::Test { ProtocolServer::Connection connection; connection.protocol = Socket::SocketProtocol::ProtocolTcp; connection.name = llvm::formatv("{0}:0", k_localhost).str(); - m_server_up = std::make_unique<TestProtocolServerMCP>(*m_debugger_sp); + m_server_up = std::make_unique<TestProtocolServerMCP>(); m_server_up->AddTool(std::make_unique<TestTool>("test", "test tool")); ASSERT_THAT_ERROR(m_server_up->Start(connection), llvm::Succeeded()); @@ -145,7 +145,7 @@ class ProtocolServerMCPTest : public ::testing::Test { TEST_F(ProtocolServerMCPTest, Intialization) { llvm::StringLiteral request = - R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"claude-ai","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; + R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; llvm::StringLiteral response = R"json({"jsonrpc":"2.0","id":0,"result":{"capabilities":{"tools":{"listChanged":true}},"protocolVersion":"2024-11-05","serverInfo":{"name":"lldb-mcp","version":"0.1.0"}}})json"; @@ -167,7 +167,7 @@ TEST_F(ProtocolServerMCPTest, ToolsList) { llvm::StringLiteral request = R"json({"method":"tools/list","params":{},"jsonrpc":"2.0","id":1})json"; llvm::StringLiteral response = - R"json({"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","name":"test"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"}},"type":"object"},"name":"lldb_command"}]}})json"; + R"json( {"id":1,"jsonrpc":"2.0","result":{"tools":[{"description":"test tool","inputSchema":{"type":"object"},"name":"test"},{"description":"List debugger instances with their debugger_id.","inputSchema":{"type":"object"},"name":"lldb_debugger_list"},{"description":"Run an lldb command.","inputSchema":{"properties":{"arguments":{"type":"string"},"debugger_id":{"type":"number"}},"required":["debugger_id"],"type":"object"},"name":"lldb_command"}]}})json"; ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); @@ -205,7 +205,7 @@ TEST_F(ProtocolServerMCPTest, ResourcesList) { TEST_F(ProtocolServerMCPTest, ToolsCall) { llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + R"json({"method":"tools/call","params":{"name":"test","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"foo","type":"text"}],"isError":false}})json"; @@ -227,7 +227,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallError) { m_server_up->AddTool(std::make_unique<ErrorTool>("error", "error tool")); llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + R"json({"method":"tools/call","params":{"name":"error","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = R"json({"error":{"code":-1,"message":"error"},"id":11,"jsonrpc":"2.0"})json"; @@ -249,7 +249,7 @@ TEST_F(ProtocolServerMCPTest, ToolsCallFail) { m_server_up->AddTool(std::make_unique<FailTool>("fail", "fail tool")); llvm::StringLiteral request = - R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo"}},"jsonrpc":"2.0","id":11})json"; + R"json({"method":"tools/call","params":{"name":"fail","arguments":{"arguments":"foo","debugger_id":0}},"jsonrpc":"2.0","id":11})json"; llvm::StringLiteral response = R"json({"id":11,"jsonrpc":"2.0","result":{"content":[{"text":"failed","type":"text"}],"isError":true}})json"; >From 33bba7ea8bb7eeb005527b1ec49e6f23b6981434 Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere <jo...@devlieghere.com> Date: Wed, 25 Jun 2025 09:59:54 -0700 Subject: [PATCH 2/3] Address John's feedback --- lldb/source/Plugins/Protocol/MCP/Protocol.h | 2 ++ .../Protocol/MCP/ProtocolServerMCP.cpp | 8 +++-- lldb/source/Plugins/Protocol/MCP/Tool.cpp | 30 +++++++++++++++---- lldb/source/Plugins/Protocol/MCP/Tool.h | 6 ++-- 4 files changed, 35 insertions(+), 11 deletions(-) diff --git a/lldb/source/Plugins/Protocol/MCP/Protocol.h b/lldb/source/Plugins/Protocol/MCP/Protocol.h index e315899406573..cb790dc4e5596 100644 --- a/lldb/source/Plugins/Protocol/MCP/Protocol.h +++ b/lldb/source/Plugins/Protocol/MCP/Protocol.h @@ -123,6 +123,8 @@ using Message = std::variant<Request, Response, Notification, Error>; bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); llvm::json::Value toJSON(const Message &); +using ToolArguments = std::variant<std::monostate, llvm::json::Value>; + } // namespace lldb_private::mcp::protocol #endif diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index fcc1343b150f5..3180341b50b91 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -314,8 +314,12 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) { if (it == m_tools.end()) return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name)); - const json::Value *args = param_obj->get("arguments"); - llvm::Expected<protocol::TextResult> text_result = it->second->Call(args); + protocol::ToolArguments tool_args; + if (const json::Value *args = param_obj->get("arguments")) + tool_args = *args; + + llvm::Expected<protocol::TextResult> text_result = + it->second->Call(tool_args); if (!text_result) return text_result.takeError(); diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index 6903a4a160461..181ecd6cb8fb8 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -14,6 +14,7 @@ using namespace lldb_private::mcp; using namespace llvm; +namespace { struct CommandToolArguments { uint64_t debugger_id; std::string arguments; @@ -26,6 +27,8 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, O.mapOptional("arguments", A.arguments); } +} // namespace + Tool::Tool(std::string name, std::string description) : m_name(std::move(name)), m_description(std::move(description)) {} @@ -41,14 +44,14 @@ protocol::ToolDefinition Tool::GetDefinition() const { } llvm::Expected<protocol::TextResult> -CommandTool::Call(const llvm::json::Value *args) { - if (!args) - return createStringError("no tool arguments"); +CommandTool::Call(const protocol::ToolArguments &args) { + if (!std::holds_alternative<json::Value>(args)) + return createStringError("CommandTool requires arguments"); - llvm::json::Path::Root root; + json::Path::Root root; CommandToolArguments arguments; - if (!fromJSON(*args, arguments, root)) + if (!fromJSON(std::get<json::Value>(args), arguments, root)) return root.getError(); lldb::DebuggerSP debugger_sp = @@ -93,9 +96,22 @@ std::optional<llvm::json::Value> CommandTool::GetSchema() const { } llvm::Expected<protocol::TextResult> -DebuggerListTool::Call(const llvm::json::Value *args) { +DebuggerListTool::Call(const protocol::ToolArguments &args) { + if (!std::holds_alternative<std::monostate>(args)) + return createStringError("DebuggerListTool takes no arguments"); + llvm::json::Path::Root root; + // Return a nested Markdown list with debuggers and target. + // Example output: + // + // - debugger 0 + // - target 0 /path/to/foo + // - target 1 + // - debugger 1 + // - target 0 /path/to/bar + // + // FIXME: Use Structured Content when we adopt protocol version 2025-06-18. std::string output; llvm::raw_string_ostream os(output); @@ -114,8 +130,10 @@ DebuggerListTool::Call(const llvm::json::Value *args) { if (!target_sp) continue; os << " - target " << j; + // Append the module path if we have one. if (Module *exe_module = target_sp->GetExecutableModulePointer()) os << " " << exe_module->GetFileSpec().GetPath(); + os << '\n'; } } diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.h b/lldb/source/Plugins/Protocol/MCP/Tool.h index 6ca987db30cac..74ab04b472522 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.h +++ b/lldb/source/Plugins/Protocol/MCP/Tool.h @@ -22,7 +22,7 @@ class Tool { virtual ~Tool() = default; virtual llvm::Expected<protocol::TextResult> - Call(const llvm::json::Value *args) = 0; + Call(const protocol::ToolArguments &args) = 0; virtual std::optional<llvm::json::Value> GetSchema() const { return llvm::json::Object{{"type", "object"}}; @@ -43,7 +43,7 @@ class CommandTool : public mcp::Tool { ~CommandTool() = default; virtual llvm::Expected<protocol::TextResult> - Call(const llvm::json::Value *args) override; + Call(const protocol::ToolArguments &args) override; virtual std::optional<llvm::json::Value> GetSchema() const override; }; @@ -54,7 +54,7 @@ class DebuggerListTool : public mcp::Tool { ~DebuggerListTool() = default; virtual llvm::Expected<protocol::TextResult> - Call(const llvm::json::Value *args) override; + Call(const protocol::ToolArguments &args) override; }; } // namespace lldb_private::mcp >From a3bfe3854ed3c18f09d955d583d0dbf0576f1c7a Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere <jo...@devlieghere.com> Date: Wed, 25 Jun 2025 10:30:46 -0700 Subject: [PATCH 3/3] Fix small bug, add text result helper --- lldb/source/Plugins/Protocol/MCP/Tool.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/lldb/source/Plugins/Protocol/MCP/Tool.cpp b/lldb/source/Plugins/Protocol/MCP/Tool.cpp index 181ecd6cb8fb8..7445e3552724f 100644 --- a/lldb/source/Plugins/Protocol/MCP/Tool.cpp +++ b/lldb/source/Plugins/Protocol/MCP/Tool.cpp @@ -27,6 +27,16 @@ bool fromJSON(const llvm::json::Value &V, CommandToolArguments &A, O.mapOptional("arguments", A.arguments); } +/// Helper function to create a TextResult from a string output. +static lldb_private::mcp::protocol::TextResult +createTextResult(std::string output, bool is_error = false) { + lldb_private::mcp::protocol::TextResult text_result; + text_result.content.emplace_back( + lldb_private::mcp::protocol::TextContent{{std::move(output)}}); + text_result.isError = is_error; + return text_result; +} + } // namespace Tool::Tool(std::string name, std::string description) @@ -77,10 +87,7 @@ CommandTool::Call(const protocol::ToolArguments &args) { output += err_str; } - mcp::protocol::TextResult text_result; - text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); - text_result.isError = !result.Succeeded(); - return text_result; + return createTextResult(output, !result.Succeeded()); } std::optional<llvm::json::Value> CommandTool::GetSchema() const { @@ -126,7 +133,7 @@ DebuggerListTool::Call(const protocol::ToolArguments &args) { const TargetList &target_list = debugger_sp->GetTargetList(); const size_t num_targets = target_list.GetNumTargets(); for (size_t j = 0; j < num_targets; ++j) { - lldb::TargetSP target_sp = target_list.GetTargetAtIndex(i); + lldb::TargetSP target_sp = target_list.GetTargetAtIndex(j); if (!target_sp) continue; os << " - target " << j; @@ -137,7 +144,5 @@ DebuggerListTool::Call(const protocol::ToolArguments &args) { } } - mcp::protocol::TextResult text_result; - text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); - return text_result; + return createTextResult(output); } _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits