commit:     1c0b5ff6aabe99224b07f7fc2e09e4cb6c9c70c4
Author:     Zac Medico <zmedico <AT> gentoo <DOT> org>
AuthorDate: Tue Oct  3 20:33:18 2023 +0000
Commit:     Zac Medico <zmedico <AT> gentoo <DOT> org>
CommitDate: Wed Oct  4 00:47:01 2023 +0000
URL:        https://gitweb.gentoo.org/proj/portage.git/commit/?id=1c0b5ff6

AsyncFunction: Use multiprocessing.Pipe for compat with multiprocessing spawn

Since fd_pipes does not work with the multiprocessing spawn start
method, AsyncFunction should use multiprocessing.Pipe instead.

Bug: https://bugs.gentoo.org/915136
Signed-off-by: Zac Medico <zmedico <AT> gentoo.org>

 lib/portage/tests/process/test_AsyncFunction.py | 34 ++++++++++++++++---------
 lib/portage/util/_async/AsyncFunction.py        | 15 +++++++----
 2 files changed, 32 insertions(+), 17 deletions(-)

diff --git a/lib/portage/tests/process/test_AsyncFunction.py 
b/lib/portage/tests/process/test_AsyncFunction.py
index 81b3f41fbf..975b590e53 100644
--- a/lib/portage/tests/process/test_AsyncFunction.py
+++ b/lib/portage/tests/process/test_AsyncFunction.py
@@ -1,6 +1,7 @@
-# Copyright 2020-2021 Gentoo Authors
+# Copyright 2020-2023 Gentoo Authors
 # Distributed under the terms of the GNU General Public License v2
 
+import multiprocessing
 import sys
 
 import portage
@@ -14,22 +15,29 @@ from portage.util.futures.unix_events import 
_set_nonblocking
 
 class AsyncFunctionTestCase(TestCase):
     @staticmethod
-    def _read_from_stdin(pw):
-        os.close(pw)
+    def _read_from_stdin(pr, pw):
+        if pw is not None:
+            os.close(pw)
+        os.dup2(pr.fileno(), sys.stdin.fileno())
         return "".join(sys.stdin)
 
     async def _testAsyncFunctionStdin(self, loop):
         test_string = "1\n2\n3\n"
-        pr, pw = os.pipe()
-        fd_pipes = {0: pr}
+        pr, pw = multiprocessing.Pipe(duplex=False)
         reader = AsyncFunction(
-            scheduler=loop, fd_pipes=fd_pipes, target=self._read_from_stdin, 
args=(pw,)
+            scheduler=loop,
+            target=self._read_from_stdin,
+            args=(
+                pr,
+                pw.fileno() if multiprocessing.get_start_method() == "fork" 
else None,
+            ),
         )
         reader.start()
-        os.close(pr)
-        _set_nonblocking(pw)
-        with open(pw, mode="wb", buffering=0) as pipe_write:
+        pr.close()
+        _set_nonblocking(pw.fileno())
+        with open(pw.fileno(), mode="wb", buffering=0, closefd=False) as 
pipe_write:
             await _writer(pipe_write, test_string.encode("utf_8"))
+        pw.close()
         self.assertEqual((await reader.async_wait()), os.EX_OK)
         self.assertEqual(reader.result, test_string)
 
@@ -37,7 +45,8 @@ class AsyncFunctionTestCase(TestCase):
         loop = asyncio._wrap_loop()
         loop.run_until_complete(self._testAsyncFunctionStdin(loop=loop))
 
-    def _test_getpid_fork(self):
+    @staticmethod
+    def _test_getpid_fork():
         """
         Verify that portage.getpid() cache is updated in a forked child 
process.
         """
@@ -45,10 +54,10 @@ class AsyncFunctionTestCase(TestCase):
         proc = AsyncFunction(scheduler=loop, target=portage.getpid)
         proc.start()
         proc.wait()
-        self.assertEqual(proc.pid, proc.result)
+        return proc.pid == proc.result
 
     def test_getpid_fork(self):
-        self._test_getpid_fork()
+        self.assertTrue(self._test_getpid_fork())
 
     def test_getpid_double_fork(self):
         """
@@ -59,3 +68,4 @@ class AsyncFunctionTestCase(TestCase):
         proc = AsyncFunction(scheduler=loop, target=self._test_getpid_fork)
         proc.start()
         self.assertEqual(proc.wait(), 0)
+        self.assertTrue(proc.result)

diff --git a/lib/portage/util/_async/AsyncFunction.py 
b/lib/portage/util/_async/AsyncFunction.py
index e13daaebb0..6f55aba565 100644
--- a/lib/portage/util/_async/AsyncFunction.py
+++ b/lib/portage/util/_async/AsyncFunction.py
@@ -2,6 +2,7 @@
 # Distributed under the terms of the GNU General Public License v2
 
 import functools
+import multiprocessing
 import pickle
 import traceback
 
@@ -23,9 +24,7 @@ class AsyncFunction(ForkProcess):
     )
 
     def _start(self):
-        pr, pw = os.pipe()
-        self.fd_pipes = {} if self.fd_pipes is None else self.fd_pipes
-        self.fd_pipes[pw] = pw
+        pr, pw = multiprocessing.Pipe(duplex=False)
         self._async_func_reader = PipeReader(
             input_files={"input": pr}, scheduler=self.scheduler
         )
@@ -34,13 +33,15 @@ class AsyncFunction(ForkProcess):
         # args and kwargs are passed as additional args by 
ForkProcess._bootstrap.
         self.target = functools.partial(self._target_wrapper, pw, self.target)
         ForkProcess._start(self)
-        os.close(pw)
+        pw.close()
 
     @staticmethod
     def _target_wrapper(pw, target, *args, **kwargs):
         try:
             result = target(*args, **kwargs)
-            os.write(pw, pickle.dumps(result))
+            result_bytes = pickle.dumps(result)
+            while result_bytes:
+                result_bytes = result_bytes[os.write(pw.fileno(), 
result_bytes) :]
         except Exception:
             traceback.print_exc()
             return 1
@@ -53,6 +54,10 @@ class AsyncFunction(ForkProcess):
         if self._async_func_reader is None:
             ForkProcess._async_waitpid(self)
 
+    def _async_wait(self):
+        if self._async_func_reader is None:
+            ForkProcess._async_wait(self)
+
     def _async_func_reader_exit(self, pipe_reader):
         try:
             self.result = pickle.loads(pipe_reader.getvalue())

Reply via email to