This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b4282266bed [SPARK-43314][CONNECT][PYTHON] Migrate Spark Connect 
client errors into error class
b4282266bed is described below

commit b4282266bed49d1dbf2a55e0d91622bef18a25d8
Author: itholic <[email protected]>
AuthorDate: Tue May 2 10:02:41 2023 +0800

    [SPARK-43314][CONNECT][PYTHON] Migrate Spark Connect client errors into 
error class
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to migrate all Spark Connect client errors into error class
    
    ### Why are the changes needed?
    
    To improve PySpark error
    
    ### Does this PR introduce _any_ user-facing change?
    
    No API changes, only error improvement.
    
    ### How was this patch tested?
    
    This existing CI should pass.
    
    Closes #40985 from itholic/error_connect_client.
    
    Authored-by: itholic <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/errors/error_classes.py             | 20 ++++++++
 python/pyspark/sql/connect/client.py               | 60 +++++++++++++++++-----
 .../sql/tests/connect/test_connect_basic.py        | 10 ++--
 3 files changed, 73 insertions(+), 17 deletions(-)

diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 1dcac48fab1..8515fdbcce3 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -144,6 +144,11 @@ ERROR_CLASSES_JSON = """
       "Argument `<arg_name>`(type: <arg_type>) should only contain a type in 
[<allowed_types>], got <return_type>"
     ]
   },
+  "EXCEED_RETRY" : {
+    "message" : [
+      "Retries exceeded but no exception caught."
+    ]
+  },
   "HIGHER_ORDER_FUNCTION_SHOULD_RETURN_COLUMN" : {
     "message" : [
       "Function `<func_name>` should return Column, got <return_type>."
@@ -165,6 +170,11 @@ ERROR_CLASSES_JSON = """
       "Invalid call to `<func_name>` on unresolved object."
     ]
   },
