This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 54eb1a2f863b [MINOR][PYTHON][TESTS] Clean up a group of connect tests 54eb1a2f863b is described below commit 54eb1a2f863bd7d8706c5c9a568895adb026c78d Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Apr 30 07:13:49 2025 +0800 [MINOR][PYTHON][TESTS] Clean up a group of connect tests ### What changes were proposed in this pull request? Clean up a group of connect tests ### Why are the changes needed? to avoid expensive and unnecessary setup ### Does this PR introduce _any_ user-facing change? no, test-only ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? no Closes #50748 from zhengruifeng/test-connect-col. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../sql/tests/connect/test_connect_collection.py | 23 ++++++------- .../sql/tests/connect/test_connect_column.py | 38 +++++++--------------- .../sql/tests/connect/test_connect_creation.py | 7 ++-- .../connect/test_connect_dataframe_property.py | 21 ++++++++---- 4 files changed, 41 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/tests/connect/test_connect_collection.py b/python/pyspark/sql/tests/connect/test_connect_collection.py index 9fe7ce1baa4d..61932c38733b 100644 --- a/python/pyspark/sql/tests/connect/test_connect_collection.py +++ b/python/pyspark/sql/tests/connect/test_connect_collection.py @@ -16,18 +16,19 @@ # import unittest -from pyspark.testing.connectutils import should_test_connect -from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase +from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils if should_test_connect: from pyspark.sql import functions as SF from pyspark.sql.connect import functions as CF -class SparkConnectCollectionTests(SparkConnectSQLTestCase): +class SparkConnectCollectionTests(ReusedMixedTestCase, PandasOnSparkTestUtils): def test_collect(self): - cdf = self.connect.read.table(self.tbl_name) - sdf = self.spark.read.table(self.tbl_name) + query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)" + cdf = self.connect.sql(query) + sdf = self.spark.sql(query) data = cdf.limit(10).collect() self.assertEqual(len(data), 10) @@ -73,25 +74,25 @@ class SparkConnectCollectionTests(SparkConnectSQLTestCase): def test_head(self): # SPARK-41002: test `head` API in Python Client - df = self.connect.read.table(self.tbl_name) + df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)") self.assertIsNotNone(len(df.head())) self.assertIsNotNone(len(df.head(1))) self.assertIsNotNone(len(df.head(5))) - df2 = self.connect.read.table(self.tbl_name_empty) + df2 = self.connect.sql("SELECT '' AS x LIMIT 0") self.assertIsNone(df2.head()) def test_first(self): # SPARK-41002: test `first` API in Python Client - df = self.connect.read.table(self.tbl_name) + df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)") self.assertIsNotNone(len(df.first())) - df2 = self.connect.read.table(self.tbl_name_empty) + df2 = self.connect.sql("SELECT '' AS x LIMIT 0") self.assertIsNone(df2.first()) def test_take(self) -> None: # SPARK-41002: test `take` API in Python Client - df = self.connect.read.table(self.tbl_name) + df = self.connect.sql("SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)") self.assertEqual(5, len(df.take(5))) - df2 = self.connect.read.table(self.tbl_name_empty) + df2 = self.connect.sql("SELECT '' AS x LIMIT 0") self.assertEqual(0, len(df2.take(5))) def test_to_pandas(self): diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py b/python/pyspark/sql/tests/connect/test_connect_column.py index 60ddcb6f22a5..4873006fbbb9 100644 --- a/python/pyspark/sql/tests/connect/test_connect_column.py +++ b/python/pyspark/sql/tests/connect/test_connect_column.py @@ -40,9 +40,8 @@ from pyspark.sql.types import ( BooleanType, ) from pyspark.errors import PySparkTypeError, PySparkValueError -from pyspark.testing.connectutils import should_test_connect -from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase - +from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils if should_test_connect: import pandas as pd @@ -63,25 +62,7 @@ if should_test_connect: from pyspark.errors.exceptions.connect import SparkConnectException -class SparkConnectColumnTests(SparkConnectSQLTestCase): - def compare_by_show(self, df1, df2, n: int = 20, truncate: int = 20): - from pyspark.sql.classic.dataframe import DataFrame as SDF - from pyspark.sql.connect.dataframe import DataFrame as CDF - - assert isinstance(df1, (SDF, CDF)) - if isinstance(df1, SDF): - str1 = df1._jdf.showString(n, truncate, False) - else: - str1 = df1._show_string(n, truncate, False) - - assert isinstance(df2, (SDF, CDF)) - if isinstance(df2, SDF): - str2 = df2._jdf.showString(n, truncate, False) - else: - str2 = df2._show_string(n, truncate, False) - - self.assertEqual(str1, str2) - +class SparkConnectColumnTests(ReusedMixedTestCase, PandasOnSparkTestUtils): def test_column_operator(self): # SPARK-41351: Column needs to support != df = self.connect.range(10) @@ -89,8 +70,9 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase): def test_columns(self): # SPARK-41036: test `columns` API for python client. - df = self.connect.read.table(self.tbl_name) - df2 = self.spark.read.table(self.tbl_name) + query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)" + df = self.connect.sql(query) + df2 = self.spark.sql(query) self.assertEqual(["id", "name"], df.columns) self.assert_eq( @@ -372,7 +354,8 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase): def test_simple_binary_expressions(self): """Test complex expression""" - cdf = self.connect.read.table(self.tbl_name) + query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)" + cdf = self.connect.sql(query) pdf = ( cdf.select(cdf.id).where(cdf.id % CF.lit(30) == CF.lit(0)).sort(cdf.id.asc()).toPandas() ) @@ -534,8 +517,9 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase): def test_cast(self): # SPARK-41412: test basic Column.cast - df = self.connect.read.table(self.tbl_name) - df2 = self.spark.read.table(self.tbl_name) + query = "SELECT id, CAST(id AS STRING) AS name FROM RANGE(100)" + df = self.connect.sql(query) + df2 = self.spark.sql(query) self.assert_eq( df.select(df.id.cast("string")).toPandas(), df2.select(df2.id.cast("string")).toPandas() diff --git a/python/pyspark/sql/tests/connect/test_connect_creation.py b/python/pyspark/sql/tests/connect/test_connect_creation.py index 3d67c33a5834..26e3596fc67d 100644 --- a/python/pyspark/sql/tests/connect/test_connect_creation.py +++ b/python/pyspark/sql/tests/connect/test_connect_creation.py @@ -33,9 +33,8 @@ from pyspark.sql.types import ( Row, ) from pyspark.testing.objects import MyObject, PythonOnlyUDT - -from pyspark.testing.connectutils import should_test_connect -from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase +from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils if should_test_connect: import pandas as pd @@ -45,7 +44,7 @@ if should_test_connect: from pyspark.errors.exceptions.connect import ParseException -class SparkConnectCreationTests(SparkConnectSQLTestCase): +class SparkConnectCreationTests(ReusedMixedTestCase, PandasOnSparkTestUtils): def test_with_local_data(self): """SPARK-41114: Test creating a dataframe using local data""" pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) diff --git a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py index c4c10c963a48..76007137bc7a 100644 --- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py +++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py @@ -17,11 +17,19 @@ import unittest -from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DoubleType +from pyspark.sql.types import ( + StructType, + StructField, + StringType, + IntegerType, + LongType, + DoubleType, + Row, +) from pyspark.sql.utils import is_remote from pyspark.sql import functions as SF -from pyspark.sql.tests.connect.test_connect_basic import SparkConnectSQLTestCase -from pyspark.testing.connectutils import should_test_connect +from pyspark.testing.connectutils import should_test_connect, ReusedMixedTestCase +from pyspark.testing.pandasutils import PandasOnSparkTestUtils from pyspark.testing.sqlutils import ( have_pandas, have_pyarrow, @@ -40,7 +48,7 @@ if should_test_connect: from pyspark.sql.connect import functions as CF -class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase): +class SparkConnectDataFramePropertyTests(ReusedMixedTestCase, PandasOnSparkTestUtils): def test_cached_property_is_copied(self): schema = StructType( [ @@ -65,8 +73,9 @@ class SparkConnectDataFramePropertyTests(SparkConnectSQLTestCase): assert len(df.columns) == 4 def test_cached_schema_to(self): - cdf = self.connect.read.table(self.tbl_name) - sdf = self.spark.read.table(self.tbl_name) + rows = [Row(id=x, name=str(x)) for x in range(100)] + cdf = self.connect.createDataFrame(rows) + sdf = self.spark.createDataFrame(rows) schema = StructType( [ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org