commit:     d5b35fa85ee18756d41cfb831389ffbd537abad6
Author:     Brian Harring <ferringb <AT> gmail <DOT> com>
AuthorDate: Tue Nov 25 09:42:59 2025 +0000
Commit:     Brian Harring <ferringb <AT> gmail <DOT> com>
CommitDate: Thu Nov 27 16:14:04 2025 +0000
URL:        
https://gitweb.gentoo.org/proj/pkgcore/snakeoil.git/commit/?id=d5b35fa8

feat: add python_namespaces.get_submodules_of

This is used to collect modules beneath a given other module.

The difference in quantity of code is this is written against modern
py3k, rather than inheriting py2k crap

Signed-off-by: Brian Harring <ferringb <AT> gmail.com>

 src/snakeoil/python_namespaces.py   | 210 +++++++++++++++---------------------
 src/snakeoil/test/mixins.py         | 140 +++++++++++++++++++++++-
 src/snakeoil/test/modules.py        |   4 +-
 src/snakeoil/test/slot_shadowing.py |   2 +-
 tests/test_demandload_usage.py      |   5 +-
 tests/test_python_namespaces.py     | 108 +++++++++++++++++++
 6 files changed, 334 insertions(+), 135 deletions(-)

diff --git a/src/snakeoil/python_namespaces.py 
b/src/snakeoil/python_namespaces.py
index 9dbe090..16c702f 100644
--- a/src/snakeoil/python_namespaces.py
+++ b/src/snakeoil/python_namespaces.py
@@ -1,136 +1,94 @@
-import errno
+__all__ = ("import_submodules_of", "get_submodules_of")
+import functools
+import importlib
+import importlib.machinery
 import os
-import stat
-import sys
-
-from .compatibility import IGNORED_EXCEPTIONS
-
-
-class PythonNamespaceWalker:
-    ignore_all_import_failures = False
-
-    valid_inits = frozenset(f"__init__.{x}" for x in ("py", "pyc", "pyo", 
"so"))
-
-    # This is for py3.2/PEP3149; dso's now have the interp + major/minor 
embedded
-    # in the name.
-    # TODO: update this for pypy's naming
-    abi_target = "cpython-%i%i" % tuple(sys.version_info[:2])
-
-    module_blacklist = frozenset(
-        {
-            "snakeoil.cli.arghparse",
-            "snakeoil.pickling",
-        }
-    )
-
-    def _default_module_blacklister(self, target):
-        return target in self.module_blacklist or 
target.startswith("snakeoil.dist")
-
-    def walk_namespace(self, namespace, **kwds):
-        location = os.path.abspath(
-            os.path.dirname(self.poor_mans_load(namespace).__file__)
-        )
-        return self.get_modules(self.recurse(location), namespace=namespace, 
**kwds)
-
-    def get_modules(
-        self, feed, namespace=None, blacklist_func=None, 
ignore_failed_imports=None
-    ):
-        if ignore_failed_imports is None:
-            ignore_failed_imports = self.ignore_all_import_failures
-        if namespace is None:
-
-            def mangle(x):  # pyright: ignore[reportRedeclaration]
-                return x
-        else:
-            orig_namespace = namespace
-
-            def mangle(x):
-                return f"{orig_namespace}.{x}"
-
-        if blacklist_func is None:
-            blacklist_func = self._default_module_blacklister
-        for mod_name in feed:
-            try:
-                if mod_name is None:
-                    if namespace is None:
-                        continue
+import pathlib
+import types
+import typing
+
+T_class_filter = typing.Callable[[str], bool]
+
+
+def get_submodules_of(
+    root: types.ModuleType,
+    /,
+    dont_import: T_class_filter | typing.Container[str] | None = None,
+    ignore_import_failures: T_class_filter | typing.Container[str] | bool = 
False,
+) -> typing.Iterable[types.ModuleType]:
+    """Visit all submodules of the target via walking the underlying filesystem
+
+    This currently cannot work against a frozen python exe (for example), nor 
source only contained within an egg; it currently just walks the FS.
+    :param root: the module to trace
+    :param dont_import: do not try importing anything in this sequence or 
dont_import(qualname)
+      boolean result.  Defaults to no filter
+    :param ignore_import_failures: filter of what modules are known to 
potentially raise an
+      ImportError, and to tolerate those if it occurs.  Defaults to tolerating 
none.
+    """
+
+    if dont_import is None:
+        dont_import = lambda _: False  # noqa: E731
+    elif isinstance(dont_import, typing.Container):
+        dont_import = dont_import.__contains__
+
+    if ignore_import_failures is True:
+        ignore_import_failures = bool
+    elif ignore_import_failures is False:
+        ignore_import_failures = lambda x: False  # noqa: E731
+    elif isinstance(ignore_import_failures, typing.Container):
+        ignore_import_failures = ignore_import_failures.__contains__
+
+    to_scan = [root]
+    while to_scan:
+        current = to_scan.pop()
+        if current.__file__ is None:
+            raise ValueError(
+                f"module {current!r} lacks __file__ attribute.  If this is a 
PEP420 namespace module, that is unsupported currently"
+            )
+
+        if current is not root:
+            yield current
+        base = pathlib.Path(os.path.abspath(current.__file__))
+        # if it's not the root of a module, there's nothing to do- return it..
+        if not base.name.startswith("__init__."):
+            continue
+        for potential in base.parent.iterdir():
+            name = potential.name
+            qualname = f"{current.__name__}.{name.split('.', 1)[0]}"
+            if name.startswith("__init__."):
+                # if we're in this directory, we already imported the 
enclosing namespace.
+                continue
+            if potential.is_dir():
+                if name == "__pycache__":
+                    continue
+            else:
+                for ext in importlib.machinery.all_suffixes():
+                    if name.endswith(ext):
+                        name = name[: -len(ext)]
+                        break
                 else:
