commit: 45f2a895b9eac1fe409cd204865aaaf6b484ba9c
Author: Brian Harring <ferringb <AT> gmail <DOT> com>
AuthorDate: Sat Nov 29 19:00:22 2025 +0000
Commit: Brian Harring <ferringb <AT> gmail <DOT> com>
CommitDate: Sat Nov 29 22:40:39 2025 +0000
URL:
https://gitweb.gentoo.org/proj/pkgcore/snakeoil.git/commit/?id=45f2a895
tweak get_submodules_of to be simpler for common usage
Signed-off-by: Brian Harring <ferringb <AT> gmail.com>
src/snakeoil/python_namespaces.py | 3 ++-
src/snakeoil/test/code_quality.py | 4 +++-
tests/test_python_namespaces.py | 6 +++++-
3 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/src/snakeoil/python_namespaces.py
b/src/snakeoil/python_namespaces.py
index c4a050b..a926f5c 100644
--- a/src/snakeoil/python_namespaces.py
+++ b/src/snakeoil/python_namespaces.py
@@ -15,6 +15,7 @@ def get_submodules_of(
/,
dont_import: T_class_filter | typing.Container[str] | None = None,
ignore_import_failures: T_class_filter | typing.Container[str] | bool =
False,
+ include_root=False,
) -> typing.Iterable[types.ModuleType]:
"""Visit all submodules of the target via walking the underlying filesystem
@@ -46,7 +47,7 @@ def get_submodules_of(
f"module {current!r} lacks __file__ attribute. If this is a
PEP420 namespace module, that is unsupported currently"
)
- if current is not root:
+ if current is not root or include_root:
yield current
base = pathlib.Path(os.path.abspath(current.__file__))
# if it's not the root of a module, there's nothing to do- return it..
diff --git a/src/snakeoil/test/code_quality.py
b/src/snakeoil/test/code_quality.py
index 8fb7f4c..f6e2397 100644
--- a/src/snakeoil/test/code_quality.py
+++ b/src/snakeoil/test/code_quality.py
@@ -86,7 +86,9 @@ class ParameterizeBase(typing.Generic[T], abc.ABC):
def collect_modules(cls) -> typing.Iterable[ModuleType]:
for namespace in cls.namespaces:
yield from get_submodules_of(
- __import__(namespace), dont_import=cls.namespace_ignores
+ __import__(namespace),
+ dont_import=cls.namespace_ignores,
+ include_root=True,
)
diff --git a/tests/test_python_namespaces.py b/tests/test_python_namespaces.py
index 0614e69..c06037e 100644
--- a/tests/test_python_namespaces.py
+++ b/tests/test_python_namespaces.py
@@ -63,10 +63,14 @@ class TestNamespaceCollector:
), (
"dont_import filter failed to prevent scanning a submodule and
it's children"
)
- assert ["_ns_test.blah", "_ns_test.real"], get_it(
+ assert ["_ns_test.blah", "_ns_test.real"] == get_it(
"_ns_test", dont_import="_ns_test.real.extra".__eq__
)
+ assert ["_ns_test.real", "_ns_test.real.extra"] == get_it(
+ "_ns_test.real", include_root=True
+ )
+
def test_load(self, tmp_path):
self.write_tree(
pathlib.Path(tmp_path) / "_ns_test",