This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new afa7f3d1bb8 [SPARK-43323][SQL][PYTHON] Fix DataFrame.toPandas with
Arrow enabled to handle exceptions properly
afa7f3d1bb8 is described below
commit afa7f3d1bb865e319b0ca7e295a9c8bf4106ae0a
Author: Takuya UESHIN <[email protected]>
AuthorDate: Tue May 2 08:16:51 2023 +0900
[SPARK-43323][SQL][PYTHON] Fix DataFrame.toPandas with Arrow enabled to
handle exceptions properly
### What changes were proposed in this pull request?
Fixes `DataFrame.toPandas` with Arrow enabled to handle exceptions properly.
```py
>>> spark.conf.set("spark.sql.ansi.enabled", True)
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
>>> spark.sql("select 1/0").toPandas()
...
Traceback (most recent call last):
...
pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO]
Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL
instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this
error.
== SQL(line 1, position 8) ==
select 1/0
^^^
```
### Why are the changes needed?
Currently `DataFrame.toPandas` doesn't capture exceptions happened in Spark
properly.
```py
>>> spark.conf.set("spark.sql.ansi.enabled", True)
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', True)
>>> spark.sql("select 1/0").toPandas()
...
An error occurred while calling o53.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult:
at
org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:322)
...
```
because `jsocket_auth_server.getResult()` always wraps the thrown
exceptions with `SparkException` that won't be captured.
Whereas without Arrow:
```py
>>> spark.conf.set("spark.sql.ansi.enabled", True)
>>> spark.conf.set('spark.sql.execution.arrow.pyspark.enabled', False)
>>> spark.sql("select 1/0").toPandas()
Traceback (most recent call last):
...
pyspark.errors.exceptions.captured.ArithmeticException: [DIVIDE_BY_ZERO]
Division by zero. Use `try_divide` to tolerate divisor being 0 and return NULL
instead. If necessary set "spark.sql.ansi.enabled" to "false" to bypass this
error.
== SQL(line 1, position 8) ==
select 1/0
^^^
```
### Does this PR introduce _any_ user-facing change?
`DataFrame.toPandas` with Arrow enabled will show a proper exception.
### How was this patch tested?
Added the related test.
Closes #40998 from ueshin/issues/SPARK-43323/getResult.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/errors/exceptions/captured.py | 20 ++++++++++++++++++--
python/pyspark/sql/pandas/conversion.py | 6 ++++--
.../pyspark/sql/tests/connect/test_parity_arrow.py | 3 +++
python/pyspark/sql/tests/test_arrow.py | 17 ++++++++++++++++-
4 files changed, 41 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/errors/exceptions/captured.py
b/python/pyspark/errors/exceptions/captured.py
index d1b57997f99..5b008f4ab00 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
-from typing import Any, Callable, Dict, Optional, cast
+from contextlib import contextmanager
+from typing import Any, Callable, Dict, Iterator, Optional, cast
import py4j
from py4j.protocol import Py4JJavaError
@@ -186,6 +186,22 @@ def capture_sql_exception(f: Callable[..., Any]) ->
Callable[..., Any]:
return deco
+@contextmanager
+def unwrap_spark_exception() -> Iterator[Any]:
+ assert SparkContext._gateway is not None
+
+ gw = SparkContext._gateway
+ try:
+ yield
+ except Py4JJavaError as e:
+ je: Py4JJavaError = e.java_exception
+ if je is not None and is_instance_of(gw, je,
"org.apache.spark.SparkException"):
+ converted = convert_exception(je.getCause())
+ if not isinstance(converted, UnknownException):
+ raise converted from None
+ raise
+
+
def install_exception_handler() -> None:
"""
Hook an exception handler into Py4j, which could capture some SQL
exceptions in Java.
diff --git a/python/pyspark/sql/pandas/conversion.py
b/python/pyspark/sql/pandas/conversion.py
index a5f0664ed75..ce0143d1851 100644
--- a/python/pyspark/sql/pandas/conversion.py
+++ b/python/pyspark/sql/pandas/conversion.py
@@ -19,6 +19,7 @@ from collections import Counter
from typing import List, Optional, Type, Union, no_type_check, overload,
TYPE_CHECKING
from warnings import catch_warnings, simplefilter, warn
+from pyspark.errors.exceptions.captured import unwrap_spark_exception
from pyspark.rdd import _load_from_socket
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import (
@@ -357,8 +358,9 @@ class PandasConversionMixin:
else:
results = list(batch_stream)
finally:
- # Join serving thread and raise any exceptions from
collectAsArrowToPython
- jsocket_auth_server.getResult()
+ with unwrap_spark_exception():
+ # Join serving thread and raise any exceptions from
collectAsArrowToPython
+ jsocket_auth_server.getResult()
# Separate RecordBatches from batch order indices in results
batches = results[:-1]
diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow.py
b/python/pyspark/sql/tests/connect/test_parity_arrow.py
index fd05821f052..f2fa9ece4df 100644
--- a/python/pyspark/sql/tests/connect/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/test_parity_arrow.py
@@ -103,6 +103,9 @@ class ArrowParityTests(ArrowTestsMixin,
ReusedConnectTestCase):
def test_timestamp_nat(self):
self.check_timestamp_nat(True)
+ def test_toPandas_error(self):
+ self.check_toPandas_error(True)
+
if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401
diff --git a/python/pyspark/sql/tests/test_arrow.py
b/python/pyspark/sql/tests/test_arrow.py
index 518e17d57b6..84c782e8d95 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -55,7 +55,7 @@ from pyspark.testing.sqlutils import (
pyarrow_requirement_message,
)
from pyspark.testing.utils import QuietTest
-from pyspark.errors import PySparkTypeError
+from pyspark.errors import ArithmeticException, PySparkTypeError
if have_pandas:
import pandas as pd
@@ -873,6 +873,21 @@ class ArrowTestsMixin:
self.assertEqual([Row(c1=1, c2="string")], df.collect())
self.assertGreater(self.spark.sparkContext.defaultParallelism,
len(pdf))
+ def test_toPandas_error(self):
+ for arrow_enabled in [True, False]:
+ with self.subTest(arrow_enabled=arrow_enabled):
+ self.check_toPandas_error(arrow_enabled)
+
+ def check_toPandas_error(self, arrow_enabled):
+ with self.sql_conf(
+ {
+ "spark.sql.ansi.enabled": True,
+ "spark.sql.execution.arrow.pyspark.enabled": arrow_enabled,
+ }
+ ):
+ with self.assertRaises(ArithmeticException):
+ self.spark.sql("select 1/0").toPandas()
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]