-                    namespace = mangle(mod_name)
-                if blacklist_func(namespace):
+                    # it's not a python source.
                     continue
-                yield self.poor_mans_load(namespace)
-            except ImportError:
-                if not ignore_failed_imports:
-                    raise
 
-    def recurse(self, location, valid_namespace=True):
-        if os.path.dirname(location) == "__pycache__":
-            # Shouldn't be possible, but make sure we avoid this if it manages
-            # to occur.
-            return
-        dirents = os.listdir(location)
-        if not self.valid_inits.intersection(dirents):
-            if valid_namespace:
-                return
-        else:
-            yield None
+            if dont_import(qualname):
+                continue
 
-        stats: list[tuple[str, int]] = []
-        for x in dirents:
             try:
-                stats.append((x, os.stat(os.path.join(location, x)).st_mode))
-            except OSError as exc:
-                if exc.errno != errno.ENOENT:
+                # intentionally re-examine it; for a file tree this is 
wasteful since
+                # we would know if it's a directory or not, but whenever this 
code gets
+                # extended for working from .whl or .egg directly, we will 
want that
+                # logic in one spot. TL;DR: this is intentionally not 
optimized for the
+                # common case.
+                to_scan.append(importlib.import_module(qualname))
+            except ImportError:
+                if not ignore_import_failures(qualname):
                     raise
-                # file disappeared under our feet... lock file from
-                # trial can cause this.  ignore.
-                import logging
 
-                logging.debug(
-                    "file %r disappeared under our feet, ignoring",
-                    os.path.join(location, x),
-                )
 