+  "INVALID_CONNECT_URL" : {
+    "message" : [
+      "Invalid URL for Spark Connect: <detail>"
+    ]
+  },
   "INVALID_ITEM_FOR_CONTAINER": {
     "message": [
       "All items in `<arg_name>` should be in <allowed_types>, got 
<item_type>."
@@ -536,6 +546,16 @@ ERROR_CLASSES_JSON = """
       "Unexpected response from iterator server."
     ]
   },
+  "UNKNOWN_EXPLAIN_MODE" : {
+    "message" : [
+      "Unknown explain mode: '<explain_mode>'. Accepted explain modes are 
'simple', 'extended', 'codegen', 'cost', 'formatted'."
+    ]
+  },
+  "UNKNOWN_RESPONSE" : {
+    "message" : [
+      "Unknown response: <response>."
+    ]
+  },
   "UNSUPPORTED_DATA_TYPE" : {
     "message" : [
       "Unsupported DataType `<data_type>`."
diff --git a/python/pyspark/sql/connect/client.py 
b/python/pyspark/sql/connect/client.py
index beb5ae86138..2f061ecfb89 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -76,6 +76,7 @@ from pyspark.sql.pandas.types import 
_check_series_localize_timestamps, _convert
 from pyspark.sql.types import DataType, MapType, StructType, TimestampType
 from pyspark.rdd import PythonEvalType
 from pyspark.storagelevel import StorageLevel
+from pyspark.errors import PySparkValueError, PySparkRuntimeError
 
 
 if TYPE_CHECKING:
@@ -169,15 +170,24 @@ class ChannelBuilder:
         """
         # Explicitly check the scheme of the URL.
         if url[:5] != "sc://":
-            raise AttributeError("URL scheme must be set to `sc`.")
+            raise PySparkValueError(
+                error_class="INVALID_CONNECT_URL",
+                message_parameters={
+                    "detail": "URL scheme must be set to `sc`.",
+                },
+            )
         # Rewrite the URL to use http as the scheme so that we can leverage
         # Python's built-in parser.
         tmp_url = "http" + url[2:]
         self.url = urllib.parse.urlparse(tmp_url)
         self.params: Dict[str, str] = {}
         if len(self.url.path) > 0 and self.url.path != "/":
-            raise AttributeError(
-                f"Path component for connection URI must be empty: 
{self.url.path}"
+            raise PySparkValueError(
+                error_class="INVALID_CONNECT_URL",
+                message_parameters={
+                    "detail": f"Path component for connection URI 
`{self.url.path}` "
+                    f"must be empty.",
+                },
             )
         self._extract_attributes()
 
@@ -197,7 +207,12 @@ class ChannelBuilder:
             for p in parts:
                 kv = p.split("=")
                 if len(kv) != 2:
-                    raise AttributeError(f"Parameter '{p}' is not a valid 
parameter key-value pair")
+                    raise PySparkValueError(
+                        error_class="INVALID_CONNECT_URL",
+                        message_parameters={
+                            "detail": f"Parameter '{p}' is not a valid 
parameter key-value pair.",
+                        },
+                    )
                 self.params[kv[0]] = urllib.parse.unquote(kv[1])
 
         netloc = self.url.netloc.split(":")
@@ -208,8 +223,12 @@ class ChannelBuilder:
             self.host = netloc[0]
             self.port = int(netloc[1])
         else:
-            raise AttributeError(
-                f"Target destination {self.url.netloc} does not match 
'<host>:<port>' pattern"
+            raise PySparkValueError(
+                error_class="INVALID_CONNECT_URL",
+                message_parameters={
+                    "detail": f"Target destination {self.url.netloc} does not 
match "
+                    f"'<host>:<port>' pattern.",
+                },
             )
 
     def metadata(self) -> Iterable[Tuple[str, str]]:
@@ -820,11 +839,11 @@ class SparkConnectClient(object):
             req.explain.plan.CopyFrom(cast(pb2.Plan, kwargs.get("plan")))
             explain_mode = kwargs.get("explain_mode")
             if explain_mode not in ["simple", "extended", "codegen", "cost", 
"formatted"]:
-                raise ValueError(
-                    f"""
-                    Unknown explain mode: {explain_mode}. Accepted "
-                    "explain modes are 'simple', 'extended', 'codegen', 
'cost', 'formatted'."
-                    """
+                raise PySparkValueError(
+                    error_class="UNKNOWN_EXPLAIN_MODE",
+                    message_parameters={
+                        "explain_mode": str(explain_mode),
+                    },
                 )
             if explain_mode == "simple":
                 req.explain.explain_mode = (
@@ -878,7 +897,12 @@ class SparkConnectClient(object):
         elif method == "get_storage_level":
             req.get_storage_level.relation.CopyFrom(cast(pb2.Relation, 
kwargs.get("relation")))
         else:
-            raise ValueError(f"Unknown Analyze method: {method}")
+            raise PySparkValueError(
+                error_class="UNSUPPORTED_OPERATION",
+                message_parameters={
+                    "operation": method,
+                },
+            )
 
         try:
             for attempt in Retrying(
@@ -1015,7 +1039,12 @@ class SparkConnectClient(object):
             elif isinstance(response, dict):
                 properties.update(**response)
             else:
-                raise ValueError(f"Unknown response: {response}")
+                raise PySparkValueError(
+                    error_class="UNKNOWN_RESPONSE",
+                    message_parameters={
+                        "response": response,
+                    },
+                )
 
         if len(batches) > 0:
             table = pa.Table.from_batches(batches=batches)
@@ -1232,7 +1261,10 @@ class Retrying:
                 if e is not None:
                     raise e
                 else:
-                    raise ValueError("Retries exceeded but no exception 
caught.")
+                    raise PySparkRuntimeError(
+                        error_class="EXCEED_RETRY",
+                        message_parameters={},
+                    )
 
             # Do backoff
             if retry_state.count() > 0:
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f4b6dbe26cd..00db1bd98f6 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2073,9 +2073,13 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertTrue("Optimized Logical Plan" in plan_str)
         self.assertTrue("Physical Plan" in plan_str)
 
-        with self.assertRaises(ValueError) as context:
+        with self.assertRaises(PySparkValueError) as pe:
             self.connect.sql("SELECT 1")._explain_string(mode="unknown")
-        self.assertTrue("unknown" in str(context.exception))
+        self.check_error(
+            exception=pe.exception,
+            error_class="UNKNOWN_EXPLAIN_MODE",
+            message_parameters={"explain_mode": "unknown"},
+        )
 
     def test_simple_datasource_read(self) -> None:
         writeDf = self.df_text
@@ -3365,7 +3369,7 @@ class ChannelBuilderTests(unittest.TestCase):
             "sc://host/;parm1;param2",
         ]
         for i in invalid:
-            self.assertRaises(AttributeError, ChannelBuilder, i)
+            self.assertRaises(PySparkValueError, ChannelBuilder, i)
 
     def test_sensible_defaults(self):
         chan = ChannelBuilder("sc://host")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to