This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4f60ebcf0f8 [SPARK-42187][CONNECT][TESTS] Avoid using 
RemoteSparkSession.builder.getOrCreate in tests
4f60ebcf0f8 is described below

commit 4f60ebcf0f83b767e17199a4ffe1edc24862fcfa
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Jan 26 13:33:07 2023 +0900

    [SPARK-42187][CONNECT][TESTS] Avoid using 
RemoteSparkSession.builder.getOrCreate in tests
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to use `pyspark.sql.SparkSession.getOrCreate` instead of 
`pyspark.sql.connect.Sparksession.builder.getOrCreate`.
    
    ### Why are the changes needed?
    
    Because `pyspark.sql.connect.Sparksession.builder.getOrCreate` is supposed 
to be internal, and it does not have the unified handling of the Spark sessions 
for both PySpark session and Spark Connect sessions.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, test-only.
    
    ### How was this patch tested?
    
    Unittests fixed.
    
    Closes #39743 from HyukjinKwon/cleanup-test.
    
    Lead-authored-by: Hyukjin Kwon <[email protected]>
    Co-authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../sql/tests/connect/test_connect_basic.py        | 41 +++++++++++----------
 .../sql/tests/connect/test_connect_function.py     | 42 +++++++++-------------
 python/pyspark/sql/utils.py                        |  8 ++---
 python/pyspark/testing/connectutils.py             | 20 ++++-------
 python/pyspark/testing/pandasutils.py              | 14 ++++----
 5 files changed, 58 insertions(+), 67 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 3f7494a6385..94aed9fcc30 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -17,12 +17,13 @@
 
 import array
 import datetime
+import os
 import unittest
 import shutil
 import tempfile
 
 from pyspark.testing.sqlutils import SQLTestUtils