-        seen = set(["__init__"])
-        for x, st in stats:
-            if not (x.startswith(".") or x.endswith("~")) and stat.S_ISREG(st):
-                if x.endswith((".py", ".pyc", ".pyo", ".so")):
-                    y = x.rsplit(".", 1)[0]
-                    # Ensure we're not looking at a >=py3k .so which injects
-                    # the version name in...
-                    if y not in seen:
-                        if "." in y and x.endswith(".so"):
-                            y, abi = x.rsplit(".", 1)
-                            if abi != self.abi_target:
-                                continue
-                        seen.add(y)
-                        yield y
+def import_submodules_of(target: types.ModuleType, **kwargs) -> None:
+    """load all modules of the given namespace.
 
-        for x, st in stats:
-            if stat.S_ISDIR(st):
-                for y in self.recurse(os.path.join(location, x)):
-                    if y is None:
-                        yield x
-                    else:
-                        yield f"{x}.{y}"
-
-    @staticmethod
-    def poor_mans_load(namespace, existence_check=False):
-        try:
-            obj = __import__(namespace)
-            if existence_check:
-                return True
-        except:
-            if existence_check:
-                return False
-            raise
-        for chunk in namespace.split(".")[1:]:
-            try:
-                obj = getattr(obj, chunk)
-            except IGNORED_EXCEPTIONS:
-                raise
-            except AttributeError:
-                raise AssertionError(f"failed importing target {namespace}")
-            except Exception as e:
-                raise AssertionError(f"failed importing target {namespace}; 
error {e}")
-        return obj
+    See get_submodules_of for the kwargs options.
+    """
+    for _ in get_submodules_of(target, **kwargs):
+        pass

diff --git a/src/snakeoil/test/mixins.py b/src/snakeoil/test/mixins.py
index c4db185..088c834 100644
--- a/src/snakeoil/test/mixins.py
+++ b/src/snakeoil/test/mixins.py
@@ -1,7 +1,139 @@
+import errno
+import os
+import sys
+
+from snakeoil.compatibility import IGNORED_EXCEPTIONS
 from snakeoil.deprecation import deprecated
 
-from ..python_namespaces import PythonNamespaceWalker as _original
 
