commit:     7d1315589f5a5cc48ba6087a6ca776bec75004f4
Author:     Brian Harring <ferringb <AT> gmail <DOT> com>
AuthorDate: Sat Nov 29 15:30:53 2025 +0000
Commit:     Brian Harring <ferringb <AT> gmail <DOT> com>
CommitDate: Sat Nov 29 18:32:47 2025 +0000
URL:        
https://gitweb.gentoo.org/proj/pkgcore/snakeoil.git/commit/?id=7d131558

feat(test): make test.protect_process usable

This was adhoc working for formatters due to a quirk of snakeoil
module layout; instead, since the code uses pytest, just exec
the fork as a pytest call w/ the test that is wrapped.

This is now generally usable, although if you're using the decorator
within a test (not 'for the test'- within)- some care is warranted
since effectively that entire scope gets ran both on the fork and
parent.  See 'test_failure' for an example of how to handle this.

Signed-off-by: Brian Harring <ferringb <AT> gmail.com>

 src/snakeoil/klass/__init__.py        |  2 +
 src/snakeoil/klass/util.py            |  9 ++--
 src/snakeoil/test/__init__.py         | 86 +++++++++++++++++++++--------------
 src/snakeoil/test/argparse_helpers.py |  8 ++++
 tests/klass/test_util.py              | 18 ++++----
 tests/test_code_quality.py            |  7 +++
 tests/test_formatters.py              | 14 +++---
 tests/test_test.py                    | 38 ++++++++++++++++
 8 files changed, 128 insertions(+), 54 deletions(-)

