Author: John Harrison Date: 2025-08-12T17:56:52-07:00 New Revision: 350f6abb8304bae39a0ed70a41fde13b7d1df510
URL: https://github.com/llvm/llvm-project/commit/350f6abb8304bae39a0ed70a41fde13b7d1df510 DIFF: https://github.com/llvm/llvm-project/commit/350f6abb8304bae39a0ed70a41fde13b7d1df510.diff LOG: [lldb] Adjusting the base MCP protocol types per the spec. (#153297) * This adjusts the `Request`/`Response` types to have an `id` that is either a string or a number. * Merges 'Error' into 'Response' to have a single response type that represents both errors and results. * Adjusts the `Error.data` field to by any JSON value. * Adds `operator==` support to the base protocol types and simplifies the tests. Added: Modified: lldb/include/lldb/Protocol/MCP/MCPError.h lldb/include/lldb/Protocol/MCP/Protocol.h lldb/source/Protocol/MCP/MCPError.cpp lldb/source/Protocol/MCP/Protocol.cpp lldb/source/Protocol/MCP/Server.cpp lldb/unittests/Protocol/ProtocolMCPTest.cpp lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp lldb/unittests/TestingSupport/TestUtilities.h Removed: ################################################################################ diff --git a/lldb/include/lldb/Protocol/MCP/MCPError.h b/lldb/include/lldb/Protocol/MCP/MCPError.h index 2bdbb9b7a6874..55dd40f124a15 100644 --- a/lldb/include/lldb/Protocol/MCP/MCPError.h +++ b/lldb/include/lldb/Protocol/MCP/MCPError.h @@ -26,7 +26,7 @@ class MCPError : public llvm::ErrorInfo<MCPError> { const std::string &getMessage() const { return m_message; } - lldb_protocol::mcp::Error toProtcolError() const; + lldb_protocol::mcp::Error toProtocolError() const; static constexpr int64_t kResourceNotFound = -32002; static constexpr int64_t kInternalError = -32603; diff --git a/lldb/include/lldb/Protocol/MCP/Protocol.h b/lldb/include/lldb/Protocol/MCP/Protocol.h index 6448416eee08f..141d064804e1e 100644 --- a/lldb/include/lldb/Protocol/MCP/Protocol.h +++ b/lldb/include/lldb/Protocol/MCP/Protocol.h @@ -23,50 +23,72 @@ namespace lldb_protocol::mcp { static llvm::StringLiteral kProtocolVersion = "2024-11-05"; +/// A Request or Response 'id'. +/// +/// NOTE: This diff ers from the JSON-RPC 2.0 spec. The MCP spec says this must +/// be a string or number, excluding a json 'null' as a valid id. +using Id = std::variant<int64_t, std::string>; + /// A request that expects a response. struct Request { - uint64_t id = 0; + /// The request id. + Id id = 0; + /// The method to be invoked. std::string method; + /// The method's params. std::optional<llvm::json::Value> params; }; llvm::json::Value toJSON(const Request &); bool fromJSON(const llvm::json::Value &, Request &, llvm::json::Path); +bool operator==(const Request &, const Request &); -struct ErrorInfo { +struct Error { + /// The error type that occurred. int64_t code = 0; + /// A short description of the error. The message SHOULD be limited to a + /// concise single sentence. std::string message; - std::string data; -}; - -llvm::json::Value toJSON(const ErrorInfo &); -bool fromJSON(const llvm::json::Value &, ErrorInfo &, llvm::json::Path); - -struct Error { - uint64_t id = 0; - ErrorInfo error; + /// Additional information about the error. The value of this member is + /// defined by the sender (e.g. detailed error information, nested errors + /// etc.). + std::optional<llvm::json::Value> data; }; llvm::json::Value toJSON(const Error &); bool fromJSON(const llvm::json::Value &, Error &, llvm::json::Path); +bool operator==(const Error &, const Error &); +/// A response to a request, either an error or a result. struct Response { - uint64_t id = 0; - std::optional<llvm::json::Value> result; - std::optional<ErrorInfo> error; + /// The request id. + Id id = 0; + /// The result of the request, either an Error or the JSON value of the + /// response. + std::variant<Error, llvm::json::Value> result; }; llvm::json::Value toJSON(const Response &); bool fromJSON(const llvm::json::Value &, Response &, llvm::json::Path); +bool operator==(const Response &, const Response &); /// A notification which does not expect a response. struct Notification { + /// The method to be invoked. std::string method; + /// The notification's params. std::optional<llvm::json::Value> params; }; llvm::json::Value toJSON(const Notification &); bool fromJSON(const llvm::json::Value &, Notification &, llvm::json::Path); +bool operator==(const Notification &, const Notification &); + +/// A general message as defined by the JSON-RPC 2.0 spec. +using Message = std::variant<Request, Response, Notification>; + +bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); +llvm::json::Value toJSON(const Message &); struct ToolCapability { /// Whether this server supports notifications for changes to the tool list. @@ -176,11 +198,6 @@ struct ToolDefinition { llvm::json::Value toJSON(const ToolDefinition &); bool fromJSON(const llvm::json::Value &, ToolDefinition &, llvm::json::Path); -using Message = std::variant<Request, Response, Notification, Error>; - -bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path); -llvm::json::Value toJSON(const Message &); - using ToolArguments = std::variant<std::monostate, llvm::json::Value>; } // namespace lldb_protocol::mcp diff --git a/lldb/source/Protocol/MCP/MCPError.cpp b/lldb/source/Protocol/MCP/MCPError.cpp index c610e882abf51..e140d11e12cfe 100644 --- a/lldb/source/Protocol/MCP/MCPError.cpp +++ b/lldb/source/Protocol/MCP/MCPError.cpp @@ -25,10 +25,10 @@ std::error_code MCPError::convertToErrorCode() const { return llvm::inconvertibleErrorCode(); } -lldb_protocol::mcp::Error MCPError::toProtcolError() const { +lldb_protocol::mcp::Error MCPError::toProtocolError() const { lldb_protocol::mcp::Error error; - error.error.code = m_error_code; - error.error.message = m_message; + error.code = m_error_code; + error.message = m_message; return error; } diff --git a/lldb/source/Protocol/MCP/Protocol.cpp b/lldb/source/Protocol/MCP/Protocol.cpp index d579b88037e63..d9b11bd766686 100644 --- a/lldb/source/Protocol/MCP/Protocol.cpp +++ b/lldb/source/Protocol/MCP/Protocol.cpp @@ -26,8 +26,45 @@ static bool mapRaw(const json::Value &Params, StringLiteral Prop, return true; } +static llvm::json::Value toJSON(const Id &Id) { + if (const int64_t *I = std::get_if<int64_t>(&Id)) + return json::Value(*I); + if (const std::string *S = std::get_if<std::string>(&Id)) + return json::Value(*S); + llvm_unreachable("unexpected type in protocol::Id"); +} + +static bool mapId(const llvm::json::Value &V, StringLiteral Prop, Id &Id, + llvm::json::Path P) { + const auto *O = V.getAsObject(); + if (!O) { + P.report("expected object"); + return false; + } + + const auto *E = O->get(Prop); + if (!E) { + P.field(Prop).report("not found"); + return false; + } + + if (auto S = E->getAsString()) { + Id = S->str(); + return true; + } + + if (auto I = E->getAsInteger()) { + Id = *I; + return true; + } + + P.report("expected string or number"); + return false; +} + llvm::json::Value toJSON(const Request &R) { - json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}, {"method", R.method}}; + json::Object Result{ + {"jsonrpc", "2.0"}, {"id", toJSON(R.id)}, {"method", R.method}}; if (R.params) Result.insert({"params", R.params}); return Result; @@ -35,47 +72,75 @@ llvm::json::Value toJSON(const Request &R) { bool fromJSON(const llvm::json::Value &V, Request &R, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - if (!O || !O.map("id", R.id) || !O.map("method", R.method)) - return false; - return mapRaw(V, "params", R.params, P); -} - -llvm::json::Value toJSON(const ErrorInfo &EI) { - llvm::json::Object Result{{"code", EI.code}, {"message", EI.message}}; - if (!EI.data.empty()) - Result.insert({"data", EI.data}); - return Result; + return O && mapId(V, "id", R.id, P) && O.map("method", R.method) && + mapRaw(V, "params", R.params, P); } -bool fromJSON(const llvm::json::Value &V, ErrorInfo &EI, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - return O && O.map("code", EI.code) && O.map("message", EI.message) && - O.mapOptional("data", EI.data); +bool operator==(const Request &a, const Request &b) { + return a.id == b.id && a.method == b.method && a.params == b.params; } llvm::json::Value toJSON(const Error &E) { - return json::Object{{"jsonrpc", "2.0"}, {"id", E.id}, {"error", E.error}}; + llvm::json::Object Result{{"code", E.code}, {"message", E.message}}; + if (E.data) + Result.insert({"data", *E.data}); + return Result; } bool fromJSON(const llvm::json::Value &V, Error &E, llvm::json::Path P) { llvm::json::ObjectMapper O(V, P); - return O && O.map("id", E.id) && O.map("error", E.error); + return O && O.map("code", E.code) && O.map("message", E.message) && + mapRaw(V, "data", E.data, P); +} + +bool operator==(const Error &a, const Error &b) { + return a.code == b.code && a.message == b.message && a.data == b.data; } llvm::json::Value toJSON(const Response &R) { - llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", R.id}}; - if (R.result) - Result.insert({"result", R.result}); - if (R.error) - Result.insert({"error", R.error}); + llvm::json::Object Result{{"jsonrpc", "2.0"}, {"id", toJSON(R.id)}}; + + if (const Error *error = std::get_if<Error>(&R.result)) + Result.insert({"error", *error}); + if (const json::Value *result = std::get_if<json::Value>(&R.result)) + Result.insert({"result", *result}); return Result; } bool fromJSON(const llvm::json::Value &V, Response &R, llvm::json::Path P) { - llvm::json::ObjectMapper O(V, P); - if (!O || !O.map("id", R.id) || !O.map("error", R.error)) + const json::Object *E = V.getAsObject(); + if (!E) { + P.report("expected object"); + return false; + } + + const json::Value *result = E->get("result"); + const json::Value *raw_error = E->get("error"); + + if (result && raw_error) { + P.report("'result' and 'error' fields are mutually exclusive"); return false; - return mapRaw(V, "result", R.result, P); + } + + if (!result && !raw_error) { + P.report("'result' or 'error' fields are required'"); + return false; + } + + if (result) { + R.result = std::move(*result); + } else { + Error error; + if (!fromJSON(*raw_error, error, P)) + return false; + R.result = std::move(error); + } + + return mapId(V, "id", R.id, P); +} + +bool operator==(const Response &a, const Response &b) { + return a.id == b.id && a.result == b.result; } llvm::json::Value toJSON(const Notification &N) { @@ -97,6 +162,10 @@ bool fromJSON(const llvm::json::Value &V, Notification &N, llvm::json::Path P) { return true; } +bool operator==(const Notification &a, const Notification &b) { + return a.method == b.method && a.params == b.params; +} + llvm::json::Value toJSON(const ToolCapability &TC) { return llvm::json::Object{{"listChanged", TC.listChanged}}; } @@ -235,24 +304,16 @@ bool fromJSON(const llvm::json::Value &V, Message &M, llvm::json::Path P) { return true; } - if (O->get("error")) { - Error E; - if (!fromJSON(V, E, P)) - return false; - M = std::move(E); - return true; - } - - if (O->get("result")) { - Response R; + if (O->get("method")) { + Request R; if (!fromJSON(V, R, P)) return false; M = std::move(R); return true; } - if (O->get("method")) { - Request R; + if (O->get("result") || O->get("error")) { + Response R; if (!fromJSON(V, R, P)) return false; M = std::move(R); diff --git a/lldb/source/Protocol/MCP/Server.cpp b/lldb/source/Protocol/MCP/Server.cpp index 4ec127fe75bdd..a9c1482e3e378 100644 --- a/lldb/source/Protocol/MCP/Server.cpp +++ b/lldb/source/Protocol/MCP/Server.cpp @@ -66,13 +66,15 @@ Server::HandleData(llvm::StringRef data) { Error protocol_error; llvm::handleAllErrors( response.takeError(), - [&](const MCPError &err) { protocol_error = err.toProtcolError(); }, + [&](const MCPError &err) { protocol_error = err.toProtocolError(); }, [&](const llvm::ErrorInfoBase &err) { - protocol_error.error.code = MCPError::kInternalError; - protocol_error.error.message = err.message(); + protocol_error.code = MCPError::kInternalError; + protocol_error.message = err.message(); }); - protocol_error.id = request->id; - return protocol_error; + Response error_response; + error_response.id = request->id; + error_response.result = std::move(protocol_error); + return error_response; } return *response; @@ -84,9 +86,6 @@ Server::HandleData(llvm::StringRef data) { return std::nullopt; } - if (std::get_if<Error>(&(*message))) - return llvm::createStringError("unexpected MCP message: error"); - if (std::get_if<Response>(&(*message))) return llvm::createStringError("unexpected MCP message: response"); @@ -123,11 +122,11 @@ void Server::AddNotificationHandler(llvm::StringRef method, llvm::Expected<Response> Server::InitializeHandler(const Request &request) { Response response; - response.result.emplace(llvm::json::Object{ + response.result = llvm::json::Object{ {"protocolVersion", mcp::kProtocolVersion}, {"capabilities", GetCapabilities()}, {"serverInfo", - llvm::json::Object{{"name", m_name}, {"version", m_version}}}}); + llvm::json::Object{{"name", m_name}, {"version", m_version}}}}; return response; } @@ -138,7 +137,7 @@ llvm::Expected<Response> Server::ToolsListHandler(const Request &request) { for (const auto &tool : m_tools) tools.emplace_back(toJSON(tool.second->GetDefinition())); - response.result.emplace(llvm::json::Object{{"tools", std::move(tools)}}); + response.result = llvm::json::Object{{"tools", std::move(tools)}}; return response; } @@ -173,7 +172,7 @@ llvm::Expected<Response> Server::ToolsCallHandler(const Request &request) { if (!text_result) return text_result.takeError(); - response.result.emplace(toJSON(*text_result)); + response.result = toJSON(*text_result); return response; } @@ -189,8 +188,7 @@ llvm::Expected<Response> Server::ResourcesListHandler(const Request &request) { for (const Resource &resource : resource_provider_up->GetResources()) resources.push_back(resource); } - response.result.emplace( - llvm::json::Object{{"resources", std::move(resources)}}); + response.result = llvm::json::Object{{"resources", std::move(resources)}}; return response; } @@ -226,7 +224,7 @@ llvm::Expected<Response> Server::ResourcesReadHandler(const Request &request) { return result.takeError(); Response response; - response.result.emplace(std::move(*result)); + response.result = std::move(*result); return response; } diff --git a/lldb/unittests/Protocol/ProtocolMCPTest.cpp b/lldb/unittests/Protocol/ProtocolMCPTest.cpp index 1dca0e5fc5bb8..ea19922522ffe 100644 --- a/lldb/unittests/Protocol/ProtocolMCPTest.cpp +++ b/lldb/unittests/Protocol/ProtocolMCPTest.cpp @@ -149,9 +149,7 @@ TEST(ProtocolMCPTest, MessageWithRequest) { const Request &deserialized_request = std::get<Request>(*deserialized_message); - EXPECT_EQ(request.id, deserialized_request.id); - EXPECT_EQ(request.method, deserialized_request.method); - EXPECT_EQ(request.params, deserialized_request.params); + EXPECT_EQ(request, deserialized_request); } TEST(ProtocolMCPTest, MessageWithResponse) { @@ -168,8 +166,7 @@ TEST(ProtocolMCPTest, MessageWithResponse) { const Response &deserialized_response = std::get<Response>(*deserialized_message); - EXPECT_EQ(response.id, deserialized_response.id); - EXPECT_EQ(response.result, deserialized_response.result); + EXPECT_EQ(response, deserialized_response); } TEST(ProtocolMCPTest, MessageWithNotification) { @@ -186,49 +183,28 @@ TEST(ProtocolMCPTest, MessageWithNotification) { const Notification &deserialized_notification = std::get<Notification>(*deserialized_message); - EXPECT_EQ(notification.method, deserialized_notification.method); - EXPECT_EQ(notification.params, deserialized_notification.params); + EXPECT_EQ(notification, deserialized_notification); } -TEST(ProtocolMCPTest, MessageWithError) { - ErrorInfo error_info; - error_info.code = -32603; - error_info.message = "Internal error"; - +TEST(ProtocolMCPTest, MessageWithErrorResponse) { Error error; - error.id = 3; - error.error = error_info; + error.code = -32603; + error.message = "Internal error"; + + Response error_response; + error_response.id = 3; + error_response.result = error; - Message message = error; + Message message = error_response; llvm::Expected<Message> deserialized_message = roundtripJSON(message); ASSERT_THAT_EXPECTED(deserialized_message, llvm::Succeeded()); - ASSERT_TRUE(std::holds_alternative<Error>(*deserialized_message)); - const Error &deserialized_error = std::get<Error>(*deserialized_message); - - EXPECT_EQ(error.id, deserialized_error.id); - EXPECT_EQ(error.error.code, deserialized_error.error.code); - EXPECT_EQ(error.error.message, deserialized_error.error.message); -} - -TEST(ProtocolMCPTest, ResponseWithError) { - ErrorInfo error_info; - error_info.code = -32700; - error_info.message = "Parse error"; - - Response response; - response.id = 4; - response.error = error_info; - - llvm::Expected<Response> deserialized_response = roundtripJSON(response); - ASSERT_THAT_EXPECTED(deserialized_response, llvm::Succeeded()); + ASSERT_TRUE(std::holds_alternative<Response>(*deserialized_message)); + const Response &deserialized_error = + std::get<Response>(*deserialized_message); - EXPECT_EQ(response.id, deserialized_response->id); - EXPECT_FALSE(deserialized_response->result.has_value()); - ASSERT_TRUE(deserialized_response->error.has_value()); - EXPECT_EQ(response.error->code, deserialized_response->error->code); - EXPECT_EQ(response.error->message, deserialized_response->error->message); + EXPECT_EQ(error_response, deserialized_error); } TEST(ProtocolMCPTest, Resource) { diff --git a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp index ade299e08698d..2ac40c41dd28e 100644 --- a/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp +++ b/lldb/unittests/ProtocolServer/ProtocolMCPServerTest.cpp @@ -200,7 +200,7 @@ class ProtocolServerMCPTest : public ::testing::Test { } // namespace -TEST_F(ProtocolServerMCPTest, Intialization) { +TEST_F(ProtocolServerMCPTest, Initialization) { llvm::StringLiteral request = R"json({"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"lldb-unit","version":"0.1.0"}},"jsonrpc":"2.0","id":0})json"; llvm::StringLiteral response = diff --git a/lldb/unittests/TestingSupport/TestUtilities.h b/lldb/unittests/TestingSupport/TestUtilities.h index db62881872fef..cc93a68a6a431 100644 --- a/lldb/unittests/TestingSupport/TestUtilities.h +++ b/lldb/unittests/TestingSupport/TestUtilities.h @@ -11,11 +11,11 @@ #include "lldb/Core/ModuleSpec.h" #include "lldb/Utility/DataBuffer.h" -#include "llvm/ADT/SmallString.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Error.h" #include "llvm/Support/FileSystem.h" -#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/JSON.h" +#include "llvm/Support/raw_ostream.h" #include <string> #define ASSERT_NO_ERROR(x) \ @@ -61,12 +61,10 @@ class TestFile { }; template <typename T> static llvm::Expected<T> roundtripJSON(const T &input) { - llvm::json::Value value = toJSON(input); - llvm::json::Path::Root root; - T output; - if (!fromJSON(value, output, root)) - return root.getError(); - return output; + std::string encoded; + llvm::raw_string_ostream OS(encoded); + OS << toJSON(input); + return llvm::json::parse<T>(encoded); } } // namespace lldb_private _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits