This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b4282266bed [SPARK-43314][CONNECT][PYTHON] Migrate Spark Connect
client errors into error class
b4282266bed is described below
commit b4282266bed49d1dbf2a55e0d91622bef18a25d8
Author: itholic <[email protected]>
AuthorDate: Tue May 2 10:02:41 2023 +0800
[SPARK-43314][CONNECT][PYTHON] Migrate Spark Connect client errors into
error class
### What changes were proposed in this pull request?
This PR proposes to migrate all Spark Connect client errors into error class
### Why are the changes needed?
To improve PySpark error
### Does this PR introduce _any_ user-facing change?
No API changes, only error improvement.
### How was this patch tested?
This existing CI should pass.
Closes #40985 from itholic/error_connect_client.
Authored-by: itholic <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/errors/error_classes.py | 20 ++++++++
python/pyspark/sql/connect/client.py | 60 +++++++++++++++++-----
.../sql/tests/connect/test_connect_basic.py | 10 ++--
3 files changed, 73 insertions(+), 17 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index 1dcac48fab1..8515fdbcce3 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -144,6 +144,11 @@ ERROR_CLASSES_JSON = """
"Argument `<arg_name>`(type: <arg_type>) should only contain a type in
[<allowed_types>], got <return_type>"
]
},
+ "EXCEED_RETRY" : {
+ "message" : [
+ "Retries exceeded but no exception caught."
+ ]
+ },
"HIGHER_ORDER_FUNCTION_SHOULD_RETURN_COLUMN" : {
"message" : [
"Function `<func_name>` should return Column, got <return_type>."
@@ -165,6 +170,11 @@ ERROR_CLASSES_JSON = """
"Invalid call to `<func_name>` on unresolved object."
]
},
+ "INVALID_CONNECT_URL" : {
+ "message" : [
+ "Invalid URL for Spark Connect: <detail>"
+ ]
+ },
"INVALID_ITEM_FOR_CONTAINER": {
"message": [
"All items in `<arg_name>` should be in <allowed_types>, got
<item_type>."
@@ -536,6 +546,16 @@ ERROR_CLASSES_JSON = """
"Unexpected response from iterator server."
]
},
+ "UNKNOWN_EXPLAIN_MODE" : {
+ "message" : [
+ "Unknown explain mode: '<explain_mode>'. Accepted explain modes are
'simple', 'extended', 'codegen', 'cost', 'formatted'."
+ ]
+ },
+ "UNKNOWN_RESPONSE" : {
+ "message" : [
+ "Unknown response: <response>."
+ ]
+ },
"UNSUPPORTED_DATA_TYPE" : {
"message" : [
"Unsupported DataType `<data_type>`."
diff --git a/python/pyspark/sql/connect/client.py
b/python/pyspark/sql/connect/client.py
index beb5ae86138..2f061ecfb89 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -76,6 +76,7 @@ from pyspark.sql.pandas.types import
_check_series_localize_timestamps, _convert
from pyspark.sql.types import DataType, MapType, StructType, TimestampType
from pyspark.rdd import PythonEvalType
from pyspark.storagelevel import StorageLevel
+from pyspark.errors import PySparkValueError, PySparkRuntimeError
if TYPE_CHECKING:
@@ -169,15 +170,24 @@ class ChannelBuilder:
"""
# Explicitly check the scheme of the URL.
if url[:5] != "sc://":
- raise AttributeError("URL scheme must be set to `sc`.")
+ raise PySparkValueError(
+ error_class="INVALID_CONNECT_URL",
+ message_parameters={
+ "detail": "URL scheme must be set to `sc`.",
+ },
+ )
# Rewrite the URL to use http as the scheme so that we can leverage
# Python's built-in parser.
tmp_url = "http" + url[2:]
self.url = urllib.parse.urlparse(tmp_url)
self.params: Dict[str, str] = {}
if len(self.url.path) > 0 and self.url.path != "/":
- raise AttributeError(
- f"Path component for connection URI must be empty:
{self.url.path}"
+ raise PySparkValueError(
+ error_class="INVALID_CONNECT_URL",
+ message_parameters={
+ "detail": f"Path component for connection URI
`{self.url.path}` "
+ f"must be empty.",
+ },
)
self._extract_attributes()
@@ -197,7 +207,12 @@ class ChannelBuilder:
for p in parts:
kv = p.split("=")
if len(kv) != 2:
- raise AttributeError(f"Parameter '{p}' is not a valid
parameter key-value pair")
+ raise PySparkValueError(
+ error_class="INVALID_CONNECT_URL",
+ message_parameters={
+ "detail": f"Parameter '{p}' is not a valid
parameter key-value pair.",
+ },
+ )
self.params[kv[0]] = urllib.parse.unquote(kv[1])
netloc = self.url.netloc.split(":")
@@ -208,8 +223,12 @@ class ChannelBuilder:
self.host = netloc[0]
self.port = int(netloc[1])
else:
- raise AttributeError(
- f"Target destination {self.url.netloc} does not match
'<host>:<port>' pattern"
+ raise PySparkValueError(
+ error_class="INVALID_CONNECT_URL",
+ message_parameters={
+ "detail": f"Target destination {self.url.netloc} does not
match "
+ f"'<host>:<port>' pattern.",
+ },
)
def metadata(self) -> Iterable[Tuple[str, str]]:
@@ -820,11 +839,11 @@ class SparkConnectClient(object):
req.explain.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
explain_mode = kwargs.get("explain_mode")
if explain_mode not in ["simple", "extended", "codegen", "cost",
"formatted"]:
- raise ValueError(
- f"""
- Unknown explain mode: {explain_mode}. Accepted "
- "explain modes are 'simple', 'extended', 'codegen',
'cost', 'formatted'."
- """
+ raise PySparkValueError(
+ error_class="UNKNOWN_EXPLAIN_MODE",
+ message_parameters={
+ "explain_mode": str(explain_mode),
+ },
)
if explain_mode == "simple":
req.explain.explain_mode = (
@@ -878,7 +897,12 @@ class SparkConnectClient(object):
elif method == "get_storage_level":
req.get_storage_level.relation.CopyFrom(cast(pb2.Relation,
kwargs.get("relation")))
else:
- raise ValueError(f"Unknown Analyze method: {method}")
+ raise PySparkValueError(
+ error_class="UNSUPPORTED_OPERATION",
+ message_parameters={
+ "operation": method,
+ },
+ )
try:
for attempt in Retrying(
@@ -1015,7 +1039,12 @@ class SparkConnectClient(object):
elif isinstance(response, dict):
properties.update(**response)
else:
- raise ValueError(f"Unknown response: {response}")
+ raise PySparkValueError(
+ error_class="UNKNOWN_RESPONSE",
+ message_parameters={
+ "response": response,
+ },
+ )
if len(batches) > 0:
table = pa.Table.from_batches(batches=batches)
@@ -1232,7 +1261,10 @@ class Retrying:
if e is not None:
raise e
else:
- raise ValueError("Retries exceeded but no exception
caught.")
+ raise PySparkRuntimeError(
+ error_class="EXCEED_RETRY",
+ message_parameters={},
+ )
# Do backoff
if retry_state.count() > 0:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f4b6dbe26cd..00db1bd98f6 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2073,9 +2073,13 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assertTrue("Optimized Logical Plan" in plan_str)
self.assertTrue("Physical Plan" in plan_str)
- with self.assertRaises(ValueError) as context:
+ with self.assertRaises(PySparkValueError) as pe:
self.connect.sql("SELECT 1")._explain_string(mode="unknown")
- self.assertTrue("unknown" in str(context.exception))
+ self.check_error(
+ exception=pe.exception,
+ error_class="UNKNOWN_EXPLAIN_MODE",
+ message_parameters={"explain_mode": "unknown"},
+ )
def test_simple_datasource_read(self) -> None:
writeDf = self.df_text
@@ -3365,7 +3369,7 @@ class ChannelBuilderTests(unittest.TestCase):
"sc://host/;parm1;param2",
]
for i in invalid:
- self.assertRaises(AttributeError, ChannelBuilder, i)
+ self.assertRaises(PySparkValueError, ChannelBuilder, i)
def test_sensible_defaults(self):
chan = ChannelBuilder("sc://host")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]