This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b45f071cd1b6 [SPARK-54868][PYTHON][INFRA][FOLLOW-UP] Also enable
`faulthandler` in classic tests
b45f071cd1b6 is described below
commit b45f071cd1b653731ae5b1c97c77d1f47560b177
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 31 13:20:41 2025 +0800
[SPARK-54868][PYTHON][INFRA][FOLLOW-UP] Also enable `faulthandler` in
classic tests
### What changes were proposed in this pull request?
Also enable `faulthandler` in classic tests, by introducing a base class
`PySparkBaseTestCase`
### Why are the changes needed?
`faulthandler` was only enabled in `ReusedConnectTestCase` for spark
connect,
after this change `faulthandler` will be enabled in most classic tests
(There are still some tests directly use `unittest.TestCase`, we can change
them when they hit hanging issues)
### Does this PR introduce _any_ user-facing change?
no, test-only
### How was this patch tested?
ci and manually check with
```
PYSPARK_TEST_TIMEOUT=10 python/run-tests -k --python-executables python3
--testnames 'pyspark.tests.test_util'
```
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #53651 from zhengruifeng/super_class.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/testing/connectutils.py | 21 +++++++++------------
python/pyspark/testing/sqlutils.py | 9 ++++++---
python/pyspark/testing/utils.py | 24 +++++++++++++++++++++---
3 files changed, 36 insertions(+), 18 deletions(-)
diff --git a/python/pyspark/testing/connectutils.py
b/python/pyspark/testing/connectutils.py
index 63dc350dd011..08c86561b57a 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -17,9 +17,6 @@
import shutil
import tempfile
import os
-import sys
-import signal
-import faulthandler
import functools
import unittest
import uuid
@@ -45,6 +42,7 @@ from pyspark.testing.utils import (
should_test_connect,
PySparkErrorTestUtils,
)
+from pyspark.testing.utils import PySparkBaseTestCase
from pyspark.testing.sqlutils import SQLTestUtils
from pyspark.sql.session import SparkSession as PySparkSession
@@ -75,7 +73,7 @@ class MockRemoteSession:
@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
+class PlanOnlyTestFixture(PySparkBaseTestCase, PySparkErrorTestUtils):
if should_test_connect:
class MockDF(DataFrame):
@@ -152,7 +150,7 @@ class PlanOnlyTestFixture(unittest.TestCase,
PySparkErrorTestUtils):
@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class ReusedConnectTestCase(unittest.TestCase, SQLTestUtils,
PySparkErrorTestUtils):
+class ReusedConnectTestCase(PySparkBaseTestCase, SQLTestUtils,
PySparkErrorTestUtils):
"""
Spark Connect version of
:class:`pyspark.testing.sqlutils.ReusedSQLTestCase`.
"""
@@ -180,8 +178,7 @@ class ReusedConnectTestCase(unittest.TestCase,
SQLTestUtils, PySparkErrorTestUti
@classmethod
def setUpClass(cls):
- if os.environ.get("PYSPARK_TEST_TIMEOUT"):
- faulthandler.register(signal.SIGTERM, file=sys.__stderr__,
all_threads=True)
+ super().setUpClass()
# This environment variable is for interrupting hanging ML-handler and
making the
# tests fail fast.
@@ -203,11 +200,11 @@ class ReusedConnectTestCase(unittest.TestCase,
SQLTestUtils, PySparkErrorTestUti
@classmethod
def tearDownClass(cls):
- if os.environ.get("PYSPARK_TEST_TIMEOUT"):
- faulthandler.unregister(signal.SIGTERM)
-
- shutil.rmtree(cls.tempdir.name, ignore_errors=True)
- cls.spark.stop()
+ try:
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ cls.spark.stop()
+ finally:
+ super().tearDownClass()
def setUp(self) -> None:
# force to clean up the ML cache before each test
diff --git a/python/pyspark/testing/sqlutils.py
b/python/pyspark/testing/sqlutils.py
index b63c98f96f4e..927e4f4250c3 100644
--- a/python/pyspark/testing/sqlutils.py
+++ b/python/pyspark/testing/sqlutils.py
@@ -209,6 +209,7 @@ class ReusedSQLTestCase(ReusedPySparkTestCase,
SQLTestUtils, PySparkErrorTestUti
@classmethod
def setUpClass(cls):
super().setUpClass()
+
cls._legacy_sc = cls.sc
cls.spark = SparkSession(cls.sc)
cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
@@ -218,9 +219,11 @@ class ReusedSQLTestCase(ReusedPySparkTestCase,
SQLTestUtils, PySparkErrorTestUti
@classmethod
def tearDownClass(cls):
- super().tearDownClass()
- cls.spark.stop()
- shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ try:
+ cls.spark.stop()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ finally:
+ super().tearDownClass()
def tearDown(self):
try:
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 4ea3dbf30178..4286a55bb699 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -20,6 +20,7 @@ import struct
import sys
import unittest
import difflib
+import faulthandler
import functools
from decimal import Decimal
from time import time, sleep
@@ -268,7 +269,19 @@ class QuietTest:
self.log4j.LogManager.getRootLogger().setLevel(self.old_level)
-class PySparkTestCase(unittest.TestCase):
+class PySparkBaseTestCase(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ if os.environ.get("PYSPARK_TEST_TIMEOUT"):
+ faulthandler.register(signal.SIGTERM, file=sys.__stderr__,
all_threads=True)
+
+ @classmethod
+ def tearDownClass(cls):
+ if os.environ.get("PYSPARK_TEST_TIMEOUT"):
+ faulthandler.unregister(signal.SIGTERM)
+
+
+class PySparkTestCase(PySparkBaseTestCase):
def setUp(self):
from pyspark import SparkContext
@@ -281,7 +294,7 @@ class PySparkTestCase(unittest.TestCase):
sys.path = self._old_sys_path
-class ReusedPySparkTestCase(unittest.TestCase):
+class ReusedPySparkTestCase(PySparkBaseTestCase):
@classmethod
def conf(cls):
"""
@@ -291,6 +304,8 @@ class ReusedPySparkTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
+ super().setUpClass()
+
from pyspark import SparkContext
cls.sc = SparkContext(cls.master(), cls.__name__, conf=cls.conf())
@@ -301,7 +316,10 @@ class ReusedPySparkTestCase(unittest.TestCase):
@classmethod
def tearDownClass(cls):
- cls.sc.stop()
+ try:
+ cls.sc.stop()
+ finally:
+ super().tearDownClass()
def test_assert_classic_mode(self):
from pyspark.sql import is_remote
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]