Allow existing sockets to be passed to connect(). The changes are pretty minimal, and this allows for far greater flexibility in setting up communications with an endpoint.
Signed-off-by: John Snow <[email protected]> Message-id: [email protected] Signed-off-by: John Snow <[email protected]> --- python/qemu/qmp/protocol.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/python/qemu/qmp/protocol.py b/python/qemu/qmp/protocol.py index 22e60298d2..d534db4631 100644 --- a/python/qemu/qmp/protocol.py +++ b/python/qemu/qmp/protocol.py @@ -370,7 +370,7 @@ async def accept(self) -> None: @upper_half @require(Runstate.IDLE) - async def connect(self, address: SocketAddrT, + async def connect(self, address: Union[SocketAddrT, socket.socket], ssl: Optional[SSLContext] = None) -> None: """ Connect to the server and begin processing message queues. @@ -615,7 +615,7 @@ async def _do_accept(self) -> None: self.logger.debug("Connection accepted.") @upper_half - async def _do_connect(self, address: SocketAddrT, + async def _do_connect(self, address: Union[SocketAddrT, socket.socket], ssl: Optional[SSLContext] = None) -> None: """ Acting as the transport client, initiate a connection to a server. @@ -634,9 +634,17 @@ async def _do_connect(self, address: SocketAddrT, # otherwise yield. await asyncio.sleep(0) - self.logger.debug("Connecting to %s ...", address) - - if isinstance(address, tuple): + if isinstance(address, socket.socket): + self.logger.debug("Connecting with existing socket: " + "fd=%d, family=%r, type=%r", + address.fileno(), address.family, address.type) + connect = asyncio.open_connection( + limit=self._limit, + ssl=ssl, + sock=address, + ) + elif isinstance(address, tuple): + self.logger.debug("Connecting to %s ...", address) connect = asyncio.open_connection( address[0], address[1], @@ -644,13 +652,14 @@ async def _do_connect(self, address: SocketAddrT, limit=self._limit, ) else: + self.logger.debug("Connecting to file://%s ...", address) connect = asyncio.open_unix_connection( path=address, ssl=ssl, limit=self._limit, ) + self._reader, self._writer = await connect - self.logger.debug("Connected.") @upper_half -- 2.40.1
