This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 53363a5236d [SPARK-43234][CONNECT][PYTHON] Migrate `ValueError` from
Conect DataFrame into error class
53363a5236d is described below
commit 53363a5236d33fa3b31f608c31acbbcaf1cd1832
Author: itholic <[email protected]>
AuthorDate: Mon Apr 24 14:45:45 2023 +0800
[SPARK-43234][CONNECT][PYTHON] Migrate `ValueError` from Conect DataFrame
into error class
### What changes were proposed in this pull request?
This PR proposes to migrate ValueError into PySparkValueError from Spark
Connect DataFrame.
### Why are the changes needed?
To improve the errors from Spark Connect.
### Does this PR introduce _any_ user-facing change?
No, it's error improvements.
### How was this patch tested?
The existing tests should pass
Closes #40910 from itholic/connect_df_value_error.
Authored-by: itholic <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/errors/error_classes.py | 7 +-
python/pyspark/sql/connect/dataframe.py | 134 ++++++++++++++++-----
.../sql/tests/connect/test_connect_basic.py | 39 +++++-
.../pyspark/sql/tests/connect/test_connect_plan.py | 19 ++-
python/pyspark/testing/connectutils.py | 2 +-
5 files changed, 162 insertions(+), 39 deletions(-)
diff --git a/python/pyspark/errors/error_classes.py
b/python/pyspark/errors/error_classes.py
index 4425ed79928..2b41f54def9 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -119,6 +119,11 @@ ERROR_CLASSES_JSON = """
"<arg1> and <arg2> should be of the same length, got <arg1_length> and
<arg2_length>."
]
},
+ "MISSING_VALID_PLAN" : {
+ "message" : [
+ "Argument to <operator> does not contain a valid plan."
+ ]
+ },
"MIXED_TYPE_REPLACEMENT" : {
"message" : [
"Mixed type replacements are not supported."
@@ -126,7 +131,7 @@ ERROR_CLASSES_JSON = """
},
"NEGATIVE_VALUE" : {
"message" : [
- "Value for `<arg_name>` must be >= 0, got '<arg_value>'."
+ "Value for `<arg_name>` must be greater than or equal to 0, got
'<arg_value>'."
]
},
"NOT_BOOL" : {
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index 3b39b82196e..16c4389ee44 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -52,7 +52,7 @@ from pyspark.sql.dataframe import (
DataFrameStatFunctions as PySparkDataFrameStatFunctions,
)
-from pyspark.errors import PySparkTypeError, PySparkAttributeError
+from pyspark.errors import PySparkTypeError, PySparkAttributeError,
PySparkValueError
from pyspark.errors.exceptions.connect import SparkConnectException
from pyspark.rdd import PythonEvalType
from pyspark.storagelevel import StorageLevel
@@ -183,7 +183,10 @@ class DataFrame:
def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame":
if not exprs:
- raise ValueError("Argument 'exprs' must not be empty")
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "exprs"},
+ )
if len(exprs) == 1 and isinstance(exprs[0], dict):
measures = [_invoke_function(f, col(e)) for e, f in
exprs[0].items()]
@@ -256,7 +259,10 @@ class DataFrame:
def coalesce(self, numPartitions: int) -> "DataFrame":
if not numPartitions > 0:
- raise ValueError("numPartitions must be positive.")
+ raise PySparkValueError(
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={"arg_name": "numPartitions", "arg_value":
str(numPartitions)},
+ )
return DataFrame.withPlan(
plan.Repartition(self._plan, num_partitions=numPartitions,
shuffle=False),
self._session,
@@ -277,7 +283,13 @@ class DataFrame:
) -> "DataFrame":
if isinstance(numPartitions, int):
if not numPartitions > 0:
- raise ValueError("numPartitions must be positive.")
+ raise PySparkValueError(
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={
+ "arg_name": "numPartitions",
+ "arg_value": str(numPartitions),
+ },
+ )
if len(cols) == 0:
return DataFrame.withPlan(
plan.Repartition(self._plan, num_partitions=numPartitions,
shuffle=True),
@@ -329,9 +341,18 @@ class DataFrame:
if isinstance(numPartitions, int):
if not numPartitions > 0:
- raise ValueError("numPartitions must be positive.")
+ raise PySparkValueError(
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={
+ "arg_name": "numPartitions",
+ "arg_value": str(numPartitions),
+ },
+ )
if len(cols) == 0:
- raise ValueError("At least one partition-by expression must be
specified.")
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "cols"},
+ )
else:
sort = []
sort.extend([_convert_col(c) for c in cols])
@@ -415,7 +436,10 @@ class DataFrame:
message_parameters={"arg_name": "cols", "arg_type":
type(cols).__name__},
)
if len(_cols) == 0:
- raise ValueError("'cols' must be non-empty")
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "cols"},
+ )
return DataFrame.withPlan(
plan.Drop(
@@ -556,7 +580,10 @@ class DataFrame:
) -> List[Column]:
"""Return a JVM Seq of Columns that describes the sort order"""
if cols is None:
- raise ValueError("should sort by at least one column")
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "cols"},
+ )
_cols: List[Column] = []
if len(cols) == 1 and isinstance(cols[0], list):
@@ -881,11 +908,17 @@ class DataFrame:
) -> List["DataFrame"]:
for w in weights:
if w < 0.0:
- raise ValueError("Weights must be positive. Found weight
value: %s" % w)
+ raise PySparkValueError(
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={"arg_name": "weights", "arg_value":
str(w)},
+ )
seed = seed if seed is not None else random.randint(0, sys.maxsize)
total = sum(weights)
if total <= 0:
- raise ValueError("Sum of weights must be positive, but got: %s" %
w)
+ raise PySparkValueError(
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={"arg_name": "sum(weights)", "arg_value":
str(total)},
+ )
proportions = list(map(lambda x: x / total, weights))
normalizedCumWeights = [0.0]
for v in proportions:
@@ -920,9 +953,15 @@ class DataFrame:
*exprs: Column,
) -> "DataFrame":
if len(exprs) == 0:
- raise ValueError("'exprs' should not be empty")
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "exprs"},
+ )
if not all(isinstance(c, Column) for c in exprs):
- raise ValueError("all 'exprs' should be Column")
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OF_COLUMN",
+ message_parameters={"arg_name": "exprs"},
+ )
if isinstance(observation, Observation):
return DataFrame.withPlan(
@@ -935,7 +974,13 @@ class DataFrame:
self._session,
)
else:
- raise ValueError("'observation' should be either `Observation` or
`str`.")
+ raise PySparkTypeError(
+ error_class="NOT_OBSERVATION_OR_STR",
+ message_parameters={
+ "arg_name": "observation",
+ "arg_type": type(observation).__name__,
+ },
+ )
observe.__doc__ = PySparkDataFrame.observe.__doc__
@@ -951,7 +996,10 @@ class DataFrame:
def unionAll(self, other: "DataFrame") -> "DataFrame":
if other._plan is None:
- raise ValueError("Argument to Union does not contain a valid
plan.")
+ raise PySparkValueError(
+ error_class="MISSING_VALID_PLAN",
+ message_parameters={"operator": "Union"},
+ )
return DataFrame.withPlan(
plan.SetOperation(self._plan, other._plan, "union", is_all=True),
session=self._session
)
@@ -960,7 +1008,10 @@ class DataFrame:
def unionByName(self, other: "DataFrame", allowMissingColumns: bool =
False) -> "DataFrame":
if other._plan is None:
- raise ValueError("Argument to UnionByName does not contain a valid
plan.")
+ raise PySparkValueError(
+ error_class="MISSING_VALID_PLAN",
+ message_parameters={"operator": "UnionByName"},
+ )
return DataFrame.withPlan(
plan.SetOperation(
self._plan,
@@ -1033,7 +1084,10 @@ class DataFrame:
)
if isinstance(value, dict):
if len(value) == 0:
- raise ValueError("value dict can not be empty")
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "value"},
+ )
for c, v in value.items():
if not isinstance(c, str):
raise PySparkTypeError(
@@ -1102,7 +1156,10 @@ class DataFrame:
elif how == "any":
min_non_nulls = None
else:
- raise ValueError("how ('" + how + "') should be 'any' or
'all'")
+ raise PySparkValueError(
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"arg_name": "how", "arg_value":
str(how)},
+ )
if thresh is not None:
if not isinstance(thresh, int):
@@ -1201,9 +1258,14 @@ class DataFrame:
if isinstance(to_replace, (list, tuple)) and isinstance(value, (list,
tuple)):
if len(to_replace) != len(value):
- raise ValueError(
- "to_replace and value lists should be of the same length. "
- "Got {0} and {1}".format(len(to_replace), len(value))
+ raise PySparkValueError(
+ error_class="LENGTH_SHOULD_BE_THE_SAME",
+ message_parameters={
+ "arg1": "to_replace",
+ "arg2": "value",
+ "arg1_length": str(len(to_replace)),
+ "arg2_length": str(len(value)),
+ },
)
if not (subset is None or isinstance(subset, (list, tuple, str))):
@@ -1234,7 +1296,10 @@ class DataFrame:
and all_of_type(x for x in rep_dict.values() if x is not None)
for all_of_type in [all_of_bool, all_of_str, all_of_numeric]
):
- raise ValueError("Mixed type replacements are not supported")
+ raise PySparkValueError(
+ error_class="MIXED_TYPE_REPLACEMENT",
+ message_parameters={},
+ )
return DataFrame.withPlan(
plan.NAReplace(child=self._plan, cols=subset,
replacements=rep_dict),
@@ -1316,9 +1381,9 @@ class DataFrame:
if not method:
method = "pearson"
if not method == "pearson":
- raise ValueError(
- "Currently only the calculation of the Pearson Correlation "
- + "coefficient is supported."
+ raise PySparkValueError(
+ error_class="VALUE_NOT_PEARSON",
+ message_parameters={"arg_name": "method", "arg_value": method},
)
pdf = DataFrame.withPlan(
plan.StatCorr(child=self._plan, col1=col1, col2=col2,
method=method),
@@ -1368,7 +1433,13 @@ class DataFrame:
probabilities = list(probabilities)
for p in probabilities:
if not isinstance(p, (float, int)) or p < 0 or p > 1:
- raise ValueError("probabilities should be numerical (float,
int) in [0,1].")
+ raise PySparkTypeError(
+ error_class="NOT_LIST_OF_FLOAT_OR_INT",
+ message_parameters={
+ "arg_name": "probabilities",
+ "arg_type": type(p).__name__,
+ },
+ )
if not isinstance(relativeError, (float, int)):
raise PySparkTypeError(
@@ -1379,7 +1450,13 @@ class DataFrame:
},
)
if relativeError < 0:
- raise ValueError("relativeError should be >= 0.")
+ raise PySparkValueError(
+ error_class="NEGATIVE_VALUE",
+ message_parameters={
+ "arg_name": "relativeError",
+ "arg_value": str(relativeError),
+ },
+ )
relativeError = float(relativeError)
pdf = DataFrame.withPlan(
plan.StatApproxQuantile(
@@ -1633,7 +1710,10 @@ class DataFrame:
self, extended: Optional[Union[bool, str]] = None, mode: Optional[str]
= None
) -> str:
if extended is not None and mode is not None:
- raise ValueError("extended and mode should not be set together.")
+ raise PySparkValueError(
+ error_class="CANNOT_SET_TOGETHER",
+ message_parameters={"arg_list": "extended and mode"},
+ )
# For the no argument case: df.explain()
is_no_argument = extended is None and mode is None
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 09657155856..8913dec568b 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -27,6 +27,7 @@ from pyspark.errors import (
PySparkAttributeError,
PySparkTypeError,
PySparkException,
+ PySparkValueError,
)
from pyspark.sql import SparkSession as PySparkSession, Row
from pyspark.sql.types import (
@@ -1704,11 +1705,24 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
values = list(map(lambda metric: metric.long,
observed_metrics[0].metrics))
self.assert_eq(values, [4, 99, 4944])
- with self.assertRaisesRegex(ValueError, "'exprs' should not be empty"):
+ with self.assertRaises(PySparkValueError) as pe:
self.connect.read.table(self.tbl_name).observe(observation_name)
- with self.assertRaisesRegex(ValueError, "all 'exprs' should be
Column"):
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="CANNOT_BE_EMPTY",
+ message_parameters={"item": "exprs"},
+ )
+
+ with self.assertRaises(PySparkTypeError) as pe:
self.connect.read.table(self.tbl_name).observe(observation_name,
CF.lit(1), "id")
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_LIST_OF_COLUMN",
+ message_parameters={"arg_name": "exprs"},
+ )
+
def test_with_columns(self):
# SPARK-41256: test withColumn(s).
self.assert_eq(
@@ -1944,12 +1958,16 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
"arg_type": "float",
},
)
- with self.assertRaisesRegex(
- ValueError, "probabilities should be numerical \\(float, int\\) in
\\[0,1\\]"
- ):
+ with self.assertRaises(PySparkTypeError) as pe:
self.connect.read.table(self.tbl_name2).stat.approxQuantile(
["col1", "col3"], [-0.1], 0.1
)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="NOT_LIST_OF_FLOAT_OR_INT",
+ message_parameters={"arg_name": "probabilities", "arg_type":
"float"},
+ )
with self.assertRaises(PySparkTypeError) as pe:
self.connect.read.table(self.tbl_name2).stat.approxQuantile(
["col1", "col3"], [0.1, 0.5, 0.9], "str"
@@ -1963,11 +1981,20 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
"arg_type": "str",
},
)
- with self.assertRaisesRegex(ValueError, "relativeError should be >=
0."):
+ with self.assertRaises(PySparkValueError) as pe:
self.connect.read.table(self.tbl_name2).stat.approxQuantile(
["col1", "col3"], [0.1, 0.5, 0.9], -0.1
)
+ self.check_error(
+ exception=pe.exception,
+ error_class="NEGATIVE_VALUE",
+ message_parameters={
+ "arg_name": "relativeError",
+ "arg_value": "-0.1",
+ },
+ )
+
def test_stat_freq_items(self):
# SPARK-41065: Test the stat.freqItems method
self.assert_eq(
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 129a25098b1..c39fb6be24c 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -25,6 +25,7 @@ from pyspark.testing.connectutils import (
should_test_connect,
connect_requirement_message,
)
+from pyspark.errors import PySparkValueError
if should_test_connect:
import pyspark.sql.connect.proto as proto
@@ -637,13 +638,23 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
plan2 = df.repartition(20)._plan.to_proto(self.connect)
self.assertTrue(plan2.root.repartition.shuffle)
- with self.assertRaises(ValueError) as context:
+ with self.assertRaises(PySparkValueError) as pe:
df.coalesce(-1)._plan.to_proto(self.connect)
- self.assertTrue("numPartitions must be positive" in
str(context.exception))
- with self.assertRaises(ValueError) as context:
+ self.check_error(
+ exception=pe.exception,
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={"arg_name": "numPartitions", "arg_value":
"-1"},
+ )
+
+ with self.assertRaises(PySparkValueError) as pe:
df.repartition(-1)._plan.to_proto(self.connect)
- self.assertTrue("numPartitions must be positive" in
str(context.exception))
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="VALUE_NOT_POSITIVE",
+ message_parameters={"arg_name": "numPartitions", "arg_value":
"-1"},
+ )
def test_repartition_by_expression(self):
# SPARK-41354: test dataframe.repartition(expressions)
diff --git a/python/pyspark/testing/connectutils.py
b/python/pyspark/testing/connectutils.py
index 5d57ad803bc..c1ca57aa3cc 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -87,7 +87,7 @@ class MockRemoteSession:
@unittest.skipIf(not should_test_connect, connect_requirement_message)
-class PlanOnlyTestFixture(unittest.TestCase):
+class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
@classmethod
def _read_table(cls, table_name):
return DataFrame.withPlan(Read(table_name), cls.connect)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]