-PythonNamespaceWalker = deprecated(
-    "snakeoil.test.mixins.PythonNamespaceWalker has moved to 
snakeoil._namespaces.  Preferably remove your dependency on it"
-)(_original)  # pyright: ignore[reportAssignmentType]
+@deprecated(
+    "snakeoil.test.mixins.PythonNamespaceWalker is deprecated, instead use 
snakeoil.python_namespaces.submodules_of"
+)
+class PythonNamespaceWalker:
+    ignore_all_import_failures = False
+
+    valid_inits = frozenset(f"__init__.{x}" for x in ("py", "pyc", "pyo", 
"so"))
+
+    # This is for py3.2/PEP3149; dso's now have the interp + major/minor 
embedded
+    # in the name.
+    # TODO: update this for pypy's naming
+    abi_target = "cpython-%i%i" % tuple(sys.version_info[:2])
+
+    module_blacklist = frozenset(
+        {
+            "snakeoil.cli.arghparse",
+            "snakeoil.pickling",
+        }
+    )
+
+    def _default_module_blacklister(self, target):
+        return target in self.module_blacklist or 
target.startswith("snakeoil.dist")
+
+    def walk_namespace(self, namespace, **kwds):
+        location = os.path.abspath(
+            os.path.dirname(self.poor_mans_load(namespace).__file__)
+        )
+        return self.get_modules(self.recurse(location), namespace=namespace, 
**kwds)
+
+    def get_modules(
+        self, feed, namespace=None, blacklist_func=None, 
ignore_failed_imports=None
+    ):
+        if ignore_failed_imports is None:
+            ignore_failed_imports = self.ignore_all_import_failures
+        if namespace is None:
+
+            def mangle(x):  # pyright: ignore[reportRedeclaration]
+                return x
+        else:
+            orig_namespace = namespace
+
+            def mangle(x):
+                return f"{orig_namespace}.{x}"
+
+        if blacklist_func is None:
+            blacklist_func = self._default_module_blacklister
+        for mod_name in feed:
+            try:
+                if mod_name is None:
+                    if namespace is None:
+                        continue
+                else:
+                    namespace = mangle(mod_name)
+                if blacklist_func(namespace):
+                    continue
+                yield self.poor_mans_load(namespace)
+            except ImportError:
+                if not ignore_failed_imports:
+                    raise
+
+    def recurse(self, location, valid_namespace=True):
+        if os.path.dirname(location) == "__pycache__":
+            # Shouldn't be possible, but make sure we avoid this if it manages
+            # to occur.
+            return
+        dirents = os.listdir(location)
+        if not self.valid_inits.intersection(dirents):
+            if valid_namespace:
+                return
+        else:
+            yield None
+
+        stats: list[tuple[str, int]] = []
+        for x in dirents:
+            try:
+                stats.append((x, os.stat(os.path.join(location, x)).st_mode))
+            except OSError as exc:
+                if exc.errno != errno.ENOENT:
+                    raise
+                # file disappeared under our feet... lock file from
+                # trial can cause this.  ignore.
+                import logging
+
+                logging.debug(
+                    "file %r disappeared under our feet, ignoring",
+                    os.path.join(location, x),
+                )
+
+        seen = set(["__init__"])
+        for x, st in stats:
+            if not (x.startswith(".") or x.endswith("~")) and stat.S_ISREG(st):
+                if x.endswith((".py", ".pyc", ".pyo", ".so")):
+                    y = x.rsplit(".", 1)[0]
+                    # Ensure we're not looking at a >=py3k .so which injects
+                    # the version name in...
+                    if y not in seen:
+                        if "." in y and x.endswith(".so"):
+                            y, abi = x.rsplit(".", 1)
+                            if abi != self.abi_target:
+                                continue
+                        seen.add(y)
+                        yield y
+
+        for x, st in stats:
+            if stat.S_ISDIR(st):
+                for y in self.recurse(os.path.join(location, x)):
+                    if y is None:
+                        yield x
+                    else:
+                        yield f"{x}.{y}"
+
+    @staticmethod
+    def poor_mans_load(namespace, existence_check=False):
+        try:
+            obj = __import__(namespace)
+            if existence_check:
+                return True
+        except:
+            if existence_check:
+                return False
+            raise
+        for chunk in namespace.split(".")[1:]:
+            try:
+                obj = getattr(obj, chunk)
+            except IGNORED_EXCEPTIONS:
+                raise
+            except AttributeError:
+                raise AssertionError(f"failed importing target {namespace}")
+            except Exception as e:
+                raise AssertionError(f"failed importing target {namespace}; 
error {e}")
+        return obj

diff --git a/src/snakeoil/test/modules.py b/src/snakeoil/test/modules.py
index 3758960..9983300 100644
--- a/src/snakeoil/test/modules.py
+++ b/src/snakeoil/test/modules.py
@@ -1,7 +1,7 @@
-from . import mixins
+from snakeoil.test.mixins import PythonNamespaceWalker
 
 
-class ExportedModules(mixins.PythonNamespaceWalker):
+class ExportedModules(PythonNamespaceWalker):
     target_namespace = "snakeoil"
 
     def test__all__accuracy(self):

diff --git a/src/snakeoil/test/slot_shadowing.py 
b/src/snakeoil/test/slot_shadowing.py
index ac17e95..71d120f 100644
--- a/src/snakeoil/test/slot_shadowing.py
+++ b/src/snakeoil/test/slot_shadowing.py
@@ -5,7 +5,7 @@ import warnings
 
 import pytest
 
-from snakeoil.python_namespaces import PythonNamespaceWalker
+from snakeoil.test.mixins import PythonNamespaceWalker
 
 
 class TargetedNamespaceWalker(PythonNamespaceWalker):

diff --git a/tests/test_demandload_usage.py b/tests/test_demandload_usage.py
index 437094c..1c4b25d 100644
--- a/tests/test_demandload_usage.py
+++ b/tests/test_demandload_usage.py
@@ -1,8 +1,9 @@
 import pytest
-from snakeoil.test import mixins
 
+from snakeoil.test.mixins import PythonNamespaceWalker
 
-class TestDemandLoadTargets(mixins.PythonNamespaceWalker):
+
+class TestDemandLoadTargets(PythonNamespaceWalker):
     target_namespace = "snakeoil"
     ignore_all_import_failures = False
 

diff --git a/tests/test_python_namespaces.py b/tests/test_python_namespaces.py
new file mode 100644
index 0000000..0614e69
--- /dev/null
+++ b/tests/test_python_namespaces.py
@@ -0,0 +1,108 @@
+import importlib
+import pathlib
+import sys
+from contextlib import contextmanager
+
+import pytest
+
+from snakeoil.python_namespaces import (
+    get_submodules_of,
+    import_submodules_of,
+)
+
+
+class TestNamespaceCollector:
+    def write_tree(self, base: pathlib.Path, *paths: str | pathlib.Path):
+        base.mkdir(exist_ok=True)
+        for path in sorted(paths):
+            path = base / pathlib.Path(path)
+            if not path.parent.exists():
+                path.parent.mkdir(parents=True)
+            path.touch()
+
+    @contextmanager
+    def protect_modules(self, base):
+        python_path = sys.path[:]
+        modules = sys.modules.copy()
+        try:
+            sys.path.append(str(base))
+            importlib.invalidate_caches()
+            yield (modules.copy())
+        finally:
+            sys.path[:] = python_path
+            sys.modules = modules
+            importlib.invalidate_caches()
+
+    def test_it(self, tmp_path):
+        self.write_tree(
+            tmp_path,
+            "_ns_test/__init__.py",
+            "_ns_test/blah.py",
+            "_ns_test/ignored",
+            "_ns_test/real/__init__.py",
+            "_ns_test/real/extra.py",
+        )
+
+        def get_it(target, *args, **kwargs):
+            target = importlib.import_module(target)
+            return list(
+                sorted(x.__name__ for x in get_submodules_of(target, *args, 
**kwargs))
+            )
+
+        with self.protect_modules(tmp_path):
+            assert ["_ns_test.blah", "_ns_test.real", "_ns_test.real.extra"] 
== get_it(
+                "_ns_test"
+            )
+            assert "_ns_test" in sys.modules
+
+            assert ["_ns_test.blah", "_ns_test.real"] == get_it(
+                "_ns_test", dont_import=["_ns_test.real.extra"]
+            )
+            assert ["_ns_test.blah"] == get_it(
+                "_ns_test", dont_import="_ns_test.real".__eq__
+            ), (
+                "dont_import filter failed to prevent scanning a submodule and 
it's children"
+            )
+            assert ["_ns_test.blah", "_ns_test.real"], get_it(
+                "_ns_test", dont_import="_ns_test.real.extra".__eq__
+            )
+
+    def test_load(self, tmp_path):
+        self.write_tree(
+            pathlib.Path(tmp_path) / "_ns_test",
+            "__init__.py",
+            "blah.py",
+            "foon.py",
+            "extra.py",
+        )
+        with self.protect_modules(tmp_path):
+            assert None is 
import_submodules_of(importlib.import_module("_ns_test"))
+            assert set(["_ns_test.blah", "_ns_test.foon", "_ns_test.extra"]) 
== set(
+                x for x in sys.modules if x.startswith("_ns_test.")
+            )
+
+    def test_import_failures(self, tmp_path):
+        base = pathlib.Path(tmp_path) / "_ns_test"
+        self.write_tree(base, "__init__.py", "blah.py")
+
+        with (base / "bad1.py").open("w") as f:
+            f.write("raise ImportError('bad1')")
+        with (base / "bad2.py").open("w") as f:
+            f.write("raise ImportError('bad2')")
+
+        with self.protect_modules(tmp_path):
+            mod = importlib.import_module("_ns_test")
+            with pytest.raises(ImportError) as capture:
+                import_submodules_of(mod, 
ignore_import_failures=["_ns_test.bad2"])
+            assert ("bad1",) == tuple(capture.value.args)
+
+            with pytest.raises(ImportError) as capture:
+                import_submodules_of(mod, 
ignore_import_failures=["_ns_test.bad1"])
+            assert ("bad2",) == tuple(capture.value.args)
+
+            with pytest.raises(ImportError):
+                import_submodules_of(mod, ignore_import_failures=False)
+
+            assert ["_ns_test.blah"] == [
+                x.__name__ for x in get_submodules_of(mod, 
ignore_import_failures=True)
+            ]

Reply via email to