diff --git a/src/snakeoil/klass/__init__.py b/src/snakeoil/klass/__init__.py
index 1e5516a..2295827 100644
--- a/src/snakeoil/klass/__init__.py
+++ b/src/snakeoil/klass/__init__.py
@@ -23,6 +23,7 @@ __all__ = (
     "cached_property_named",
     "copy_docs",
     "steal_docs",
+    "is_metaclass",
     "ImmutableInstance",
     "immutable_instance",
     "inject_immutable_instance",
@@ -68,6 +69,7 @@ from .util import (
     get_slot_of,
     get_slots_of,
     get_subclasses_of,
+    is_metaclass,
 )
 
 sentinel = object()

diff --git a/src/snakeoil/klass/util.py b/src/snakeoil/klass/util.py
index fa2876c..4d6aca3 100644
--- a/src/snakeoil/klass/util.py
+++ b/src/snakeoil/klass/util.py
@@ -90,6 +90,11 @@ def get_attrs_of(
                     seen.add(slot)
 
 
+def is_metaclass(cls: type) -> typing.TypeGuard[type[type]]:
+    """discern if something is a metaclass.  This intentionally ignores 
function based metaclasses"""
+    return issubclass(cls, type)
+
+
 def get_subclasses_of(
     cls: type,
     only_leaf_nodes=False,
@@ -133,10 +138,6 @@ def get_subclasses_of(
             yield current
 
 
-def is_metaclass(cls: type) -> typing.TypeGuard[type[type]]:
-    return issubclass(cls, type)
-
-
 @functools.lru_cache
 def combine_classes(kls: type, *extra: type) -> type:
     """Given a set of classes, combine this as if one had wrote the class by 
hand

diff --git a/src/snakeoil/test/__init__.py b/src/snakeoil/test/__init__.py
index 8e97367..cbfead8 100644
--- a/src/snakeoil/test/__init__.py
+++ b/src/snakeoil/test/__init__.py
@@ -1,5 +1,16 @@
 """Our unittest extensions."""
 
+__all__ = (
+    "coverage",
+    "hide_imports",
+    "Modules",
+    "ParameterizeBase",
+    "protect_process",
+    "random_str",
+    "Slots",
+)
+
+import functools
 import os
 import random
 import string
@@ -7,8 +18,7 @@ import subprocess
 import sys
 from unittest.mock import patch
 
-# not relative imports so protect_process() works properly
-from snakeoil import klass
+from .code_quality import Modules, ParameterizeBase, Slots
 
 
 def random_str(length):
@@ -35,45 +45,53 @@ def coverage():
     return cov
 
 
-_PROTECT_ENV_VAR = "SNAKEOIL_UNITTEST_PROTECT_PROCESS"
-
-
-def protect_process(functor, name=None):
-    def _inner_run(self, name=name):
-        if os.environ.get(_PROTECT_ENV_VAR, False):
-            return functor(self)
-        if name is None:
-            name = (
-                
f"{self.__class__.__module__}.{self.__class__.__name__}.{method_name}"
+def protect_process(
+    forced_test: None | str = None,
+    marker_env_var="SNAKEOIL_UNITTEST_PROTECT_PROCESS",
+    extra_env: dict[str, str] | None = None,
+):
+    def wrapper(functor):
+        @functools.wraps(functor)
+        def _inner_run(self, *args, **kwargs):
+            # we're in the child.  Just run it.
+            if os.environ.get(marker_env_var, False):
+                return functor(self, *args, **kwargs)
+
+            # we're in the parent.  if capsys is in there, we have
+            # to intercept it for the code below.
+            capsys = kwargs.get("capsys")
+            env = os.environ.copy()
+            if extra_env:
+                env.update(extra_env)
+            env[marker_env_var] = "disable"
+            test = (
+                os.environ["PYTEST_CURRENT_TEST"]
+                if forced_test is None
+                else forced_test
             )
-        runner_path = __file__
-        if runner_path.endswith(".pyc") or runner_path.endswith(".pyo"):
-            runner_path = runner_path.rsplit(".", maxsplit=1)[0] + ".py"
-        wipe = _PROTECT_ENV_VAR not in os.environ
-        try:
-            os.environ[_PROTECT_ENV_VAR] = "yes"
-            args = [sys.executable, __file__, name]
+            # 
https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable
+            assert test.endswith(" (call)")
+            test = test[: -len(" (call)")]
+            args = [sys.executable, "-m", "pytest", "-v", test]
             p = subprocess.Popen(
                 args,
-                shell=False,
-                env=os.environ.copy(),
-                stdout=subprocess.PIPE,
-                stderr=subprocess.STDOUT,
+                env=env,
+                stdout=None if capsys else subprocess.PIPE,
+                stderr=None if capsys else subprocess.PIPE,
             )
-            stdout, _ = p.communicate()
+
+            stdout, stderr = p.communicate()
+            if capsys:
+                result = capsys.readouterr()
+                stdout, stderr = result.out, result.err
             ret = p.wait()
             assert ret == 0, (
-                f"subprocess run: {args!r}\nnon zero exit: 
{ret}\nstdout:\n{stdout}"
+                f"subprocess run: {args!r}\nnon zero exit: 
{ret}\nstdout:\n{stdout.decode()}'n\nstderr:\n{stderr.decode()}"
             )
-        finally:
-            if wipe:
-                os.environ.pop(_PROTECT_ENV_VAR, None)
-
-    for x in ("__doc__", "__name__"):
-        if hasattr(functor, x):
-            setattr(_inner_run, x, getattr(functor, x))
-    method_name = getattr(functor, "__name__", None)
-    return _inner_run
+
+        return _inner_run
+
+    return wrapper
 
 
 def hide_imports(*import_names: str):

diff --git a/src/snakeoil/test/argparse_helpers.py 
b/src/snakeoil/test/argparse_helpers.py
index bfcdc3b..7887f1c 100644
--- a/src/snakeoil/test/argparse_helpers.py
+++ b/src/snakeoil/test/argparse_helpers.py
@@ -1,3 +1,11 @@
+__all__ = (
+    "ArgParseMixin",
+    "Bold",
+    "Color",
+    "FakeStreamFormatter",
+    "mangle_parser",
+    "Reset",
+)
 import difflib
 from copy import copy
 

diff --git a/tests/klass/test_util.py b/tests/klass/test_util.py
index 2a66cf8..81990be 100644
--- a/tests/klass/test_util.py
+++ b/tests/klass/test_util.py
@@ -129,6 +129,15 @@ def test_combine_classes():
     assert [combined, kls1, kls2, type, object] == list(combined.__mro__)
 
 
+def test_is_metaclass():
+    assert not is_metaclass(object)
+    assert is_metaclass(type)
+
+    class foon(type): ...
+
+    assert is_metaclass(foon)
+
+
 def test_get_subclasses_of():
     attr = operator.attrgetter("__name__")
 
@@ -182,12 +191,3 @@ def test_get_subclasses_of():
     class combined(left, right): ...
 
     assert_it(base, [left, right, combined])
-
-
-def test_is_metaclass():
-    assert not is_metaclass(object)
-    assert is_metaclass(type)
-
-    class foon(type): ...
-
-    assert is_metaclass(foon)

diff --git a/tests/test_code_quality.py b/tests/test_code_quality.py
index 83b253d..98d3f95 100644
--- a/tests/test_code_quality.py
+++ b/tests/test_code_quality.py
@@ -28,3 +28,10 @@ class TestSlots(code_quality.Slots):
 
 class TestModules(code_quality.Modules):
     namespaces = ("snakeoil",)
+    namespace_ignores = (
+        # dead code only existing to not break versions of down stream
+        # packaging.  They'll be removed
+        "snakeoil.test.eq_hash_inheritance",
+        "snakeoil.test.mixins",
+        "snakeoil.test.slot_shadowing",
+    )

diff --git a/tests/test_formatters.py b/tests/test_formatters.py
index 2b7552b..78b4dee 100644
--- a/tests/test_formatters.py
+++ b/tests/test_formatters.py
@@ -169,7 +169,7 @@ class TerminfoFormatterTest:
             result,
         )
 
-    @pythonGHissue51816
+    @pythonGHissue51816()
     def test_terminfo(self):
         esc = "\x1b["
         stream = TemporaryFile()
@@ -217,7 +217,7 @@ class TerminfoFormatterTest:
         with pytest.raises(formatters.TerminfoUnsupported):
             formatters.TerminfoFormatter(stream, term="dumb")
 
-    @pythonGHissue51816
+    @pythonGHissue51816()
     def test_title(self):
         stream = TemporaryFile()
         try:
@@ -251,35 +251,35 @@ def _get_pty_pair(encoding="ascii"):
 
 
 class TestGetFormatter:
-    @pythonGHissue51816
+    @pythonGHissue51816()
     def test_dumb_terminal(self):
         master, _out = _get_pty_pair()
         with forced_term("dumb"):
             formatter = formatters.get_formatter(master)
             assert isinstance(formatter, formatters.PlainTextFormatter)
 
-    @pythonGHissue51816
+    @pythonGHissue51816()
     def test_vt100_terminal(self):
         master, _out = _get_pty_pair()
         with forced_term("vt100"):
             formatter = formatters.get_formatter(master)
             assert isinstance(formatter, formatters.PlainTextFormatter)
 
-    @pythonGHissue51816
+    @pythonGHissue51816()
     def test_smart_terminal(self):
         master, _out = _get_pty_pair()
         with forced_term("xterm"):
             formatter = formatters.get_formatter(master)
             assert isinstance(formatter, formatters.TerminfoFormatter)
 
-    @pythonGHissue51816
+    @pythonGHissue51816()
     def test_not_a_tty(self):
         with TemporaryFile() as stream:
             with forced_term("xterm"):
                 formatter = formatters.get_formatter(stream)
                 assert isinstance(formatter, formatters.PlainTextFormatter)
 
-    @pythonGHissue51816
+    @pythonGHissue51816()
     def test_no_fd(self):
         stream = BytesIO()
         with forced_term("xterm"):

diff --git a/tests/test_test.py b/tests/test_test.py
new file mode 100644
index 0000000..f325d66
--- /dev/null
+++ b/tests/test_test.py
@@ -0,0 +1,38 @@
+import os
+
+import pytest
+
+from snakeoil import test
+
+
+class Test_protect_process:
+    def test_success(self, capsys):
+        @test.protect_process()
+        def no_fail(self) -> None:
+            pass
+
+        assert None is no_fail(capsys)
+        captured = capsys.readouterr()
+        assert "" == captured.out, (
+            "no stdout should be captured for success: {captured.out}"
+        )
+        assert "" == captured.err, (
+            "no stderr should be captured for success: {captured.err}"
+        )
+
+    def test_failure(self, capsys):
+        unique_string = 
"asdfasdfasdfasdfdasdfdasdfasdfasdfasdfasdfasdfasdfsadf"
+
+        @test.protect_process(extra_env={unique_string: unique_string})
+        def fail(self, capsys) -> None:
+            raise AssertionError(unique_string)
+
+        if os.environ.get(unique_string):
+            # we're in the child.
+            fail(self, capsys)
+            raise Exception("implementation is broke, fail didn't throw an 
exception")
+
+        with pytest.raises(AssertionError) as failed:
+            fail(self, capsys)
+
+        assert unique_string in str(failed.value)

Reply via email to