llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-lldb Author: Pavel Labath (labath) <details> <summary>Changes</summary> To go along with the existing TCPSocket implementation. --- Full diff: https://github.com/llvm/llvm-project/pull/108188.diff 9 Files Affected: - (modified) lldb/include/lldb/Host/Socket.h (+12-1) - (modified) lldb/include/lldb/Host/common/TCPSocket.h (+2-8) - (modified) lldb/include/lldb/Host/common/UDPSocket.h (+7-1) - (modified) lldb/include/lldb/Host/posix/DomainSocket.h (+7-1) - (modified) lldb/source/Host/common/Socket.cpp (+14) - (modified) lldb/source/Host/common/TCPSocket.cpp (+3-16) - (modified) lldb/source/Host/common/UDPSocket.cpp (-4) - (modified) lldb/source/Host/posix/DomainSocket.cpp (+34-8) - (modified) lldb/unittests/Host/SocketTest.cpp (+38-2) ``````````diff diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h index 764a048976eb41..14468c98ac5a3a 100644 --- a/lldb/include/lldb/Host/Socket.h +++ b/lldb/include/lldb/Host/Socket.h @@ -12,6 +12,7 @@ #include <memory> #include <string> +#include "lldb/Host/MainLoopBase.h" #include "lldb/lldb-private.h" #include "lldb/Host/SocketAddress.h" @@ -97,7 +98,17 @@ class Socket : public IOObject { virtual Status Connect(llvm::StringRef name) = 0; virtual Status Listen(llvm::StringRef name, int backlog) = 0; - virtual Status Accept(Socket *&socket) = 0; + + // Use the provided main loop instance to accept new connections. The callback + // will be called (from MainLoop::Run) for each new connection. This function + // does not block. + virtual llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> + Accept(MainLoopBase &loop, + std::function<void(std::unique_ptr<Socket> socket)> sock_cb) = 0; + + // Accept a single connection and "return" it in the pointer argument. This + // function blocks until the connection arrives. + virtual Status Accept(Socket *&socket); // Initialize a Tcp Socket object in listening mode. listen and accept are // implemented separately because the caller may wish to manipulate or query diff --git a/lldb/include/lldb/Host/common/TCPSocket.h b/lldb/include/lldb/Host/common/TCPSocket.h index 78e80568e39967..eefe0240fe4a95 100644 --- a/lldb/include/lldb/Host/common/TCPSocket.h +++ b/lldb/include/lldb/Host/common/TCPSocket.h @@ -42,16 +42,10 @@ class TCPSocket : public Socket { Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - // Use the provided main loop instance to accept new connections. The callback - // will be called (from MainLoop::Run) for each new connection. This function - // does not block. + using Socket::Accept; llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> Accept(MainLoopBase &loop, - std::function<void(std::unique_ptr<TCPSocket> socket)> sock_cb); - - // Accept a single connection and "return" it in the pointer argument. This - // function blocks until the connection arrives. - Status Accept(Socket *&conn_socket) override; + std::function<void(std::unique_ptr<Socket> socket)> sock_cb) override; Status CreateSocket(int domain); diff --git a/lldb/include/lldb/Host/common/UDPSocket.h b/lldb/include/lldb/Host/common/UDPSocket.h index bae707e345d87c..7348010d02ada2 100644 --- a/lldb/include/lldb/Host/common/UDPSocket.h +++ b/lldb/include/lldb/Host/common/UDPSocket.h @@ -27,7 +27,13 @@ class UDPSocket : public Socket { size_t Send(const void *buf, const size_t num_bytes) override; Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - Status Accept(Socket *&socket) override; + + llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> + Accept(MainLoopBase &loop, + std::function<void(std::unique_ptr<Socket> socket)> sock_cb) override { + return llvm::errorCodeToError( + std::make_error_code(std::errc::operation_not_supported)); + } SocketAddress m_sockaddr; }; diff --git a/lldb/include/lldb/Host/posix/DomainSocket.h b/lldb/include/lldb/Host/posix/DomainSocket.h index 35c33811f60de6..983f43bd633719 100644 --- a/lldb/include/lldb/Host/posix/DomainSocket.h +++ b/lldb/include/lldb/Host/posix/DomainSocket.h @@ -14,11 +14,17 @@ namespace lldb_private { class DomainSocket : public Socket { public: + DomainSocket(NativeSocket socket, bool should_close, + bool child_processes_inherit); DomainSocket(bool should_close, bool child_processes_inherit); Status Connect(llvm::StringRef name) override; Status Listen(llvm::StringRef name, int backlog) override; - Status Accept(Socket *&socket) override; + + using Socket::Accept; + llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> + Accept(MainLoopBase &loop, + std::function<void(std::unique_ptr<Socket> socket)> sock_cb) override; std::string GetRemoteConnectionURI() const override; diff --git a/lldb/source/Host/common/Socket.cpp b/lldb/source/Host/common/Socket.cpp index 1a63571b94c6f1..d69eb608204033 100644 --- a/lldb/source/Host/common/Socket.cpp +++ b/lldb/source/Host/common/Socket.cpp @@ -10,6 +10,7 @@ #include "lldb/Host/Config.h" #include "lldb/Host/Host.h" +#include "lldb/Host/MainLoop.h" #include "lldb/Host/SocketAddress.h" #include "lldb/Host/common/TCPSocket.h" #include "lldb/Host/common/UDPSocket.h" @@ -459,6 +460,19 @@ NativeSocket Socket::CreateSocket(const int domain, const int type, return sock; } +Status Socket::Accept(Socket *&socket) { + MainLoop accept_loop; + llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> expected_handles = + Accept(accept_loop, + [&accept_loop, &socket](std::unique_ptr<Socket> sock) { + socket = sock.release(); + accept_loop.RequestTermination(); + }); + if (!expected_handles) + return Status::FromError(expected_handles.takeError()); + return accept_loop.Run(); +} + NativeSocket Socket::AcceptSocket(NativeSocket sockfd, struct sockaddr *addr, socklen_t *addrlen, bool child_processes_inherit, Status &error) { diff --git a/lldb/source/Host/common/TCPSocket.cpp b/lldb/source/Host/common/TCPSocket.cpp index b28ba148ee1afa..2d16b605af9497 100644 --- a/lldb/source/Host/common/TCPSocket.cpp +++ b/lldb/source/Host/common/TCPSocket.cpp @@ -232,9 +232,9 @@ void TCPSocket::CloseListenSockets() { m_listen_sockets.clear(); } -llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> TCPSocket::Accept( - MainLoopBase &loop, - std::function<void(std::unique_ptr<TCPSocket> socket)> sock_cb) { +llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> +TCPSocket::Accept(MainLoopBase &loop, + std::function<void(std::unique_ptr<Socket> socket)> sock_cb) { if (m_listen_sockets.size() == 0) return llvm::createStringError("No open listening sockets!"); @@ -278,19 +278,6 @@ llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> TCPSocket::Accept( return handles; } -Status TCPSocket::Accept(Socket *&conn_socket) { - MainLoop accept_loop; - llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> expected_handles = - Accept(accept_loop, - [&accept_loop, &conn_socket](std::unique_ptr<TCPSocket> sock) { - conn_socket = sock.release(); - accept_loop.RequestTermination(); - }); - if (!expected_handles) - return Status::FromError(expected_handles.takeError()); - return accept_loop.Run(); -} - int TCPSocket::SetOptionNoDelay() { return SetOption(IPPROTO_TCP, TCP_NODELAY, 1); } diff --git a/lldb/source/Host/common/UDPSocket.cpp b/lldb/source/Host/common/UDPSocket.cpp index 2a7a6cff414b14..05d7b2e6506027 100644 --- a/lldb/source/Host/common/UDPSocket.cpp +++ b/lldb/source/Host/common/UDPSocket.cpp @@ -47,10 +47,6 @@ Status UDPSocket::Listen(llvm::StringRef name, int backlog) { return Status::FromErrorStringWithFormat("%s", g_not_supported_error); } -Status UDPSocket::Accept(Socket *&socket) { - return Status::FromErrorStringWithFormat("%s", g_not_supported_error); -} - llvm::Expected<std::unique_ptr<UDPSocket>> UDPSocket::Connect(llvm::StringRef name, bool child_processes_inherit) { std::unique_ptr<UDPSocket> socket; diff --git a/lldb/source/Host/posix/DomainSocket.cpp b/lldb/source/Host/posix/DomainSocket.cpp index 2d18995c3bb469..369123f2239302 100644 --- a/lldb/source/Host/posix/DomainSocket.cpp +++ b/lldb/source/Host/posix/DomainSocket.cpp @@ -7,11 +7,13 @@ //===----------------------------------------------------------------------===// #include "lldb/Host/posix/DomainSocket.h" +#include "lldb/Utility/LLDBLog.h" #include "llvm/Support/Errno.h" #include "llvm/Support/FileSystem.h" #include <cstddef> +#include <memory> #include <sys/socket.h> #include <sys/un.h> @@ -57,7 +59,14 @@ static bool SetSockAddr(llvm::StringRef name, const size_t name_offset, } DomainSocket::DomainSocket(bool should_close, bool child_processes_inherit) - : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) {} + : DomainSocket(kInvalidSocketValue, should_close, child_processes_inherit) { +} + +DomainSocket::DomainSocket(NativeSocket socket, bool should_close, + bool child_processes_inherit) + : Socket(ProtocolUnixDomain, should_close, child_processes_inherit) { + m_socket = socket; +} DomainSocket::DomainSocket(SocketProtocol protocol, bool child_processes_inherit) @@ -108,14 +117,31 @@ Status DomainSocket::Listen(llvm::StringRef name, int backlog) { return error; } -Status DomainSocket::Accept(Socket *&socket) { - Status error; - auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr, - m_child_processes_inherit, error); - if (error.Success()) - socket = new DomainSocket(conn_fd, *this); +llvm::Expected<std::vector<MainLoopBase::ReadHandleUP>> DomainSocket::Accept( + MainLoopBase &loop, + std::function<void(std::unique_ptr<Socket> socket)> sock_cb) { + // TODO: Refactor MainLoop to avoid the shared_ptr requirement. + auto io_sp = std::make_shared<DomainSocket>(GetNativeSocket(), false, + m_child_processes_inherit); + auto cb = [this, sock_cb](MainLoopBase &loop) { + Log *log = GetLog(LLDBLog::Host); + Status error; + auto conn_fd = AcceptSocket(GetNativeSocket(), nullptr, nullptr, + m_child_processes_inherit, error); + if (error.Fail()) { + LLDB_LOG(log, "AcceptSocket({0}): {1}", GetNativeSocket(), error); + return; + } + std::unique_ptr<DomainSocket> sock_up(new DomainSocket(conn_fd, *this)); + sock_cb(std::move(sock_up)); + }; - return error; + Status error; + std::vector<MainLoopBase::ReadHandleUP> handles; + handles.emplace_back(loop.RegisterReadObject(io_sp, cb, error)); + if (error.Fail()) + return error.ToError(); + return handles; } size_t DomainSocket::GetNameOffset() const { return 0; } diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp index 3a356d11ba1a51..a93b928e274d03 100644 --- a/lldb/unittests/Host/SocketTest.cpp +++ b/lldb/unittests/Host/SocketTest.cpp @@ -85,6 +85,42 @@ TEST_P(SocketTest, DomainListenConnectAccept) { std::unique_ptr<DomainSocket> socket_b_up; CreateDomainConnectedSockets(Path, &socket_a_up, &socket_b_up); } + +TEST_P(SocketTest, DomainMainLoopAccept) { + llvm::SmallString<64> Path; + std::error_code EC = llvm::sys::fs::createUniqueDirectory("DomainListenConnectAccept", Path); + ASSERT_FALSE(EC); + llvm::sys::path::append(Path, "test"); + + // Skip the test if the $TMPDIR is too long to hold a domain socket. + if (Path.size() > 107u) + return; + + auto listen_socket_up = std::make_unique<DomainSocket>( + /*should_close=*/true, /*child_process_inherit=*/false); + Status error = listen_socket_up->Listen(Path, 5); + ASSERT_THAT_ERROR(error.ToError(), llvm::Succeeded()); + ASSERT_TRUE(listen_socket_up->IsValid()); + + MainLoop loop; + std::unique_ptr<Socket> accepted_socket_up; + auto expected_handles = listen_socket_up->Accept( + loop, [&accepted_socket_up, &loop](std::unique_ptr<Socket> sock_up) { + accepted_socket_up = std::move(sock_up); + loop.RequestTermination(); + }); + ASSERT_THAT_EXPECTED(expected_handles, llvm::Succeeded()); + + auto connect_socket_up = std::make_unique<DomainSocket>( + /*should_close=*/true, /*child_process_inherit=*/false); + ASSERT_THAT_ERROR(connect_socket_up->Connect(Path).ToError(), + llvm::Succeeded()); + ASSERT_TRUE(connect_socket_up->IsValid()); + + loop.Run(); + ASSERT_TRUE(accepted_socket_up); + ASSERT_TRUE(accepted_socket_up->IsValid()); +} #endif TEST_P(SocketTest, TCPListen0ConnectAccept) { @@ -109,9 +145,9 @@ TEST_P(SocketTest, TCPMainLoopAccept) { ASSERT_TRUE(listen_socket_up->IsValid()); MainLoop loop; - std::unique_ptr<TCPSocket> accepted_socket_up; + std::unique_ptr<Socket> accepted_socket_up; auto expected_handles = listen_socket_up->Accept( - loop, [&accepted_socket_up, &loop](std::unique_ptr<TCPSocket> sock_up) { + loop, [&accepted_socket_up, &loop](std::unique_ptr<Socket> sock_up) { accepted_socket_up = std::move(sock_up); loop.RequestTermination(); }); `````````` </details> https://github.com/llvm/llvm-project/pull/108188 _______________________________________________ lldb-commits mailing list lldb-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/lldb-commits