commit: e7c724bbaad864ee81301f3be4de9c353f85f668
Author: Brian Harring <ferringb <AT> gmail <DOT> com>
AuthorDate: Wed Dec 3 17:27:47 2025 +0000
Commit: Brian Harring <ferringb <AT> gmail <DOT> com>
CommitDate: Wed Dec 3 19:52:33 2025 +0000
URL:
https://gitweb.gentoo.org/proj/pkgcore/snakeoil.git/commit/?id=e7c724bb
feat(python_namespaces): add import_module_from_path for arbitary imports, and
protect_imports for bad ideas.
protect_imports: Use this to protect and unwind changes to sys.modules
and sys.path. It is explicitly not thread safe because it's impossible
for it to be thread safe. Generally this should only be used in tests,
but there are special cases where temporarily importing a secondary tree
has valu. For example to grab a submodule within a tree- just that submodule-
and then discard everything but it, ensuring that there is no conflicting
intersection from that tree. Again, useful for tests, but for plugin
architectures this has a fair amount of value.
import_module_from_path: use this for code that must import python source
that isn't part of a normal python tree, or has strong reasons
to ensure it doesn't get integrated into sys.modules. It's edge
case.
Signed-off-by: Brian Harring <ferringb <AT> gmail.com>
src/snakeoil/python_namespaces.py | 57 +++++++++++++++++-
tests/test_python_namespaces.py | 120 ++++++++++++++++++++++++++++++++++----
2 files changed, 165 insertions(+), 12 deletions(-)
diff --git a/src/snakeoil/python_namespaces.py
b/src/snakeoil/python_namespaces.py
index 244a17d..931cded 100644
--- a/src/snakeoil/python_namespaces.py
+++ b/src/snakeoil/python_namespaces.py
@@ -1,9 +1,12 @@
__all__ = ("import_submodules_of", "get_submodules_of")
+import contextlib
import os
+import sys
import types
import typing
-from importlib import import_module, machinery
+from importlib import import_module, invalidate_caches, machinery
+from importlib import util as import_util
from pathlib import Path
T_class_filter = typing.Callable[[str], bool]
@@ -103,3 +106,55 @@ def remove_py_extension(path: Path | str) -> str | None:
if name.endswith(ext):
return name[: -len(ext)]
return None
+
+
[email protected]
+def protect_imports() -> typing.Generator[
+ tuple[list[str], dict[str, types.ModuleType]], None, None
+]:
+ """
+ Non threadsafe mock.patch of internal imports to allow revision
+
+ This should used in tests or very select scenarios. Assume that underlying
+ c extensions that hold internal static state (curse module) will reimport,
but
+ will not be 'clean'. Any changes an import inflicts on the other modules
in
+ memory, etc, this cannot block that. Nor is this intended to do so; it's
+ for controlled tests or very specific usages.
+ """
+ orig_content = sys.path[:]
+ orig_modules = sys.modules.copy()
+ with contextlib.nullcontext():
+ yield sys.path, sys.modules
+
+ sys.path[:] = orig_content
+ # This is explicitly not thread safe, but manipulating sys.path
fundamentally isn't thus this context
+ # isn't thread safe. TL;dr: nuke it, and restore, it's the only way to be
sure (to paraphrase)
+ sys.modules.clear()
+ sys.modules.update(orig_modules)
+ # Out of paranoia, force loaders to reset their caches.
+ invalidate_caches()
+
+
+def import_module_from_path(
+ path: str | Path, module_name: str | None = None
+) -> types.ModuleType:
+ """Load and return a module from a file path, without needing a package.
+
+ :param path: the path to load. No python package structure will be
inferred from this. Currently it
+ must end in a python extension.
+ :param module_name: If given, this is __name__ within the module. If not
given it is
+ inferred from path if path has a valid python extension. If it does
not, an ImportError
+ is raised and you must specify module_name yourself.
+ """
+ if (default_module_name := remove_py_extension(path)) is None:
+ raise ValueError(f"{path} must end in a valid python extension like
.py")
+
+ module_name = default_module_name if module_name is None else module_name
+
+ spec = import_util.spec_from_file_location(module_name, path)
+ if spec is None or spec.loader is None:
+ raise ImportError(f"Cannot create import spec for {path}")
+
+ module = import_util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ return module
diff --git a/tests/test_python_namespaces.py b/tests/test_python_namespaces.py
index 2ad1bdc..ca43f72 100644
--- a/tests/test_python_namespaces.py
+++ b/tests/test_python_namespaces.py
@@ -1,26 +1,32 @@
+import contextlib
import pathlib
import sys
+import types
from contextlib import contextmanager
from importlib import import_module, invalidate_caches, machinery
+from typing import Any, NamedTuple
import pytest
from snakeoil.python_namespaces import (
get_submodules_of,
+ import_module_from_path,
import_submodules_of,
+ protect_imports,
remove_py_extension,
)
-class test_python_namespaces:
- def write_tree(self, base: pathlib.Path, *paths: str | pathlib.Path):
- base.mkdir(exist_ok=True)
- for path in sorted(paths):
- path = base / pathlib.Path(path)
- if not path.parent.exists():
- path.parent.mkdir(parents=True)
- path.touch()
+def write_tree(base: pathlib.Path, *paths: str | pathlib.Path):
+ base.mkdir(exist_ok=True)
+ for path in sorted(paths):
+ path = base / pathlib.Path(path)
+ if not path.parent.exists():
+ path.parent.mkdir(parents=True)
+ path.touch()
+
+class test_python_namespaces:
@contextmanager
def protect_modules(self, base):
python_path = sys.path[:]
@@ -35,7 +41,7 @@ class test_python_namespaces:
invalidate_caches()
def test_it(self, tmp_path):
- self.write_tree(
+ write_tree(
tmp_path,
"_ns_test/__init__.py",
"_ns_test/blah.py",
@@ -73,7 +79,7 @@ class test_python_namespaces:
)
def test_load(self, tmp_path):
- self.write_tree(
+ write_tree(
pathlib.Path(tmp_path) / "_ns_test",
"__init__.py",
"blah.py",
@@ -88,7 +94,7 @@ class test_python_namespaces:
def test_import_failures(self, tmp_path):
base = pathlib.Path(tmp_path) / "_ns_test"
- self.write_tree(base, "__init__.py", "blah.py")
+ write_tree(base, "__init__.py", "blah.py")
with (base / "bad1.py").open("w") as f:
f.write("raise ImportError('bad1')")
@@ -127,3 +133,95 @@ def test_remove_py_extension():
)
assert None is remove_py_extension("asdf")
assert None is remove_py_extension("asdf.txt")
+
+
[email protected]
+def assert_protect_modules():
+ # cpython has notes that swapping the object may result in unexpected
behavior- thus
+ # assert we don't.
+ orig_modules_content = (orig_modules := sys.modules).copy()
+ orig_path_content = (orig_path := sys.path)[:]
+ with protect_imports() as (path, modules):
+ assert orig_path is sys.path
+ assert orig_path is path
+ assert orig_path_content == path
+ assert orig_modules is sys.modules
+ assert orig_modules is modules
+ assert orig_modules_content == modules
+ yield (path, modules)
+
+ assert orig_modules is sys.modules, "sys.modules isn't the same object"
+ assert not set(orig_modules_content).symmetric_difference(sys.modules), (
+ "sys.modules wasn't reset to it's original content"
+ )
+ assert orig_path is sys.path, "sys.path isn't the same object"
+ assert orig_path_content == sys.path, (
+ "sys.path content wasn't reset to it's original content"
+ )
+
+
+def test_protect_imports(tmp_path):
+ p = tmp_path / "_must_not_exist.py"
+ p.touch()
+ with assert_protect_modules() as (path, modules):
+ with pytest.raises(ModuleNotFoundError):
+ # confirm we're not somehow intersecting something elsewhere
+ import_module(p.stem)
+
+ # also validate assert_protect_module while we're at it- thus the extra
+ # checks.
+ path.append(str(tmp_path))
+ assert str(tmp_path) == sys.path[-1]
+ import_module(p.stem)
+ assert p.stem in modules
+ assert p.stem in sys.modules
+
+
+class ShouldBeReachedOnlyInSuccess(Exception): ...
+
+
+class params(NamedTuple):
+ name: str
+ module_name: str = ""
+ throws: type[Exception] = ShouldBeReachedOnlyInSuccess
+ content: str = ""
+ attrs: dict[str, Any] = {}
+
+
[email protected](
+ "config",
+ [
+ params("blah.py", "blah"),
+ params("asdf", throws=ValueError),
+ # enforce override
+ params("blah.py", module_name="asdf"),
+ params(
+ "blah.py",
+ throws=DeprecationWarning,
+ content="raise DeprecationWarning()",
+ ),
+ params("foon.py", "foon", content='x="value";y=2',
attrs=dict(x="value", y=2)),
+ # basic validation of pass through of underlying python machinery
failure
+ params("foon.py", "foon", content="fda=", throws=SyntaxError),
+ ],
+)
+def test_import_module_from_path(tmp_path, config):
+ p = tmp_path / config.name
+
+ with p.open("w") as f:
+ f.write(config.content)
+
+ with assert_protect_modules():
+ # there's a trick here; either the exception required gets thrown, or
+ # we terminate the "success" path via throwing the default exception,
thus
+ # making that 'fine'. For code expecting a different exception, our
default throw
+ # flags them as not matching the assertion.
+ with pytest.raises(config.throws):
+ module = import_module_from_path(p, config.module_name)
+ assert isinstance(module, types.ModuleType)
+ assert config.module_name == module.__name__
+ assert str(p) == module.__file__
+ for k, v in config.attrs.items():
+ # Let fly the AttributeError- if it occurs, it's because the
test is faulty.
+ assert v == getattr(module, k)
+ raise ShouldBeReachedOnlyInSuccess()