This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 40093c58be38 [SPARK-51836][PYTHON][CONNECT][TESTS] Avoid per-test-function connect session setup 40093c58be38 is described below commit 40093c58be3823fd944bd59b46385eeddb5e27fb Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Apr 18 13:22:01 2025 +0800 [SPARK-51836][PYTHON][CONNECT][TESTS] Avoid per-test-function connect session setup ### What changes were proposed in this pull request? Avoid per-test-function session setup ### Why are the changes needed? To make test stable and fast these `python/pyspark/ml/tests/connect/test_connect_xxx` are known to be kind of flaky, I notice that they setup/teardown spark session for each test function, this seems unnecessary and costly. There are still some similar places, but such change cause test failure, will revisit them later. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #50632 from zhengruifeng/py_ml_test_session_setup. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../ml/tests/connect/test_connect_evaluation.py | 14 +++++------- .../ml/tests/connect/test_connect_feature.py | 14 +++++------- .../ml/tests/connect/test_connect_pipeline.py | 26 ++++++++++------------ .../ml/tests/connect/test_connect_summarizer.py | 16 +++++-------- .../ml/tests/connect/test_connect_tuning.py | 26 ++++++++++------------ 5 files changed, 40 insertions(+), 56 deletions(-) diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py b/python/pyspark/ml/tests/connect/test_connect_evaluation.py index 662fe8a2ffdf..73b9e0943bea 100644 --- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py @@ -18,21 +18,17 @@ import os import unittest -from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect +from pyspark.testing.connectutils import ReusedConnectTestCase if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_evaluation import EvaluationTestsMixin @unittest.skip("SPARK-50956: Flaky with RetriesExceeded") - class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() + class EvaluationTestsOnConnect(EvaluationTestsMixin, ReusedConnectTestCase): + @classmethod + def master(cls): + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py b/python/pyspark/ml/tests/connect/test_connect_feature.py index 879cbff6d0cc..04a8d3664b96 100644 --- a/python/pyspark/ml/tests/connect/test_connect_feature.py +++ b/python/pyspark/ml/tests/connect/test_connect_feature.py @@ -18,9 +18,9 @@ import os import unittest -from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import have_sklearn, sklearn_requirement_message +from pyspark.testing.connectutils import ReusedConnectTestCase if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_feature import FeatureTestsMixin @@ -29,14 +29,10 @@ if should_test_connect: not should_test_connect or not have_sklearn, connect_requirement_message or sklearn_requirement_message, ) - class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() + class FeatureTestsOnConnect(FeatureTestsMixin, ReusedConnectTestCase): + @classmethod + def master(cls): + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py b/python/pyspark/ml/tests/connect/test_connect_pipeline.py index f8576d0cb09d..2b408911fbd2 100644 --- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py +++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py @@ -15,14 +15,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # + import os import unittest from pyspark.util import is_remote_only -from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import have_torch, torch_requirement_message - +from pyspark.testing.connectutils import ReusedConnectTestCase if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_pipeline import PipelineTestsMixin @@ -33,18 +33,16 @@ if should_test_connect: or torch_requirement_message or "Requires PySpark core library in Spark Connect server", ) - class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = ( - SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ) - .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") - .getOrCreate() - ) - - def tearDown(self) -> None: - self.spark.stop() + class PipelineTestsOnConnect(PipelineTestsMixin, ReusedConnectTestCase): + @classmethod + def conf(cls): + config = super(PipelineTestsOnConnect, cls).conf() + config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") + return config + + @classmethod + def master(cls): + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py b/python/pyspark/ml/tests/connect/test_connect_summarizer.py index 9c737c96ee87..57911779d6bb 100644 --- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py +++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py @@ -15,24 +15,20 @@ # limitations under the License. # -import unittest import os +import unittest -from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message +from pyspark.testing.connectutils import ReusedConnectTestCase if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_summarizer import SummarizerTestsMixin @unittest.skipIf(not should_test_connect, connect_requirement_message) - class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ).getOrCreate() - - def tearDown(self) -> None: - self.spark.stop() + class SummarizerTestsOnConnect(SummarizerTestsMixin, ReusedConnectTestCase): + @classmethod + def master(cls): + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") if __name__ == "__main__": diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py b/python/pyspark/ml/tests/connect/test_connect_tuning.py index d737dd5767db..3b7f977b57ae 100644 --- a/python/pyspark/ml/tests/connect/test_connect_tuning.py +++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py @@ -16,13 +16,13 @@ # limitations under the License. # -import unittest import os +import unittest from pyspark.util import is_remote_only -from pyspark.sql import SparkSession from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import have_torch, torch_requirement_message +from pyspark.testing.connectutils import ReusedConnectTestCase if should_test_connect: from pyspark.ml.tests.connect.test_legacy_mode_tuning import CrossValidatorTestsMixin @@ -33,18 +33,16 @@ if should_test_connect: or torch_requirement_message or "Requires PySpark core library in Spark Connect server", ) - class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, unittest.TestCase): - def setUp(self) -> None: - self.spark = ( - SparkSession.builder.remote( - os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") - ) - .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") - .getOrCreate() - ) - - def tearDown(self) -> None: - self.spark.stop() + class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, ReusedConnectTestCase): + @classmethod + def conf(cls): + config = super(CrossValidatorTestsOnConnect, cls).conf() + config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true") + return config + + @classmethod + def master(cls): + return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]") if __name__ == "__main__": --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org