commit:     d7e64317e3a0d789e17dfea7b451c6b3b6cdb3c0
Author:     Zac Medico <zmedico <AT> gentoo <DOT> org>
AuthorDate: Tue Oct 24 23:52:46 2023 +0000
Commit:     Zac Medico <zmedico <AT> gentoo <DOT> org>
CommitDate: Tue Oct 24 23:56:15 2023 +0000
URL:        https://gitweb.gentoo.org/proj/portage.git/commit/?id=d7e64317

test_auxdb: multiprocessing spawn compat

Use staticmethod and functools.partial to avoid unpicklable
local functions. Also, don't try to pickle anydbm or sqlite
modules since they currently are not picklable. Ultimately,
it might be a good idea to implement pickling for the sqlite
module.

Bug: https://bugs.gentoo.org/916245
Signed-off-by: Zac Medico <zmedico <AT> gentoo.org>

 lib/portage/tests/dbapi/test_auxdb.py | 103 ++++++++++++++++++++++------------
 1 file changed, 66 insertions(+), 37 deletions(-)

diff --git a/lib/portage/tests/dbapi/test_auxdb.py 
b/lib/portage/tests/dbapi/test_auxdb.py
index f022e02adc..1bbf1bde35 100644
--- a/lib/portage/tests/dbapi/test_auxdb.py
+++ b/lib/portage/tests/dbapi/test_auxdb.py
@@ -1,6 +1,9 @@
-# Copyright 2020-2021 Gentoo Authors
+# Copyright 2020-2023 Gentoo Authors
 # Distributed under the terms of the GNU General Public License v2
 
+import functools
+import multiprocessing
+
 from portage.tests import TestCase
 from portage.tests.resolver.ResolverPlayground import ResolverPlayground
 from portage.util.futures import asyncio
@@ -13,7 +16,9 @@ class AuxdbTestCase(TestCase):
             from portage.cache.anydbm import database
         except ImportError:
             self.skipTest("dbm import failed")
-        self._test_mod("portage.cache.anydbm.database", multiproc=False)
+        self._test_mod(
+            "portage.cache.anydbm.database", multiproc=False, picklable=False
+        )
 
     def test_flat_hash_md5(self):
         self._test_mod("portage.cache.flat_hash.md5_database")
@@ -26,9 +31,9 @@ class AuxdbTestCase(TestCase):
             import sqlite3
         except ImportError:
             self.skipTest("sqlite3 import failed")
-        self._test_mod("portage.cache.sqlite.database")
+        self._test_mod("portage.cache.sqlite.database", picklable=False)
 
-    def _test_mod(self, auxdbmodule, multiproc=True):
+    def _test_mod(self, auxdbmodule, multiproc=True, picklable=True):
         ebuilds = {
             "cat/A-1": {
                 "EAPI": "7",
@@ -60,55 +65,79 @@ class AuxdbTestCase(TestCase):
         )
 
         portdb = playground.trees[playground.eroot]["porttree"].dbapi
+        metadata_keys = ["DEFINED_PHASES", "DEPEND", "EAPI", "INHERITED"]
 
-        def test_func():
-            loop = asyncio._wrap_loop()
-            return loop.run_until_complete(
-                self._test_mod_async(
-                    ebuilds,
-                    ebuild_inherited,
-                    eclass_defined_phases,
-                    eclass_depend,
-                    portdb,
-                )
-            )
+        test_func = functools.partial(
+            self._run_test_mod_async, ebuilds, metadata_keys, portdb
+        )
 
-        self.assertTrue(test_func())
+        results = test_func()
 
-        loop = asyncio._wrap_loop()
-        self.assertTrue(
-            loop.run_until_complete(loop.run_in_executor(ForkExecutor(), 
test_func))
+        self._compare_results(
+            ebuilds, eclass_defined_phases, eclass_depend, ebuild_inherited, 
results
         )
 
+        loop = asyncio._wrap_loop()
+        picklable_or_fork = picklable or multiprocessing.get_start_method == 
"fork"
+        if picklable_or_fork:
+            results = loop.run_until_complete(
+                loop.run_in_executor(ForkExecutor(), test_func)
+            )
+
+            self._compare_results(
+                ebuilds, eclass_defined_phases, eclass_depend, 
ebuild_inherited, results
+            )
+
         auxdb = portdb.auxdb[portdb.getRepositoryPath("test_repo")]
         cpv = next(iter(ebuilds))
 
-        def modify_auxdb():
-            metadata = auxdb[cpv]
-            metadata["RESTRICT"] = "test"
-            try:
-                del metadata["_eclasses_"]
-            except KeyError:
-                pass
-            auxdb[cpv] = metadata
+        modify_auxdb = functools.partial(self._modify_auxdb, auxdb, cpv)
 
-        if multiproc:
+        if multiproc and picklable_or_fork:
             loop.run_until_complete(loop.run_in_executor(ForkExecutor(), 
modify_auxdb))
         else:
             modify_auxdb()
 
         self.assertEqual(auxdb[cpv]["RESTRICT"], "test")
 
-    async def _test_mod_async(
-        self, ebuilds, ebuild_inherited, eclass_defined_phases, eclass_depend, 
portdb
+    def _compare_results(
+        self, ebuilds, eclass_defined_phases, eclass_depend, ebuild_inherited, 
results
     ):
         for cpv, metadata in ebuilds.items():
-            defined_phases, depend, eapi, inherited = await 
portdb.async_aux_get(
-                cpv, ["DEFINED_PHASES", "DEPEND", "EAPI", "INHERITED"]
+            self.assertEqual(results[cpv]["DEFINED_PHASES"], 
eclass_defined_phases)
+            self.assertEqual(results[cpv]["DEPEND"], eclass_depend)
+            self.assertEqual(results[cpv]["EAPI"], metadata["EAPI"])
+            self.assertEqual(
+                frozenset(results[cpv]["INHERITED"].split()), ebuild_inherited
             )
-            self.assertEqual(defined_phases, eclass_defined_phases)
-            self.assertEqual(depend, eclass_depend)
-            self.assertEqual(eapi, metadata["EAPI"])
-            self.assertEqual(frozenset(inherited.split()), ebuild_inherited)
 
-        return True
+    @staticmethod
+    def _run_test_mod_async(ebuilds, metadata_keys, portdb):
+        loop = asyncio._wrap_loop()
+        return loop.run_until_complete(
+            AuxdbTestCase._test_mod_async(
+                ebuilds,
+                metadata_keys,
+                portdb,
+            )
+        )
+
+    @staticmethod
+    async def _test_mod_async(ebuilds, metadata_keys, portdb):
+        results = {}
+        for cpv, metadata in ebuilds.items():
+            results[cpv] = dict(
+                zip(metadata_keys, await portdb.async_aux_get(cpv, 
metadata_keys))
+            )
+
+        return results
+
+    @staticmethod
+    def _modify_auxdb(auxdb, cpv):
+        metadata = auxdb[cpv]
+        metadata["RESTRICT"] = "test"
+        try:
+            del metadata["_eclasses_"]
+        except KeyError:
+            pass
+        auxdb[cpv] = metadata

Reply via email to