This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 40093c58be38 [SPARK-51836][PYTHON][CONNECT][TESTS] Avoid 
per-test-function connect session setup
40093c58be38 is described below

commit 40093c58be3823fd944bd59b46385eeddb5e27fb
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Apr 18 13:22:01 2025 +0800

    [SPARK-51836][PYTHON][CONNECT][TESTS] Avoid per-test-function connect 
session setup
    
    ### What changes were proposed in this pull request?
    Avoid per-test-function session setup
    
    ### Why are the changes needed?
    To make test stable and fast
    these `python/pyspark/ml/tests/connect/test_connect_xxx` are known to be 
kind of flaky, I notice that they setup/teardown spark session for each test 
function, this seems unnecessary and costly.
    
    There are still some similar places, but such change cause test failure, 
will revisit them later.
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50632 from zhengruifeng/py_ml_test_session_setup.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../ml/tests/connect/test_connect_evaluation.py    | 14 +++++-------
 .../ml/tests/connect/test_connect_feature.py       | 14 +++++-------
 .../ml/tests/connect/test_connect_pipeline.py      | 26 ++++++++++------------
 .../ml/tests/connect/test_connect_summarizer.py    | 16 +++++--------
 .../ml/tests/connect/test_connect_tuning.py        | 26 ++++++++++------------
 5 files changed, 40 insertions(+), 56 deletions(-)

diff --git a/python/pyspark/ml/tests/connect/test_connect_evaluation.py 
b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
index 662fe8a2ffdf..73b9e0943bea 100644
--- a/python/pyspark/ml/tests/connect/test_connect_evaluation.py
+++ b/python/pyspark/ml/tests/connect/test_connect_evaluation.py
@@ -18,21 +18,17 @@
 import os
 import unittest
 
-from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_evaluation import 
EvaluationTestsMixin
 
     @unittest.skip("SPARK-50956: Flaky with RetriesExceeded")
-    class EvaluationTestsOnConnect(EvaluationTestsMixin, unittest.TestCase):
-        def setUp(self) -> None:
-            self.spark = SparkSession.builder.remote(
-                os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-            ).getOrCreate()
-
-        def tearDown(self) -> None:
-            self.spark.stop()
+    class EvaluationTestsOnConnect(EvaluationTestsMixin, 
ReusedConnectTestCase):
+        @classmethod
+        def master(cls):
+            return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_feature.py 
b/python/pyspark/ml/tests/connect/test_connect_feature.py
index 879cbff6d0cc..04a8d3664b96 100644
--- a/python/pyspark/ml/tests/connect/test_connect_feature.py
+++ b/python/pyspark/ml/tests/connect/test_connect_feature.py
@@ -18,9 +18,9 @@
 import os
 import unittest
 
-from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import have_sklearn, sklearn_requirement_message
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_feature import 
FeatureTestsMixin
@@ -29,14 +29,10 @@ if should_test_connect:
         not should_test_connect or not have_sklearn,
         connect_requirement_message or sklearn_requirement_message,
     )
-    class FeatureTestsOnConnect(FeatureTestsMixin, unittest.TestCase):
-        def setUp(self) -> None:
-            self.spark = SparkSession.builder.remote(
-                os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-            ).getOrCreate()
-
-        def tearDown(self) -> None:
-            self.spark.stop()
+    class FeatureTestsOnConnect(FeatureTestsMixin, ReusedConnectTestCase):
+        @classmethod
+        def master(cls):
+            return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py 
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index f8576d0cb09d..2b408911fbd2 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -15,14 +15,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+
 import os
 import unittest
 
 from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import have_torch, torch_requirement_message
-
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_pipeline import 
PipelineTestsMixin
@@ -33,18 +33,16 @@ if should_test_connect:
         or torch_requirement_message
         or "Requires PySpark core library in Spark Connect server",
     )
-    class PipelineTestsOnConnect(PipelineTestsMixin, unittest.TestCase):
-        def setUp(self) -> None:
-            self.spark = (
-                SparkSession.builder.remote(
-                    os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-                )
-                .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", 
"true")
-                .getOrCreate()
-            )
-
-        def tearDown(self) -> None:
-            self.spark.stop()
+    class PipelineTestsOnConnect(PipelineTestsMixin, ReusedConnectTestCase):
+        @classmethod
+        def conf(cls):
+            config = super(PipelineTestsOnConnect, cls).conf()
+            config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", 
"true")
+            return config
+
+        @classmethod
+        def master(cls):
+            return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_summarizer.py 
b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
index 9c737c96ee87..57911779d6bb 100644
--- a/python/pyspark/ml/tests/connect/test_connect_summarizer.py
+++ b/python/pyspark/ml/tests/connect/test_connect_summarizer.py
@@ -15,24 +15,20 @@
 # limitations under the License.
 #
 
-import unittest
 import os
+import unittest
 
-from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_summarizer import 
SummarizerTestsMixin
 
     @unittest.skipIf(not should_test_connect, connect_requirement_message)
-    class SummarizerTestsOnConnect(SummarizerTestsMixin, unittest.TestCase):
-        def setUp(self) -> None:
-            self.spark = SparkSession.builder.remote(
-                os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-            ).getOrCreate()
-
-        def tearDown(self) -> None:
-            self.spark.stop()
+    class SummarizerTestsOnConnect(SummarizerTestsMixin, 
ReusedConnectTestCase):
+        @classmethod
+        def master(cls):
+            return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py 
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index d737dd5767db..3b7f977b57ae 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -16,13 +16,13 @@
 # limitations under the License.
 #
 
-import unittest
 import os
+import unittest
 
 from pyspark.util import is_remote_only
-from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.utils import have_torch, torch_requirement_message
+from pyspark.testing.connectutils import ReusedConnectTestCase
 
 if should_test_connect:
     from pyspark.ml.tests.connect.test_legacy_mode_tuning import 
CrossValidatorTestsMixin
@@ -33,18 +33,16 @@ if should_test_connect:
         or torch_requirement_message
         or "Requires PySpark core library in Spark Connect server",
     )
-    class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, 
unittest.TestCase):
-        def setUp(self) -> None:
-            self.spark = (
-                SparkSession.builder.remote(
-                    os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
-                )
-                .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", 
"true")
-                .getOrCreate()
-            )
-
-        def tearDown(self) -> None:
-            self.spark.stop()
+    class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin, 
ReusedConnectTestCase):
+        @classmethod
+        def conf(cls):
+            config = super(CrossValidatorTestsOnConnect, cls).conf()
+            config.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", 
"true")
+            return config
+
+        @classmethod
+        def master(cls):
+            return os.environ.get("SPARK_CONNECT_TESTING_REMOTE", "local[2]")
 
 
 if __name__ == "__main__":


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to