https://github.com/ashgti updated https://github.com/llvm/llvm-project/pull/159160
>From 5472b257f704288060ce3bad408c958c2a538e0d Mon Sep 17 00:00:00 2001 From: John Harrison <[email protected]> Date: Wed, 10 Sep 2025 10:42:56 -0700 Subject: [PATCH] [lldb] Adding A new Binding helper for JSONTransport. This adds a new Binding helper class to allow mapping of incoming and outgoing requests / events to specific handlers. This should make it easier to create new protocol implementations and allow us to create a relay in the lldb-mcp binary. --- lldb/include/lldb/Host/JSONTransport.h | 376 ++++++++++++++++-- lldb/include/lldb/Protocol/MCP/Protocol.h | 8 + lldb/include/lldb/Protocol/MCP/Server.h | 73 ++-- lldb/include/lldb/Protocol/MCP/Transport.h | 77 +++- lldb/source/Host/common/JSONTransport.cpp | 10 + .../Protocol/MCP/ProtocolServerMCP.cpp | 42 +- .../Plugins/Protocol/MCP/ProtocolServerMCP.h | 20 +- lldb/source/Protocol/MCP/Server.cpp | 212 +++------- lldb/tools/lldb-dap/DAP.h | 6 +- lldb/tools/lldb-dap/Protocol/ProtocolBase.h | 6 +- lldb/tools/lldb-dap/Transport.h | 6 +- lldb/unittests/DAP/DAPTest.cpp | 20 +- lldb/unittests/DAP/Handler/DisconnectTest.cpp | 4 +- lldb/unittests/DAP/TestBase.cpp | 42 +- lldb/unittests/DAP/TestBase.h | 122 +++--- lldb/unittests/Host/JSONTransportTest.cpp | 332 ++++++++++++---- .../Protocol/ProtocolMCPServerTest.cpp | 282 +++++++------ .../Host/JSONTransportTestUtilities.h | 96 ++++- 18 files changed, 1156 insertions(+), 578 deletions(-) diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 210f33edace6e..080dce96ef3c4 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -18,6 +18,7 @@ #include "lldb/Utility/IOObject.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" @@ -25,8 +26,11 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" +#include <mutex> +#include <optional> #include <string> #include <system_error> +#include <type_traits> #include <variant> #include <vector> @@ -50,17 +54,70 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; +class InvalidParams : public llvm::ErrorInfo<InvalidParams> { +public: + static char ID; + + explicit InvalidParams(std::string method, std::string context) + : m_method(std::move(method)), m_context(std::move(context)) {} + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + +private: + std::string m_method; + std::string m_context; +}; + +// Value for tracking functions that have a void param or result. +using VoidT = std::monostate; + +template <typename T> using Callback = llvm::unique_function<T>; + +template <typename T> +using Reply = typename std::conditional< + std::is_same_v<T, VoidT> == true, llvm::unique_function<void(llvm::Error)>, + llvm::unique_function<void(llvm::Expected<T>)>>::type; + +template <typename Result, typename Params> +using OutgoingRequest = typename std::conditional< + std::is_same_v<Params, VoidT> == true, + llvm::unique_function<void(Reply<Result>)>, + llvm::unique_function<void(const Params &, Reply<Result>)>>::type; + +template <typename Params> +using OutgoingEvent = typename std::conditional< + std::is_same_v<Params, VoidT> == true, llvm::unique_function<void()>, + llvm::unique_function<void(const Params &)>>::type; + +template <typename Id, typename Req> +Req make_request(Id id, llvm::StringRef method, + std::optional<llvm::json::Value> params = std::nullopt); +template <typename Req, typename Resp> +Resp make_response(const Req &req, llvm::Error error); +template <typename Req, typename Resp> +Resp make_response(const Req &req, llvm::json::Value result); +template <typename Evt> +Evt make_event(llvm::StringRef method, + std::optional<llvm::json::Value> params = std::nullopt); +template <typename Resp> +llvm::Expected<llvm::json::Value> get_result(const Resp &resp); +template <typename Id, typename T> Id get_id(const T &); +template <typename T> llvm::StringRef get_method(const T &); +template <typename T> llvm::json::Value get_params(const T &); + /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. /// /// Transports have limited thread safety requirements: /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. -template <typename Req, typename Resp, typename Evt> class Transport { +template <typename Id, typename Req, typename Resp, typename Evt> +class JSONTransport { public: using Message = std::variant<Req, Resp, Evt>; - virtual ~Transport() = default; + virtual ~JSONTransport() = default; /// Sends an event, a message that does not require a response. virtual llvm::Error Send(const Evt &) = 0; @@ -90,8 +147,6 @@ template <typename Req, typename Resp, typename Evt> class Transport { virtual void OnClosed() = 0; }; - using MessageHandlerSP = std::shared_ptr<MessageHandler>; - /// RegisterMessageHandler registers the Transport with the given MainLoop and /// handles any incoming messages using the given MessageHandler. /// @@ -100,22 +155,293 @@ template <typename Req, typename Resp, typename Evt> class Transport { virtual llvm::Expected<MainLoop::ReadHandleUP> RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; - // FIXME: Refactor mcp::Server to not directly access log on the transport. - // protected: +protected: template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) { Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str()); } virtual void Log(llvm::StringRef message) = 0; + + /// Function object to reply to a call. + /// Each instance must be called exactly once, otherwise: + /// - the bug is logged, and (in debug mode) an assert will fire + /// - if there was no reply, an error reply is sent + /// - if there were multiple replies, only the first is sent + class ReplyOnce { + std::atomic<bool> replied = {false}; + const Req req; + JSONTransport *transport; // Null when moved-from. + JSONTransport::MessageHandler *handler; // Null when moved-from. + + public: + ReplyOnce(const Req req, JSONTransport *transport, + JSONTransport::MessageHandler *handler) + : req(req), transport(transport), handler(handler) { + assert(handler); + } + ReplyOnce(ReplyOnce &&other) + : replied(other.replied.load()), req(other.req), + transport(other.transport), handler(other.handler) { + other.transport = nullptr; + other.handler = nullptr; + } + ReplyOnce &operator=(ReplyOnce &&) = delete; + ReplyOnce(const ReplyOnce &) = delete; + ReplyOnce &operator=(const ReplyOnce &) = delete; + + ~ReplyOnce() { + if (transport && handler && !replied) { + assert(false && "must reply to all calls!"); + (*this)(make_response<Req, Resp>( + req, llvm::createStringError("failed to reply"))); + } + } + + void operator()(const Resp &resp) { + assert(transport && handler && "moved-from!"); + if (replied.exchange(true)) { + assert(false && "must reply to each call only once!"); + return; + } + + if (llvm::Error error = transport->Send(resp)) + handler->OnError(std::move(error)); + } + }; + +public: + class Binder; + using BinderUP = std::unique_ptr<Binder>; + + /// Binder collects a table of functions that handle calls. + /// + /// The wrapper takes care of parsing/serializing responses. + class Binder : public JSONTransport::MessageHandler { + public: + explicit Binder(JSONTransport &transport) + : m_transport(transport), m_seq(0) {} + + Binder(const Binder &) = delete; + Binder &operator=(const Binder &) = delete; + + /// Bind a handler on transport disconnect. + template <typename Fn, typename... Args> + void disconnected(Fn &&fn, Args &&...args) { + m_disconnect_handler = + std::bind(std::forward<Fn>(fn), std::forward<Args>(args)...); + } + + /// Bind a handler on error when communicating with the transport. + template <typename Fn, typename... Args> + void error(Fn &&fn, Args &&...args) { + m_error_handler = + std::bind(std::forward<Fn>(fn), std::forward<Args>(args)...); + } + + template <typename T> + static llvm::Expected<T> parse(const llvm::json::Value &raw, + llvm::StringRef method) { + T result; + llvm::json::Path::Root root; + if (!fromJSON(raw, result, root)) { + // Dump the relevant parts of the broken message. + std::string context; + llvm::raw_string_ostream OS(context); + root.printErrorContext(raw, OS); + return llvm::make_error<InvalidParams>(method.str(), context); + } + return std::move(result); + } + + /// Bind a handler for a request. + /// e.g. Bind.request("peek", this, &ThisModule::peek); + /// Handler should be e.g. Expected<PeekResult> peek(const PeekParams&); + /// PeekParams must be JSON parsable and PeekResult must be serializable. + template <typename Result, typename Params, typename Fn, typename... Args> + void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_request_handlers.find(method) == m_request_handlers.end() && + "request already bound"); + if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) { + std::function<llvm::Expected<Result>()> handler = + std::bind(std::forward<Fn>(fn), std::forward<Args>(args)...); + m_request_handlers[method] = + [handler](const Req &req, + llvm::unique_function<void(const Resp &)> reply) { + llvm::Expected<Result> result = handler(); + if (!result) + return reply(make_response<Req, Resp>(req, result.takeError())); + reply(make_response<Req, Resp>(req, toJSON(*result))); + }; + } else { + std::function<llvm::Expected<Result>(const Params &)> handler = + std::bind(std::forward<Fn>(fn), std::forward<Args>(args)...); + m_request_handlers[method] = + [method, handler](const Req &req, + llvm::unique_function<void(const Resp &)> reply) { + llvm::Expected<Params> params = + parse<Params>(get_params<Req>(req), method); + if (!params) + return reply(make_response<Req, Resp>(req, params.takeError())); + + llvm::Expected<Result> result = handler(*params); + if (!result) + return reply(make_response<Req, Resp>(req, result.takeError())); + + reply(make_response<Req, Resp>(req, toJSON(*result))); + }; + } + } + + /// Bind a handler for a event. + /// e.g. bind("peek", &ThisModule::peek, this, std::placeholders::_1); + /// Handler should be e.g. void peek(const PeekParams&); + /// PeekParams must be JSON parsable. + template <typename Params, typename Fn, typename... Args> + void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_event_handlers.find(method) == m_event_handlers.end() && + "event already bound"); + std::function<void(const Params &)> handler = + std::bind(std::forward<Fn>(fn), std::forward<Args>(args)...); + m_event_handlers[method] = [this, method, handler](const Evt &evt) { + llvm::Expected<Params> params = + parse<Params>(get_params<Evt>(evt), method); + if (!params) + return OnError(params.takeError()); + handler(*params); + }; + } + + /// Bind a function object to be used for outgoing requests. + /// e.g. OutgoingRequest<Params, Result> Edit = bind("edit"); Params must be + /// JSON-serializable, Result must be parsable. + template <typename Result, typename Params> + OutgoingRequest<Result, Params> bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) { + return [this, method](Reply<Result> fn) { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + Id id = ++m_seq; + Req req = make_request<Req, Resp>(id, method, std::nullopt); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected<llvm::json::Value> result = get_result<Resp>(resp); + if (!result) + return fn(result.takeError()); + fn(parse<Result>(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms, Reply<Result> fn) { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + Id id = ++m_seq; + Req req = + make_request<Id, Req>(id, method, llvm::json::Value(params)); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected<llvm::json::Value> result = get_result<Resp>(resp); + if (llvm::Error err = result.takeError()) + return fn(std::move(err)); + fn(parse<Result>(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } + } + + /// Bind a function object to be used for outgoing events. + /// e.g. OutgoingEvent<LogParams> Log = bind("log"); + /// LogParams must be JSON-serializable. + template <typename Params> + OutgoingEvent<Params> bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) { + return [this, method]() { + if (llvm::Error error = + m_transport.Send(make_event<Evt>(method, std::nullopt))) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms) { + if (llvm::Error error = + m_transport.Send(make_event<Evt>(method, toJSON(params)))) + OnError(std::move(error)); + }; + } + } + + void Received(const Evt &evt) override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_event_handlers.find(get_method<Evt>(evt)); + if (it == m_event_handlers.end()) { + OnError(llvm::createStringError( + llvm::formatv("no handled for event {0}", toJSON(evt)))); + return; + } + it->second(evt); + } + + void Received(const Req &req) override { + ReplyOnce reply(req, &m_transport, this); + + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_request_handlers.find(get_method<Req>(req)); + if (it == m_request_handlers.end()) { + reply(make_response<Req, Resp>( + req, llvm::createStringError("method not found"))); + return; + } + + it->second(req, std::move(reply)); + } + + void Received(const Resp &resp) override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_pending_responses.find(get_id<Id, Resp>(resp)); + if (it == m_pending_responses.end()) { + OnError(llvm::createStringError( + llvm::formatv("no pending request for {0}", toJSON(resp)))); + return; + } + + it->second(resp); + m_pending_responses.erase(it); + } + + void OnError(llvm::Error err) override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + if (m_error_handler) + m_error_handler(std::move(err)); + } + + void OnClosed() override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + if (m_disconnect_handler) + m_disconnect_handler(); + } + + private: + std::recursive_mutex m_mutex; + JSONTransport &m_transport; + Id m_seq; + std::map<Id, Callback<void(const Resp &)>> m_pending_responses; + llvm::StringMap<Callback<void(const Req &, Callback<void(const Resp &)>)>> + m_request_handlers; + llvm::StringMap<Callback<void(const Evt &)>> m_event_handlers; + Callback<void()> m_disconnect_handler; + Callback<void(llvm::Error)> m_error_handler; + }; }; -/// A JSONTransport will encode and decode messages using JSON. -template <typename Req, typename Resp, typename Evt> -class JSONTransport : public Transport<Req, Resp, Evt> { +/// A IOTransport will encode and decode messages using an IOObject like a +/// file or a socket. +template <typename Id, typename Req, typename Resp, typename Evt> +class IOTransport : public JSONTransport<Id, Req, Resp, Evt> { public: - using Transport<Req, Resp, Evt>::Transport; - using MessageHandler = typename Transport<Req, Resp, Evt>::MessageHandler; + using Message = typename JSONTransport<Id, Req, Resp, Evt>::Message; + using MessageHandler = + typename JSONTransport<Id, Req, Resp, Evt>::MessageHandler; - JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} llvm::Error Send(const Evt &evt) override { return Write(evt); } @@ -127,7 +453,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> { Status status; MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( m_in, - std::bind(&JSONTransport::OnRead, this, std::placeholders::_1, + std::bind(&IOTransport::OnRead, this, std::placeholders::_1, std::ref(handler)), status); if (status.Fail()) { @@ -140,7 +466,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> { /// detail. static constexpr size_t kReadBufferSize = 1024; - // FIXME: Write should be protected. +protected: llvm::Error Write(const llvm::json::Value &message) { this->Logv("<-- {0}", message); std::string output = Encode(message); @@ -148,7 +474,6 @@ class JSONTransport : public Transport<Req, Resp, Evt> { return m_out->Write(output.data(), bytes_written).takeError(); } -protected: virtual llvm::Expected<std::vector<std::string>> Parse() = 0; virtual std::string Encode(const llvm::json::Value &message) = 0; @@ -175,9 +500,8 @@ class JSONTransport : public Transport<Req, Resp, Evt> { } for (const std::string &raw_message : *raw_messages) { - llvm::Expected<typename Transport<Req, Resp, Evt>::Message> message = - llvm::json::parse<typename Transport<Req, Resp, Evt>::Message>( - raw_message); + llvm::Expected<Message> message = + llvm::json::parse<Message>(raw_message); if (!message) { handler.OnError(message.takeError()); return; @@ -202,10 +526,10 @@ class JSONTransport : public Transport<Req, Resp, Evt> { }; /// A transport class for JSON with a HTTP header. -template <typename Req, typename Resp, typename Evt> -class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { +template <typename Id, typename Req, typename Resp, typename Evt> +class HTTPDelimitedJSONTransport : public IOTransport<Id, Req, Resp, Evt> { public: - using JSONTransport<Req, Resp, Evt>::JSONTransport; + using IOTransport<Id, Req, Resp, Evt>::IOTransport; protected: /// Encodes messages based on @@ -231,8 +555,8 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { for (const llvm::StringRef &header : llvm::split(headers, kHeaderSeparator)) { auto [key, value] = header.split(kHeaderFieldSeparator); - // 'Content-Length' is the only meaningful key at the moment. Others are - // ignored. + // 'Content-Length' is the only meaningful key at the moment. Others + // are ignored. if (!key.equals_insensitive(kHeaderContentLength)) continue; @@ -269,10 +593,10 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { }; /// A transport class for JSON RPC. -template <typename Req, typename Resp, typename Evt> -class JSONRPCTransport : public JSONTransport<Req, Resp, Evt> { +template <typename Id, typename Req, typename Resp, typename Evt> +class JSONRPCTransport : public IOTransport<Id, Req, Resp, Evt> { public: - using JSONTransport<Req, Resp, Evt>::JSONTransport; + using IOTransport<Id, Req, Resp, Evt>::IOTransport; protected: std::string Encode(const llvm::json::Value &message) override { diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 6e1ffcbe1f3e3..1e0816110b80a 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -14,6 +14,7 @@ #ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H #define LLDB_PROTOCOL_MCP_PROTOCOL_H +#include "llvm/ADT/StringRef.h" #include "llvm/Support/JSON.h" #include <optional> #include <string> @@ -324,4 +325,11 @@ bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path); } // namespace lldb_protocol::mcp +namespace llvm::json { +inline Value toJSON(const lldb_protocol::mcp::Void &) { return Object(); } +inline bool fromJSON(const Value &, lldb_protocol::mcp::Void &, Path) { + return true; +} +} // namespace llvm::json + #endif diff --git a/lldb/include/lldb/Protocol/MCP/Server.h b/lldb/include/lldb/Protocol/MCP/Server.h index 1f916ae525b5c..df2a4810ce620 100644 --- a/lldb/include/lldb/Protocol/MCP/Server.h +++ b/lldb/include/lldb/Protocol/MCP/Server.h @@ -9,7 +9,6 @@ #ifndef LLDB_PROTOCOL_MCP_SERVER_H #define LLDB_PROTOCOL_MCP_SERVER_H -#include "lldb/Host/JSONTransport.h" #include "lldb/Host/MainLoop.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Resource.h" @@ -19,74 +18,66 @@ #include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/Signals.h" -#include <functional> #include <memory> #include <string> #include <vector> namespace lldb_protocol::mcp { -class Server : public MCPTransport::MessageHandler { +class Server { + + using MCPTransportUP = std::unique_ptr<lldb_protocol::mcp::MCPTransport>; + + using ReadHandleUP = lldb_private::MainLoop::ReadHandleUP; + public: - Server(std::string name, std::string version, - std::unique_ptr<MCPTransport> transport_up, - lldb_private::MainLoop &loop); + Server(std::string name, std::string version, LogCallback log_callback = {}); ~Server() = default; - using NotificationHandler = std::function<void(const Notification &)>; - void AddTool(std::unique_ptr<Tool> tool); void AddResourceProvider(std::unique_ptr<ResourceProvider> resource_provider); - void AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler); - llvm::Error Run(); + llvm::Error Accept(lldb_private::MainLoop &, MCPTransportUP); protected: - ServerCapabilities GetCapabilities(); - - using RequestHandler = - std::function<llvm::Expected<Response>(const Request &)>; + MCPTransport::BinderUP Bind(MCPTransport &); - void AddRequestHandlers(); - - void AddRequestHandler(llvm::StringRef method, RequestHandler handler); - - llvm::Expected<std::optional<Message>> HandleData(llvm::StringRef data); - - llvm::Expected<Response> Handle(const Request &request); - void Handle(const Notification ¬ification); - - llvm::Expected<Response> InitializeHandler(const Request &); + ServerCapabilities GetCapabilities(); - llvm::Expected<Response> ToolsListHandler(const Request &); - llvm::Expected<Response> ToolsCallHandler(const Request &); + llvm::Expected<InitializeResult> InitializeHandler(const InitializeParams &); - llvm::Expected<Response> ResourcesListHandler(const Request &); - llvm::Expected<Response> ResourcesReadHandler(const Request &); + llvm::Expected<ListToolsResult> ToolsListHandler(); + llvm::Expected<CallToolResult> ToolsCallHandler(const CallToolParams &); - void Received(const Request &) override; - void Received(const Response &) override; - void Received(const Notification &) override; - void OnError(llvm::Error) override; - void OnClosed() override; + llvm::Expected<ListResourcesResult> ResourcesListHandler(); + llvm::Expected<ReadResourceResult> + ResourcesReadHandler(const ReadResourceParams &); - void TerminateLoop(); + template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) { + Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str()); + } + void Log(llvm::StringRef message) { + if (m_log_callback) + m_log_callback(message); + } private: const std::string m_name; const std::string m_version; - std::unique_ptr<MCPTransport> m_transport_up; - lldb_private::MainLoop &m_loop; + LogCallback m_log_callback; + struct Client { + ReadHandleUP handle; + MCPTransportUP transport; + MCPTransport::BinderUP binder; + }; + std::map<MCPTransport *, Client> m_instances; llvm::StringMap<std::unique_ptr<Tool>> m_tools; std::vector<std::unique_ptr<ResourceProvider>> m_resource_providers; - - llvm::StringMap<RequestHandler> m_request_handlers; - llvm::StringMap<NotificationHandler> m_notification_handlers; }; class ServerInfoHandle; @@ -120,7 +111,7 @@ class ServerInfoHandle { ServerInfoHandle &operator=(const ServerInfoHandle &) = delete; /// @} - /// Remove the file. + /// Remove the file on disk, if one is tracked. void Remove(); private: diff --git a/lldb/include/lldb/Protocol/MCP/Transport.h b/lldb/include/lldb/Protocol/MCP/Transport.h index 47c2ccfc44dfe..55b2e8fa0a7f2 100644 --- a/lldb/include/lldb/Protocol/MCP/Transport.h +++ b/lldb/include/lldb/Protocol/MCP/Transport.h @@ -10,22 +10,95 @@ #define LLDB_PROTOCOL_MCP_TRANSPORT_H #include "lldb/Host/JSONTransport.h" +#include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" #include "lldb/lldb-forward.h" #include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Error.h" + +namespace lldb_private { +/// Specializations of the JSONTransport protocol functions for MCP. +/// @{ +template <> +inline lldb_protocol::mcp::Request +make_request(int64_t id, llvm::StringRef method, + std::optional<llvm::json::Value> params) { + return lldb_protocol::mcp::Request{id, method.str(), params}; +} +template <> +inline lldb_protocol::mcp::Response +make_response(const lldb_protocol::mcp::Request &req, llvm::Error error) { + lldb_protocol::mcp::Error protocol_error; + llvm::handleAllErrors( + std::move(error), + [&](const lldb_protocol::mcp::MCPError &err) { + protocol_error = err.toProtocolError(); + }, + [&](const llvm::ErrorInfoBase &err) { + protocol_error.code = lldb_protocol::mcp::MCPError::kInternalError; + protocol_error.message = err.message(); + }); + + return lldb_protocol::mcp::Response{req.id, std::move(protocol_error)}; +} +template <> +inline lldb_protocol::mcp::Response +make_response(const lldb_protocol::mcp::Request &req, + llvm::json::Value result) { + return lldb_protocol::mcp::Response{req.id, std::move(result)}; +} +template <> +inline lldb_protocol::mcp::Notification +make_event(llvm::StringRef method, std::optional<llvm::json::Value> params) { + return lldb_protocol::mcp::Notification{method.str(), params}; +} +template <> +inline llvm::Expected<llvm::json::Value> +get_result(const lldb_protocol::mcp::Response &resp) { + if (const lldb_protocol::mcp::Error *error = + std::get_if<lldb_protocol::mcp::Error>(&resp.result)) + return llvm::make_error<lldb_protocol::mcp::MCPError>(error->message, + error->code); + return std::get<llvm::json::Value>(resp.result); +} +template <> inline int64_t get_id(const lldb_protocol::mcp::Response &resp) { + return std::get<int64_t>(resp.id); +} +template <> +inline llvm::StringRef get_method(const lldb_protocol::mcp::Request &req) { + return req.method; +} +template <> +inline llvm::StringRef get_method(const lldb_protocol::mcp::Notification &evt) { + return evt.method; +} +template <> +inline llvm::json::Value get_params(const lldb_protocol::mcp::Request &req) { + return req.params; +} +template <> +inline llvm::json::Value +get_params(const lldb_protocol::mcp::Notification &evt) { + return evt.params; +} +/// @} + +} // end namespace lldb_private namespace lldb_protocol::mcp { /// Generic transport that uses the MCP protocol. -using MCPTransport = lldb_private::Transport<Request, Response, Notification>; +using MCPTransport = + lldb_private::JSONTransport<int64_t, Request, Response, Notification>; /// Generic logging callback, to allow the MCP server / client / transport layer /// to be independent of the lldb log implementation. using LogCallback = llvm::unique_function<void(llvm::StringRef message)>; class Transport final - : public lldb_private::JSONRPCTransport<Request, Response, Notification> { + : public lldb_private::JSONRPCTransport<int64_t, Request, Response, + Notification> { public: Transport(lldb::IOObjectSP in, lldb::IOObjectSP out, LogCallback log_callback = {}); diff --git a/lldb/source/Host/common/JSONTransport.cpp b/lldb/source/Host/common/JSONTransport.cpp index c4b42eafc85d3..f809ef478c8f7 100644 --- a/lldb/source/Host/common/JSONTransport.cpp +++ b/lldb/source/Host/common/JSONTransport.cpp @@ -30,3 +30,13 @@ void TransportUnhandledContentsError::log(llvm::raw_ostream &OS) const { std::error_code TransportUnhandledContentsError::convertToErrorCode() const { return std::make_error_code(std::errc::bad_message); } + +char InvalidParams::ID; + +void InvalidParams::log(llvm::raw_ostream &OS) const { + OS << "invalid parameters for method '" << m_method << "': '" << m_context + << "'"; +} +std::error_code InvalidParams::convertToErrorCode() const { + return std::make_error_code(std::errc::invalid_argument); +} diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp index d3af3cf25c4a1..46a7a96cc5fc0 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp @@ -52,11 +52,6 @@ llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() { } void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { - server.AddNotificationHandler("notifications/initialized", - [](const lldb_protocol::mcp::Notification &) { - LLDB_LOG(GetLog(LLDBLog::Host), - "MCP initialization complete"); - }); server.AddTool( std::make_unique<CommandTool>("command", "Run an lldb command.")); server.AddTool(std::make_unique<DebuggerListTool>( @@ -66,7 +61,7 @@ void ProtocolServerMCP::Extend(lldb_protocol::mcp::Server &server) const { void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { Log *log = GetLog(LLDBLog::Host); - std::string client_name = llvm::formatv("client_{0}", m_instances.size() + 1); + std::string client_name = llvm::formatv("client_{0}", m_client_count++); LLDB_LOG(log, "New MCP client connected: {0}", client_name); lldb::IOObjectSP io_sp = std::move(socket); @@ -74,16 +69,9 @@ void ProtocolServerMCP::AcceptCallback(std::unique_ptr<Socket> socket) { io_sp, io_sp, [client_name](llvm::StringRef message) { LLDB_LOG(GetLog(LLDBLog::Host), "{0}: {1}", client_name, message); }); - auto instance_up = std::make_unique<lldb_protocol::mcp::Server>( - std::string(kName), std::string(kVersion), std::move(transport_up), - m_loop); - Extend(*instance_up); - llvm::Error error = instance_up->Run(); - if (error) { - LLDB_LOG_ERROR(log, std::move(error), "Failed to run MCP server: {0}"); - return; - } - m_instances.push_back(std::move(instance_up)); + + if (auto error = m_server->Accept(m_loop, std::move(transport_up))) + LLDB_LOG_ERROR(log, std::move(error), "{0}:"); } llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { @@ -114,13 +102,20 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) { llvm::join(m_listener->GetListeningConnectionURI(), ", "); ServerInfo info{listening_uris[0]}; - llvm::Expected<ServerInfoHandle> handle = ServerInfo::Write(info); - if (!handle) - return handle.takeError(); + llvm::Expected<ServerInfoHandle> server_info_handle = ServerInfo::Write(info); + if (!server_info_handle) + return server_info_handle.takeError(); + + m_client_count = 0; + m_server = std::make_unique<lldb_protocol::mcp::Server>( + std::string(kName), std::string(kVersion), [](StringRef message) { + LLDB_LOG(GetLog(LLDBLog::Host), "MCP Server: {0}", message); + }); + Extend(*m_server); m_running = true; - m_server_info_handle = std::move(*handle); - m_listen_handlers = std::move(*handles); + m_server_info_handle = std::move(*server_info_handle); + m_accept_handles = std::move(*handles); m_loop_thread = std::thread([=] { llvm::set_thread_name("protocol-server.mcp"); m_loop.Run(); @@ -145,9 +140,10 @@ llvm::Error ProtocolServerMCP::Stop() { if (m_loop_thread.joinable()) m_loop_thread.join(); + m_accept_handles.clear(); + + m_server.reset(nullptr); m_server_info_handle.Remove(); - m_listen_handlers.clear(); - m_instances.clear(); return llvm::Error::success(); } diff --git a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h index 0251664a2acc4..d34b22e29765f 100644 --- a/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h +++ b/lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h @@ -12,19 +12,23 @@ #include "lldb/Core/ProtocolServer.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/Socket.h" -#include "lldb/Protocol/MCP/Protocol.h" #include "lldb/Protocol/MCP/Server.h" #include <thread> namespace lldb_private::mcp { class ProtocolServerMCP : public ProtocolServer { + + using ServerUP = std::unique_ptr<lldb_protocol::mcp::Server>; + + using ReadHandleUP = MainLoop::ReadHandleUP; + public: ProtocolServerMCP(); - virtual ~ProtocolServerMCP() override; + ~ProtocolServerMCP() override; - virtual llvm::Error Start(ProtocolServer::Connection connection) override; - virtual llvm::Error Stop() override; + llvm::Error Start(ProtocolServer::Connection connection) override; + llvm::Error Stop() override; static void Initialize(); static void Terminate(); @@ -48,16 +52,18 @@ class ProtocolServerMCP : public ProtocolServer { bool m_running = false; - lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; lldb_private::MainLoop m_loop; std::thread m_loop_thread; + unsigned m_client_count = 0; std::mutex m_mutex; std::unique_ptr<Socket> m_listener; + std::vector<ReadHandleUP> m_accept_handles; - std::vector<MainLoopBase::ReadHandleUP> m_listen_handlers; - std::vector<std::unique_ptr<lldb_protocol::mcp::Server>> m_instances; + ServerUP m_server; + lldb_protocol::mcp::ServerInfoHandle m_server_info_handle; }; + } // namespace lldb_private::mcp #endif diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index a08874e7321af..7af0e0c85f7a9 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -13,6 +13,7 @@ #include "lldb/Host/JSONTransport.h" #include "lldb/Protocol/MCP/MCPError.h" #include "lldb/Protocol/MCP/Protocol.h" +#include "lldb/Protocol/MCP/Transport.h" #include "llvm/ADT/SmallString.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" @@ -109,48 +110,9 @@ Expected<std::vector<ServerInfo>> ServerInfo::Load() { return infos; } -Server::Server(std::string name, std::string version, - std::unique_ptr<MCPTransport> transport_up, - lldb_private::MainLoop &loop) +Server::Server(std::string name, std::string version, LogCallback log_callback) : m_name(std::move(name)), m_version(std::move(version)), - m_transport_up(std::move(transport_up)), m_loop(loop) { - AddRequestHandlers(); -} - -void Server::AddRequestHandlers() { - AddRequestHandler("initialize", std::bind(&Server::InitializeHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/list", std::bind(&Server::ToolsListHandler, this, - std::placeholders::_1)); - AddRequestHandler("tools/call", std::bind(&Server::ToolsCallHandler, this, - std::placeholders::_1)); - AddRequestHandler("resources/list", std::bind(&Server::ResourcesListHandler, - this, std::placeholders::_1)); - AddRequestHandler("resources/read", std::bind(&Server::ResourcesReadHandler, - this, std::placeholders::_1)); -} - -llvm::Expected<Response> Server::Handle(const Request &request) { - auto it = m_request_handlers.find(request.method); - if (it != m_request_handlers.end()) { - llvm::Expected<Response> response = it->second(request); - if (!response) - return response; - response->id = request.id; - return *response; - } - - return llvm::make_error<MCPError>( - llvm::formatv("no handler for request: {0}", request.method).str()); -} - -void Server::Handle(const Notification ¬ification) { - auto it = m_notification_handlers.find(notification.method); - if (it != m_notification_handlers.end()) { - it->second(notification); - return; - } -} + m_log_callback(std::move(log_callback)) {} void Server::AddTool(std::unique_ptr<Tool> tool) { if (!tool) @@ -165,48 +127,63 @@ void Server::AddResourceProvider( m_resource_providers.push_back(std::move(resource_provider)); } -void Server::AddRequestHandler(llvm::StringRef method, RequestHandler handler) { - m_request_handlers[method] = std::move(handler); -} +MCPTransport::BinderUP Server::Bind(MCPTransport &transport) { + MCPTransport::BinderUP binder = + std::make_unique<MCPTransport::Binder>(transport); + binder->bind<InitializeResult, InitializeParams>( + "initialize", &Server::InitializeHandler, this, std::placeholders::_1); + binder->bind<ListToolsResult, lldb_private::VoidT>( + "tools/list", &Server::ToolsListHandler, this); + binder->bind<CallToolResult, CallToolParams>( + "tools/call", &Server::ToolsCallHandler, this, std::placeholders::_1); + binder->bind<ListResourcesResult, lldb_private::VoidT>( + "resources/list", &Server::ResourcesListHandler, this); + binder->bind<ReadResourceResult, ReadResourceParams>( + "resources/read", &Server::ResourcesReadHandler, this, + std::placeholders::_1); + binder->bind<VoidT>("notifications/initialized", + [this]() { Log("MCP initialization complete"); }); + return binder; +} + +llvm::Error Server::Accept(MainLoop &loop, MCPTransportUP transport) { + MCPTransport::BinderUP binder = Bind(*transport); + MCPTransport *transport_ptr = transport.get(); + binder->disconnected([this, transport_ptr]() { + assert(m_instances.find(transport_ptr) != m_instances.end() && + "Client not found in m_instances"); + m_instances.erase(transport_ptr); + }); + + auto handle = transport->RegisterMessageHandler(loop, *binder); + if (!handle) + return handle.takeError(); -void Server::AddNotificationHandler(llvm::StringRef method, - NotificationHandler handler) { - m_notification_handlers[method] = std::move(handler); + m_instances[transport_ptr] = + Client{std::move(*handle), std::move(transport), std::move(binder)}; + return llvm::Error::success(); } -llvm::Expected<Response> Server::InitializeHandler(const Request &request) { - Response response; +Expected<InitializeResult> +Server::InitializeHandler(const InitializeParams &request) { InitializeResult result; result.protocolVersion = mcp::kProtocolVersion; result.capabilities = GetCapabilities(); result.serverInfo.name = m_name; result.serverInfo.version = m_version; - response.result = std::move(result); - return response; + return result; } -llvm::Expected<Response> Server::ToolsListHandler(const Request &request) { - Response response; - +llvm::Expected<ListToolsResult> Server::ToolsListHandler() { ListToolsResult result; for (const auto &tool : m_tools) result.tools.emplace_back(tool.second->GetDefinition()); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no tool parameters"); - CallToolParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - +llvm::Expected<CallToolResult> +Server::ToolsCallHandler(const CallToolParams ¶ms) { llvm::StringRef tool_name = params.name; if (tool_name.empty()) return llvm::createStringError("no tool name"); @@ -223,125 +200,50 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { if (!text_result) return text_result.takeError(); - response.result = toJSON(*text_result); - - return response; + return text_result; } -llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) { - Response response; - +llvm::Expected<ListResourcesResult> Server::ResourcesListHandler() { ListResourcesResult result; for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) for (const Resource &resource : resource_provider_up->GetResources()) result.resources.push_back(resource); - response.result = std::move(result); - - return response; + return result; } -llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { - Response response; - - if (!request.params) - return llvm::createStringError("no resource parameters"); - - ReadResourceParams params; - json::Path::Root root("params"); - if (!fromJSON(request.params, params, root)) - return root.getError(); - - llvm::StringRef uri_str = params.uri; +Expected<ReadResourceResult> +Server::ResourcesReadHandler(const ReadResourceParams ¶ms) { + StringRef uri_str = params.uri; if (uri_str.empty()) - return llvm::createStringError("no resource uri"); + return createStringError("no resource uri"); for (std::unique_ptr<ResourceProvider> &resource_provider_up : m_resource_providers) { - llvm::Expected<ReadResourceResult> result = + Expected<ReadResourceResult> result = resource_provider_up->ReadResource(uri_str); if (result.errorIsA<UnsupportedURI>()) { - llvm::consumeError(result.takeError()); + consumeError(result.takeError()); continue; } if (!result) return result.takeError(); - Response response; - response.result = std::move(*result); - return response; + return *result; } return make_error<MCPError>( - llvm::formatv("no resource handler for uri: {0}", uri_str).str(), + formatv("no resource handler for uri: {0}", uri_str).str(), MCPError::kResourceNotFound); } ServerCapabilities Server::GetCapabilities() { lldb_protocol::mcp::ServerCapabilities capabilities; capabilities.supportsToolsList = true; + capabilities.supportsResourcesList = true; // FIXME: Support sending notifications when a debugger/target are // added/removed. - capabilities.supportsResourcesList = false; + capabilities.supportsResourcesSubscribe = false; return capabilities; } - -llvm::Error Server::Run() { - auto handle = m_transport_up->RegisterMessageHandler(m_loop, *this); - if (!handle) - return handle.takeError(); - - lldb_private::Status status = m_loop.Run(); - if (status.Fail()) - return status.takeError(); - - return llvm::Error::success(); -} - -void Server::Received(const Request &request) { - auto SendResponse = [this](const Response &response) { - if (llvm::Error error = m_transport_up->Send(response)) - m_transport_up->Log(llvm::toString(std::move(error))); - }; - - llvm::Expected<Response> response = Handle(request); - if (response) - return SendResponse(*response); - - lldb_protocol::mcp::Error protocol_error; - llvm::handleAllErrors( - response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, - [&](const llvm::ErrorInfoBase &err) { - protocol_error.code = MCPError::kInternalError; - protocol_error.message = err.message(); - }); - Response error_response; - error_response.id = request.id; - error_response.result = std::move(protocol_error); - SendResponse(error_response); -} - -void Server::Received(const Response &response) { - m_transport_up->Log("unexpected MCP message: response"); -} - -void Server::Received(const Notification ¬ification) { - Handle(notification); -} - -void Server::OnError(llvm::Error error) { - m_transport_up->Log(llvm::toString(std::move(error))); - TerminateLoop(); -} - -void Server::OnClosed() { - m_transport_up->Log("EOF"); - TerminateLoop(); -} - -void Server::TerminateLoop() { - m_loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); -} diff --git a/lldb/tools/lldb-dap/DAP.h b/lldb/tools/lldb-dap/DAP.h index 71681fd4b51ed..0c921e5b72d74 100644 --- a/lldb/tools/lldb-dap/DAP.h +++ b/lldb/tools/lldb-dap/DAP.h @@ -79,10 +79,10 @@ enum DAPBroadcasterBits { enum class ReplMode { Variable = 0, Command, Auto }; using DAPTransport = - lldb_private::Transport<protocol::Request, protocol::Response, - protocol::Event>; + lldb_private::JSONTransport<protocol::Id, protocol::Request, + protocol::Response, protocol::Event>; -struct DAP final : private DAPTransport::MessageHandler { +struct DAP final : public DAPTransport::MessageHandler { /// Path to the lldb-dap binary itself. static llvm::StringRef debug_adapter_path; diff --git a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h index 0a9ef538a7398..92e41b1dbf595 100644 --- a/lldb/tools/lldb-dap/Protocol/ProtocolBase.h +++ b/lldb/tools/lldb-dap/Protocol/ProtocolBase.h @@ -30,6 +30,8 @@ namespace lldb_dap::protocol { // MARK: Base Protocol +using Id = int64_t; + /// A client or debug adapter initiated request. struct Request { /// Sequence number of the message (also known as message ID). The `seq` for @@ -39,7 +41,7 @@ struct Request { /// associate requests with their corresponding responses. For protocol /// messages of type `request` the sequence number can be used to cancel the /// request. - int64_t seq; + Id seq; /// The command to execute. std::string command; @@ -76,7 +78,7 @@ enum ResponseMessage : unsigned { /// Response for a request. struct Response { /// Sequence number of the corresponding request. - int64_t request_seq; + Id request_seq; /// The command requested. std::string command; diff --git a/lldb/tools/lldb-dap/Transport.h b/lldb/tools/lldb-dap/Transport.h index 4a9dd76c2303e..6462c155eb9af 100644 --- a/lldb/tools/lldb-dap/Transport.h +++ b/lldb/tools/lldb-dap/Transport.h @@ -24,9 +24,9 @@ namespace lldb_dap { /// A transport class that performs the Debug Adapter Protocol communication /// with the client. -class Transport final - : public lldb_private::HTTPDelimitedJSONTransport< - protocol::Request, protocol::Response, protocol::Event> { +class Transport final : public lldb_private::HTTPDelimitedJSONTransport< + protocol::Id, protocol::Request, protocol::Response, + protocol::Event> { public: Transport(llvm::StringRef client_name, lldb_dap::Log *log, lldb::IOObjectSP input, lldb::IOObjectSP output); diff --git a/lldb/unittests/DAP/DAPTest.cpp b/lldb/unittests/DAP/DAPTest.cpp index 2090fe6896d6b..4fd6cd546e6fa 100644 --- a/lldb/unittests/DAP/DAPTest.cpp +++ b/lldb/unittests/DAP/DAPTest.cpp @@ -9,13 +9,10 @@ #include "DAP.h" #include "Protocol/ProtocolBase.h" #include "TestBase.h" -#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include <optional> -using namespace llvm; -using namespace lldb; using namespace lldb_dap; using namespace lldb_dap_tests; using namespace lldb_dap::protocol; @@ -24,18 +21,7 @@ using namespace testing; class DAPTest : public TransportBase {}; TEST_F(DAPTest, SendProtocolMessages) { - DAP dap{ - /*log=*/nullptr, - /*default_repl_mode=*/ReplMode::Auto, - /*pre_init_commands=*/{}, - /*no_lldbinit=*/false, - /*client_name=*/"test_client", - /*transport=*/*transport, - /*loop=*/loop, - }; - dap.Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); - loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); - EXPECT_CALL(client, Received(IsEvent("my-event", std::nullopt))); - ASSERT_THAT_ERROR(dap.Loop(), llvm::Succeeded()); + dap->Send(Event{/*event=*/"my-event", /*body=*/std::nullopt}); + EXPECT_CALL(client, Received(IsEvent("my-event"))); + Run(); } diff --git a/lldb/unittests/DAP/Handler/DisconnectTest.cpp b/lldb/unittests/DAP/Handler/DisconnectTest.cpp index c6ff1f90b01d5..88d6e9a69eca3 100644 --- a/lldb/unittests/DAP/Handler/DisconnectTest.cpp +++ b/lldb/unittests/DAP/Handler/DisconnectTest.cpp @@ -31,7 +31,7 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminated) { DisconnectRequestHandler handler(*dap); ASSERT_THAT_ERROR(handler.Run(std::nullopt), Succeeded()); EXPECT_CALL(client, Received(IsEvent("terminated", _))); - RunOnce(); + Run(); } TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { @@ -53,5 +53,5 @@ TEST_F(DisconnectRequestHandlerTest, DisconnectTriggersTerminateCommands) { EXPECT_CALL(client, Received(Output("(lldb) script print(2)\n"))); EXPECT_CALL(client, Received(Output("Running terminateCommands:\n"))); EXPECT_CALL(client, Received(IsEvent("terminated", _))); - RunOnce(); + Run(); } diff --git a/lldb/unittests/DAP/TestBase.cpp b/lldb/unittests/DAP/TestBase.cpp index ba7baf2103799..3721e09d8b699 100644 --- a/lldb/unittests/DAP/TestBase.cpp +++ b/lldb/unittests/DAP/TestBase.cpp @@ -32,23 +32,9 @@ using lldb_private::FileSystem; using lldb_private::MainLoop; using lldb_private::Pipe; -Expected<MainLoop::ReadHandleUP> -TestTransport::RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) { - Expected<lldb::FileUP> dummy_file = FileSystem::Instance().Open( - FileSpec(FileSystem::DEV_NULL), File::eOpenOptionReadWrite); - if (!dummy_file) - return dummy_file.takeError(); - m_dummy_file = std::move(*dummy_file); - lldb_private::Status status; - auto handle = loop.RegisterReadObject( - m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); - if (status.Fail()) - return status.takeError(); - return handle; -} +void TransportBase::SetUp() { + std::tie(to_client, to_server) = TestDAPTransport::createPair(); -void DAPTestBase::SetUp() { - TransportBase::SetUp(); std::error_code EC; log = std::make_unique<Log>("-", EC); dap = std::make_unique<DAP>( @@ -57,16 +43,30 @@ void DAPTestBase::SetUp() { /*pre_init_commands=*/std::vector<std::string>(), /*no_lldbinit=*/false, /*client_name=*/"test_client", - /*transport=*/*transport, /*loop=*/loop); + /*transport=*/*to_client, /*loop=*/loop); + + auto server_handle = to_server->RegisterMessageHandler(loop, *dap.get()); + EXPECT_THAT_EXPECTED(server_handle, Succeeded()); + handles[0] = std::move(*server_handle); + + auto client_handle = to_client->RegisterMessageHandler(loop, client); + EXPECT_THAT_EXPECTED(client_handle, Succeeded()); + handles[1] = std::move(*client_handle); } +void TransportBase::Run() { + loop.AddPendingCallback( + [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), llvm::Succeeded()); +} + +void DAPTestBase::SetUp() { TransportBase::SetUp(); } + void DAPTestBase::TearDown() { - if (core) { + if (core) ASSERT_THAT_ERROR(core->discard(), Succeeded()); - } - if (binary) { + if (binary) ASSERT_THAT_ERROR(binary->discard(), Succeeded()); - } } void DAPTestBase::SetUpTestSuite() { diff --git a/lldb/unittests/DAP/TestBase.h b/lldb/unittests/DAP/TestBase.h index c19eead4e37e7..aaeab3b3d2cd9 100644 --- a/lldb/unittests/DAP/TestBase.h +++ b/lldb/unittests/DAP/TestBase.h @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "DAP.h" +#include "DAPLog.h" #include "Protocol/ProtocolBase.h" #include "TestingSupport/Host/JSONTransportTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" @@ -14,66 +15,41 @@ #include "lldb/Host/HostInfo.h" #include "lldb/Host/MainLoop.h" #include "lldb/Host/MainLoopBase.h" -#include "lldb/lldb-forward.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" -#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include <memory> +#include <optional> + +/// Helpers for gtest printing. +namespace lldb_dap::protocol { + +inline void PrintTo(const Request &req, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(req)).str(); +} + +inline void PrintTo(const Response &resp, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(resp)).str(); +} + +inline void PrintTo(const Event &evt, std::ostream *os) { + *os << llvm::formatv("{0}", toJSON(evt)).str(); +} + +inline void PrintTo(const Message &message, std::ostream *os) { + return std::visit([os](auto &&message) { return PrintTo(message, os); }, + message); +} + +} // namespace lldb_dap::protocol namespace lldb_dap_tests { -class TestTransport final - : public lldb_private::Transport<lldb_dap::protocol::Request, - lldb_dap::protocol::Response, - lldb_dap::protocol::Event> { -public: - using Message = lldb_private::Transport<lldb_dap::protocol::Request, - lldb_dap::protocol::Response, - lldb_dap::protocol::Event>::Message; - - TestTransport(lldb_private::MainLoop &loop, MessageHandler &handler) - : m_loop(loop), m_handler(handler) {} - - llvm::Error Send(const lldb_dap::protocol::Event &e) override { - m_loop.AddPendingCallback([this, e](lldb_private::MainLoopBase &) { - this->m_handler.Received(e); - }); - return llvm::Error::success(); - } - - llvm::Error Send(const lldb_dap::protocol::Request &r) override { - m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.Received(r); - }); - return llvm::Error::success(); - } - - llvm::Error Send(const lldb_dap::protocol::Response &r) override { - m_loop.AddPendingCallback([this, r](lldb_private::MainLoopBase &) { - this->m_handler.Received(r); - }); - return llvm::Error::success(); - } - - llvm::Expected<lldb_private::MainLoop::ReadHandleUP> - RegisterMessageHandler(lldb_private::MainLoop &loop, - MessageHandler &handler) override; - - void Log(llvm::StringRef message) override { - log_messages.emplace_back(message); - } - - std::vector<std::string> log_messages; - -private: - lldb_private::MainLoop &m_loop; - MessageHandler &m_handler; - lldb::FileSP m_dummy_file; -}; +using TestDAPTransport = + TestTransport<int64_t, lldb_dap::protocol::Request, + lldb_dap::protocol::Response, lldb_dap::protocol::Event>; /// A base class for tests that need transport configured for communicating DAP /// messages. @@ -82,22 +58,38 @@ class TransportBase : public testing::Test { lldb_private::SubsystemRAII<lldb_private::FileSystem, lldb_private::HostInfo> subsystems; lldb_private::MainLoop loop; - std::unique_ptr<TestTransport> transport; - MockMessageHandler<lldb_dap::protocol::Request, lldb_dap::protocol::Response, - lldb_dap::protocol::Event> + lldb_private::MainLoop::ReadHandleUP handles[2]; + + std::unique_ptr<lldb_dap::Log> log; + + std::unique_ptr<TestDAPTransport> to_client; + MockMessageHandler<int64_t, lldb_dap::protocol::Request, + lldb_dap::protocol::Response, lldb_dap::protocol::Event> client; - void SetUp() override { - transport = std::make_unique<TestTransport>(loop, client); - } + std::unique_ptr<TestDAPTransport> to_server; + std::unique_ptr<lldb_dap::DAP> dap; + + void SetUp() override; + + void Run(); }; /// A matcher for a DAP event. -template <typename M1, typename M2> +template <typename EventMatcher, typename BodyMatcher> inline testing::Matcher<const lldb_dap::protocol::Event &> -IsEvent(const M1 &m1, const M2 &m2) { - return testing::AllOf(testing::Field(&lldb_dap::protocol::Event::event, m1), - testing::Field(&lldb_dap::protocol::Event::body, m2)); +IsEvent(const EventMatcher &event_matcher, const BodyMatcher &body_matcher) { + return testing::AllOf( + testing::Field(&lldb_dap::protocol::Event::event, event_matcher), + testing::Field(&lldb_dap::protocol::Event::body, body_matcher)); +} + +template <typename EventMatcher> +inline testing::Matcher<const lldb_dap::protocol::Event &> +IsEvent(const EventMatcher &event_matcher) { + return testing::AllOf( + testing::Field(&lldb_dap::protocol::Event::event, event_matcher), + testing::Field(&lldb_dap::protocol::Event::body, std::nullopt)); } /// Matches an "output" event. @@ -110,8 +102,6 @@ inline auto Output(llvm::StringRef o, llvm::StringRef cat = "console") { /// A base class for tests that interact with a `lldb_dap::DAP` instance. class DAPTestBase : public TransportBase { protected: - std::unique_ptr<lldb_dap::Log> log; - std::unique_ptr<lldb_dap::DAP> dap; std::optional<llvm::sys::fs::TempFile> core; std::optional<llvm::sys::fs::TempFile> binary; @@ -126,12 +116,6 @@ class DAPTestBase : public TransportBase { bool GetDebuggerSupportsTarget(llvm::StringRef platform); void CreateDebugger(); void LoadCore(); - - void RunOnce() { - loop.AddPendingCallback( - [](lldb_private::MainLoopBase &loop) { loop.RequestTermination(); }); - ASSERT_THAT_ERROR(dap->Loop(), llvm::Succeeded()); - } }; } // namespace lldb_dap_tests diff --git a/lldb/unittests/Host/JSONTransportTest.cpp b/lldb/unittests/Host/JSONTransportTest.cpp index 445674f402252..b2853cfc7d73e 100644 --- a/lldb/unittests/Host/JSONTransportTest.cpp +++ b/lldb/unittests/Host/JSONTransportTest.cpp @@ -25,6 +25,7 @@ #include <chrono> #include <cstddef> #include <memory> +#include <optional> #include <string> using namespace llvm; @@ -32,20 +33,35 @@ using namespace lldb_private; using testing::_; using testing::HasSubstr; using testing::InSequence; +using testing::Ref; + +namespace llvm::json { +static bool fromJSON(const Value &V, Value &T, Path P) { + T = V; + return true; +} +} // namespace llvm::json namespace { namespace test_protocol { struct Req { + int id = 0; std::string name; + std::optional<json::Value> params; }; -json::Value toJSON(const Req &T) { return json::Object{{"req", T.name}}; } +json::Value toJSON(const Req &T) { + return json::Object{{"name", T.name}, {"id", T.id}, {"params", T.params}}; +} bool fromJSON(const json::Value &V, Req &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("req", T.name); + return O && O.map("name", T.name) && O.map("id", T.id) && + O.map("params", T.params); +} +bool operator==(const Req &a, const Req &b) { + return a.name == b.name && a.id == b.id && a.params == b.params; } -bool operator==(const Req &a, const Req &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Req &V) { OS << toJSON(V); return OS; @@ -58,14 +74,19 @@ void PrintTo(const Req &message, std::ostream *os) { } struct Resp { - std::string name; + int id = 0; + std::optional<json::Value> result; }; -json::Value toJSON(const Resp &T) { return json::Object{{"resp", T.name}}; } +json::Value toJSON(const Resp &T) { + return json::Object{{"id", T.id}, {"result", T.result}}; +} bool fromJSON(const json::Value &V, Resp &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("resp", T.name); + return O && O.map("id", T.id) && O.map("result", T.result); +} +bool operator==(const Resp &a, const Resp &b) { + return a.id == b.id && a.result == b.result; } -bool operator==(const Resp &a, const Resp &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Resp &V) { OS << toJSON(V); return OS; @@ -79,11 +100,14 @@ void PrintTo(const Resp &message, std::ostream *os) { struct Evt { std::string name; + std::optional<json::Value> params; }; -json::Value toJSON(const Evt &T) { return json::Object{{"evt", T.name}}; } +json::Value toJSON(const Evt &T) { + return json::Object{{"name", T.name}, {"params", T.params}}; +} bool fromJSON(const json::Value &V, Evt &T, json::Path P) { json::ObjectMapper O(V, P); - return O && O.map("evt", T.name); + return O && O.map("name", T.name) && O.map("params", T.params); } bool operator==(const Evt &a, const Evt &b) { return a.name == b.name; } inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, const Evt &V) { @@ -107,41 +131,62 @@ bool fromJSON(const json::Value &V, Message &msg, json::Path P) { P.report("expected object"); return false; } - if (O->get("req")) { - Req R; - if (!fromJSON(V, R, P)) + + if (O->find("id") == O->end()) { + Evt E; + if (!fromJSON(V, E, P)) return false; - msg = std::move(R); + msg = std::move(E); return true; } - if (O->get("resp")) { - Resp R; + + if (O->get("name")) { + Req R; if (!fromJSON(V, R, P)) return false; msg = std::move(R); return true; } - if (O->get("evt")) { - Evt E; - if (!fromJSON(V, E, P)) - return false; - msg = std::move(E); - return true; - } - P.report("unknown message type"); - return false; + Resp R; + if (!fromJSON(V, R, P)) + return false; + + msg = std::move(R); + return true; } -} // namespace test_protocol +struct MyFnParams { + int a = 0; + int b = 0; +}; +json::Value toJSON(const MyFnParams &T) { + return json::Object{{"a", T.a}, {"b", T.b}}; +} +bool fromJSON(const json::Value &V, MyFnParams &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("a", T.a) && O.map("b", T.b); +} + +struct MyFnResult { + int c = 0; +}; +json::Value toJSON(const MyFnResult &T) { return json::Object{{"c", T.c}}; } +bool fromJSON(const json::Value &V, MyFnResult &T, json::Path P) { + json::ObjectMapper O(V, P); + return O && O.map("c", T.c); +} -template <typename T, typename Req, typename Resp, typename Evt> -class JSONTransportTest : public PipePairTest { +using Transport = TestTransport<int, Req, Resp, Evt>; +using MessageHandler = MockMessageHandler<int, Req, Resp, Evt>; +} // namespace test_protocol + +template <typename T> class JSONTransportTest : public PipePairTest { protected: - MockMessageHandler<Req, Resp, Evt> message_handler; + test_protocol::MessageHandler message_handler; std::unique_ptr<T> transport; MainLoop loop; @@ -191,8 +236,8 @@ class JSONTransportTest : public PipePairTest { }; class TestHTTPDelimitedJSONTransport final - : public HTTPDelimitedJSONTransport<test_protocol::Req, test_protocol::Resp, - test_protocol::Evt> { + : public HTTPDelimitedJSONTransport< + int, test_protocol::Req, test_protocol::Resp, test_protocol::Evt> { public: using HTTPDelimitedJSONTransport::HTTPDelimitedJSONTransport; @@ -204,9 +249,7 @@ class TestHTTPDelimitedJSONTransport final }; class HTTPDelimitedJSONTransportTest - : public JSONTransportTest<TestHTTPDelimitedJSONTransport, - test_protocol::Req, test_protocol::Resp, - test_protocol::Evt> { + : public JSONTransportTest<TestHTTPDelimitedJSONTransport> { public: using JSONTransportTest::JSONTransportTest; @@ -222,7 +265,7 @@ class HTTPDelimitedJSONTransportTest }; class TestJSONRPCTransport final - : public JSONRPCTransport<test_protocol::Req, test_protocol::Resp, + : public JSONRPCTransport<int, test_protocol::Req, test_protocol::Resp, test_protocol::Evt> { public: using JSONRPCTransport::JSONRPCTransport; @@ -234,9 +277,7 @@ class TestJSONRPCTransport final std::vector<std::string> log_messages; }; -class JSONRPCTransportTest - : public JSONTransportTest<TestJSONRPCTransport, test_protocol::Req, - test_protocol::Resp, test_protocol::Evt> { +class JSONRPCTransportTest : public JSONTransportTest<TestJSONRPCTransport> { public: using JSONTransportTest::JSONTransportTest; @@ -248,8 +289,69 @@ class JSONRPCTransportTest } }; +class TestTransportBinder : public testing::Test { +protected: + std::unique_ptr<test_protocol::Transport> to_remote; + std::unique_ptr<test_protocol::Transport> from_remote; + std::unique_ptr<test_protocol::Transport::Binder> binder; + test_protocol::MessageHandler remote; + MainLoop loop; + + void SetUp() override { + std::tie(to_remote, from_remote) = test_protocol::Transport::createPair(); + binder = std::make_unique<test_protocol::Transport::Binder>(*to_remote); + + auto binder_handle = to_remote->RegisterMessageHandler(loop, remote); + EXPECT_THAT_EXPECTED(binder_handle, Succeeded()); + + auto remote_handle = from_remote->RegisterMessageHandler(loop, *binder); + EXPECT_THAT_EXPECTED(remote_handle, Succeeded()); + } + + void Run() { + loop.AddPendingCallback([](auto &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); + } +}; + } // namespace +namespace lldb_private { +using namespace test_protocol; +template <> +inline test_protocol::Req make_request(int id, llvm::StringRef method, + std::optional<json::Value> params) { + return test_protocol::Req{id, method.str(), params}; +} +template <> inline Resp make_response(const Req &req, llvm::Error error) { + llvm::consumeError(std::move(error)); + return Resp{req.id, std::nullopt}; +} +template <> inline Resp make_response(const Req &req, json::Value result) { + return Resp{req.id, std::move(result)}; +} +template <> +inline Evt make_event(llvm::StringRef method, + std::optional<json::Value> params) { + return Evt{method.str(), params}; +} + +template <> inline llvm::Expected<json::Value> get_result(const Resp &resp) { + return resp.result; +} + +template <> inline int get_id(const Resp &resp) { return resp.id; } +template <> inline llvm::StringRef get_method(const Req &req) { + return req.name; +} +template <> inline llvm::StringRef get_method(const Evt &evt) { + return evt.name; +} +template <> inline json::Value get_params(const Req &req) { return req.params; } +template <> inline json::Value get_params(const Evt &evt) { return evt.params; } + +} // namespace lldb_private + // Failing on Windows, see https://github.com/llvm/llvm-project/issues/153446. #ifndef _WIN32 using namespace test_protocol; @@ -269,35 +371,47 @@ TEST_F(HTTPDelimitedJSONTransportTest, MalformedRequests) { } TEST_F(HTTPDelimitedJSONTransportTest, Read) { - Write(Req{"foo"}); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + Write(Req{6, "foo", std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadMultipleMessagesInSingleWrite) { InSequence seq; - Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); - EXPECT_CALL(message_handler, Received(Req{"one"})); - EXPECT_CALL(message_handler, Received(Evt{"two"})); - EXPECT_CALL(message_handler, Received(Resp{"three"})); + Write( + Message{ + Req{6, "one", std::nullopt}, + }, + Message{ + Evt{"two", std::nullopt}, + }, + Message{ + Resp{2, std::nullopt}, + }); + EXPECT_CALL(message_handler, Received(Req{6, "one", std::nullopt})); + EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); + EXPECT_CALL(message_handler, Received(Resp{2, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadAcrossMultipleChunks) { std::string long_str = std::string( - HTTPDelimitedJSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x'); - Write(Req{long_str}); - EXPECT_CALL(message_handler, Received(Req{long_str})); + HTTPDelimitedJSONTransport<int, test_protocol::Req, test_protocol::Resp, + test_protocol::Evt>::kReadBufferSize * + 2, + 'x'); + Write(Req{5, long_str, std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{5, long_str, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { - std::string message = Encode(Req{"foo"}); + std::string message = Encode(Req{5, "foo", std::nullopt}); auto split_at = message.size() / 2; std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{5, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); loop.AddPendingCallback( @@ -309,12 +423,12 @@ TEST_F(HTTPDelimitedJSONTransportTest, ReadPartialMessage) { } TEST_F(HTTPDelimitedJSONTransportTest, ReadWithZeroByteWrites) { - std::string message = Encode(Req{"foo"}); + std::string message = Encode(Req{6, "foo", std::nullopt}); auto split_at = message.size() / 2; std::string part1 = message.substr(0, split_at); std::string part2 = message.substr(split_at); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{6, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); @@ -366,20 +480,21 @@ TEST_F(HTTPDelimitedJSONTransportTest, InvalidTransport) { } TEST_F(HTTPDelimitedJSONTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{7, "foo", std::nullopt}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{5, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected<size_t> bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef("Content-Length: 13\r\n\r\n" - R"({"req":"foo"})" - "Content-Length: 14\r\n\r\n" - R"({"resp":"bar"})" - "Content-Length: 13\r\n\r\n" - R"({"evt":"baz"})")); + ASSERT_EQ(StringRef(buf, *bytes_read), + StringRef("Content-Length: 35\r\n\r\n" + R"({"id":7,"name":"foo","params":null})" + "Content-Length: 23\r\n\r\n" + R"({"id":5,"result":"bar"})" + "Content-Length: 28\r\n\r\n" + R"({"name":"baz","params":null})")); } TEST_F(JSONRPCTransportTest, MalformedRequests) { @@ -395,17 +510,18 @@ TEST_F(JSONRPCTransportTest, MalformedRequests) { } TEST_F(JSONRPCTransportTest, Read) { - Write(Message{Req{"foo"}}); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + Write(Message{Req{1, "foo", std::nullopt}}); + EXPECT_CALL(message_handler, Received(Req{1, "foo", std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadMultipleMessagesInSingleWrite) { InSequence seq; - Write(Message{Req{"one"}}, Message{Evt{"two"}}, Message{Resp{"three"}}); - EXPECT_CALL(message_handler, Received(Req{"one"})); - EXPECT_CALL(message_handler, Received(Evt{"two"})); - EXPECT_CALL(message_handler, Received(Resp{"three"})); + Write(Message{Req{1, "one", std::nullopt}}, Message{Evt{"two", std::nullopt}}, + Message{Resp{3, "three"}}); + EXPECT_CALL(message_handler, Received(Req{1, "one", std::nullopt})); + EXPECT_CALL(message_handler, Received(Evt{"two", std::nullopt})); + EXPECT_CALL(message_handler, Received(Resp{3, "three"})); ASSERT_THAT_ERROR(Run(), Succeeded()); } @@ -413,19 +529,22 @@ TEST_F(JSONRPCTransportTest, ReadAcrossMultipleChunks) { // Use a string longer than the chunk size to ensure we split the message // across the chunk boundary. std::string long_str = - std::string(JSONTransport<Req, Resp, Evt>::kReadBufferSize * 2, 'x'); - Write(Req{long_str}); - EXPECT_CALL(message_handler, Received(Req{long_str})); + std::string(IOTransport<int, test_protocol::Req, test_protocol::Resp, + test_protocol::Evt>::kReadBufferSize * + 2, + 'x'); + Write(Req{42, long_str, std::nullopt}); + EXPECT_CALL(message_handler, Received(Req{42, long_str, std::nullopt})); ASSERT_THAT_ERROR(Run(), Succeeded()); } TEST_F(JSONRPCTransportTest, ReadPartialMessage) { - std::string message = R"({"req": "foo"})" + std::string message = R"({"id":42,"name":"foo","params":null})" "\n"; std::string part1 = message.substr(0, 7); std::string part2 = message.substr(7); - EXPECT_CALL(message_handler, Received(Req{"foo"})); + EXPECT_CALL(message_handler, Received(Req{42, "foo", std::nullopt})); ASSERT_THAT_EXPECTED(input.Write(part1.data(), part1.size()), Succeeded()); loop.AddPendingCallback( @@ -455,20 +574,21 @@ TEST_F(JSONRPCTransportTest, ReaderWithUnhandledData) { } TEST_F(JSONRPCTransportTest, Write) { - ASSERT_THAT_ERROR(transport->Send(Req{"foo"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Resp{"bar"}), Succeeded()); - ASSERT_THAT_ERROR(transport->Send(Evt{"baz"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Req{11, "foo", std::nullopt}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Resp{14, "bar"}), Succeeded()); + ASSERT_THAT_ERROR(transport->Send(Evt{"baz", std::nullopt}), Succeeded()); output.CloseWriteFileDescriptor(); char buf[1024]; Expected<size_t> bytes_read = output.Read(buf, sizeof(buf), std::chrono::milliseconds(1)); ASSERT_THAT_EXPECTED(bytes_read, Succeeded()); - ASSERT_EQ(StringRef(buf, *bytes_read), StringRef(R"({"req":"foo"})" - "\n" - R"({"resp":"bar"})" - "\n" - R"({"evt":"baz"})" - "\n")); + ASSERT_EQ(StringRef(buf, *bytes_read), + StringRef(R"({"id":11,"name":"foo","params":null})" + "\n" + R"({"id":14,"result":"bar"})" + "\n" + R"({"name":"baz","params":null})" + "\n")); } TEST_F(JSONRPCTransportTest, InvalidTransport) { @@ -477,4 +597,58 @@ TEST_F(JSONRPCTransportTest, InvalidTransport) { FailedWithMessage("IO object is not valid.")); } +// Out-bound binding request handler. +TEST_F(TestTransportBinder, OutBoundRequests) { + auto addFn = binder->bind<MyFnResult, MyFnParams>("add"); + addFn(MyFnParams{1, 2}, [](Expected<MyFnResult> result) { + EXPECT_THAT_EXPECTED(result, Succeeded()); + EXPECT_EQ(result->c, 3); + }); + EXPECT_CALL(remote, Received(Req{1, "add", MyFnParams{1, 2}})); + // Queue a reply that will be sent during 'Run'. + EXPECT_THAT_ERROR(from_remote->Send(Resp{1, toJSON(MyFnResult{3})}), + Succeeded()); + Run(); +} + +// In-bound binding request handler. +TEST_F(TestTransportBinder, InBoundRequests) { + binder->bind<MyFnResult, MyFnParams>( + "add", + [](int captured_param, const MyFnParams ¶ms) -> Expected<MyFnResult> { + return MyFnResult{params.a + params.b + captured_param}; + }, + 2, std::placeholders::_1); + EXPECT_THAT_ERROR(from_remote->Send(Req{2, "add", MyFnParams{3, 4}}), + Succeeded()); + EXPECT_CALL(remote, Received(Resp{2, MyFnResult{9}})); + Run(); +} + +// Out-bound binding event handler. +TEST_F(TestTransportBinder, OutBoundEvents) { + auto emitEvent = binder->bind<MyFnParams>("evt"); + emitEvent(MyFnParams{1, 2}); + EXPECT_CALL(remote, Received(Evt{"evt", MyFnParams{1, 2}})); + Run(); +} + +// In-bound binding event handler. +TEST_F(TestTransportBinder, InBoundEvents) { + bool called = false; + binder->bind<MyFnParams>( + "evt", + [&](int captured_arg, const MyFnParams ¶ms) { + EXPECT_EQ(captured_arg, 42); + EXPECT_EQ(params.a, 3); + EXPECT_EQ(params.b, 4); + called = true; + }, + 42, std::placeholders::_1); + EXPECT_THAT_ERROR(from_remote->Send(Evt{"evt", MyFnParams{3, 4}}), + Succeeded()); + Run(); + EXPECT_TRUE(called); +} + #endif diff --git a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp index f686255c6d41d..5e43fb026197a 100644 --- a/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPServerTest.cpp @@ -6,9 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "ProtocolMCPTestUtilities.h" +#include "ProtocolMCPTestUtilities.h" // IWYU pragma: keep #include "TestingSupport/Host/JSONTransportTestUtilities.h" -#include "TestingSupport/Host/PipeTestUtilities.h" #include "TestingSupport/SubsystemRAII.h" #include "lldb/Host/FileSystem.h" #include "lldb/Host/HostInfo.h" @@ -28,20 +27,21 @@ #include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include <chrono> -#include <condition_variable> +#include <future> +#include <memory> +#include <optional> using namespace llvm; using namespace lldb; using namespace lldb_private; using namespace lldb_protocol::mcp; +using testing::_; namespace { -class TestServer : public Server { -public: - using Server::Server; -}; +template <typename T> Response make_response(T &&result, Id id = 1) { + return Response{id, std::forward<T>(result)}; +} /// Test tool that returns it argument as text. class TestTool : public Tool { @@ -118,175 +118,211 @@ class FailTool : public Tool { } }; -class ProtocolServerMCPTest : public PipePairTest { +class TestServer : public Server { +public: + using Server::Bind; + using Server::Server; +}; + +using Transport = TestTransport<int64_t, lldb_protocol::mcp::Request, + lldb_protocol::mcp::Response, + lldb_protocol::mcp::Notification>; + +class ProtocolServerMCPTest : public testing::Test { public: SubsystemRAII<FileSystem, HostInfo, Socket> subsystems; - std::unique_ptr<lldb_protocol::mcp::Transport> transport_up; - std::unique_ptr<TestServer> server_up; MainLoop loop; - MockMessageHandler<Request, Response, Notification> message_handler; + lldb_private::MainLoop::ReadHandleUP handles[2]; - llvm::Error Write(llvm::StringRef message) { - llvm::Expected<json::Value> value = json::parse(message); - if (!value) - return value.takeError(); - return transport_up->Write(*value); - } + std::unique_ptr<Transport> to_server; + Transport::BinderUP binder; + std::unique_ptr<TestServer> server_up; - llvm::Error Write(json::Value value) { return transport_up->Write(value); } + std::unique_ptr<Transport> to_client; + MockMessageHandler<int64_t, Request, Response, Notification> client; - /// Run the transport MainLoop and return any messages received. - llvm::Error - Run(std::chrono::milliseconds timeout = std::chrono::milliseconds(200)) { - loop.AddCallback([](MainLoopBase &loop) { loop.RequestTermination(); }, - timeout); - auto handle = transport_up->RegisterMessageHandler(loop, message_handler); - if (!handle) - return handle.takeError(); + std::vector<std::string> logged_messages; - return server_up->Run(); + /// Runs the MainLoop a single time, executing any pending callbacks. + void Run() { + loop.AddPendingCallback( + [](MainLoopBase &loop) { loop.RequestTermination(); }); + EXPECT_THAT_ERROR(loop.Run().takeError(), Succeeded()); } void SetUp() override { - PipePairTest::SetUp(); - - transport_up = std::make_unique<lldb_protocol::mcp::Transport>( - std::make_shared<NativeFile>(input.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared<NativeFile>(output.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)); + std::tie(to_client, to_server) = Transport::createPair(); server_up = std::make_unique<TestServer>( "lldb-mcp", "0.1.0", - std::make_unique<lldb_protocol::mcp::Transport>( - std::make_shared<NativeFile>(output.GetReadFileDescriptor(), - File::eOpenOptionReadOnly, - NativeFile::Unowned), - std::make_shared<NativeFile>(input.GetWriteFileDescriptor(), - File::eOpenOptionWriteOnly, - NativeFile::Unowned)), - loop); + [this](StringRef msg) { logged_messages.push_back(msg.str()); }); + binder = server_up->Bind(*to_client); + auto server_handle = to_server->RegisterMessageHandler(loop, *binder); + EXPECT_THAT_EXPECTED(server_handle, Succeeded()); + binder->error( + [](llvm::Error error) { + llvm::errs() << formatv("Server transport error: {0}", error); + }, + std::placeholders::_1); + handles[0] = std::move(*server_handle); + + auto client_handle = to_client->RegisterMessageHandler(loop, client); + EXPECT_THAT_EXPECTED(client_handle, Succeeded()); + handles[1] = std::move(*client_handle); + } + + template <typename Result, typename Params> + Expected<json::Value> Call(StringRef method, const Params ¶ms) { + std::promise<Response> promised_result; + Request req = make_request<int64_t, lldb_protocol::mcp::Request>( + /*id=*/1, method, toJSON(params)); + EXPECT_THAT_ERROR(to_server->Send(req), Succeeded()); + EXPECT_CALL(client, Received(testing::An<const Response &>())) + .WillOnce( + [&](const Response &resp) { promised_result.set_value(resp); }); + Run(); + Response resp = promised_result.get_future().get(); + return toJSON(resp); + } + + template <typename Result> + Expected<json::Value> + Capture(llvm::unique_function<void(Reply<Result>)> &fn) { + std::promise<llvm::Expected<Result>> promised_result; + fn([&promised_result](llvm::Expected<Result> result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected<Result> result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); + } + + template <typename Result, typename Params> + Expected<json::Value> + Capture(llvm::unique_function<void(const Params &, Reply<Result>)> &fn, + const Params ¶ms) { + std::promise<llvm::Expected<Result>> promised_result; + fn(params, [&promised_result](llvm::Expected<Result> result) { + promised_result.set_value(std::move(result)); + }); + Run(); + llvm::Expected<Result> result = promised_result.get_future().get(); + if (!result) + return result.takeError(); + return toJSON(*result); } }; template <typename T> -Request make_request(StringLiteral method, T &¶ms, Id id = 1) { - return Request{id, method.str(), toJSON(std::forward<T>(params))}; -} - -template <typename T> Response make_response(T &&result, Id id = 1) { - return Response{id, std::forward<T>(result)}; +inline testing::internal::EqMatcher<llvm::json::Value> HasJSON(T x) { + return testing::internal::EqMatcher<llvm::json::Value>(toJSON(x)); } } // namespace TEST_F(ProtocolServerMCPTest, Initialization) { - Request request = make_request( - "initialize", InitializeParams{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{}, - /*clientInfo=*/{"lldb-unit", "0.1.0"}}); - Response response = make_response( - InitializeResult{/*protocolVersion=*/"2024-11-05", - /*capabilities=*/{/*supportsToolsList=*/true}, - /*serverInfo=*/{"lldb-mcp", "0.1.0"}}); - - ASSERT_THAT_ERROR(Write(request), Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call<InitializeResult, InitializeParams>( + "initialize", + InitializeParams{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/{}, + /*clientInfo=*/{"lldb-unit", "0.1.0"}})), + HasValue(make_response( + InitializeResult{/*protocolVersion=*/"2024-11-05", + /*capabilities=*/ + { + /*supportsToolsList=*/true, + /*supportsResourcesList=*/true, + }, + /*serverInfo=*/{"lldb-mcp", "0.1.0"}}))); } TEST_F(ProtocolServerMCPTest, ToolsList) { server_up->AddTool(std::make_unique<TestTool>("test", "test tool")); - Request request = make_request("tools/list", Void{}, /*id=*/"one"); - ToolDefinition test_tool; test_tool.name = "test"; test_tool.description = "test tool"; test_tool.inputSchema = json::Object{{"type", "object"}}; - Response response = make_response(ListToolsResult{{test_tool}}, /*id=*/"one"); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call<ListToolsResult>("tools/list", Void{}), + HasValue(make_response(ListToolsResult{{test_tool}}))); } TEST_F(ProtocolServerMCPTest, ResourcesList) { server_up->AddResourceProvider(std::make_unique<TestResourceProvider>()); - Request request = make_request("resources/list", Void{}); - Response response = make_response(ListResourcesResult{ - {{/*uri=*/"lldb://foo/bar", /*name=*/"name", - /*description=*/"description", /*mimeType=*/"application/json"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED(Call<ListResourcesResult>("resources/list", Void{}), + HasValue(make_response(ListResourcesResult{{ + { + /*uri=*/"lldb://foo/bar", + /*name=*/"name", + /*description=*/"description", + /*mimeType=*/"application/json", + }, + }}))); } TEST_F(ProtocolServerMCPTest, ToolsCall) { server_up->AddTool(std::make_unique<TestTool>("test", "test tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"test", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = make_response(CallToolResult{{{/*text=*/"foo"}}}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED( + (Call<CallToolResult, CallToolParams>("tools/call", + CallToolParams{ + /*name=*/"test", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{{{/*text=*/"foo"}}}))); } TEST_F(ProtocolServerMCPTest, ToolsCallError) { server_up->AddTool(std::make_unique<ErrorTool>("error", "error tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"error", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(lldb_protocol::mcp::Error{eErrorCodeInternalError, - /*message=*/"error"}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>( + "tools/call", CallToolParams{ + /*name=*/"error", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(lldb_protocol::mcp::Error{ + eErrorCodeInternalError, "error"}))); } TEST_F(ProtocolServerMCPTest, ToolsCallFail) { server_up->AddTool(std::make_unique<FailTool>("fail", "fail tool")); - Request request = make_request( - "tools/call", CallToolParams{/*name=*/"fail", /*arguments=*/json::Object{ - {"arguments", "foo"}, - {"debugger_id", 0}, - }}); - Response response = - make_response(CallToolResult{{{/*text=*/"failed"}}, /*isError=*/true}); - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_CALL(message_handler, Received(response)); - EXPECT_THAT_ERROR(Run(), Succeeded()); + EXPECT_THAT_EXPECTED((Call<CallToolResult, CallToolParams>( + "tools/call", CallToolParams{ + /*name=*/"fail", + /*arguments=*/ + json::Object{ + {"arguments", "foo"}, + {"debugger_id", 0}, + }, + })), + HasValue(make_response(CallToolResult{ + {{/*text=*/"failed"}}, + /*isError=*/true, + }))); } TEST_F(ProtocolServerMCPTest, NotificationInitialized) { - bool handler_called = false; - std::condition_variable cv; - - server_up->AddNotificationHandler( - "notifications/initialized", - [&](const Notification ¬ification) { handler_called = true; }); - llvm::StringLiteral request = - R"json({"method":"notifications/initialized","jsonrpc":"2.0"})json"; - - ASSERT_THAT_ERROR(Write(request), llvm::Succeeded()); - EXPECT_THAT_ERROR(Run(), Succeeded()); - EXPECT_TRUE(handler_called); + EXPECT_THAT_ERROR(to_server->Send(lldb_protocol::mcp::Notification{ + "notifications/initialized", + std::nullopt, + }), + Succeeded()); + Run(); + EXPECT_THAT(logged_messages, + testing::Contains("MCP initialization complete")); } diff --git a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h index 5a9eb8e59f2b6..4dbcd614e400b 100644 --- a/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h +++ b/lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h @@ -6,19 +6,105 @@ // //===----------------------------------------------------------------------===// -#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H -#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_NATIVEPROCESSTESTUTILS_H +#ifndef LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H +#define LLDB_UNITTESTS_TESTINGSUPPORT_HOST_JSONTRANSPORTTESTUTILITIES_H +#include "lldb/Host/FileSystem.h" #include "lldb/Host/JSONTransport.h" +#include "lldb/Host/MainLoop.h" +#include "lldb/Utility/FileSpec.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Testing/Support/Error.h" #include "gmock/gmock.h" +#include "gtest/gtest.h" +#include <cstddef> +#include <memory> +#include <utility> -template <typename Req, typename Resp, typename Evt> +template <typename Id, typename Req, typename Resp, typename Evt> +class TestTransport final + : public lldb_private::JSONTransport<Id, Req, Resp, Evt> { +public: + using MessageHandler = + typename lldb_private::JSONTransport<Id, Req, Resp, Evt>::MessageHandler; + + static std::pair<std::unique_ptr<TestTransport<Id, Req, Resp, Evt>>, + std::unique_ptr<TestTransport<Id, Req, Resp, Evt>>> + createPair() { + std::unique_ptr<TestTransport<Id, Req, Resp, Evt>> transports[2] = { + std::make_unique<TestTransport<Id, Req, Resp, Evt>>(), + std::make_unique<TestTransport<Id, Req, Resp, Evt>>()}; + return std::make_pair(std::move(transports[0]), std::move(transports[1])); + } + + explicit TestTransport() { + llvm::Expected<lldb::FileUP> dummy_file = + lldb_private::FileSystem::Instance().Open( + lldb_private::FileSpec(lldb_private::FileSystem::DEV_NULL), + lldb_private::File::eOpenOptionReadWrite); + EXPECT_THAT_EXPECTED(dummy_file, llvm::Succeeded()); + m_dummy_file = std::move(*dummy_file); + } + + llvm::Error Send(const Evt &evt) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, evt](lldb_private::MainLoopBase &) { + m_handler->Received(evt); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const Req &req) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, req](lldb_private::MainLoopBase &) { + m_handler->Received(req); + }); + return llvm::Error::success(); + } + + llvm::Error Send(const Resp &resp) override { + EXPECT_TRUE(m_loop && m_handler) + << "Send called before RegisterMessageHandler"; + m_loop->AddPendingCallback([this, resp](lldb_private::MainLoopBase &) { + m_handler->Received(resp); + }); + return llvm::Error::success(); + } + + llvm::Expected<lldb_private::MainLoop::ReadHandleUP> + RegisterMessageHandler(lldb_private::MainLoop &loop, + MessageHandler &handler) override { + if (!m_loop) + m_loop = &loop; + if (!m_handler) + m_handler = &handler; + lldb_private::Status status; + auto handle = loop.RegisterReadObject( + m_dummy_file, [](lldb_private::MainLoopBase &) {}, status); + if (status.Fail()) + return status.takeError(); + return handle; + } + +protected: + void Log(llvm::StringRef message) override {}; + +private: + lldb_private::MainLoop *m_loop = nullptr; + MessageHandler *m_handler = nullptr; + // Dummy file for registering with the MainLoop. + lldb::FileSP m_dummy_file = nullptr; +}; + +template <typename Id, typename Req, typename Resp, typename Evt> class MockMessageHandler final - : public lldb_private::Transport<Req, Resp, Evt>::MessageHandler { + : public lldb_private::JSONTransport<Id, Req, Resp, Evt>::MessageHandler { public: - MOCK_METHOD(void, Received, (const Evt &), (override)); MOCK_METHOD(void, Received, (const Req &), (override)); MOCK_METHOD(void, Received, (const Resp &), (override)); + MOCK_METHOD(void, Received, (const Evt &), (override)); MOCK_METHOD(void, OnError, (llvm::Error), (override)); MOCK_METHOD(void, OnClosed, (), (override)); }; _______________________________________________ lldb-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits
