This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 05cdec8e065 [SPARK-45412][PYTHON][CONNECT] Validate the plan and
session in `DataFrame.__init__`
05cdec8e065 is described below
commit 05cdec8e06507b7e4d84198209038c39fbb234d1
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Oct 5 09:36:08 2023 +0900
[SPARK-45412][PYTHON][CONNECT] Validate the plan and session in
`DataFrame.__init__`
### What changes were proposed in this pull request?
Validate the plan and session in `DataFrame.__init__`
### Why are the changes needed?
- many dataframe APIs validate the plan and session, but throw different
exceptions: `SparkConnectException`, `Exception`, `PySparkValueError`,
`AssertionError`.
- there are still some method haven't validate the plan and session
### Does this PR introduce _any_ user-facing change?
yes, the exceptions are unified
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43215 from zhengruifeng/df_validate_v2.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 91 +++++++--------------------------
python/pyspark/testing/connectutils.py | 7 ++-
2 files changed, 21 insertions(+), 77 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index e4e6613f997..9e46c8e1bf3 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -58,8 +58,8 @@ from pyspark.errors import (
PySparkAttributeError,
PySparkValueError,
PySparkNotImplementedError,
+ PySparkRuntimeError,
)
-from pyspark.errors.exceptions.connect import SparkConnectException
from pyspark.rdd import PythonEvalType
from pyspark.storagelevel import StorageLevel
import pyspark.sql.connect.plan as plan
@@ -99,11 +99,24 @@ if TYPE_CHECKING:
class DataFrame:
def __init__(
self,
+ plan: plan.LogicalPlan,
session: "SparkSession",
):
"""Creates a new data frame"""
- self._plan: Optional[plan.LogicalPlan] = None
+ self._plan = plan
+ if self._plan is None:
+ raise PySparkRuntimeError(
+ error_class="MISSING_VALID_PLAN",
+ message_parameters={"operator": "__init__"},
+ )
+
self._session: "SparkSession" = session
+ if self._session is None:
+ raise PySparkRuntimeError(
+ error_class="NO_ACTIVE_SESSION",
+ message_parameters={"operator": "__init__"},
+ )
+
# Check whether _repr_html is supported or not, we use it to avoid
calling RPC twice
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
@@ -212,10 +225,6 @@ class DataFrame:
alias.__doc__ = PySparkDataFrame.alias.__doc__
def colRegex(self, colName: str) -> Column:
- if self._plan is None:
- raise SparkConnectException("Cannot colRegex on empty plan.")
- if self._session is None:
- raise Exception("Cannot analyze without SparkSession.")
if not isinstance(colName, str):
raise PySparkTypeError(
error_class="NOT_STR",
@@ -250,10 +259,6 @@ class DataFrame:
count.__doc__ = PySparkDataFrame.count.__doc__
def crossJoin(self, other: "DataFrame") -> "DataFrame":
- if self._plan is None:
- raise Exception("Cannot cartesian join when self._plan is empty.")
- if other._plan is None:
- raise Exception("Cannot cartesian join when other._plan is empty.")
self.checkSameSparkSession(other)
return DataFrame.withPlan(
plan.Join(left=self._plan, right=other._plan, on=None,
how="cross"),
@@ -573,10 +578,6 @@ class DataFrame:
on: Optional[Union[str, List[str], Column, List[Column]]] = None,
how: Optional[str] = None,
) -> "DataFrame":
- if self._plan is None:
- raise Exception("Cannot join when self._plan is empty.")
- if other._plan is None:
- raise Exception("Cannot join when other._plan is empty.")
if how is not None and isinstance(how, str):
how = how.lower().replace("_", "")
self.checkSameSparkSession(other)
@@ -599,11 +600,6 @@ class DataFrame:
allowExactMatches: bool = True,
direction: str = "backward",
) -> "DataFrame":
- if self._plan is None:
- raise Exception("Cannot join when self._plan is empty.")
- if other._plan is None:
- raise Exception("Cannot join when other._plan is empty.")
-
if how is None:
how = "inner"
assert isinstance(how, str), "how should be a string"
@@ -1111,11 +1107,6 @@ class DataFrame:
unionAll.__doc__ = PySparkDataFrame.unionAll.__doc__
def unionByName(self, other: "DataFrame", allowMissingColumns: bool =
False) -> "DataFrame":
- if other._plan is None:
- raise PySparkValueError(
- error_class="MISSING_VALID_PLAN",
- message_parameters={"operator": "UnionByName"},
- )
self.checkSameSparkSession(other)
return DataFrame.withPlan(
plan.SetOperation(
@@ -1649,9 +1640,6 @@ class DataFrame:
sampleBy.__doc__ = PySparkDataFrame.sampleBy.__doc__
def __getattr__(self, name: str) -> "Column":
- if self._plan is None:
- raise SparkConnectException("Cannot analyze on empty plan.")
-
if name in ["_jseq", "_jdf", "_jmap", "_jcols"]:
raise PySparkAttributeError(
error_class="JVM_ATTRIBUTE_NOT_SUPPORTED",
message_parameters={"attr_name": name}
@@ -1736,20 +1724,12 @@ class DataFrame:
collect.__doc__ = PySparkDataFrame.collect.__doc__
def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
- if self._plan is None:
- raise Exception("Cannot collect on empty plan.")
- if self._session is None:
- raise Exception("Cannot collect on empty session.")
query = self._plan.to_proto(self._session.client)
table, schema = self._session.client.to_table(query)
assert table is not None
return (table, schema)
def toPandas(self) -> "pandas.DataFrame":
- if self._plan is None:
- raise Exception("Cannot collect on empty plan.")
- if self._session is None:
- raise Exception("Cannot collect on empty session.")
query = self._plan.to_proto(self._session.client)
return self._session.client.to_pandas(query)
@@ -1757,19 +1737,12 @@ class DataFrame:
@property
def schema(self) -> StructType:
- if self._plan is not None:
- if self._session is None:
- raise Exception("Cannot analyze without SparkSession.")
- query = self._plan.to_proto(self._session.client)
- return self._session.client.schema(query)
- else:
- raise Exception("Empty plan.")
+ query = self._plan.to_proto(self._session.client)
+ return self._session.client.schema(query)
schema.__doc__ = PySparkDataFrame.schema.__doc__
def isLocal(self) -> bool:
- if self._plan is None:
- raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session.client)
result = self._session.client._analyze(method="is_local",
plan=query).is_local
assert result is not None
@@ -1779,8 +1752,6 @@ class DataFrame:
@property
def isStreaming(self) -> bool:
- if self._plan is None:
- raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session.client)
result = self._session.client._analyze(method="is_streaming",
plan=query).is_streaming
assert result is not None
@@ -1789,8 +1760,6 @@ class DataFrame:
isStreaming.__doc__ = PySparkDataFrame.isStreaming.__doc__
def _tree_string(self, level: Optional[int] = None) -> str:
- if self._plan is None:
- raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session.client)
result = self._session.client._analyze(
method="tree_string", plan=query, level=level
@@ -1804,8 +1773,6 @@ class DataFrame:
printSchema.__doc__ = PySparkDataFrame.printSchema.__doc__
def inputFiles(self) -> List[str]:
- if self._plan is None:
- raise Exception("Cannot analyze on empty plan.")
query = self._plan.to_proto(self._session.client)
result = self._session.client._analyze(method="input_files",
plan=query).input_files
assert result is not None
@@ -1845,11 +1812,6 @@ class DataFrame:
def _explain_string(
self, extended: Optional[Union[bool, str]] = None, mode: Optional[str]
= None
) -> str:
- if self._plan is None:
- raise SparkConnectException("Cannot explain on empty plan.")
- if self._session is None:
- raise Exception("Cannot analyze without SparkSession.")
-
if extended is not None and mode is not None:
raise PySparkValueError(
error_class="CANNOT_SET_TOGETHER",
@@ -1935,8 +1897,6 @@ class DataFrame:
createOrReplaceGlobalTempView.__doc__ =
PySparkDataFrame.createOrReplaceGlobalTempView.__doc__
def cache(self) -> "DataFrame":
- if self._plan is None:
- raise Exception("Cannot cache on empty plan.")
return self.persist()
cache.__doc__ = PySparkDataFrame.cache.__doc__
@@ -1945,8 +1905,6 @@ class DataFrame:
self,
storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
) -> "DataFrame":
- if self._plan is None:
- raise Exception("Cannot persist on empty plan.")
relation = self._plan.plan(self._session.client)
self._session.client._analyze(
method="persist", relation=relation, storage_level=storageLevel
@@ -1957,8 +1915,6 @@ class DataFrame:
@property
def storageLevel(self) -> StorageLevel:
- if self._plan is None:
- raise Exception("Cannot persist on empty plan.")
relation = self._plan.plan(self._session.client)
storage_level = self._session.client._analyze(
method="get_storage_level", relation=relation
@@ -1969,8 +1925,6 @@ class DataFrame:
storageLevel.__doc__ = PySparkDataFrame.storageLevel.__doc__
def unpersist(self, blocking: bool = False) -> "DataFrame":
- if self._plan is None:
- raise Exception("Cannot unpersist on empty plan.")
relation = self._plan.plan(self._session.client)
self._session.client._analyze(method="unpersist", relation=relation,
blocking=blocking)
return self
@@ -1982,10 +1936,6 @@ class DataFrame:
return self.storageLevel != StorageLevel.NONE
def toLocalIterator(self, prefetchPartitions: bool = False) ->
Iterator[Row]:
- if self._plan is None:
- raise Exception("Cannot collect on empty plan.")
- if self._session is None:
- raise Exception("Cannot collect on empty session.")
query = self._plan.to_proto(self._session.client)
schema: Optional[StructType] = None
@@ -2043,9 +1993,6 @@ class DataFrame:
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
- if self._plan is None:
- raise Exception("Cannot mapInPandas when self._plan is empty.")
-
udf_obj = UserDefinedFunction(
func,
returnType=schema,
@@ -2160,9 +2107,7 @@ class DataFrame:
Main initialization method used to construct a new data frame with a
child plan.
This is for internal purpose.
"""
- new_frame = DataFrame(session=session)
- new_frame._plan = plan
- return new_frame
+ return DataFrame(plan=plan, session=session)
class DataFrameNaFunctions:
diff --git a/python/pyspark/testing/connectutils.py
b/python/pyspark/testing/connectutils.py
index 33c920eff86..01b109d2016 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -93,9 +93,8 @@ class MockRemoteSession:
class MockDF(DataFrame):
"""Helper class that must only be used for the mock plan tests."""
- def __init__(self, session: SparkSession, plan: LogicalPlan):
- super().__init__(session)
- self._plan = plan
+ def __init__(self, plan: LogicalPlan, session: SparkSession):
+ super().__init__(plan, session)
def __getattr__(self, name):
"""All attributes are resolved to columns, because none really exist
in the
@@ -115,7 +114,7 @@ class PlanOnlyTestFixture(unittest.TestCase,
PySparkErrorTestUtils):
@classmethod
def _df_mock(cls, plan: LogicalPlan) -> MockDF:
- return MockDF(cls.connect, plan)
+ return MockDF(plan, cls.connect)
@classmethod
def _session_range(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]