This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 74c043d697b [SPARK-41884][CONNECT] Support naive tuple as a nested row
74c043d697b is described below
commit 74c043d697bede8e22bd14de5ed29684614188fc
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Jan 20 15:27:10 2023 +0900
[SPARK-41884][CONNECT] Support naive tuple as a nested row
### What changes were proposed in this pull request?
This PR proposes to support a `tuple` as a nested row to match with
PySpark's.
Meaning that:
```python
spark.createDataFrame(
[[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5, 6.0)]]],
"array_struct_col Array<struct<col1:string, col2:long, col3:double>>"
)
```
### Why are the changes needed?
For feature parity in Spark Connect.
### Does this PR introduce _any_ user-facing change?
No to end users because Spark Connect has not been released yet.
### How was this patch tested?
Unittest enabled back
Closes #39661 from HyukjinKwon/SPARK-41884.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/conversion.py | 2 +-
.../sql/tests/connect/test_parity_dataframe.py | 5 ++---
python/pyspark/sql/tests/test_dataframe.py | 22 +++++++++++++---------
3 files changed, 16 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/sql/connect/conversion.py
b/python/pyspark/sql/connect/conversion.py
index cfcd94d4c6c..56712ae18f2 100644
--- a/python/pyspark/sql/connect/conversion.py
+++ b/python/pyspark/sql/connect/conversion.py
@@ -92,7 +92,7 @@ class LocalDataToArrowConversion:
if value is None:
return None
else:
- assert isinstance(value, (Row, dict)), f"{type(value)}
{value}"
+ assert isinstance(value, (tuple, dict)), f"{type(value)}
{value}"
_dict = {}
if isinstance(value, dict):
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index ee785ebd534..cebe501938f 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -129,10 +129,9 @@ class DataFrameParityTests(DataFrameTestsMixin,
ReusedConnectTestCase):
def test_to_pandas(self):
super().test_to_pandas()
- # TODO(SPARK-41884): DataFrame `toPandas` parity in return types
- @unittest.skip("Fails in Spark Connect, should enable.")
def test_to_pandas_for_array_of_struct(self):
- super().test_to_pandas_for_array_of_struct()
+ # Spark Connect's implementation is based on Arrow.
+ super().check_to_pandas_for_array_of_struct(True)
# TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 4101158105a..7ccb78856d2 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1210,6 +1210,11 @@ class DataFrameTestsMixin:
or "Pyarrow version must be 2.0.0 or higher",
)
def test_to_pandas_for_array_of_struct(self):
+ for is_arrow_enabled in [True, False]:
+ with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled":
is_arrow_enabled}):
+ self.check_to_pandas_for_array_of_struct(is_arrow_enabled)
+
+ def check_to_pandas_for_array_of_struct(self, is_arrow_enabled):
# SPARK-38098: Support Array of Struct for Pandas UDFs and toPandas
import numpy as np
import pandas as pd
@@ -1218,15 +1223,14 @@ class DataFrameTestsMixin:
[[[("a", 2, 3.0), ("a", 2, 3.0)]], [[("b", 5, 6.0), ("b", 5,
6.0)]]],
"array_struct_col Array<struct<col1:string, col2:long,
col3:double>>",
)
- for is_arrow_enabled in [True, False]:
- with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled":
is_arrow_enabled}):
- pdf = df.toPandas()
- self.assertEqual(type(pdf), pd.DataFrame)
- self.assertEqual(type(pdf["array_struct_col"]), pd.Series)
- if is_arrow_enabled:
- self.assertEqual(type(pdf["array_struct_col"][0]),
np.ndarray)
- else:
- self.assertEqual(type(pdf["array_struct_col"][0]), list)
+
+ pdf = df.toPandas()
+ self.assertEqual(type(pdf), pd.DataFrame)
+ self.assertEqual(type(pdf["array_struct_col"]), pd.Series)
+ if is_arrow_enabled:
+ self.assertEqual(type(pdf["array_struct_col"][0]), np.ndarray)
+ else:
+ self.assertEqual(type(pdf["array_struct_col"][0]), list)
def test_create_dataframe_from_array_of_long(self):
import array
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]