This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new a86324cb52c Revert "[SPARK-40770][PYTHON] Improved error messages for
applyInPandas for schema mismatch"
a86324cb52c is described below
commit a86324cb52ce341339389e1f4079297cd2ec9d76
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Thu Feb 9 09:53:56 2023 +0900
Revert "[SPARK-40770][PYTHON] Improved error messages for applyInPandas for
schema mismatch"
This reverts commit c4c28cfbfe7b3f58f08c93d2c1cd421c302b0cd3.
---
python/pyspark/sql/pandas/serializers.py | 29 +-
.../sql/tests/pandas/test_pandas_cogrouped_map.py | 317 ++++++++-------------
.../sql/tests/pandas/test_pandas_grouped_map.py | 183 ++++--------
.../pandas/test_pandas_grouped_map_with_state.py | 2 +-
python/pyspark/sql/tests/test_arrow.py | 17 +-
python/pyspark/worker.py | 108 +++----
6 files changed, 232 insertions(+), 424 deletions(-)
diff --git a/python/pyspark/sql/pandas/serializers.py
b/python/pyspark/sql/pandas/serializers.py
index 30c2d102456..ca249c75ea5 100644
--- a/python/pyspark/sql/pandas/serializers.py
+++ b/python/pyspark/sql/pandas/serializers.py
@@ -231,25 +231,18 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
s = s.astype(s.dtypes.categories.dtype)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t,
safe=self._safecheck)
- except TypeError as e:
- error_msg = (
- "Exception thrown when converting pandas.Series (%s) "
- "with name '%s' to Arrow Array (%s)."
- )
- raise TypeError(error_msg % (s.dtype, s.name, t)) from e
except ValueError as e:
- error_msg = (
- "Exception thrown when converting pandas.Series (%s) "
- "with name '%s' to Arrow Array (%s)."
- )
if self._safecheck:
- error_msg = error_msg + (
- " It can be caused by overflows or other "
- "unsafe conversions warned by Arrow. Arrow safe type
check "
- "can be disabled by using SQL config "
-
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
+ error_msg = (
+ "Exception thrown when converting pandas.Series (%s)
to "
+ + "Arrow Array (%s). It can be caused by overflows or
other "
+ + "unsafe conversions warned by Arrow. Arrow safe type
check "
+ + "can be disabled by using SQL config "
+ +
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
- raise ValueError(error_msg % (s.dtype, s.name, t)) from e
+ raise ValueError(error_msg % (s.dtype, t)) from e
+ else:
+ raise e
return array
arrs = []
@@ -272,9 +265,7 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
# Assign result columns by position
else:
arrs_names = [
- # the selected series has name '1', so we rename it to
field.name
- # as the name is used by create_array to provide a
meaningful error message
- (create_array(s[s.columns[i]].rename(field.name),
field.type), field.name)
+ (create_array(s[s.columns[i]], field.type), field.name)
for i, field in enumerate(t)
]
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index 47ed12d2f46..5cbc9e1caa4 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -43,7 +43,7 @@ if have_pyarrow:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
-class CogroupedApplyInPandasTests(ReusedSQLTestCase):
+class CogroupedMapInPandasTests(ReusedSQLTestCase):
@property
def data1(self):
return (
@@ -79,9 +79,7 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
def test_different_schemas(self):
right = self.data2.withColumn("v3", lit("a"))
- self._test_merge(
- self.data1, right, output_schema="id long, k int, v int, v2 int,
v3 string"
- )
+ self._test_merge(self.data1, right, "id long, k int, v int, v2 int, v3
string")
def test_different_keys(self):
left = self.data1
@@ -130,7 +128,26 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
assert_frame_equal(expected, result)
def test_empty_group_by(self):
- self._test_merge(self.data1, self.data2, by=[])
+ left = self.data1
+ right = self.data2
+
+ def merge_pandas(lft, rgt):
+ return pd.merge(lft, rgt, on=["id", "k"])
+
+ result = (
+ left.groupby()
+ .cogroup(right.groupby())
+ .applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
+ .sort(["id", "k"])
+ .toPandas()
+ )
+
+ left = left.toPandas()
+ right = right.toPandas()
+
+ expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id",
"k"])
+
+ assert_frame_equal(expected, result)
def test_different_group_key_cardinality(self):
left = self.data1
@@ -149,35 +166,29 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
)
def test_apply_in_pandas_not_returning_pandas_dataframe(self):
- self._test_merge_error(
- fn=lambda lft, rgt: lft.size + rgt.size,
- error_class=PythonException,
- error_message_regex="Return type of the user-defined function "
- "should be pandas.DataFrame, but is <class 'numpy.int64'>",
- )
-
- def test_apply_in_pandas_returning_column_names(self):
- self._test_merge(fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id",
"k"]))
+ left = self.data1
+ right = self.data2
- def test_apply_in_pandas_returning_no_column_names(self):
def merge_pandas(lft, rgt):
- res = pd.merge(lft, rgt, on=["id", "k"])
- res.columns = range(res.columns.size)
- return res
-
- self._test_merge(fn=merge_pandas)
+ return lft.size + rgt.size
- def test_apply_in_pandas_returning_column_names_sometimes(self):
- def merge_pandas(lft, rgt):
- res = pd.merge(lft, rgt, on=["id", "k"])
- if 0 in lft["id"] and lft["id"][0] % 2 == 0:
- return res
- res.columns = range(res.columns.size)
- return res
+ with QuietTest(self.sc):
+ with self.assertRaisesRegex(
+ PythonException,
+ "Return type of the user-defined function should be
pandas.DataFrame, "
+ "but is <class 'numpy.int64'>",
+ ):
+ (
+ left.groupby("id")
+ .cogroup(right.groupby("id"))
+ .applyInPandas(merge_pandas, "id long, k int, v int, v2
int")
+ .collect()
+ )
- self._test_merge(fn=merge_pandas)
+ def test_apply_in_pandas_returning_wrong_number_of_columns(self):
+ left = self.data1
+ right = self.data2
- def test_apply_in_pandas_returning_wrong_column_names(self):
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
lft["add"] = 0
@@ -185,77 +196,70 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
rgt["more"] = 1
return pd.merge(lft, rgt, on=["id", "k"])
- self._test_merge_error(
- fn=merge_pandas,
- error_class=PythonException,
- error_message_regex="Column names of the returned pandas.DataFrame
"
- "do not match specified schema. Unexpected: add, more.\n",
- )
+ with QuietTest(self.sc):
+ with self.assertRaisesRegex(
+ PythonException,
+ "Number of columns of the returned pandas.DataFrame "
+ "doesn't match specified schema. Expected: 4 Actual: 6",
+ ):
+ (
+ # merge_pandas returns two columns for even keys while we
set schema to four
+ left.groupby("id")
+ .cogroup(right.groupby("id"))
+ .applyInPandas(merge_pandas, "id long, k int, v int, v2
int")
+ .collect()
+ )
+
+ def test_apply_in_pandas_returning_empty_dataframe(self):
+ left = self.data1
+ right = self.data2
- def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
- lft[3] = 0
+ return pd.DataFrame([])
if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
- rgt[3] = 1
- res = pd.merge(lft, rgt, on=["id", "k"])
- res.columns = range(res.columns.size)
- return res
-
- self._test_merge_error(
- fn=merge_pandas,
- error_class=PythonException,
- error_message_regex="Number of columns of the returned
pandas.DataFrame "
- "doesn't match specified schema. Expected: 4 Actual: 6\n",
+ return pd.DataFrame([])
+ return pd.merge(lft, rgt, on=["id", "k"])
+
+ result = (
+ left.groupby("id")
+ .cogroup(right.groupby("id"))
+ .applyInPandas(merge_pandas, "id long, k int, v int, v2 int")
+ .sort(["id", "k"])
+ .toPandas()
)
- def test_apply_in_pandas_returning_empty_dataframe(self):
+ left = left.toPandas()
+ right = right.toPandas()
+
+ expected = pd.merge(
+ left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id",
"k"]
+ ).sort_values(by=["id", "k"])
+
+ assert_frame_equal(expected, result)
+
+ def
test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
+ left = self.data1
+ right = self.data2
+
def merge_pandas(lft, rgt):
if 0 in lft["id"] and lft["id"][0] % 2 == 0:
- return pd.DataFrame()
- if 0 in rgt["id"] and rgt["id"][0] % 3 == 0:
- return pd.DataFrame()
+ return pd.DataFrame([], columns=["id", "k"])
return pd.merge(lft, rgt, on=["id", "k"])
- self._test_merge_empty(fn=merge_pandas)
-
- def test_apply_in_pandas_returning_incompatible_type(self):
- for safely in [True, False]:
- with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
- {"spark.sql.execution.pandas.convertToArrowArraySafely":
safely}
- ), QuietTest(self.sc):
- # sometimes we see ValueErrors
- with self.subTest(convert="string to double"):
- expected = (
- r"ValueError: Exception thrown when converting
pandas.Series \(object\) "
- r"with name 'k' to Arrow Array \(double\)."
- )
- if safely:
- expected = expected + (
- " It can be caused by overflows or other "
- "unsafe conversions warned by Arrow. Arrow safe
type check "
- "can be disabled by using SQL config "
-
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
- )
- self._test_merge_error(
- fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k":
["2.0"]}),
- output_schema="id long, k double",
- error_class=PythonException,
- error_message_regex=expected,
- )
-
- # sometimes we see TypeErrors
- with self.subTest(convert="double to string"):
- expected = (
- r"TypeError: Exception thrown when converting
pandas.Series \(float64\) "
- r"with name 'k' to Arrow Array \(string\).\n"
- )
- self._test_merge_error(
- fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k":
[2.0]}),
- output_schema="id long, k string",
- error_class=PythonException,
- error_message_regex=expected,
- )
+ with QuietTest(self.sc):
+ with self.assertRaisesRegex(
+ PythonException,
+ "Number of columns of the returned pandas.DataFrame doesn't "
+ "match specified schema. Expected: 4 Actual: 2",
+ ):
+ (
+ # merge_pandas returns two columns for even keys while we
set schema to four
+ left.groupby("id")
+ .cogroup(right.groupby("id"))
+ .applyInPandas(merge_pandas, "id long, k int, v int, v2
int")
+ .collect()
+ )
def test_mixed_scalar_udfs_followed_by_cogrouby_apply(self):
df = self.spark.range(0, 10).toDF("v1")
@@ -308,20 +312,23 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
def test_wrong_return_type(self):
# Test that we get a sensible exception invalid values passed to apply
- self._test_merge_error(
- fn=lambda l, r: l,
- output_schema="id long, v array<timestamp>",
- error_class=NotImplementedError,
- error_message_regex="Invalid return
type.*ArrayType.*TimestampType",
- )
+ left = self.data1
+ right = self.data2
+ with QuietTest(self.sc):
+ with self.assertRaisesRegex(
+ NotImplementedError, "Invalid return
type.*ArrayType.*TimestampType"
+ ):
+ left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
+ lambda l, r: l, "id long, v array<timestamp>"
+ )
def test_wrong_args(self):
- self.__test_merge_error(
- fn=lambda: 1,
- output_schema=StructType([StructField("d", DoubleType())]),
- error_class=ValueError,
- error_message_regex="Invalid function",
- )
+ left = self.data1
+ right = self.data2
+ with self.assertRaisesRegex(ValueError, "Invalid function"):
+ left.groupby("id").cogroup(right.groupby("id")).applyInPandas(
+ lambda: 1, StructType([StructField("d", DoubleType())])
+ )
def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
@@ -427,51 +434,15 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
assert_frame_equal(expected, result)
- def _test_merge_empty(self, fn):
- left = self.data1.toPandas()
- right = self.data2.toPandas()
-
- expected = pd.merge(
- left[left["id"] % 2 != 0], right[right["id"] % 3 != 0], on=["id",
"k"]
- ).sort_values(by=["id", "k"])
-
- self._test_merge(self.data1, self.data2, fn=fn, expected=expected)
-
- def _test_merge(
- self,
- left=None,
- right=None,
- by=["id"],
- fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
- output_schema="id long, k int, v int, v2 int",
- expected=None,
- ):
- def fn_with_key(_, lft, rgt):
- return fn(lft, rgt)
-
- # Test fn with and without key argument
- with self.subTest("without key"):
- self.__test_merge(left, right, by, fn, output_schema, expected)
- with self.subTest("with key"):
- self.__test_merge(left, right, by, fn_with_key, output_schema,
expected)
-
- def __test_merge(
- self,
- left=None,
- right=None,
- by=["id"],
- fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
- output_schema="id long, k int, v int, v2 int",
- expected=None,
- ):
- # Test fn as is, cf. _test_merge
- left = self.data1 if left is None else left
- right = self.data2 if right is None else right
+ @staticmethod
+ def _test_merge(left, right, output_schema="id long, k int, v int, v2
int"):
+ def merge_pandas(lft, rgt):
+ return pd.merge(lft, rgt, on=["id", "k"])
result = (
- left.groupby(*by)
- .cogroup(right.groupby(*by))
- .applyInPandas(fn, output_schema)
+ left.groupby("id")
+ .cogroup(right.groupby("id"))
+ .applyInPandas(merge_pandas, output_schema)
.sort(["id", "k"])
.toPandas()
)
@@ -479,64 +450,10 @@ class CogroupedApplyInPandasTests(ReusedSQLTestCase):
left = left.toPandas()
right = right.toPandas()
- expected = (
- pd.merge(left, right, on=["id", "k"]).sort_values(by=["id", "k"])
- if expected is None
- else expected
- )
+ expected = pd.merge(left, right, on=["id", "k"]).sort_values(by=["id",
"k"])
assert_frame_equal(expected, result)
- def _test_merge_error(
- self,
- error_class,
- error_message_regex,
- left=None,
- right=None,
- by=["id"],
- fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
- output_schema="id long, k int, v int, v2 int",
- ):
- def fn_with_key(_, lft, rgt):
- return fn(lft, rgt)
-
- # Test fn with and without key argument
- with self.subTest("without key"):
- self.__test_merge_error(
- left=left,
- right=right,
- by=by,
- fn=fn,
- output_schema=output_schema,
- error_class=error_class,
- error_message_regex=error_message_regex,
- )
- with self.subTest("with key"):
- self.__test_merge_error(
- left=left,
- right=right,
- by=by,
- fn=fn_with_key,
- output_schema=output_schema,
- error_class=error_class,
- error_message_regex=error_message_regex,
- )
-
- def __test_merge_error(
- self,
- error_class,
- error_message_regex,
- left=None,
- right=None,
- by=["id"],
- fn=lambda lft, rgt: pd.merge(lft, rgt, on=["id", "k"]),
- output_schema="id long, k int, v int, v2 int",
- ):
- # Test fn as is, cf. _test_merge_error
- with QuietTest(self.sc):
- with self.assertRaisesRegex(error_class, error_message_regex):
- self.__test_merge(left, right, by, fn, output_schema)
-
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_cogrouped_map import * # noqa:
F401
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 88e68b04303..5f103c97926 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -73,7 +73,7 @@ if have_pyarrow:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
-class GroupedApplyInPandasTests(ReusedSQLTestCase):
+class GroupedMapInPandasTests(ReusedSQLTestCase):
@property
def data(self):
return (
@@ -270,101 +270,79 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
assert_frame_equal(expected, result)
def test_apply_in_pandas_not_returning_pandas_dataframe(self):
+ df = self.data
+
+ def stats(key, _):
+ return key
+
with QuietTest(self.sc):
with self.assertRaisesRegex(
PythonException,
"Return type of the user-defined function should be
pandas.DataFrame, "
"but is <class 'tuple'>",
):
- self._test_apply_in_pandas(lambda key, pdf: key)
-
- @staticmethod
- def stats_with_column_names(key, pdf):
- # order of column can be different to applyInPandas schema when column
names are given
- return pd.DataFrame([(pdf.v.mean(),) + key], columns=["mean", "id"])
-
- @staticmethod
- def stats_with_no_column_names(key, pdf):
- # columns must be in order of applyInPandas schema when no columns
given
- return pd.DataFrame([key + (pdf.v.mean(),)])
+ df.groupby("id").applyInPandas(stats, schema="id integer, m
double").collect()
- def test_apply_in_pandas_returning_column_names(self):
-
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_column_names)
-
- def test_apply_in_pandas_returning_no_column_names(self):
-
self._test_apply_in_pandas(GroupedApplyInPandasTests.stats_with_no_column_names)
+ def test_apply_in_pandas_returning_wrong_number_of_columns(self):
+ df = self.data
- def test_apply_in_pandas_returning_column_names_sometimes(self):
def stats(key, pdf):
- if key[0] % 2:
- return GroupedApplyInPandasTests.stats_with_column_names(key,
pdf)
- else:
- return
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
-
- self._test_apply_in_pandas(stats)
+ v = pdf.v
+ # returning three columns
+ res = pd.DataFrame([key + (v.mean(), v.std())])
+ return res
- def test_apply_in_pandas_returning_wrong_column_names(self):
with QuietTest(self.sc):
with self.assertRaisesRegex(
PythonException,
- "Column names of the returned pandas.DataFrame do not match
specified schema. "
- "Missing: mean. Unexpected: median, std.\n",
+ "Number of columns of the returned pandas.DataFrame doesn't
match "
+ "specified schema. Expected: 2 Actual: 3",
):
- self._test_apply_in_pandas(
- lambda key, pdf: pd.DataFrame(
- [key + (pdf.v.median(), pdf.v.std())], columns=["id",
"median", "std"]
- )
- )
+ # stats returns three columns while here we set schema with
two columns
+ df.groupby("id").applyInPandas(stats, schema="id integer, m
double").collect()
+
+ def test_apply_in_pandas_returning_empty_dataframe(self):
+ df = self.data
+
+ def odd_means(key, pdf):
+ if key[0] % 2 == 0:
+ return pd.DataFrame([])
+ else:
+ return pd.DataFrame([key + (pdf.v.mean(),)])
+
+ expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 !=
0}
+
+ result = (
+ df.groupby("id")
+ .applyInPandas(odd_means, schema="id integer, m double")
+ .sort("id", "m")
+ .collect()
+ )
+
+ actual_ids = {row[0] for row in result}
+ self.assertSetEqual(expected_ids, actual_ids)
+
+ self.assertEqual(len(expected_ids), len(result))
+ for row in result:
+ self.assertEqual(24.5, row[1])
+
+ def
test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
+ df = self.data
+
+ def odd_means(key, pdf):
+ if key[0] % 2 == 0:
+ return pd.DataFrame([], columns=["id"])
+ else:
+ return pd.DataFrame([key + (pdf.v.mean(),)])
- def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
with QuietTest(self.sc):
with self.assertRaisesRegex(
PythonException,
"Number of columns of the returned pandas.DataFrame doesn't
match "
- "specified schema. Expected: 2 Actual: 3\n",
+ "specified schema. Expected: 2 Actual: 1",
):
- self._test_apply_in_pandas(
- lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),
pdf.v.std())])
- )
-
- def test_apply_in_pandas_returning_empty_dataframe(self):
- self._test_apply_in_pandas_returning_empty_dataframe(pd.DataFrame())
-
- def test_apply_in_pandas_returning_incompatible_type(self):
- for safely in [True, False]:
- with self.subTest(convertToArrowArraySafely=safely), self.sql_conf(
- {"spark.sql.execution.pandas.convertToArrowArraySafely":
safely}
- ), QuietTest(self.sc):
- # sometimes we see ValueErrors
- with self.subTest(convert="string to double"):
- expected = (
- r"ValueError: Exception thrown when converting
pandas.Series \(object\) "
- r"with name 'mean' to Arrow Array \(double\)."
- )
- if safely:
- expected = expected + (
- " It can be caused by overflows or other "
- "unsafe conversions warned by Arrow. Arrow safe
type check "
- "can be disabled by using SQL config "
-
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
- )
- with self.assertRaisesRegex(PythonException, expected +
"\n"):
- self._test_apply_in_pandas(
- lambda key, pdf: pd.DataFrame([key +
(str(pdf.v.mean()),)]),
- output_schema="id long, mean double",
- )
-
- # sometimes we see TypeErrors
- with self.subTest(convert="double to string"):
- with self.assertRaisesRegex(
- PythonException,
- r"TypeError: Exception thrown when converting
pandas.Series \(float64\) "
- r"with name 'mean' to Arrow Array \(string\).\n",
- ):
- self._test_apply_in_pandas(
- lambda key, pdf: pd.DataFrame([key +
(pdf.v.mean(),)]),
- output_schema="id long, mean string",
- )
+ # stats returns one column for even keys while here we set
schema with two columns
+ df.groupby("id").applyInPandas(odd_means, schema="id integer,
m double").collect()
def test_datatype_string(self):
df = self.data
@@ -588,11 +566,7 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
with
self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
with QuietTest(self.sc):
- with self.assertRaisesRegex(
- PythonException,
- "RuntimeError: Column names of the returned
pandas.DataFrame do not match "
- "specified schema. Missing: id. Unexpected: iid.\n",
- ):
+ with self.assertRaisesRegex(Exception, "KeyError: 'id'"):
grouped_df.apply(column_name_typo).collect()
with self.assertRaisesRegex(Exception,
"[D|d]ecimal.*got.*date"):
grouped_df.apply(invalid_positional_types).collect()
@@ -681,11 +655,10 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
df.groupby("group", window("ts", "5 days"))
.applyInPandas(f, df.schema)
.select("id", "result")
- .orderBy("id")
.collect()
)
-
- self.assertListEqual([Row(id=key, result=val) for key, val in
expected.items()], result)
+ for r in result:
+ self.assertListEqual(expected[r[0]], r[1])
def test_grouped_over_window_with_key(self):
@@ -747,11 +720,11 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
df.groupby("group", window("ts", "5 days"))
.applyInPandas(f, df.schema)
.select("id", "result")
- .orderBy("id")
.collect()
)
- self.assertListEqual([Row(id=key, result=val) for key, val in
expected.items()], result)
+ for r in result:
+ self.assertListEqual(expected[r[0]], r[1])
def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
@@ -766,44 +739,6 @@ class GroupedApplyInPandasTests(ReusedSQLTestCase):
)
self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
- def _test_apply_in_pandas(self, f, output_schema="id long, mean double"):
- df = self.data
-
- result = (
- df.groupby("id").applyInPandas(f, schema=output_schema).sort("id",
"mean").toPandas()
- )
- expected = df.select("id").distinct().withColumn("mean",
lit(24.5)).toPandas()
-
- assert_frame_equal(expected, result)
-
- def _test_apply_in_pandas_returning_empty_dataframe(self, empty_df):
- """Tests some returned DataFrames are empty."""
- df = self.data
-
- def stats(key, pdf):
- if key[0] % 2 == 0:
- return
GroupedApplyInPandasTests.stats_with_no_column_names(key, pdf)
- return empty_df
-
- result = (
- df.groupby("id")
- .applyInPandas(stats, schema="id long, mean double")
- .sort("id", "mean")
- .collect()
- )
-
- actual_ids = {row[0] for row in result}
- expected_ids = {row[0] for row in self.data.collect() if row[0] % 2 ==
0}
- self.assertSetEqual(expected_ids, actual_ids)
- self.assertEqual(len(expected_ids), len(result))
- for row in result:
- self.assertEqual(24.5, row[1])
-
- def _test_apply_in_pandas_returning_empty_dataframe_error(self, empty_df,
error):
- with QuietTest(self.sc):
- with self.assertRaisesRegex(PythonException, error):
- self._test_apply_in_pandas_returning_empty_dataframe(empty_df)
-
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_grouped_map import * # noqa:
F401
diff --git
a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
index 9600d1e3445..655f0bf151d 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map_with_state.py
@@ -53,7 +53,7 @@ if have_pyarrow:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
-class GroupedApplyInPandasWithStateTests(ReusedSQLTestCase):
+class GroupedMapInPandasWithStateTests(ReusedSQLTestCase):
@classmethod
def conf(cls):
cfg = SparkConf()
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index c61994380e6..6083f31ac81 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -465,24 +465,9 @@ class ArrowTests(ReusedSQLTestCase):
wrong_schema = StructType(fields)
with
self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
with QuietTest(self.sc):
- with self.assertRaises(Exception) as context:
+ with self.assertRaisesRegex(Exception,
"[D|d]ecimal.*got.*date"):
self.spark.createDataFrame(pdf, schema=wrong_schema)
- # the exception provides us with the column that is incorrect
- exception = context.exception
- self.assertTrue(hasattr(exception, "args"))
- self.assertEqual(len(exception.args), 1)
- self.assertRegex(
- exception.args[0],
- "with name '7_date_t' " "to Arrow Array
\\(decimal128\\(38, 18\\)\\)",
- )
-
- # the inner exception provides us with the incorrect types
- exception = exception.__context__
- self.assertTrue(hasattr(exception, "args"))
- self.assertEqual(len(exception.args), 1)
- self.assertRegex(exception.args[0], "[D|d]ecimal.*got.*date")
-
def test_createDataFrame_with_names(self):
pdf = self.create_pandas_data_frame()
new_names = list(map(str, range(len(self.schema.fieldNames()))))
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f7d98a9a18c..c1c3669701f 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -146,49 +146,7 @@ def wrap_batch_iter_udf(f, return_type):
)
-def verify_pandas_result(result, return_type, assign_cols_by_name):
- import pandas as pd
-
- if not isinstance(result, pd.DataFrame):
- raise TypeError(
- "Return type of the user-defined function should be "
- "pandas.DataFrame, but is {}".format(type(result))
- )
-
- # check the schema of the result only if it is not empty or has columns
- if not result.empty or len(result.columns) != 0:
- # if any column name of the result is a string
- # the column names of the result have to match the return type
- # see create_array in
pyspark.sql.pandas.serializers.ArrowStreamPandasSerializer
- field_names = set([field.name for field in return_type.fields])
- column_names = set(result.columns)
- if (
- assign_cols_by_name
- and any(isinstance(name, str) for name in result.columns)
- and column_names != field_names
- ):
- missing = sorted(list(field_names.difference(column_names)))
- missing = f" Missing: {', '.join(missing)}." if missing else ""
-
- extra = sorted(list(column_names.difference(field_names)))
- extra = f" Unexpected: {', '.join(extra)}." if extra else ""
-
- raise RuntimeError(
- "Column names of the returned pandas.DataFrame do not match
specified schema."
- "{}{}".format(missing, extra)
- )
- # otherwise the number of columns of result have to match the return
type
- elif len(result.columns) != len(return_type):
- raise RuntimeError(
- "Number of columns of the returned pandas.DataFrame "
- "doesn't match specified schema. "
- "Expected: {} Actual: {}".format(len(return_type),
len(result.columns))
- )
-
-
-def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf):
- _assign_cols_by_name = assign_cols_by_name(runner_conf)
-
+def wrap_cogrouped_map_pandas_udf(f, return_type, argspec):
def wrapped(left_key_series, left_value_series, right_key_series,
right_value_series):
import pandas as pd
@@ -201,16 +159,27 @@ def wrap_cogrouped_map_pandas_udf(f, return_type,
argspec, runner_conf):
key_series = left_key_series if not left_df.empty else
right_key_series
key = tuple(s[0] for s in key_series)
result = f(key, left_df, right_df)
- verify_pandas_result(result, return_type, _assign_cols_by_name)
-
+ if not isinstance(result, pd.DataFrame):
+ raise TypeError(
+ "Return type of the user-defined function should be "
+ "pandas.DataFrame, but is {}".format(type(result))
+ )
+ # the number of columns of result have to match the return type
+ # but it is fine for result to have no columns at all if it is empty
+ if not (
+ len(result.columns) == len(return_type) or len(result.columns) ==
0 and result.empty
+ ):
+ raise RuntimeError(
+ "Number of columns of the returned pandas.DataFrame "
+ "doesn't match specified schema. "
+ "Expected: {} Actual: {}".format(len(return_type),
len(result.columns))
+ )
return result
return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr),
to_arrow_type(return_type))]
-def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf):
- _assign_cols_by_name = assign_cols_by_name(runner_conf)
-
+def wrap_grouped_map_pandas_udf(f, return_type, argspec):
def wrapped(key_series, value_series):
import pandas as pd
@@ -219,8 +188,22 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec,
runner_conf):
elif len(argspec.args) == 2:
key = tuple(s[0] for s in key_series)
result = f(key, pd.concat(value_series, axis=1))
- verify_pandas_result(result, return_type, _assign_cols_by_name)
+ if not isinstance(result, pd.DataFrame):
+ raise TypeError(
+ "Return type of the user-defined function should be "
+ "pandas.DataFrame, but is {}".format(type(result))
+ )
+ # the number of columns of result have to match the return type
+ # but it is fine for result to have no columns at all if it is empty
+ if not (
+ len(result.columns) == len(return_type) or len(result.columns) ==
0 and result.empty
+ ):
+ raise RuntimeError(
+ "Number of columns of the returned pandas.DataFrame "
+ "doesn't match specified schema. "
+ "Expected: {} Actual: {}".format(len(return_type),
len(result.columns))
+ )
return result
return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
@@ -413,12 +396,12 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
return arg_offsets, wrap_batch_iter_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when
wrapping it
- return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type,
argspec, runner_conf)
+ return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type,
argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func,
return_type)
elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
argspec = getfullargspec(chained_func) # signature was lost when
wrapping it
- return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type,
argspec, runner_conf)
+ return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type,
argspec)
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
@@ -429,16 +412,6 @@ def read_single_udf(pickleSer, infile, eval_type,
runner_conf, udf_index):
raise ValueError("Unknown eval type: {}".format(eval_type))
-# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when returning
StructType
-def assign_cols_by_name(runner_conf):
- return (
- runner_conf.get(
-
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
- ).lower()
- == "true"
- )
-
-
def read_udfs(pickleSer, infile, eval_type):
runner_conf = {}
@@ -471,9 +444,16 @@ def read_udfs(pickleSer, infile, eval_type):
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely",
"false").lower()
== "true"
)
+ # Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF when
returning StructType
+ assign_cols_by_name = (
+ runner_conf.get(
+
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
+ ).lower()
+ == "true"
+ )
if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
- ser = CogroupUDFSerializer(timezone, safecheck,
assign_cols_by_name(runner_conf))
+ ser = CogroupUDFSerializer(timezone, safecheck,
assign_cols_by_name)
elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
arrow_max_records_per_batch = runner_conf.get(
"spark.sql.execution.arrow.maxRecordsPerBatch", 10000
@@ -483,7 +463,7 @@ def read_udfs(pickleSer, infile, eval_type):
ser = ApplyInPandasWithStateSerializer(
timezone,
safecheck,
- assign_cols_by_name(runner_conf),
+ assign_cols_by_name,
state_object_schema,
arrow_max_records_per_batch,
)
@@ -498,7 +478,7 @@ def read_udfs(pickleSer, infile, eval_type):
or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
)
ser = ArrowStreamPandasUDFSerializer(
- timezone, safecheck, assign_cols_by_name(runner_conf),
df_for_struct
+ timezone, safecheck, assign_cols_by_name, df_for_struct
)
else:
ser = BatchedSerializer(CPickleSerializer(), 100)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]