https://github.com/JDevlieghere updated https://github.com/llvm/llvm-project/pull/143628
>From 4b0e386016e3ee0d74a673e4f857b5537309614f Mon Sep 17 00:00:00 2001 From: Jonas Devlieghere <jo...@devlieghere.com> Date: Wed, 4 Jun 2025 23:56:57 -0700 Subject: [PATCH] [lldb] Add MCP support to LLDB (PoC) --- lldb/include/lldb/MCP/MCPError.h | 33 +++ lldb/include/lldb/MCP/Protocol.h | 131 +++++++++ lldb/include/lldb/MCP/Server.h | 67 +++++ lldb/include/lldb/MCP/Tool.h | 45 ++++ lldb/include/lldb/MCP/Transport.h | 100 +++++++ lldb/source/CMakeLists.txt | 1 + lldb/source/Commands/CMakeLists.txt | 2 + lldb/source/Commands/CommandObjectMCP.cpp | 162 +++++++++++ lldb/source/Commands/CommandObjectMCP.h | 24 ++ .../source/Interpreter/CommandInterpreter.cpp | 2 + lldb/source/MCP/CMakeLists.txt | 14 + lldb/source/MCP/MCPError.cpp | 31 +++ lldb/source/MCP/Protocol.cpp | 180 +++++++++++++ lldb/source/MCP/Server.cpp | 251 ++++++++++++++++++ lldb/source/MCP/Tool.cpp | 28 ++ lldb/source/MCP/Transport.cpp | 109 ++++++++ 16 files changed, 1180 insertions(+) create mode 100644 lldb/include/lldb/MCP/MCPError.h create mode 100644 lldb/include/lldb/MCP/Protocol.h create mode 100644 lldb/include/lldb/MCP/Server.h create mode 100644 lldb/include/lldb/MCP/Tool.h create mode 100644 lldb/include/lldb/MCP/Transport.h create mode 100644 lldb/source/Commands/CommandObjectMCP.cpp create mode 100644 lldb/source/Commands/CommandObjectMCP.h create mode 100644 lldb/source/MCP/CMakeLists.txt create mode 100644 lldb/source/MCP/MCPError.cpp create mode 100644 lldb/source/MCP/Protocol.cpp create mode 100644 lldb/source/MCP/Server.cpp create mode 100644 lldb/source/MCP/Tool.cpp create mode 100644 lldb/source/MCP/Transport.cpp diff --git a/lldb/include/lldb/MCP/MCPError.h b/lldb/include/lldb/MCP/MCPError.h new file mode 100644 index 0000000000000..2a76a7b087e20 --- /dev/null +++ b/lldb/include/lldb/MCP/MCPError.h @@ -0,0 +1,33 @@ +//===-- MCPError.h --------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Protocol.h" +#include "llvm/Support/Error.h" +#include <string> + +namespace lldb_private::mcp { + +class MCPError : public llvm::ErrorInfo<MCPError> { +public: + static char ID; + + MCPError(std::string message, int64_t error_code); + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + + const std::string &getMessage() const { return m_message; } + + protocol::Error toProtcolError() const; + +private: + std::string m_message; + int64_t m_error_code; +}; + +} // namespace lldb_private::mcp diff --git a/lldb/include/lldb/MCP/Protocol.h b/lldb/include/lldb/MCP/Protocol.h new file mode 100644 index 0000000000000..e661c3f7643af --- /dev/null +++ b/lldb/include/lldb/MCP/Protocol.h @@ -0,0 +1,131 @@ +//===- Protocol.h ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_TOOLS_LLDB_MCP_PROTOCOL_H +#define LLDB_TOOLS_LLDB_MCP_PROTOCOL_H + +#include "llvm/Support/JSON.h" +#include <optional> +#include <string> + +namespace lldb_private::mcp::protocol { + +static llvm::StringLiteral kProtocolVersion = "2025-03-26"; + +struct Request { + uint64_t id = 0; + std::string method; + std::optional<llvm::json::Value> params; +}; + +llvm::json::Value toJSON(const Request &); +bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); + +struct Error { + int64_t code = 0; + std::string message; + std::optional<std::string> data; +}; + +llvm::json::Value toJSON(const Error &); +bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); + +struct ProtocolError { + uint64_t id = 0; + Error error; +}; + +llvm::json::Value toJSON(const ProtocolError &); +bool fromJSON(const llvm::json::Value &, ProtocolError &, llvm::json::Path); + +struct Response { + uint64_t id = 0; + std::optional<llvm::json::Value> result; + std::optional<Error> error; +}; + +llvm::json::Value toJSON(const Response &); +bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); + +struct Notification { + std::string method; + std::optional<llvm::json::Value> params; +}; + +llvm::json::Value toJSON(const Notification &); +bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); + +struct ToolCapability { + bool listChanged = false; +}; + +llvm::json::Value toJSON(const ToolCapability &); +bool fromJSON(const llvm::json::Value &, ToolCapability &, llvm::json::Path); + +struct Capabilities { + ToolCapability tools; +}; + +llvm::json::Value toJSON(const Capabilities &); +bool fromJSON(const llvm::json::Value &, Capabilities &, llvm::json::Path); + +struct TextContent { + std::string text; +}; + +llvm::json::Value toJSON(const TextContent &); +bool fromJSON(const llvm::json::Value &, TextContent &, llvm::json::Path); + +struct TextResult { + std::vector<TextContent> content; + bool isError = false; +}; + +llvm::json::Value toJSON(const TextResult &); +bool fromJSON(const llvm::json::Value &, TextResult &, llvm::json::Path); + +struct ToolAnnotations { + // Human-readable title for the tool. + std::optional<std::string> title; + + /// If true, the tool does not modify its environment. + std::optional<bool> readOnlyHint; + + /// If true, the tool may perform destructive updates. + std::optional<bool> destructiveHint; + + /// If true, repeated calls with same args have no additional effect. + std::optional<bool> idempotentHint; + + /// If true, tool interacts with external entities. + std::optional<bool> openWorldHint; +}; + +llvm::json::Value toJSON(const ToolAnnotations &); +bool fromJSON(const llvm::json::Value &, ToolAnnotations &, llvm::json::Path); + +struct ToolDefinition { + /// Unique identifier for the tool. + std::string name; + + /// Human-readable description. + std::optional<std::string> description; + + // JSON Schema for the tool's parameters. + std::optional<llvm::json::Value> inputSchema; + + // Optional hints about tool behavior. + std::optional<ToolAnnotations> annotations; +}; + +llvm::json::Value toJSON(const ToolDefinition &); +bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); + +} // namespace lldb_private::mcp::protocol + +#endif diff --git a/lldb/include/lldb/MCP/Server.h b/lldb/include/lldb/MCP/Server.h new file mode 100644 index 0000000000000..e28db5e36b670 --- /dev/null +++ b/lldb/include/lldb/MCP/Server.h @@ -0,0 +1,67 @@ +//===- Server.h -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_TOOLS_LLDB_MCP_SERVER_H +#define LLDB_TOOLS_LLDB_MCP_SERVER_H + +#include "Protocol.h" +#include "Tool.h" +#include "Transport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Host/Socket.h" +#include "llvm/ADT/StringMap.h" +#include <string> +#include <thread> + +namespace lldb_private::mcp { + +class Server { +public: + Server(std::string name, std::string version); + ~Server() { Stop(); } + + llvm::Error Start(const Socket::SocketProtocol &protocol, + const std::string &name); + + void Stop(); + + void AddTool(std::unique_ptr<Tool> tool); + void AddHandler( + llvm::StringRef method, + std::function<protocol::Response(const protocol::Request &)> handler); + +private: + llvm::Error Run(std::unique_ptr<Transport> transport); + llvm::Expected<protocol::Response> Handle(protocol::Request request); + + protocol::Response InitializeHandler(const protocol::Request &); + protocol::Response ToolsListHandler(const protocol::Request &); + protocol::Response ToolsCallHandler(const protocol::Request &); + + protocol::Capabilities GetCapabilities(); + + std::string m_name; + std::string m_version; + + bool m_running = false; + + MainLoop m_loop; + std::thread m_loop_thread; + + std::unique_ptr<Socket> m_listener; + std::vector<MainLoopBase::ReadHandleUP> m_read_handles; + std::vector<std::pair<lldb::IOObjectSP, std::thread>> m_clients; + + std::mutex m_server_mutex; + llvm::StringMap<std::unique_ptr<Tool>> m_tools; + llvm::StringMap<std::function<protocol::Response(const protocol::Request &)>> + m_handlers; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/include/lldb/MCP/Tool.h b/lldb/include/lldb/MCP/Tool.h new file mode 100644 index 0000000000000..0fb2db98716ad --- /dev/null +++ b/lldb/include/lldb/MCP/Tool.h @@ -0,0 +1,45 @@ +//===- Tool.h -------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_TOOLS_LLDB_MCP_TOOL_H +#define LLDB_TOOLS_LLDB_MCP_TOOL_H + +#include "Protocol.h" +#include "llvm/Support/JSON.h" +#include <string> + +namespace lldb_private::mcp { + +class Tool { +public: + Tool(std::string name, std::string description); + virtual ~Tool() = default; + + virtual protocol::TextResult Call(const llvm::json::Value &args) = 0; + + virtual std::optional<llvm::json::Value> GetSchema() const { + return std::nullopt; + } + + protocol::ToolDefinition GetDefinition() const; + + const std::string &GetName() { return m_name; } + +protected: + void SetAnnotations(protocol::ToolAnnotations annotations) { + m_annotations.emplace(std::move(annotations)); + } + +private: + std::string m_name; + std::string m_description; + std::optional<protocol::ToolAnnotations> m_annotations; +}; +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/include/lldb/MCP/Transport.h b/lldb/include/lldb/MCP/Transport.h new file mode 100644 index 0000000000000..fb42541f067ec --- /dev/null +++ b/lldb/include/lldb/MCP/Transport.h @@ -0,0 +1,100 @@ +//===-- Transport.h -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_TOOLS_LLDB_MCP_TRANSPORT_H + +#include "Protocol.h" +#include "lldb/Utility/IOObject.h" +#include "lldb/Utility/Status.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include <chrono> +#include <system_error> + +namespace lldb_private::mcp { + +class EndOfFileError : public llvm::ErrorInfo<EndOfFileError> { +public: + static char ID; + + EndOfFileError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "end of file reached"; + } + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } +}; + +class TransportClosedError : public llvm::ErrorInfo<TransportClosedError> { +public: + static char ID; + + TransportClosedError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "transport is closed"; + } + std::error_code convertToErrorCode() const override { + return llvm::inconvertibleErrorCode(); + } +}; + +class TimeoutError : public llvm::ErrorInfo<TimeoutError> { +public: + static char ID; + + TimeoutError() = default; + + void log(llvm::raw_ostream &OS) const override { + OS << "operation timed out"; + } + std::error_code convertToErrorCode() const override { + return std::make_error_code(std::errc::timed_out); + } +}; + +class Transport { +public: + Transport(llvm::StringRef client_name, lldb::IOObjectSP input, + lldb::IOObjectSP output); + ~Transport() = default; + + /// Transport is not copyable. + /// @{ + Transport(const Transport &rhs) = delete; + void operator=(const Transport &rhs) = delete; + /// @} + + template <typename T> llvm::Error Write(const T &t) { + if (!m_output || !m_output->IsValid()) + return llvm::make_error<TransportClosedError>(); + + std::string Output; + llvm::raw_string_ostream OS(Output); + OS << toJSON(t) << '\n'; + size_t num_bytes = Output.size(); + return m_output->Write(Output.data(), num_bytes).takeError(); + } + + llvm::Expected<protocol::Request> + Read(const std::chrono::microseconds &timeout); + + llvm::StringRef GetClientName() { return m_client_name; } + +private: + llvm::StringRef m_client_name; + lldb::IOObjectSP m_input; + lldb::IOObjectSP m_output; +}; + +} // namespace lldb_private::mcp + +#endif diff --git a/lldb/source/CMakeLists.txt b/lldb/source/CMakeLists.txt index 51c9f9c90826e..d35521ffd5c22 100644 --- a/lldb/source/CMakeLists.txt +++ b/lldb/source/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(Expression) add_subdirectory(Host) add_subdirectory(Initialization) add_subdirectory(Interpreter) +add_subdirectory(MCP) add_subdirectory(Plugins) add_subdirectory(Symbol) add_subdirectory(Target) diff --git a/lldb/source/Commands/CMakeLists.txt b/lldb/source/Commands/CMakeLists.txt index 1ea51acec5f15..2d29653971de8 100644 --- a/lldb/source/Commands/CMakeLists.txt +++ b/lldb/source/Commands/CMakeLists.txt @@ -17,6 +17,7 @@ add_lldb_library(lldbCommands NO_PLUGIN_DEPENDENCIES CommandObjectHelp.cpp CommandObjectLanguage.cpp CommandObjectLog.cpp + CommandObjectMCP.cpp CommandObjectMemory.cpp CommandObjectMemoryTag.cpp CommandObjectMultiword.cpp @@ -51,6 +52,7 @@ add_lldb_library(lldbCommands NO_PLUGIN_DEPENDENCIES lldbDataFormatters lldbExpression lldbHost + lldbMCP lldbInterpreter lldbSymbol lldbTarget diff --git a/lldb/source/Commands/CommandObjectMCP.cpp b/lldb/source/Commands/CommandObjectMCP.cpp new file mode 100644 index 0000000000000..781995fcc9318 --- /dev/null +++ b/lldb/source/Commands/CommandObjectMCP.cpp @@ -0,0 +1,162 @@ +//===-- CommandObjectMCP.cpp ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "CommandObjectMCP.h" +#include "lldb/Host/Socket.h" +#include "lldb/Interpreter/CommandInterpreter.h" +#include "lldb/Interpreter/CommandReturnObject.h" +#include "lldb/MCP/Server.h" +#include "lldb/Utility/UriParser.h" +#include "llvm/Support/FormatAdapters.h" + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; + +#define LLDB_OPTIONS_mcp +#include "CommandOptions.inc" + +static std::optional<mcp::Server> g_mcp_server; + +class LLDBCommandTool : public mcp::Tool { +public: + LLDBCommandTool(std::string name, std::string description, Debugger &debugger) + : mcp::Tool(std::move(name), std::move(description)), + m_debugger(debugger) {} + ~LLDBCommandTool() = default; + + virtual mcp::protocol::TextResult + Call(const llvm::json::Value &args) override { + std::string arguments; + if (const json::Object *args_obj = args.getAsObject()) { + if (const json::Value *s = args_obj->get("arguments")) { + arguments = s->getAsString().value_or(""); + } + } + + CommandReturnObject result(/*colors=*/false); + m_debugger.GetCommandInterpreter().HandleCommand(arguments.c_str(), + eLazyBoolYes, result); + + std::string output; + llvm::StringRef output_str = result.GetOutputString(); + if (!output_str.empty()) + output += output_str.str(); + + std::string err_str = result.GetErrorString(); + if (!err_str.empty()) { + if (!output.empty()) + output += '\n'; + output += err_str; + } + + mcp::protocol::TextResult text_result; + text_result.content.emplace_back(mcp::protocol::TextContent{{output}}); + return text_result; + } + + virtual std::optional<llvm::json::Value> GetSchema() const override { + llvm::json::Object str_type{{"type", "string"}}; + llvm::json::Object properties{{"arguments", std::move(str_type)}}; + llvm::json::Object schema{{"type", "object"}, + {"properties", std::move(properties)}}; + return schema; + } + +private: + Debugger &m_debugger; +}; + +static llvm::Expected<std::pair<Socket::SocketProtocol, std::string>> +validateConnection(llvm::StringRef conn) { + auto uri = lldb_private::URI::Parse(conn); + + if (uri && (uri->scheme == "tcp" || uri->scheme == "connect" || + !uri->hostname.empty() || uri->port)) { + return std::make_pair( + Socket::ProtocolTcp, + formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname, + uri->port.value_or(0))); + } + + if (uri && (uri->scheme == "unix" || uri->scheme == "unix-connect" || + uri->path != "/")) { + return std::make_pair(Socket::ProtocolUnixDomain, uri->path.str()); + } + + return llvm::createStringError( + "Unsupported connection specifier, expected 'unix-connect:///path' or " + "'connect://[host]:port', got '%s'.", + conn.str().c_str()); +} + +class CommandObjectMCPStart : public CommandObjectParsed { +public: + CommandObjectMCPStart(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "mcp start", "start MCP server", + "mcp start <connection>") { + AddSimpleArgumentList(lldb::eArgTypeConnectURL, eArgRepeatOptional); + } + + ~CommandObjectMCPStart() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (args.GetArgumentCount() != 1) { + result.AppendError("no connection specified"); + return; + } + + llvm::StringRef connection = args.GetArgumentAtIndex(0); + auto maybeProtoclAndName = validateConnection(connection); + if (auto error = maybeProtoclAndName.takeError()) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + + if (!g_mcp_server) { + g_mcp_server.emplace("lldb-mcp", "0.1.0"); + g_mcp_server->AddTool(std::make_unique<LLDBCommandTool>( + "lldb_command", "Run an lldb command.", GetDebugger())); + } + + Socket::SocketProtocol protocol; + std::string name; + std::tie(protocol, name) = *maybeProtoclAndName; + if (llvm::Error error = g_mcp_server->Start(protocol, name)) { + result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error))); + return; + } + } +}; + +class CommandObjectMCPStop : public CommandObjectParsed { +public: + CommandObjectMCPStop(CommandInterpreter &interpreter) + : CommandObjectParsed(interpreter, "mcp stop", "stop MCP server", + "mcp stop") {} + + ~CommandObjectMCPStop() override = default; + +protected: + void DoExecute(Args &args, CommandReturnObject &result) override { + if (g_mcp_server) + g_mcp_server->Stop(); + } +}; + +CommandObjectMCP::CommandObjectMCP(CommandInterpreter &interpreter) + : CommandObjectMultiword(interpreter, "mcp", "Start an MCP server.", + "mcp") { + LoadSubCommand("start", + CommandObjectSP(new CommandObjectMCPStart(interpreter))); + LoadSubCommand("stop", + CommandObjectSP(new CommandObjectMCPStop(interpreter))); +} + +CommandObjectMCP::~CommandObjectMCP() = default; diff --git a/lldb/source/Commands/CommandObjectMCP.h b/lldb/source/Commands/CommandObjectMCP.h new file mode 100644 index 0000000000000..7f5a08452fdfc --- /dev/null +++ b/lldb/source/Commands/CommandObjectMCP.h @@ -0,0 +1,24 @@ +//===-- CommandObjectMCP.h ------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef LLDB_SOURCE_COMMANDS_COMMANDOBJECTMCP_H +#define LLDB_SOURCE_COMMANDS_COMMANDOBJECTMCP_H + +#include "lldb/Interpreter/CommandObjectMultiword.h" + +namespace lldb_private { + +class CommandObjectMCP : public CommandObjectMultiword { +public: + CommandObjectMCP(CommandInterpreter &interpreter); + ~CommandObjectMCP() override; +}; + +} // namespace lldb_private + +#endif // LLDB_SOURCE_COMMANDS_COMMANDOBJECTMCP_H diff --git a/lldb/source/Interpreter/CommandInterpreter.cpp b/lldb/source/Interpreter/CommandInterpreter.cpp index 4f9ae104dedea..705fe38c39d63 100644 --- a/lldb/source/Interpreter/CommandInterpreter.cpp +++ b/lldb/source/Interpreter/CommandInterpreter.cpp @@ -26,6 +26,7 @@ #include "Commands/CommandObjectHelp.h" #include "Commands/CommandObjectLanguage.h" #include "Commands/CommandObjectLog.h" +#include "Commands/CommandObjectMCP.h" #include "Commands/CommandObjectMemory.h" #include "Commands/CommandObjectPlatform.h" #include "Commands/CommandObjectPlugin.h" @@ -570,6 +571,7 @@ void CommandInterpreter::LoadCommandDictionary() { REGISTER_COMMAND_OBJECT("gui", CommandObjectGUI); REGISTER_COMMAND_OBJECT("help", CommandObjectHelp); REGISTER_COMMAND_OBJECT("log", CommandObjectLog); + REGISTER_COMMAND_OBJECT("mcp", CommandObjectMCP); REGISTER_COMMAND_OBJECT("memory", CommandObjectMemory); REGISTER_COMMAND_OBJECT("platform", CommandObjectPlatform); REGISTER_COMMAND_OBJECT("plugin", CommandObjectPlugin); diff --git a/lldb/source/MCP/CMakeLists.txt b/lldb/source/MCP/CMakeLists.txt new file mode 100644 index 0000000000000..7f23dae4bc9c7 --- /dev/null +++ b/lldb/source/MCP/CMakeLists.txt @@ -0,0 +1,14 @@ +add_lldb_library(lldbMCP + MCPError.cpp + Protocol.cpp + Server.cpp + Tool.cpp + Transport.cpp + + LINK_COMPONENTS + Support + + LINK_LIBS + lldbHost + lldbUtility +) diff --git a/lldb/source/MCP/MCPError.cpp b/lldb/source/MCP/MCPError.cpp new file mode 100644 index 0000000000000..b4768b75a4617 --- /dev/null +++ b/lldb/source/MCP/MCPError.cpp @@ -0,0 +1,31 @@ +//===-- MCPError.cpp ------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/MCP/MCPError.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include <system_error> + +namespace lldb_private::mcp { + +char MCPError::ID; + +MCPError::MCPError(std::string message, int64_t error_code) + : m_message(message), m_error_code(error_code) {} + +void MCPError::log(llvm::raw_ostream &OS) const { OS << m_message; } + +std::error_code MCPError::convertToErrorCode() const { + return llvm::inconvertibleErrorCode(); +} + +protocol::Error MCPError::toProtcolError() const { + return protocol::Error{m_error_code, m_message, std::nullopt}; +} + +} // namespace lldb_private::mcp diff --git a/lldb/source/MCP/Protocol.cpp b/lldb/source/MCP/Protocol.cpp new file mode 100644 index 0000000000000..1a74746400a30 --- /dev/null +++ b/lldb/source/MCP/Protocol.cpp @@ -0,0 +1,180 @@ +//===- Protocol.cpp -------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/MCP/Protocol.h" +#include "llvm/Support/JSON.h" + +using namespace llvm; + +namespace lldb_private::mcp::protocol { + +static bool mapRaw(const json::Value &Params, StringLiteral Prop, + std::optional<json::Value> &V, json::Path P) { + const auto *O = Params.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + const json::Value *E = O->get(Prop); + if (E) + V = std::move(*E); + return true; +} + +llvm::json::Value toJSON(const Request &R) { + json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + if (R.params) + Result.insert({"params", R.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("id", R.id) || !O.map("method", R.method)) + return false; + return mapRaw(V, "params", R.params, P); +} + +llvm::json::Value toJSON(const Error &E) { + return llvm::json::Object{ + {"code", E.code}, {"message", E.message}, {"data", E.data}}; +} + +bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("code", E.code) && O.map("message", E.message) && + O.map("data", E.data); +} + +llvm::json::Value toJSON(const ProtocolError &PE) { + return llvm::json::Object{{"id", PE.id}, {"error", toJSON(PE.error)}}; +} + +bool fromJSON(const llvm::json::Value &V, ProtocolError &PE, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("id", PE.id) && O.map("error", PE.error); +} + +llvm::json::Value toJSON(const Response &R) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; + if (R.result) + Result.insert({"result", R.result}); + if (R.error) + Result.insert({"error", R.error}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { + return true; +} + +llvm::json::Value toJSON(const Notification &N) { + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"method", N.method}}; + if (N.params) + Result.insert({"params", N.params}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("method", N.method)) + return false; + auto *Obj = V.getAsObject(); + if (!Obj) + return false; + if (auto *Params = Obj->get("params")) + N.params = *Params; + return true; +} + +llvm::json::Value toJSON(const ToolCapability &TC) { + return llvm::json::Object{{"listChanged", TC.listChanged}}; +} + +bool fromJSON(const llvm::json::Value &V, ToolCapability &TC, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("listChanged", TC.listChanged); +} + +llvm::json::Value toJSON(const Capabilities &C) { + return llvm::json::Object{{"tools", C.tools}}; +} + +bool fromJSON(const llvm::json::Value &V, Capabilities &C, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("tool", C.tools); +} + +llvm::json::Value toJSON(const TextContent &TC) { + return llvm::json::Object{{"type", "text"}, {"text", TC.text}}; +} + +bool fromJSON(const llvm::json::Value &V, TextContent &TC, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("text", TC.text); +} + +llvm::json::Value toJSON(const TextResult &TR) { + return llvm::json::Object{{"content", TR.content}, {"isError", TR.isError}}; +} + +bool fromJSON(const llvm::json::Value &V, TextResult &TR, llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.map("content", TR.content) && O.map("isError", TR.isError); +} + +llvm::json::Value toJSON(const ToolAnnotations &TA) { + llvm::json::Object Result; + if (TA.title) + Result.insert({"title", TA.title}); + if (TA.readOnlyHint) + Result.insert({"readOnlyHint", TA.readOnlyHint}); + if (TA.destructiveHint) + Result.insert({"destructiveHint", TA.destructiveHint}); + if (TA.idempotentHint) + Result.insert({"idempotentHint", TA.idempotentHint}); + if (TA.openWorldHint) + Result.insert({"openWorldHint", TA.openWorldHint}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ToolAnnotations &TA, + llvm::json::Path P) { + llvm::json::ObjectMapper O(V, P); + return O && O.mapOptional("title", TA.title) && + O.mapOptional("readOnlyHint", TA.readOnlyHint) && + O.mapOptional("destructiveHint", TA.destructiveHint) && + O.mapOptional("idempotentHint", TA.idempotentHint) && + O.mapOptional("openWorldHint", TA.openWorldHint); +} + +llvm::json::Value toJSON(const ToolDefinition &TD) { + llvm::json::Object Result{{"name", TD.name}}; + if (TD.description) + Result.insert({"description", TD.description}); + if (TD.inputSchema) + Result.insert({"inputSchema", TD.inputSchema}); + if (TD.annotations) + Result.insert({"annotations", TD.annotations}); + return Result; +} + +bool fromJSON(const llvm::json::Value &V, ToolDefinition &TD, + llvm::json::Path P) { + + llvm::json::ObjectMapper O(V, P); + if (!O || !O.map("name", TD.name) || + !O.mapOptional("description", TD.description) || + !O.mapOptional("annotations", TD.annotations)) + return false; + return mapRaw(V, "inputSchema", TD.inputSchema, P); +} + +} // namespace lldb_private::mcp::protocol diff --git a/lldb/source/MCP/Server.cpp b/lldb/source/MCP/Server.cpp new file mode 100644 index 0000000000000..0d744da0810cc --- /dev/null +++ b/lldb/source/MCP/Server.cpp @@ -0,0 +1,251 @@ +//===- Server.cpp ---------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/MCP/Server.h" +#include "lldb/MCP/MCPError.h" +#include "lldb/Utility/LLDBLog.h" +#include "lldb/Utility/Log.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/Threading.h" +#include <thread> + +using namespace lldb_private::mcp; +using namespace llvm; + +Server::Server(std::string name, std::string version) + : m_name(std::move(name)), m_version(std::move(version)) { + + AddHandler("initialize", std::bind(&Server::InitializeHandler, this, + std::placeholders::_1)); + AddHandler("tools/list", + std::bind(&Server::ToolsListHandler, this, std::placeholders::_1)); + AddHandler("tools/call", + std::bind(&Server::ToolsCallHandler, this, std::placeholders::_1)); +} + +llvm::Expected<protocol::Response> Server::Handle(protocol::Request request) { + auto it = m_handlers.find(request.method); + if (it != m_handlers.end()) + return it->second(request); + + return make_error<MCPError>( + llvm::formatv("no handler for request: {0}", request.method).str(), 1); +} + +llvm::Error Server::Start(const Socket::SocketProtocol &protocol, + const std::string &name) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + if (m_running) + return llvm::createStringError("server already running"); + + Status status; + m_listener = Socket::Create(protocol, status); + if (status.Fail()) + return status.takeError(); + + status = m_listener->Listen(name, /*backlog=*/5); + if (status.Fail()) + return status.takeError(); + + std::string address = + llvm::join(m_listener->GetListeningConnectionURI(), ", "); + Log *log = GetLog(LLDBLog::Host); + LLDB_LOG(log, "MCP server started with connection listeners: {0}", address); + + auto handles = m_listener->Accept(m_loop, [=](std::unique_ptr<Socket> sock) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + const std::string client_name = + llvm::formatv("client-{0}", m_clients.size() + 1).str(); + LLDB_LOG(log, "client {0} connected", client_name); + + lldb::IOObjectSP io(std::move(sock)); + + m_clients.emplace_back(io, [=]() { + llvm::set_thread_name(client_name + "-runloop"); + if (auto Err = Run(std::make_unique<Transport>(client_name, io, io))) + LLDB_LOG_ERROR(GetLog(LLDBLog::Host), std::move(Err), "MCP Error: {0}"); + }); + }); + if (llvm::Error error = handles.takeError()) + return error; + + m_read_handles = std::move(*handles); + m_loop_thread = std::thread([=] { + llvm::set_thread_name("mcp-runloop"); + m_loop.Run(); + }); + + return llvm::Error::success(); +} + +void Server::Stop() { + { + std::lock_guard<std::mutex> guard(m_server_mutex); + m_running = false; + } + + // Stop accepting new connections. + m_loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + + // Wait for the main loop to exit. + if (m_loop_thread.joinable()) + m_loop_thread.join(); + + // Wait for all our clients to exit. + for (auto &client : m_clients) { + client.first->Close(); + if (client.second.joinable()) + client.second.join(); + } + + { + std::lock_guard<std::mutex> guard(m_server_mutex); + m_listener.reset(); + m_read_handles.clear(); + m_clients.clear(); + } +} + +llvm::Error Server::Run(std::unique_ptr<Transport> transport) { + Log *log = GetLog(LLDBLog::Host); + + while (true) { + llvm::Expected<protocol::Request> request = + transport->Read(std::chrono::seconds(1)); + if (request.errorIsA<EndOfFileError>() || + request.errorIsA<TransportClosedError>()) { + consumeError(request.takeError()); + break; + } + + if (request.errorIsA<TimeoutError>()) { + consumeError(request.takeError()); + continue; + } + + if (llvm::Error err = request.takeError()) { + LLDB_LOG_ERROR(log, std::move(err), "{0}"); + continue; + } + + protocol::Response response; + llvm::Expected<protocol::Response> maybe_response = Handle(*request); + if (!maybe_response) { + llvm::handleAllErrors( + maybe_response.takeError(), + [&](const MCPError &err) { + response.error.emplace(err.toProtcolError()); + }, + [&](const llvm::ErrorInfoBase &err) { + response.error.emplace(); + response.error->code = -1; + response.error->message = err.message(); + }); + } else { + response = *maybe_response; + } + + response.id = request->id; + + if (llvm::Error err = transport->Write(response)) + return err; + } + + return llvm::Error::success(); +} + +protocol::Capabilities Server::GetCapabilities() { + protocol::Capabilities capabilities; + capabilities.tools.listChanged = true; + return capabilities; +} + +void Server::AddTool(std::unique_ptr<Tool> tool) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + if (!tool) + return; + m_tools[tool->GetName()] = std::move(tool); +} + +void Server::AddHandler( + llvm::StringRef method, + std::function<protocol::Response(const protocol::Request &)> handler) { + std::lock_guard<std::mutex> guard(m_server_mutex); + + m_handlers[method] = std::move(handler); +} + +protocol::Response Server::InitializeHandler(const protocol::Request &request) { + protocol::Response response; + + std::string protocol_version = protocol::kProtocolVersion.str(); + if (request.params) { + if (const json::Object *param_obj = request.params->getAsObject()) { + if (const json::Value *val = param_obj->get("protocolVersion")) { + if (auto protocol_version_str = val->getAsString()) { + protocol_version = *protocol_version_str; + } + } + } + } + + response.result.emplace(llvm::json::Object{ + {"protocolVersion", protocol_version}, + {"capabilities", GetCapabilities()}, + {"serverInfo", + llvm::json::Object{{"name", m_name}, {"version", m_version}}}}); + return response; +} + +protocol::Response Server::ToolsListHandler(const protocol::Request &request) { + protocol::Response response; + + llvm::json::Array tools; + for (const auto &tool : m_tools) + tools.emplace_back(toJSON(tool.second->GetDefinition())); + + response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + + return response; +} + +protocol::Response Server::ToolsCallHandler(const protocol::Request &request) { + protocol::Response response; + + if (!request.params) + return response; + + const json::Object *param_obj = request.params->getAsObject(); + if (!param_obj) + return response; + + const json::Value *name = param_obj->get("name"); + if (!name) + return response; + + llvm::StringRef tool_name = name->getAsString().value_or(""); + if (tool_name.empty()) + return response; + + auto it = m_tools.find(tool_name); + if (it == m_tools.end()) + return response; + + const json::Value *args = param_obj->get("arguments"); + if (!args) + return response; + + protocol::TextResult text_result = it->second->Call(*args); + response.result.emplace(toJSON(text_result)); + + return response; +} diff --git a/lldb/source/MCP/Tool.cpp b/lldb/source/MCP/Tool.cpp new file mode 100644 index 0000000000000..9d2e82723c2f5 --- /dev/null +++ b/lldb/source/MCP/Tool.cpp @@ -0,0 +1,28 @@ +//===- Tool.cpp -----------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/MCP/Tool.h" + +using namespace lldb_private::mcp; + +Tool::Tool(std::string name, std::string description) + : m_name(std::move(name)), m_description(std::move(description)) {} + +protocol::ToolDefinition Tool::GetDefinition() const { + protocol::ToolDefinition definition; + definition.name = m_name; + definition.description.emplace(m_description); + + if (std::optional<llvm::json::Value> input_schema = GetSchema()) + definition.inputSchema = *input_schema; + + if (m_annotations) + definition.annotations = m_annotations; + + return definition; +} diff --git a/lldb/source/MCP/Transport.cpp b/lldb/source/MCP/Transport.cpp new file mode 100644 index 0000000000000..9070453bc69bb --- /dev/null +++ b/lldb/source/MCP/Transport.cpp @@ -0,0 +1,109 @@ +//===-- Transport.cpp -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "lldb/MCP/Transport.h" +#include "lldb/Utility/SelectHelper.h" +#include "lldb/Utility/Status.h" +#include "lldb/lldb-forward.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/raw_ostream.h" +#include <optional> +#include <string> +#include <utility> + +using namespace llvm; +using namespace lldb; +using namespace lldb_private; +using namespace lldb_private::mcp; + +static constexpr StringLiteral kMessageSeparator = "\n"; + +/// ReadFull attempts to read the specified number of bytes. If EOF is +/// encountered, an empty string is returned. +static Expected<std::string> +ReadFull(IOObject &descriptor, size_t length, + std::optional<std::chrono::microseconds> timeout = std::nullopt) { + if (!descriptor.IsValid()) + return make_error<TransportClosedError>(); + + bool timeout_supported = true; + // FIXME: SelectHelper does not work with NativeFile on Win32. +#if _WIN32 + timeout_supported = descriptor.GetFdType() == IOObject::eFDTypeSocket; +#endif + + if (timeout && timeout_supported) { + SelectHelper sh; + sh.SetTimeout(*timeout); + sh.FDSetRead(descriptor.GetWaitableHandle()); + Status status = sh.Select(); + if (status.Fail()) { + // Convert timeouts into a specific error. + if (status.GetType() == lldb::eErrorTypePOSIX && + status.GetError() == ETIMEDOUT) + return make_error<TimeoutError>(); + return status.takeError(); + } + } + + std::string data; + data.resize(length); + Status status = descriptor.Read(data.data(), length); + if (status.Fail()) + return status.takeError(); + + // Read returns '' on EOF. + if (length == 0) + return make_error<EndOfFileError>(); + + // Return the actual number of bytes read. + return data.substr(0, length); +} + +static Expected<std::string> +ReadUntil(IOObject &descriptor, StringRef delimiter, + std::optional<std::chrono::microseconds> timeout = std::nullopt) { + std::string buffer; + buffer.reserve(delimiter.size() + 1); + while (!llvm::StringRef(buffer).ends_with(delimiter)) { + Expected<std::string> next = + ReadFull(descriptor, buffer.empty() ? delimiter.size() : 1, timeout); + if (auto Err = next.takeError()) + return std::move(Err); + buffer += *next; + } + return buffer.substr(0, buffer.size() - delimiter.size()); +} + +namespace lldb_private::mcp { + +char EndOfFileError::ID; +char TimeoutError::ID; +char TransportClosedError::ID; + +Transport::Transport(StringRef client_name, IOObjectSP input, IOObjectSP output) + : m_client_name(client_name), m_input(std::move(input)), + m_output(std::move(output)) {} + +Expected<protocol::Request> +Transport::Read(const std::chrono::microseconds &timeout) { + if (!m_input || !m_input->IsValid()) + return make_error<TransportClosedError>(); + + IOObject *input = m_input.get(); + Expected<std::string> raw_json = ReadUntil(*input, kMessageSeparator); + if (!raw_json) + return raw_json.takeError(); + + return json::parse<protocol::Request>(/*JSON=*/*raw_json, + /*RootName=*/"protocol_request"); +} + +} // namespace lldb_private::mcp _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits