This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 74c043d697b [SPARK-41884][CONNECT] Support naive tuple as a nested row
74c043d697b is described below

commit 74c043d697bede8e22bd14de5ed29684614188fc
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Jan 20 15:27:10 2023 +0900

    [SPARK-41884][CONNECT] Support naive tuple as a nested row
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to support a `tuple` as a nested row to match with 
PySpark's.
    
    Meaning that:
    
    ```python
    spark.createDataFrame(
        [[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5, 6.0)]]],
        "array_struct_col Array<struct<col1:string, col2:long, col3:double>>"
    )
    ```
    
    ### Why are the changes needed?
    
    For feature parity in Spark Connect.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No to end users because Spark Connect has not been released yet.
    
    ### How was this patch tested?
    
    Unittest enabled back
    
    Closes #39661 from HyukjinKwon/SPARK-41884.
    
    Authored-by: Hyukjin Kwon <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/conversion.py           |  2 +-
 .../sql/tests/connect/test_parity_dataframe.py     |  5 ++---
 python/pyspark/sql/tests/test_dataframe.py         | 22 +++++++++++++---------
 3 files changed, 16 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/connect/conversion.py 
b/python/pyspark/sql/connect/conversion.py
index cfcd94d4c6c..56712ae18f2 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -92,7 +92,7 @@ class LocalDataToArrowConversion:
                 if value is None:
                     return None
                 else:
-                    assert isinstance(value, (Row, dict)), f"{type(value)} 
{value}"
+                    assert isinstance(value, (tuple, dict)), f"{type(value)} 
{value}"
 
                     _dict = {}
                     if isinstance(value, dict):
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py 
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index ee785ebd534..cebe501938f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -129,10 +129,9 @@ class DataFrameParityTests(DataFrameTestsMixin, 
ReusedConnectTestCase):
     def test_to_pandas(self):
         super().test_to_pandas()
 
-    # TODO(SPARK-41884): DataFrame `toPandas` parity in return types
-    @unittest.skip("Fails in Spark Connect, should enable.")
     def test_to_pandas_for_array_of_struct(self):
-        super().test_to_pandas_for_array_of_struct()
+        # Spark Connect's implementation is based on Arrow.
+        super().check_to_pandas_for_array_of_struct(True)
 
     # TODO(SPARK-41834): Implement SparkSession.conf
     @unittest.skip("Fails in Spark Connect, should enable.")
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index 4101158105a..7ccb78856d2 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1210,6 +1210,11 @@ class DataFrameTestsMixin:
         or "Pyarrow version must be 2.0.0 or higher",
     )
     def test_to_pandas_for_array_of_struct(self):
+        for is_arrow_enabled in [True, False]:
+            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
is_arrow_enabled}):
+                self.check_to_pandas_for_array_of_struct(is_arrow_enabled)
+
+    def check_to_pandas_for_array_of_struct(self, is_arrow_enabled):
         # SPARK-38098: Support Array of Struct for Pandas UDFs and toPandas
         import numpy as np
         import pandas as pd
@@ -1218,15 +1223,14 @@ class DataFrameTestsMixin:
             [[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5, 
6.0)]]],
             "array_struct_col Array<struct<col1:string, col2:long, 
col3:double>>",
         )
-        for is_arrow_enabled in [True, False]:
-            with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
is_arrow_enabled}):
-                pdf = df.toPandas()
-                self.assertEqual(type(pdf), pd.DataFrame)
-                self.assertEqual(type(pdf["array_struct_col"]), pd.Series)
-                if is_arrow_enabled:
-                    self.assertEqual(type(pdf["array_struct_col"][0]), 
np.ndarray)
-                else:
-                    self.assertEqual(type(pdf["array_struct_col"][0]), list)
+
+        pdf = df.toPandas()
+        self.assertEqual(type(pdf), pd.DataFrame)
+        self.assertEqual(type(pdf["array_struct_col"]), pd.Series)
+        if is_arrow_enabled:
+            self.assertEqual(type(pdf["array_struct_col"][0]), np.ndarray)
+        else:
+            self.assertEqual(type(pdf["array_struct_col"][0]), list)
 
     def test_create_dataframe_from_array_of_long(self):
         import array


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to