-from pyspark.sql import SparkSession, Row
+from pyspark.sql import SparkSession as PySparkSession, Row
 from pyspark.sql.types import (
     StructType,
     StructField,
@@ -33,9 +34,11 @@ from pyspark.sql.types import (
     ArrayType,
     Row,
 )
-from pyspark.testing.utils import ReusedPySparkTestCase
-from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-from pyspark.testing.pandasutils import PandasOnSparkTestCase
+from pyspark.testing.connectutils import (
+    should_test_connect,
+    ReusedConnectTestCase,
+)
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 from pyspark.errors import (
     SparkConnectException,
     SparkConnectAnalysisException,
@@ -57,22 +60,25 @@ if should_test_connect:
     from pyspark.sql.connect import functions as CF
 
 
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectSQLTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, 
SQLTestUtils):
+class SparkConnectSQLTestCase(ReusedConnectTestCase, SQLTestUtils, 
PandasOnSparkTestUtils):
     """Parent test fixture class for all Spark Connect related
     test cases."""
 
     @classmethod
     def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
-        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
-        cls.hive_available = True
-        # Create the new Spark Session
-        cls.spark = SparkSession(cls.sc)
+        super(SparkConnectSQLTestCase, cls).setUpClass()
+        # Disable the shared namespace so pyspark.sql.functions, etc point the 
regular
+        # PySpark libraries.
+        os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+
+        cls.connect = cls.spark  # Switch Spark Connect session and regular 
PySpark sesion.
+        cls.spark = PySparkSession._instantiatedSession
+        assert cls.spark is not None
+
         cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
         cls.testDataStr = [Row(key=str(i)) for i in range(100)]
-        cls.df = cls.sc.parallelize(cls.testData).toDF()
-        cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF()
+        cls.df = cls.spark.sparkContext.parallelize(cls.testData).toDF()
+        cls.df_text = 
cls.spark.sparkContext.parallelize(cls.testDataStr).toDF()
 
         cls.tbl_name = "test_connect_basic_table_1"
         cls.tbl_name2 = "test_connect_basic_table_2"
@@ -88,12 +94,12 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, 
ReusedPySparkTestCase, SQLT
     @classmethod
     def tearDownClass(cls):
         cls.spark_connect_clean_up_test_data()
-        ReusedPySparkTestCase.tearDownClass()
+        cls.spark = cls.connect  # Stopping Spark Connect closes the session 
in JVM at the server.
+        super(SparkConnectSQLTestCase, cls).setUpClass()
+        del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
 
     @classmethod
     def spark_connect_load_test_data(cls):
-        # Setup Remote Spark Session
-        cls.connect = RemoteSparkSession.builder.remote().getOrCreate()
         df = cls.spark.createDataFrame([(x, f"{x}") for x in range(100)], 
["id", "name"])
         # Since we might create multiple Spark sessions, we need to create 
global temporary view
         # that is specifically maintained in the "global_temp" schema.
@@ -2596,8 +2602,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
                 getattr(df.write, f)()
 
 
[email protected](not should_test_connect, connect_requirement_message)
-class ChannelBuilderTests(ReusedPySparkTestCase):
+class ChannelBuilderTests(unittest.TestCase):
     def test_invalid_connection_strings(self):
         invalid = [
             "scc://host:12",
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py 
b/python/pyspark/sql/tests/connect/test_connect_function.py
index b74b1a9ee69..7042a7e8e6f 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -14,45 +14,37 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
+import os
 import unittest
-import tempfile
 
 from pyspark.errors import PySparkTypeError
-from pyspark.sql import SparkSession
+from pyspark.sql import SparkSession as PySparkSession
 from pyspark.sql.types import StringType, StructType, StructField, ArrayType, 
IntegerType
-from pyspark.testing.pandasutils import PandasOnSparkTestCase
-from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
-from pyspark.testing.utils import ReusedPySparkTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
+from pyspark.testing.connectutils import ReusedConnectTestCase
 from pyspark.testing.sqlutils import SQLTestUtils
 from pyspark.errors import SparkConnectAnalysisException, SparkConnectException
 
-if should_test_connect:
-    from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
 
-
[email protected](not should_test_connect, connect_requirement_message)
-class SparkConnectFuncTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, 
SQLTestUtils):
-    """Parent test fixture class for all Spark Connect related
-    test cases."""
+class SparkConnectFunctionTests(ReusedConnectTestCase, PandasOnSparkTestUtils, 
SQLTestUtils):
+    """These test cases exercise the interface to the proto plan
+    generation but do not call Spark."""
 
     @classmethod
     def setUpClass(cls):
-        ReusedPySparkTestCase.setUpClass()
-        cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
-        cls.hive_available = True
-        # Create the new Spark Session
-        cls.spark = SparkSession(cls.sc)
-        # Setup Remote Spark Session
-        cls.connect = RemoteSparkSession.builder.remote().getOrCreate()
+        super(SparkConnectFunctionTests, cls).setUpClass()
+        # Disable the shared namespace so pyspark.sql.functions, etc point the 
regular
+        # PySpark libraries.
+        os.environ["PYSPARK_NO_NAMESPACE_SHARE"] = "1"
+        cls.connect = cls.spark  # Switch Spark Connect session and regular 
PySpark sesion.
+        cls.spark = PySparkSession._instantiatedSession
+        assert cls.spark is not None
 
     @classmethod
     def tearDownClass(cls):
-        ReusedPySparkTestCase.tearDownClass()
-
-
-class SparkConnectFunctionTests(SparkConnectFuncTestCase):
-    """These test cases exercise the interface to the proto plan
-    generation but do not call Spark."""
+        cls.spark = cls.connect  # Stopping Spark Connect closes the session 
in JVM at the server.
+        super(SparkConnectFunctionTests, cls).setUpClass()
+        del os.environ["PYSPARK_NO_NAMESPACE_SHARE"]
 
     def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
         from pyspark.sql.dataframe import DataFrame as SDF
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index 4f99a23b82d..b9b045541a6 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -151,7 +151,7 @@ def try_remote_functions(f: FuncT) -> FuncT:
     @functools.wraps(f)
     def wrapped(*args: Any, **kwargs: Any) -> Any:
 
-        if is_remote():
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
             from pyspark.sql.connect import functions
 
             return getattr(functions, f.__name__)(*args, **kwargs)
@@ -167,7 +167,7 @@ def try_remote_window(f: FuncT) -> FuncT:
     @functools.wraps(f)
     def wrapped(*args: Any, **kwargs: Any) -> Any:
 
-        if is_remote():
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
             from pyspark.sql.connect.window import Window
 
             return getattr(Window, f.__name__)(*args, **kwargs)
@@ -183,7 +183,7 @@ def try_remote_windowspec(f: FuncT) -> FuncT:
     @functools.wraps(f)
     def wrapped(*args: Any, **kwargs: Any) -> Any:
 
-        if is_remote():
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
             from pyspark.sql.connect.window import WindowSpec
 
             return getattr(WindowSpec, f.__name__)(*args, **kwargs)
@@ -199,7 +199,7 @@ def try_remote_observation(f: FuncT) -> FuncT:
     @functools.wraps(f)
     def wrapped(*args: Any, **kwargs: Any) -> Any:
         # TODO(SPARK-41527): Add the support of Observation.
-        if is_remote():
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
             raise NotImplementedError()
         return f(*args, **kwargs)
 
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index 64934c763c3..210d525ade7 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -55,7 +55,6 @@ except ImportError as e:
     googleapis_common_protos_requirement_message = str(e)
 have_googleapis_common_protos = googleapis_common_protos_requirement_message 
is None
 
-connect_not_compiled_message = None
 if (
     have_pandas
     and have_pyarrow
@@ -63,19 +62,7 @@ if (
     and have_grpc_status
     and have_googleapis_common_protos
 ):
-    from pyspark.sql.connect import DataFrame
-    from pyspark.sql.connect.plan import Read, Range, SQL
-    from pyspark.testing.utils import search_jar
-    from pyspark.sql.connect.session import SparkSession
-
-    connect_jar = search_jar("connector/connect/server", 
"spark-connect-assembly-", "spark-connect")
-    existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
-    connect_url = "--remote sc://localhost"
-    jars_args = "--jars %s" % connect_jar
-    plugin_args = "--conf 
spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin"
-    os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join(
-        [connect_url, jars_args, plugin_args, existing_args]
-    )
+    connect_not_compiled_message = None
 else:
     connect_not_compiled_message = (
         "Skipping all Spark Connect Python tests as the optional Spark Connect 
project was "
@@ -94,6 +81,11 @@ connect_requirement_message = (
 )
 should_test_connect: str = typing.cast(str, connect_requirement_message is 
None)
 
+if should_test_connect:
+    from pyspark.sql.connect import DataFrame
+    from pyspark.sql.connect.plan import Read, Range, SQL
+    from pyspark.sql.connect.session import SparkSession
+
 
 class MockRemoteSession:
     def __init__(self):
diff --git a/python/pyspark/testing/pandasutils.py 
b/python/pyspark/testing/pandasutils.py
index 6a828f10026..202603ca5c0 100644
--- a/python/pyspark/testing/pandasutils.py
+++ b/python/pyspark/testing/pandasutils.py
@@ -54,12 +54,7 @@ except ImportError as e:
 have_plotly = plotly_requirement_message is None
 
 
-class PandasOnSparkTestCase(ReusedSQLTestCase):
-    @classmethod
-    def setUpClass(cls):
-        super(PandasOnSparkTestCase, cls).setUpClass()
-        cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
-
+class PandasOnSparkTestUtils:
     def convert_str_to_lambda(self, func):
         """
         This function coverts `func` str to lambda call
@@ -248,6 +243,13 @@ class PandasOnSparkTestCase(ReusedSQLTestCase):
             return obj
 
 
+class PandasOnSparkTestCase(ReusedSQLTestCase, PandasOnSparkTestUtils):
+    @classmethod
+    def setUpClass(cls):
+        super(PandasOnSparkTestCase, cls).setUpClass()
+        cls.spark.conf.set(SPARK_CONF_ARROW_ENABLED, True)
+
+
 class TestUtils:
     @contextmanager
     def temp_dir(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to