This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 54eb1a2f863b [MINOR][PYTHON][TESTS] Clean up a group of connect tests
54eb1a2f863b is described below

commit 54eb1a2f863bd7d8706c5c9a568895adb026c78d
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Apr 30 07:13:49 2025 +0800

    [MINOR][PYTHON][TESTS] Clean up a group of connect tests
    
    ### What changes were proposed in this pull request?
    Clean up a group of connect tests
    
    ### Why are the changes needed?
    to avoid expensive and unnecessary setup
    
    ### Does this PR introduce _any_ user-facing change?
    no, test-only
    
    ### How was this patch tested?
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #50748 from zhengruifeng/test-connect-col.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../sql/tests/connect/test_connect_collection.py   | 23 ++++++-------
 .../sql/tests/connect/test_connect_column.py       | 38 +++++++---------------
 .../sql/tests/connect/test_connect_creation.py     |  7 ++--
 .../connect/test_connect_dataframe_property.py     | 21 ++++++++----
 4 files changed, 41 insertions(+), 48 deletions(-)

diff --git a/python/pyspark/sql/tests/connect/test_connect_collection.py 
b/python/pyspark/sql/tests/connect/test_connect_collection.py
index 9fe7ce1baa4d..61932c38733b 100644
--- a/python/pyspark/sql/tests/connect/test_connect_collection.py
+++ b/python/pyspark/sql/tests/connect/test_connect_collection.py
@@ -16,18 +16,19 @@
 #
 
 import unittest
-from pyspark.testing.connectutils import should_test_connect
-from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
+from pyspark.testing.connectutils import should_test_connect, 
ReusedMixedTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 if should_test_connect:
     from pyspark.sql import functions as SF
     from pyspark.sql.connect import functions as CF
 
 
-class SparkConnectCollectionTests(SparkConnectSQLTestCase):
+class SparkConnectCollectionTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
     def test_collect(self):
-        cdf = self.connect.read.table(self.tbl_name)
-        sdf = self.spark.read.table(self.tbl_name)
+        query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
 
         data = cdf.limit(10).collect()
         self.assertEqual(len(data), 10)
@@ -73,25 +74,25 @@ class SparkConnectCollectionTests(SparkConnectSQLTestCase):
 
     def test_head(self):
         # SPARK-41002: test `head` API in Python Client
-        df = self.connect.read.table(self.tbl_name)
+        df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM 
RANGE(100)")
         self.assertIsNotNone(len(df.head()))
         self.assertIsNotNone(len(df.head(1)))
         self.assertIsNotNone(len(df.head(5)))
-        df2 = self.connect.read.table(self.tbl_name_empty)
+        df2 = self.connect.sql("SELECT '' AS x LIMIT 0")
         self.assertIsNone(df2.head())
 
     def test_first(self):
         # SPARK-41002: test `first` API in Python Client
-        df = self.connect.read.table(self.tbl_name)
+        df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM 
RANGE(100)")
         self.assertIsNotNone(len(df.first()))
-        df2 = self.connect.read.table(self.tbl_name_empty)
+        df2 = self.connect.sql("SELECT '' AS x LIMIT 0")
         self.assertIsNone(df2.first())
 
     def test_take(self) -> None:
         # SPARK-41002: test `take` API in Python Client
-        df = self.connect.read.table(self.tbl_name)
+        df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM 
RANGE(100)")
         self.assertEqual(5, len(df.take(5)))
-        df2 = self.connect.read.table(self.tbl_name_empty)
+        df2 = self.connect.sql("SELECT '' AS x LIMIT 0")
         self.assertEqual(0, len(df2.take(5)))
 
     def test_to_pandas(self):
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 60ddcb6f22a5..4873006fbbb9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -40,9 +40,8 @@ from pyspark.sql.types import (
     BooleanType,
 )
 from pyspark.errors import PySparkTypeError, PySparkValueError
-from pyspark.testing.connectutils import should_test_connect
-from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
-
+from pyspark.testing.connectutils import should_test_connect, 
ReusedMixedTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 if should_test_connect:
     import pandas as pd
@@ -63,25 +62,7 @@ if should_test_connect:
     from pyspark.errors.exceptions.connect import SparkConnectException
 
 
