llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-lldb Author: John Harrison (ashgti) <details> <summary>Changes</summary> This adds a new Binding helper class to allow mapping of incoming and outgoing requests / events to specific handlers. This should make it easier to create new protocol implementations and allow us to create a relay in the lldb-mcp binary. --- Patch is 95.85 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/159160.diff 18 Files Affected: - (modified) lldb/include/lldb/Host/JSONTransport.h (+361-26) - (modified) lldb/include/lldb/Protocol/MCP/Protocol.h (+8) - (modified) lldb/include/lldb/Protocol/MCP/Server.h (+32-41) - (modified) lldb/include/lldb/Protocol/MCP/Transport.h (+75-2) - (modified) lldb/source/Host/common/JSONTransport.cpp (+10) - (modified) lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp (+19-23) - (modified) lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h (+13-7) - (modified) lldb/source/Protocol/MCP/Server.cpp (+56-155) - (modified) lldb/tools/lldb-dap/DAP.h (+3-3) - (modified) lldb/tools/lldb-dap/Protocol/ProtocolBase.h (+4-2) - (modified) lldb/tools/lldb-dap/Transport.h (+3-3) - (modified) lldb/unittests/DAP/DAPTest.cpp (+3-17) - (modified) lldb/unittests/DAP/Handler/DisconnectTest.cpp (+2-2) - (modified) lldb/unittests/DAP/TestBase.cpp (+21-21) - (modified) lldb/unittests/DAP/TestBase.h (+53-69) - (modified) lldb/unittests/Host/JSONTransportTest.cpp (+259-79) - (modified) lldb/unittests/Protocol/ProtocolMCPServerTest.cpp (+157-123) - (modified) lldb/unittests/TestingSupport/Host/JSONTransportTestUtilities.h (+91-5) ``````````diff diff --git a/lldb/include/lldb/Host/JSONTransport.h b/lldb/include/lldb/Host/JSONTransport.h index 210f33edace6e..da1ae43118538 100644 --- a/lldb/include/lldb/Host/JSONTransport.h +++ b/lldb/include/lldb/Host/JSONTransport.h @@ -18,6 +18,7 @@ #include "lldb/Utility/IOObject.h" #include "lldb/Utility/Status.h" #include "lldb/lldb-forward.h" +#include "llvm/ADT/FunctionExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Error.h" @@ -25,8 +26,13 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" #include "llvm/Support/raw_ostream.h" +#include <functional> +#include <mutex> +#include <optional> #include <string> #include <system_error> +#include <type_traits> +#include <utility> #include <variant> #include <vector> @@ -50,17 +56,70 @@ class TransportUnhandledContentsError std::string m_unhandled_contents; }; +class InvalidParams : public llvm::ErrorInfo<InvalidParams> { +public: + static char ID; + + explicit InvalidParams(std::string method, std::string context) + : m_method(std::move(method)), m_context(std::move(context)) {} + + void log(llvm::raw_ostream &OS) const override; + std::error_code convertToErrorCode() const override; + +private: + std::string m_method; + std::string m_context; +}; + +// Value for tracking functions that have a void param or result. +using VoidT = std::monostate; + +template <typename T> using Callback = llvm::unique_function<T>; + +template <typename T> +using Reply = typename std::conditional< + std::is_same_v<T, VoidT> == true, llvm::unique_function<void(llvm::Error)>, + llvm::unique_function<void(llvm::Expected<T>)>>::type; + +template <typename Result, typename Params> +using OutgoingRequest = typename std::conditional< + std::is_same_v<Params, VoidT> == true, + llvm::unique_function<void(Reply<Result>)>, + llvm::unique_function<void(const Params &, Reply<Result>)>>::type; + +template <typename Params> +using OutgoingEvent = typename std::conditional< + std::is_same_v<Params, VoidT> == true, llvm::unique_function<void()>, + llvm::unique_function<void(const Params &)>>::type; + +template <typename Id, typename Req> +Req make_request(Id id, llvm::StringRef method, + std::optional<llvm::json::Value> params = std::nullopt); +template <typename Req, typename Resp> +Resp make_response(const Req &req, llvm::Error error); +template <typename Req, typename Resp> +Resp make_response(const Req &req, llvm::json::Value result); +template <typename Evt> +Evt make_event(llvm::StringRef method, + std::optional<llvm::json::Value> params = std::nullopt); +template <typename Resp> +llvm::Expected<llvm::json::Value> get_result(const Resp &resp); +template <typename Id, typename T> Id get_id(const T &); +template <typename T> llvm::StringRef get_method(const T &); +template <typename T> llvm::json::Value get_params(const T &); + /// A transport is responsible for maintaining the connection to a client /// application, and reading/writing structured messages to it. /// /// Transports have limited thread safety requirements: /// - Messages will not be sent concurrently. /// - Messages MAY be sent while Run() is reading, or its callback is active. -template <typename Req, typename Resp, typename Evt> class Transport { +template <typename Id, typename Req, typename Resp, typename Evt> +class JSONTransport { public: using Message = std::variant<Req, Resp, Evt>; - virtual ~Transport() = default; + virtual ~JSONTransport() = default; /// Sends an event, a message that does not require a response. virtual llvm::Error Send(const Evt &) = 0; @@ -90,8 +149,6 @@ template <typename Req, typename Resp, typename Evt> class Transport { virtual void OnClosed() = 0; }; - using MessageHandlerSP = std::shared_ptr<MessageHandler>; - /// RegisterMessageHandler registers the Transport with the given MainLoop and /// handles any incoming messages using the given MessageHandler. /// @@ -100,22 +157,302 @@ template <typename Req, typename Resp, typename Evt> class Transport { virtual llvm::Expected<MainLoop::ReadHandleUP> RegisterMessageHandler(MainLoop &loop, MessageHandler &handler) = 0; - // FIXME: Refactor mcp::Server to not directly access log on the transport. - // protected: +protected: template <typename... Ts> inline auto Logv(const char *Fmt, Ts &&...Vals) { Log(llvm::formatv(Fmt, std::forward<Ts>(Vals)...).str()); } virtual void Log(llvm::StringRef message) = 0; + + /// Function object to reply to a call. + /// Each instance must be called exactly once, otherwise: + /// - the bug is logged, and (in debug mode) an assert will fire + /// - if there was no reply, an error reply is sent + /// - if there were multiple replies, only the first is sent + class ReplyOnce { + std::atomic<bool> replied = {false}; + const Req req; + JSONTransport *transport; // Null when moved-from. + JSONTransport::MessageHandler *handler; // Null when moved-from. + + public: + ReplyOnce(const Req req, JSONTransport *transport, + JSONTransport::MessageHandler *handler) + : req(req), transport(transport), handler(handler) { + assert(handler); + } + ReplyOnce(ReplyOnce &&other) + : replied(other.replied.load()), req(other.req), + transport(other.transport), handler(other.handler) { + other.transport = nullptr; + other.handler = nullptr; + } + ReplyOnce &operator=(ReplyOnce &&) = delete; + ReplyOnce(const ReplyOnce &) = delete; + ReplyOnce &operator=(const ReplyOnce &) = delete; + + ~ReplyOnce() { + if (transport && handler && !replied) { + assert(false && "must reply to all calls!"); + (*this)(make_response<Req, Resp>( + req, llvm::createStringError("failed to reply"))); + } + } + + void operator()(const Resp &resp) { + assert(transport && handler && "moved-from!"); + if (replied.exchange(true)) { + assert(false && "must reply to each call only once!"); + return; + } + + if (llvm::Error error = transport->Send(resp)) + handler->OnError(std::move(error)); + } + }; + +public: + class Binder; + using BinderUP = std::unique_ptr<Binder>; + + /// Binder collects a table of functions that handle calls. + /// + /// The wrapper takes care of parsing/serializing responses. + class Binder : public JSONTransport::MessageHandler { + public: + explicit Binder(JSONTransport &transport) + : m_transport(transport), m_seq(0) {} + + Binder(const Binder &) = delete; + Binder &operator=(const Binder &) = delete; + + /// Bind a handler on transport disconnect. + template <typename Fn, typename... Args> + void disconnected(Fn &&fn, Args &&...args) { + m_disconnect_handler = [&, args...]() mutable { + std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...); + }; + } + + /// Bind a handler on error when communicating with the transport. + template <typename Fn, typename... Args> + void error(Fn &&fn, Args &&...args) { + m_error_handler = [&, args...](llvm::Error error) mutable { + std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)..., + std::move(error)); + }; + } + + template <typename T> + static llvm::Expected<T> parse(const llvm::json::Value &raw, + llvm::StringRef method) { + T result; + llvm::json::Path::Root root; + if (!fromJSON(raw, result, root)) { + // Dump the relevant parts of the broken message. + std::string context; + llvm::raw_string_ostream OS(context); + root.printErrorContext(raw, OS); + return llvm::make_error<InvalidParams>(method.str(), context); + } + return std::move(result); + } + + /// Bind a handler for a request. + /// e.g. `bind("peek", &ThisModule::peek, this, std::placeholders::_1);`. + /// Handler should be e.g. `Expected<PeekResult> peek(const PeekParams&);` + /// PeekParams must be JSON parsable and PeekResult must be serializable. + template <typename Result, typename Params, typename Fn, typename... Args> + void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_request_handlers.find(method) == m_request_handlers.end() && + "request already bound"); + if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) { + m_request_handlers[method] = + [fn, + args...](const Req &req, + llvm::unique_function<void(const Resp &)> reply) mutable { + llvm::Expected<Result> result = std::invoke( + std::forward<Fn>(fn), std::forward<Args>(args)...); + if (!result) + return reply(make_response<Req, Resp>(req, result.takeError())); + reply(make_response<Req, Resp>(req, toJSON(*result))); + }; + } else { + m_request_handlers[method] = + [method, fn, + args...](const Req &req, + llvm::unique_function<void(const Resp &)> reply) mutable { + llvm::Expected<Params> params = + parse<Params>(get_params<Req>(req), method); + if (!params) + return reply(make_response<Req, Resp>(req, params.takeError())); + + llvm::Expected<Result> result = std::invoke( + std::forward<Fn>(fn), std::forward<Args>(args)..., *params); + if (!result) + return reply(make_response<Req, Resp>(req, result.takeError())); + + reply(make_response<Req, Resp>(req, toJSON(*result))); + }; + } + } + + /// Bind a handler for a event. + /// e.g. `bind("peek", &ThisModule::peek, this);` + /// Handler should be e.g. `void peek(const PeekParams&);` + /// PeekParams must be JSON parsable. + template <typename Params, typename Fn, typename... Args> + void bind(llvm::StringLiteral method, Fn &&fn, Args &&...args) { + assert(m_event_handlers.find(method) == m_event_handlers.end() && + "event already bound"); + if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) { + m_event_handlers[method] = [fn, args...](const Evt &) mutable { + std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)...); + }; + } else { + m_event_handlers[method] = [this, method, fn, + args...](const Evt &evt) mutable { + llvm::Expected<Params> params = + parse<Params>(get_params<Evt>(evt), method); + if (!params) + return OnError(params.takeError()); + std::invoke(std::forward<Fn>(fn), std::forward<Args>(args)..., + *params); + }; + } + } + + /// Bind a function object to be used for outgoing requests. + /// e.g. `OutgoingRequest<Params, Result> Edit = bind("edit");` + /// Params must be JSON-serializable, Result must be parsable. + template <typename Result, typename Params> + OutgoingRequest<Result, Params> bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) { + return [this, method](Reply<Result> fn) { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + Id id = ++m_seq; + Req req = make_request<Req, Resp>(id, method, std::nullopt); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected<llvm::json::Value> result = get_result<Resp>(resp); + if (!result) + return fn(result.takeError()); + fn(parse<Result>(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms, Reply<Result> fn) { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + Id id = ++m_seq; + Req req = + make_request<Id, Req>(id, method, llvm::json::Value(params)); + m_pending_responses[id] = [fn = std::move(fn), + method](const Resp &resp) mutable { + llvm::Expected<llvm::json::Value> result = get_result<Resp>(resp); + if (llvm::Error err = result.takeError()) + return fn(std::move(err)); + fn(parse<Result>(*result, method)); + }; + if (llvm::Error error = m_transport.Send(req)) + OnError(std::move(error)); + }; + } + } + + /// Bind a function object to be used for outgoing events. + /// e.g. `OutgoingEvent<LogParams> Log = bind("log");` + /// LogParams must be JSON-serializable. + template <typename Params> + OutgoingEvent<Params> bind(llvm::StringLiteral method) { + if constexpr (std::is_void_v<Params> || std::is_same_v<VoidT, Params>) { + return [this, method]() { + if (llvm::Error error = + m_transport.Send(make_event<Evt>(method, std::nullopt))) + OnError(std::move(error)); + }; + } else { + return [this, method](const Params ¶ms) { + if (llvm::Error error = + m_transport.Send(make_event<Evt>(method, toJSON(params)))) + OnError(std::move(error)); + }; + } + } + + void Received(const Evt &evt) override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_event_handlers.find(get_method<Evt>(evt)); + if (it == m_event_handlers.end()) { + OnError(llvm::createStringError( + llvm::formatv("no handled for event {0}", toJSON(evt)))); + return; + } + it->second(evt); + } + + void Received(const Req &req) override { + ReplyOnce reply(req, &m_transport, this); + + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_request_handlers.find(get_method<Req>(req)); + if (it == m_request_handlers.end()) { + reply(make_response<Req, Resp>( + req, llvm::createStringError("method not found"))); + return; + } + + it->second(req, std::move(reply)); + } + + void Received(const Resp &resp) override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + auto it = m_pending_responses.find(get_id<Id, Resp>(resp)); + if (it == m_pending_responses.end()) { + OnError(llvm::createStringError( + llvm::formatv("no pending request for {0}", toJSON(resp)))); + return; + } + + it->second(resp); + m_pending_responses.erase(it); + } + + void OnError(llvm::Error err) override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + if (m_error_handler) + m_error_handler(std::move(err)); + } + + void OnClosed() override { + std::scoped_lock<std::recursive_mutex> guard(m_mutex); + if (m_disconnect_handler) + m_disconnect_handler(); + } + + private: + std::recursive_mutex m_mutex; + JSONTransport &m_transport; + Id m_seq; + std::map<Id, Callback<void(const Resp &)>> m_pending_responses; + llvm::StringMap<Callback<void(const Req &, Callback<void(const Resp &)>)>> + m_request_handlers; + llvm::StringMap<Callback<void(const Evt &)>> m_event_handlers; + Callback<void()> m_disconnect_handler; + Callback<void(llvm::Error)> m_error_handler; + }; }; -/// A JSONTransport will encode and decode messages using JSON. -template <typename Req, typename Resp, typename Evt> -class JSONTransport : public Transport<Req, Resp, Evt> { +/// A IOTransport will encode and decode messages using an IOObject like a +/// file or a socket. +template <typename Id, typename Req, typename Resp, typename Evt> +class IOTransport : public JSONTransport<Id, Req, Resp, Evt> { public: - using Transport<Req, Resp, Evt>::Transport; - using MessageHandler = typename Transport<Req, Resp, Evt>::MessageHandler; + using Message = typename JSONTransport<Id, Req, Resp, Evt>::Message; + using MessageHandler = + typename JSONTransport<Id, Req, Resp, Evt>::MessageHandler; - JSONTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) + IOTransport(lldb::IOObjectSP in, lldb::IOObjectSP out) : m_in(in), m_out(out) {} llvm::Error Send(const Evt &evt) override { return Write(evt); } @@ -127,7 +464,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> { Status status; MainLoop::ReadHandleUP read_handle = loop.RegisterReadObject( m_in, - std::bind(&JSONTransport::OnRead, this, std::placeholders::_1, + std::bind(&IOTransport::OnRead, this, std::placeholders::_1, std::ref(handler)), status); if (status.Fail()) { @@ -140,7 +477,7 @@ class JSONTransport : public Transport<Req, Resp, Evt> { /// detail. static constexpr size_t kReadBufferSize = 1024; - // FIXME: Write should be protected. +protected: llvm::Error Write(const llvm::json::Value &message) { this->Logv("<-- {0}", message); std::string output = Encode(message); @@ -148,7 +485,6 @@ class JSONTransport : public Transport<Req, Resp, Evt> { return m_out->Write(output.data(), bytes_written).takeError(); } -protected: virtual llvm::Expected<std::vector<std::string>> Parse() = 0; virtual std::string Encode(const llvm::json::Value &message) = 0; @@ -175,9 +511,8 @@ class JSONTransport : public Transport<Req, Resp, Evt> { } for (const std::string &raw_message : *raw_messages) { - llvm::Expected<typename Transport<Req, Resp, Evt>::Message> message = - llvm::json::parse<typename Transport<Req, Resp, Evt>::Message>( - raw_message); + llvm::Expected<Message> message = + llvm::json::parse<Message>(raw_message); if (!message) { handler.OnError(message.takeError()); return; @@ -202,10 +537,10 @@ class JSONTransport : public Transport<Req, Resp, Evt> { }; /// A transport class for JSON with a HTTP header. -template <typename Req, typename Resp, typename Evt> -class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { +template <typename Id, typename Req, typename Resp, typename Evt> +class HTTPDelimitedJSONTransport : public IOTransport<Id, Req, Resp, Evt> { public: - using JSONTransport<Req, Resp, Evt>::JSONTransport; + using IOTransport<Id, Req, Resp, Evt>::IOTransport; protected: /// Encodes messages based on @@ -231,8 +566,8 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { for (const llvm::StringRef &header : llvm::split(headers, kHeaderSeparator)) { auto [key, value] = header.split(kHeaderFieldSeparator); - // 'Content-Length' is the only meaningful key at the moment. Others are - // ignored. + // 'Content-Length' is the only meaningful key at the moment. Others + // are ignored. if (!key.equals_insensitive(kHeaderContentLength)) continue; @@ -269,10 +604,10 @@ class HTTPDelimitedJSONTransport : public JSONTransport<Req, Resp, Evt> { }; /// A transport class for JSON RPC. -template <typename Req, typename Resp, typename Evt> -class JSONRPCTransport : public JSONTransport<Req, Resp, Evt> { +template <typename Id, typename Req, typename Resp, typename Evt> +class JSONRPCTransport : public IOTransport<Id, Req, Resp, Evt> { public: - using JSONTransport<Req, Resp, Evt>::JSONTransport; + using IOTransport<Id, Req, Resp, Evt>::IOTransport; protected: std::string Encode(const llvm::json::Value &message) override { diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 6e1ffcbe1f3e3..1e0816110b80a 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -14,6 +14,7 @@ #ifndef LLDB_PROTOCOL_MCP_PROTOCOL_H #define LLDB_PROTOCOL_MCP_PROTOCOL_H +#include "llvm/ADT/StringRef.h" #include "llvm/Support/JSON.h" #include <optional> #include <string> @@ -324,4 +325,11 @@ bool fromJSON(const llvm::json::Value &, CallToolResult &, llvm::json::Path); } // namespace lldb_pr... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/159160 _______________________________________________ lldb-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits
