This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 6fea291 [SPARK-31186][PYSPARK][SQL] toPandas should not fail on
duplicate column names
6fea291 is described below
commit 6fea291762af3e802cb4c237bdad51ebf5d7152c
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Fri Mar 27 12:10:30 2020 +0900
[SPARK-31186][PYSPARK][SQL] toPandas should not fail on duplicate column
names
### What changes were proposed in this pull request?
When `toPandas` API works on duplicate column names produced from operators
like join, we see the error like:
```
ValueError: The truth value of a Series is ambiguous. Use a.empty,
a.bool(), a.item(), a.any() or a.all().
```
This patch fixes the error in `toPandas` API.
### Why are the changes needed?
To make `toPandas` work on dataframe with duplicate column names.
### Does this PR introduce any user-facing change?
Yes. Previously calling `toPandas` API on a dataframe with duplicate column
names will fail. After this patch, it will produce correct result.
### How was this patch tested?
Unit test.
Closes #28025 from viirya/SPARK-31186.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: HyukjinKwon <[email protected]>
(cherry picked from commit 559d3e4051500d5c49e9a7f3ac33aac3de19c9c6)
Signed-off-by: HyukjinKwon <[email protected]>
---
python/pyspark/sql/pandas/conversion.py | 48 +++++++++++++++++++++++-------
python/pyspark/sql/tests/test_dataframe.py | 18 +++++++++++
2 files changed, 56 insertions(+), 10 deletions(-)
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index 8548cd2..47cf8bb 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -21,6 +21,7 @@ if sys.version >= '3':
xrange = range
else:
from itertools import izip as zip
+from collections import Counter
from pyspark import since
from pyspark.rdd import _load_from_socket
@@ -131,9 +132,16 @@ class PandasConversionMixin(object):
# Below is toPandas without Arrow optimization.
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
+ column_counter = Counter(self.columns)
+
+ dtype = [None] * len(self.schema)
+ for fieldIdx, field in enumerate(self.schema):
+ # For duplicate column name, we use `iloc` to access it.
+ if column_counter[field.name] > 1:
+ pandas_col = pdf.iloc[:, fieldIdx]
+ else:
+ pandas_col = pdf[field.name]
- dtype = {}
- for field in self.schema:
pandas_type =
PandasConversionMixin._to_corrected_pandas_type(field.dataType)
# SPARK-21766: if an integer field is nullable and has null
values, it can be
# inferred by pandas as float column. Once we convert the column
with NaN back
@@ -141,16 +149,36 @@ class PandasConversionMixin(object):
# float type, not the corrected type from the schema in this case.
if pandas_type is not None and \
not(isinstance(field.dataType, IntegralType) and
field.nullable and
- pdf[field.name].isnull().any()):
- dtype[field.name] = pandas_type
+ pandas_col.isnull().any()):
+ dtype[fieldIdx] = pandas_type
# Ensure we fall back to nullable numpy types, even when whole
column is null:
- if isinstance(field.dataType, IntegralType) and
pdf[field.name].isnull().any():
- dtype[field.name] = np.float64
- if isinstance(field.dataType, BooleanType) and
pdf[field.name].isnull().any():
- dtype[field.name] = np.object
+ if isinstance(field.dataType, IntegralType) and
pandas_col.isnull().any():
+ dtype[fieldIdx] = np.float64
+ if isinstance(field.dataType, BooleanType) and
pandas_col.isnull().any():
+ dtype[fieldIdx] = np.object
+
+ df = pd.DataFrame()
+ for index, t in enumerate(dtype):
+ column_name = self.schema[index].name
+
+ # For duplicate column name, we use `iloc` to access it.
+ if column_counter[column_name] > 1:
+ series = pdf.iloc[:, index]
+ else:
+ series = pdf[column_name]
+
+ if t is not None:
+ series = series.astype(t, copy=False)
+
+ # `insert` API makes copy of data, we only do it for Series of
duplicate column names.
+ # `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work
because `iloc` could
+ # return a view or a copy depending by context.
+ if column_counter[column_name] > 1:
+ df.insert(index, column_name, series, allow_duplicates=True)
+ else:
+ df[column_name] = series
- for f, t in dtype.items():
- pdf[f] = pdf[f].astype(t, copy=False)
+ pdf = df
if timezone is None:
return pdf
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index d738449..d9dcbc0 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -529,6 +529,24 @@ class DataFrameTests(ReusedSQLTestCase):
self.assertEquals(types[4], np.object) # datetime.date
self.assertEquals(types[5], 'datetime64[ns]')
+ @unittest.skipIf(not have_pandas, pandas_requirement_message)
+ def test_to_pandas_on_cross_join(self):
+ import numpy as np
+
+ sql = """
+ select t1.*, t2.* from (
+ select explode(sequence(1, 3)) v
+ ) t1 left join (
+ select explode(sequence(1, 3)) v
+ ) t2
+ """
+ with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+ df = self.spark.sql(sql)
+ pdf = df.toPandas()
+ types = pdf.dtypes
+ self.assertEquals(types.iloc[0], np.int32)
+ self.assertEquals(types.iloc[1], np.int32)
+
@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_to_pandas_required_pandas_not_found(self):
with QuietTest(self.sc):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]