commit:     45f2a895b9eac1fe409cd204865aaaf6b484ba9c
Author:     Brian Harring <ferringb <AT> gmail <DOT> com>
AuthorDate: Sat Nov 29 19:00:22 2025 +0000
Commit:     Brian Harring <ferringb <AT> gmail <DOT> com>
CommitDate: Sat Nov 29 22:40:39 2025 +0000
URL:        
https://gitweb.gentoo.org/proj/pkgcore/snakeoil.git/commit/?id=45f2a895

tweak get_submodules_of to be simpler for common usage

Signed-off-by: Brian Harring <ferringb <AT> gmail.com>

 src/snakeoil/python_namespaces.py | 3 ++-
 src/snakeoil/test/code_quality.py | 4 +++-
 tests/test_python_namespaces.py   | 6 +++++-
 3 files changed, 10 insertions(+), 3 deletions(-)

diff --git a/src/snakeoil/python_namespaces.py 
b/src/snakeoil/python_namespaces.py
index c4a050b..a926f5c 100644
--- a/src/snakeoil/python_namespaces.py
+++ b/src/snakeoil/python_namespaces.py
@@ -15,6 +15,7 @@ def get_submodules_of(
     /,
     dont_import: T_class_filter | typing.Container[str] | None = None,
     ignore_import_failures: T_class_filter | typing.Container[str] | bool = 
False,
+    include_root=False,
 ) -> typing.Iterable[types.ModuleType]:
     """Visit all submodules of the target via walking the underlying filesystem
 
@@ -46,7 +47,7 @@ def get_submodules_of(
                 f"module {current!r} lacks __file__ attribute.  If this is a 
PEP420 namespace module, that is unsupported currently"
             )
 
-        if current is not root:
+        if current is not root or include_root:
             yield current
         base = pathlib.Path(os.path.abspath(current.__file__))
         # if it's not the root of a module, there's nothing to do- return it..

diff --git a/src/snakeoil/test/code_quality.py 
b/src/snakeoil/test/code_quality.py
index 8fb7f4c..f6e2397 100644
--- a/src/snakeoil/test/code_quality.py
+++ b/src/snakeoil/test/code_quality.py
@@ -86,7 +86,9 @@ class ParameterizeBase(typing.Generic[T], abc.ABC):
     def collect_modules(cls) -> typing.Iterable[ModuleType]:
         for namespace in cls.namespaces:
             yield from get_submodules_of(
-                __import__(namespace), dont_import=cls.namespace_ignores
+                __import__(namespace),
+                dont_import=cls.namespace_ignores,
+                include_root=True,
             )
 
 

diff --git a/tests/test_python_namespaces.py b/tests/test_python_namespaces.py
index 0614e69..c06037e 100644
--- a/tests/test_python_namespaces.py
+++ b/tests/test_python_namespaces.py
@@ -63,10 +63,14 @@ class TestNamespaceCollector:
             ), (
                 "dont_import filter failed to prevent scanning a submodule and 
it's children"
             )
-            assert ["_ns_test.blah", "_ns_test.real"], get_it(
+            assert ["_ns_test.blah", "_ns_test.real"] == get_it(
                 "_ns_test", dont_import="_ns_test.real.extra".__eq__
             )
 
+            assert ["_ns_test.real", "_ns_test.real.extra"] == get_it(
+                "_ns_test.real", include_root=True
+            )
+
     def test_load(self, tmp_path):
         self.write_tree(
             pathlib.Path(tmp_path) / "_ns_test",

Reply via email to