-class SparkConnectColumnTests(SparkConnectSQLTestCase):
-    def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20):
-        from pyspark.sql.classic.dataframe import DataFrame as SDF
-        from pyspark.sql.connect.dataframe import DataFrame as CDF
-
-        assert isinstance(df1, (SDF, CDF))
-        if isinstance(df1, SDF):
-            str1 = df1._jdf.showString(n, truncate, False)
-        else:
-            str1 = df1._show_string(n, truncate, False)
-
-        assert isinstance(df2, (SDF, CDF))
-        if isinstance(df2, SDF):
-            str2 = df2._jdf.showString(n, truncate, False)
-        else:
-            str2 = df2._show_string(n, truncate, False)
-
-        self.assertEqual(str1, str2)
-
+class SparkConnectColumnTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
     def test_column_operator(self):
         # SPARK-41351: Column needs to support !=
         df = self.connect.range(10)
@@ -89,8 +70,9 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
 
     def test_columns(self):
         # SPARK-41036: test `columns` API for python client.
-        df = self.connect.read.table(self.tbl_name)
-        df2 = self.spark.read.table(self.tbl_name)
+        query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
+        df = self.connect.sql(query)
+        df2 = self.spark.sql(query)
         self.assertEqual(["id", "name"], df.columns)
 
         self.assert_eq(
@@ -372,7 +354,8 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
 
     def test_simple_binary_expressions(self):
         """Test complex expression"""
-        cdf = self.connect.read.table(self.tbl_name)
+        query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
+        cdf = self.connect.sql(query)
         pdf = (
             cdf.select(cdf.id).where(cdf.id % CF.lit(30) == 
CF.lit(0)).sort(cdf.id.asc()).toPandas()
         )
@@ -534,8 +517,9 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
 
     def test_cast(self):
         # SPARK-41412: test basic Column.cast
-        df = self.connect.read.table(self.tbl_name)
-        df2 = self.spark.read.table(self.tbl_name)
+        query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)"
+        df = self.connect.sql(query)
+        df2 = self.spark.sql(query)
 
         self.assert_eq(
             df.select(df.id.cast("string")).toPandas(), 
df2.select(df2.id.cast("string")).toPandas()
diff --git a/python/pyspark/sql/tests/connect/test_connect_creation.py 
b/python/pyspark/sql/tests/connect/test_connect_creation.py
index 3d67c33a5834..26e3596fc67d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_creation.py
+++ b/python/pyspark/sql/tests/connect/test_connect_creation.py
@@ -33,9 +33,8 @@ from pyspark.sql.types import (
     Row,
 )
 from pyspark.testing.objects import MyObject, PythonOnlyUDT
-
-from pyspark.testing.connectutils import should_test_connect
-from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
+from pyspark.testing.connectutils import should_test_connect, 
ReusedMixedTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 
 if should_test_connect:
     import pandas as pd
@@ -45,7 +44,7 @@ if should_test_connect:
     from pyspark.errors.exceptions.connect import ParseException
 
 
-class SparkConnectCreationTests(SparkConnectSQLTestCase):
+class SparkConnectCreationTests(ReusedMixedTestCase, PandasOnSparkTestUtils):
     def test_with_local_data(self):
         """SPARK-41114: Test creating a dataframe using local data"""
         pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
diff --git 
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py 
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index c4c10c963a48..76007137bc7a 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -17,11 +17,19 @@
 
 import unittest
 
-from pyspark.sql.types import StructType, StructField, StringType, 
IntegerType, LongType, DoubleType
+from pyspark.sql.types import (
+    StructType,
+    StructField,
+    StringType,
+    IntegerType,
+    LongType,
+    DoubleType,
+    Row,
+)
 from pyspark.sql.utils import is_remote
 from pyspark.sql import functions as SF
-from pyspark.sql.tests.connect.test_connect_basic import 
SparkConnectSQLTestCase
-from pyspark.testing.connectutils import should_test_connect
+from pyspark.testing.connectutils import should_test_connect, 
ReusedMixedTestCase
+from pyspark.testing.pandasutils import PandasOnSparkTestUtils
 from pyspark.testing.sqlutils import (
     have_pandas,
     have_pyarrow,
@@ -40,7 +48,7 @@ if should_test_connect:
     from pyspark.sql.connect import functions as CF
 
 
-class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
+class SparkConnectDataFramePropertyTests(ReusedMixedTestCase, 
PandasOnSparkTestUtils):
     def test_cached_property_is_copied(self):
         schema = StructType(
             [
@@ -65,8 +73,9 @@ class 
SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase):
         assert len(df.columns) == 4
 
     def test_cached_schema_to(self):
-        cdf = self.connect.read.table(self.tbl_name)
-        sdf = self.spark.read.table(self.tbl_name)
+        rows = [Row(id=x, name=str(x)) for x in range(100)]
+        cdf = self.connect.createDataFrame(rows)
+        sdf = self.spark.createDataFrame(rows)
 
         schema = StructType(
             [


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to