This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b45f071cd1b6 [SPARK-54868][PYTHON][INFRA][FOLLOW-UP] Also enable 
`faulthandler` in classic tests
b45f071cd1b6 is described below

commit b45f071cd1b653731ae5b1c97c77d1f47560b177
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 31 13:20:41 2025 +0800

    [SPARK-54868][PYTHON][INFRA][FOLLOW-UP] Also enable `faulthandler` in 
classic tests
    
    ### What changes were proposed in this pull request?
    Also enable `faulthandler` in classic tests, by introducing a base class 
`PySparkBaseTestCase`
    
    ### Why are the changes needed?
    `faulthandler` was only enabled in `ReusedConnectTestCase` for spark 
connect,
    
    after this change `faulthandler` will be enabled in most classic tests 
(There are still some tests directly use `unittest.TestCase`, we can change 
them when they hit hanging issues)
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    ci and manually check with
    ```
    PYSPARK_TEST_TIMEOUT=10 python/run-tests -k --python-executables python3 
--testnames 'pyspark.tests.test_util'
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #53651 from zhengruifeng/super_class.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/testing/connectutils.py | 21 +++++++++------------
 python/pyspark/testing/sqlutils.py     |  9 ++++++---
 python/pyspark/testing/utils.py        | 24 +++++++++++++++++++++---
 3 files changed, 36 insertions(+), 18 deletions(-)

diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index 63dc350dd011..08c86561b57a 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -17,9 +17,6 @@
 import shutil
 import tempfile
 import os
-import sys
-import signal
-import faulthandler
 import functools
 import unittest
 import uuid
@@ -45,6 +42,7 @@ from pyspark.testing.utils import (
     should_test_connect,
     PySparkErrorTestUtils,
 )
+from pyspark.testing.utils import PySparkBaseTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
 from pyspark.sql.session import SparkSession as PySparkSession
 
@@ -75,7 +73,7 @@ class MockRemoteSession:
 
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
-class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
+class PlanOnlyTestFixture(PySparkBaseTestCase, PySparkErrorTestUtils):
     if should_test_connect:
 
         class MockDF(DataFrame):
@@ -152,7 +150,7 @@ class PlanOnlyTestFixture(unittest.TestCase, 
PySparkErrorTestUtils):
 
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
-class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils, 
PySparkErrorTestUtils):
+class ReusedConnectTestCase(PySparkBaseTestCase, SQLTestUtils, 
PySparkErrorTestUtils):
     """
     Spark Connect version of 
:class:`pyspark.testing.sqlutils.ReusedSQLTestCase`.
     """
@@ -180,8 +178,7 @@ class ReusedConnectTestCase(unittest.TestCase, 
SQLTestUtils, PySparkErrorTestUti
 
     @classmethod
     def setUpClass(cls):
-        if os.environ.get("PYSPARK_TEST_TIMEOUT"):
-            faulthandler.register(signal.SIGTERM, file=sys.__stderr__, 
all_threads=True)
+        super().setUpClass()
 
         # This environment variable is for interrupting hanging ML-handler and 
making the
         # tests fail fast.
@@ -203,11 +200,11 @@ class ReusedConnectTestCase(unittest.TestCase, 
SQLTestUtils, PySparkErrorTestUti
 
     @classmethod
     def tearDownClass(cls):
-        if os.environ.get("PYSPARK_TEST_TIMEOUT"):
-            faulthandler.unregister(signal.SIGTERM)
-
-        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
-        cls.spark.stop()
+        try:
+            shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+            cls.spark.stop()
+        finally:
+            super().tearDownClass()
 
     def setUp(self) -> None:
         # force to clean up the ML cache before each test
diff --git a/python/pyspark/testing/sqlutils.py 
b/python/pyspark/testing/sqlutils.py
index b63c98f96f4e..927e4f4250c3 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -209,6 +209,7 @@ class ReusedSQLTestCase(ReusedPySparkTestCase, 
SQLTestUtils, PySparkErrorTestUti
     @classmethod
     def setUpClass(cls):
         super().setUpClass()
+
         cls._legacy_sc = cls.sc
         cls.spark = SparkSession(cls.sc)
         cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
@@ -218,9 +219,11 @@ class ReusedSQLTestCase(ReusedPySparkTestCase, 
SQLTestUtils, PySparkErrorTestUti
 
     @classmethod
     def tearDownClass(cls):
-        super().tearDownClass()
-        cls.spark.stop()
-        shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+        try:
+            cls.spark.stop()
+            shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+        finally:
+            super().tearDownClass()
 
     def tearDown(self):
         try:
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 4ea3dbf30178..4286a55bb699 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -20,6 +20,7 @@ import struct
 import sys
 import unittest
 import difflib
+import faulthandler
 import functools
 from decimal import Decimal
 from time import time, sleep
@@ -268,7 +269,19 @@ class QuietTest:
         self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
 
 
-class PySparkTestCase(unittest.TestCase):
+class PySparkBaseTestCase(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        if os.environ.get("PYSPARK_TEST_TIMEOUT"):
+            faulthandler.register(signal.SIGTERM, file=sys.__stderr__, 
all_threads=True)
+
+    @classmethod
+    def tearDownClass(cls):
+        if os.environ.get("PYSPARK_TEST_TIMEOUT"):
+            faulthandler.unregister(signal.SIGTERM)
+
+
+class PySparkTestCase(PySparkBaseTestCase):
     def setUp(self):
         from pyspark import SparkContext
 
@@ -281,7 +294,7 @@ class PySparkTestCase(unittest.TestCase):
         sys.path = self._old_sys_path
 
 
-class ReusedPySparkTestCase(unittest.TestCase):
+class ReusedPySparkTestCase(PySparkBaseTestCase):
     @classmethod
     def conf(cls):
         """
@@ -291,6 +304,8 @@ class ReusedPySparkTestCase(unittest.TestCase):
 
     @classmethod
     def setUpClass(cls):
+        super().setUpClass()
+
         from pyspark import SparkContext
 
         cls.sc = SparkContext(cls.master(), cls.__name__, conf=cls.conf())
@@ -301,7 +316,10 @@ class ReusedPySparkTestCase(unittest.TestCase):
 
     @classmethod
     def tearDownClass(cls):
-        cls.sc.stop()
+        try:
+            cls.sc.stop()
+        finally:
+            super().tearDownClass()
 
     def test_assert_classic_mode(self):
         from pyspark.sql import is_remote


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to