This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 05cdec8e065 [SPARK-45412][PYTHON][CONNECT] Validate the plan and 
session in `DataFrame.__init__`
05cdec8e065 is described below

commit 05cdec8e06507b7e4d84198209038c39fbb234d1
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Oct 5 09:36:08 2023 +0900

    [SPARK-45412][PYTHON][CONNECT] Validate the plan and session in 
`DataFrame.__init__`
    
    ### What changes were proposed in this pull request?
    Validate the plan and session in `DataFrame.__init__`
    
    ### Why are the changes needed?
    
    - many dataframe APIs validate the plan and session, but throw different 
exceptions: `SparkConnectException`, `Exception`, `PySparkValueError`, 
`AssertionError`.
    - there are still some method haven't validate the plan and session
    
    ### Does this PR introduce _any_ user-facing change?
    yes, the exceptions are unified
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43215 from zhengruifeng/df_validate_v2.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py | 91 +++++++--------------------------
 python/pyspark/testing/connectutils.py  |  7 ++-
 2 files changed, 21 insertions(+), 77 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index e4e6613f997..9e46c8e1bf3 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -58,8 +58,8 @@ from pyspark.errors import (
     PySparkAttributeError,
     PySparkValueError,
     PySparkNotImplementedError,
+    PySparkRuntimeError,
 )
-from pyspark.errors.exceptions.connect import SparkConnectException
 from pyspark.rdd import PythonEvalType
 from pyspark.storagelevel import StorageLevel
 import pyspark.sql.connect.plan as plan
@@ -99,11 +99,24 @@ if TYPE_CHECKING:
 class DataFrame:
     def __init__(
         self,
+        plan: plan.LogicalPlan,
         session: "SparkSession",
     ):
         """Creates a new data frame"""
-        self._plan: Optional[plan.LogicalPlan] = None
+        self._plan = plan
+        if self._plan is None:
+            raise PySparkRuntimeError(
+                error_class="MISSING_VALID_PLAN",
+                message_parameters={"operator": "__init__"},
+            )
+
         self._session: "SparkSession" = session
+        if self._session is None:
+            raise PySparkRuntimeError(
+                error_class="NO_ACTIVE_SESSION",
+                message_parameters={"operator": "__init__"},
+            )
+
         # Check whether _repr_html is supported or not, we use it to avoid 
calling RPC twice
         # by __repr__ and _repr_html_ while eager evaluation opens.
         self._support_repr_html = False
@@ -212,10 +225,6 @@ class DataFrame:
     alias.__doc__ = PySparkDataFrame.alias.__doc__
 
     def colRegex(self, colName: str) -> Column:
-        if self._plan is None:
-            raise SparkConnectException("Cannot colRegex on empty plan.")
-        if self._session is None:
-            raise Exception("Cannot analyze without SparkSession.")
         if not isinstance(colName, str):
             raise PySparkTypeError(
                 error_class="NOT_STR",
@@ -250,10 +259,6 @@ class DataFrame:
     count.__doc__ = PySparkDataFrame.count.__doc__
 
     def crossJoin(self, other: "DataFrame") -> "DataFrame":
-        if self._plan is None:
-            raise Exception("Cannot cartesian join when self._plan is empty.")
-        if other._plan is None:
-            raise Exception("Cannot cartesian join when other._plan is empty.")
         self.checkSameSparkSession(other)
         return DataFrame.withPlan(
             plan.Join(left=self._plan, right=other._plan, on=None, 
how="cross"),
@@ -573,10 +578,6 @@ class DataFrame:
         on: Optional[Union[str, List[str], Column, List[Column]]] = None,
         how: Optional[str] = None,
     ) -> "DataFrame":
-        if self._plan is None:
-            raise Exception("Cannot join when self._plan is empty.")
-        if other._plan is None:
-            raise Exception("Cannot join when other._plan is empty.")
         if how is not None and isinstance(how, str):
             how = how.lower().replace("_", "")
         self.checkSameSparkSession(other)
@@ -599,11 +600,6 @@ class DataFrame:
         allowExactMatches: bool = True,
         direction: str = "backward",
     ) -> "DataFrame":
-        if self._plan is None:
-            raise Exception("Cannot join when self._plan is empty.")
-        if other._plan is None:
-            raise Exception("Cannot join when other._plan is empty.")
-
         if how is None:
             how = "inner"
         assert isinstance(how, str), "how should be a string"
@@ -1111,11 +1107,6 @@ class DataFrame:
     unionAll.__doc__ = PySparkDataFrame.unionAll.__doc__
 
     def unionByName(self, other: "DataFrame", allowMissingColumns: bool = 
False) -> "DataFrame":
-        if other._plan is None:
-            raise PySparkValueError(
-                error_class="MISSING_VALID_PLAN",
-                message_parameters={"operator": "UnionByName"},
-            )
         self.checkSameSparkSession(other)
         return DataFrame.withPlan(
             plan.SetOperation(
@@ -1649,9 +1640,6 @@ class DataFrame:
     sampleBy.__doc__ = PySparkDataFrame.sampleBy.__doc__
 
     def __getattr__(self, name: str) -> "Column":
-        if self._plan is None:
-            raise SparkConnectException("Cannot analyze on empty plan.")
-
         if name in ["_jseq", "_jdf", "_jmap", "_jcols"]:
             raise PySparkAttributeError(
                 error_class="JVM_ATTRIBUTE_NOT_SUPPORTED", 
message_parameters={"attr_name": name}
@@ -1736,20 +1724,12 @@ class DataFrame:
     collect.__doc__ = PySparkDataFrame.collect.__doc__
 
     def _to_table(self) -> Tuple["pa.Table", Optional[StructType]]:
-        if self._plan is None:
-            raise Exception("Cannot collect on empty plan.")
-        if self._session is None:
-            raise Exception("Cannot collect on empty session.")
         query = self._plan.to_proto(self._session.client)
         table, schema = self._session.client.to_table(query)
         assert table is not None
         return (table, schema)
 
     def toPandas(self) -> "pandas.DataFrame":
-        if self._plan is None:
-            raise Exception("Cannot collect on empty plan.")
-        if self._session is None:
-            raise Exception("Cannot collect on empty session.")
         query = self._plan.to_proto(self._session.client)
         return self._session.client.to_pandas(query)
 
@@ -1757,19 +1737,12 @@ class DataFrame:
 
     @property
     def schema(self) -> StructType:
-        if self._plan is not None:
-            if self._session is None:
-                raise Exception("Cannot analyze without SparkSession.")
-            query = self._plan.to_proto(self._session.client)
-            return self._session.client.schema(query)
-        else:
-            raise Exception("Empty plan.")
+        query = self._plan.to_proto(self._session.client)
+        return self._session.client.schema(query)
 
     schema.__doc__ = PySparkDataFrame.schema.__doc__
 
     def isLocal(self) -> bool:
-        if self._plan is None:
-            raise Exception("Cannot analyze on empty plan.")
         query = self._plan.to_proto(self._session.client)
         result = self._session.client._analyze(method="is_local", 
plan=query).is_local
         assert result is not None
@@ -1779,8 +1752,6 @@ class DataFrame:
 
     @property
     def isStreaming(self) -> bool:
-        if self._plan is None:
-            raise Exception("Cannot analyze on empty plan.")
         query = self._plan.to_proto(self._session.client)
         result = self._session.client._analyze(method="is_streaming", 
plan=query).is_streaming
         assert result is not None
@@ -1789,8 +1760,6 @@ class DataFrame:
     isStreaming.__doc__ = PySparkDataFrame.isStreaming.__doc__
 
     def _tree_string(self, level: Optional[int] = None) -> str:
-        if self._plan is None:
-            raise Exception("Cannot analyze on empty plan.")
         query = self._plan.to_proto(self._session.client)
         result = self._session.client._analyze(
             method="tree_string", plan=query, level=level
@@ -1804,8 +1773,6 @@ class DataFrame:
     printSchema.__doc__ = PySparkDataFrame.printSchema.__doc__
 
     def inputFiles(self) -> List[str]:
-        if self._plan is None:
-            raise Exception("Cannot analyze on empty plan.")
         query = self._plan.to_proto(self._session.client)
         result = self._session.client._analyze(method="input_files", 
plan=query).input_files
         assert result is not None
@@ -1845,11 +1812,6 @@ class DataFrame:
     def _explain_string(
         self, extended: Optional[Union[bool, str]] = None, mode: Optional[str] 
= None
     ) -> str:
-        if self._plan is None:
-            raise SparkConnectException("Cannot explain on empty plan.")
-        if self._session is None:
-            raise Exception("Cannot analyze without SparkSession.")
-
         if extended is not None and mode is not None:
             raise PySparkValueError(
                 error_class="CANNOT_SET_TOGETHER",
@@ -1935,8 +1897,6 @@ class DataFrame:
     createOrReplaceGlobalTempView.__doc__ = 
PySparkDataFrame.createOrReplaceGlobalTempView.__doc__
 
     def cache(self) -> "DataFrame":
-        if self._plan is None:
-            raise Exception("Cannot cache on empty plan.")
         return self.persist()
 
     cache.__doc__ = PySparkDataFrame.cache.__doc__
@@ -1945,8 +1905,6 @@ class DataFrame:
         self,
         storageLevel: StorageLevel = (StorageLevel.MEMORY_AND_DISK_DESER),
     ) -> "DataFrame":
-        if self._plan is None:
-            raise Exception("Cannot persist on empty plan.")
         relation = self._plan.plan(self._session.client)
         self._session.client._analyze(
             method="persist", relation=relation, storage_level=storageLevel
@@ -1957,8 +1915,6 @@ class DataFrame:
 
     @property
     def storageLevel(self) -> StorageLevel:
-        if self._plan is None:
-            raise Exception("Cannot persist on empty plan.")
         relation = self._plan.plan(self._session.client)
         storage_level = self._session.client._analyze(
             method="get_storage_level", relation=relation
@@ -1969,8 +1925,6 @@ class DataFrame:
     storageLevel.__doc__ = PySparkDataFrame.storageLevel.__doc__
 
     def unpersist(self, blocking: bool = False) -> "DataFrame":
-        if self._plan is None:
-            raise Exception("Cannot unpersist on empty plan.")
         relation = self._plan.plan(self._session.client)
         self._session.client._analyze(method="unpersist", relation=relation, 
blocking=blocking)
         return self
@@ -1982,10 +1936,6 @@ class DataFrame:
         return self.storageLevel != StorageLevel.NONE
 
     def toLocalIterator(self, prefetchPartitions: bool = False) -> 
Iterator[Row]:
-        if self._plan is None:
-            raise Exception("Cannot collect on empty plan.")
-        if self._session is None:
-            raise Exception("Cannot collect on empty session.")
         query = self._plan.to_proto(self._session.client)
 
         schema: Optional[StructType] = None
@@ -2043,9 +1993,6 @@ class DataFrame:
     ) -> "DataFrame":
         from pyspark.sql.connect.udf import UserDefinedFunction
 
-        if self._plan is None:
-            raise Exception("Cannot mapInPandas when self._plan is empty.")
-
         udf_obj = UserDefinedFunction(
             func,
             returnType=schema,
@@ -2160,9 +2107,7 @@ class DataFrame:
         Main initialization method used to construct a new data frame with a 
child plan.
         This is for internal purpose.
         """
-        new_frame = DataFrame(session=session)
-        new_frame._plan = plan
-        return new_frame
+        return DataFrame(plan=plan, session=session)
 
 
 class DataFrameNaFunctions:
diff --git a/python/pyspark/testing/connectutils.py 
b/python/pyspark/testing/connectutils.py
index 33c920eff86..01b109d2016 100644
--- a/python/pyspark/testing/connectutils.py
+++ b/python/pyspark/testing/connectutils.py
@@ -93,9 +93,8 @@ class MockRemoteSession:
 class MockDF(DataFrame):
     """Helper class that must only be used for the mock plan tests."""
 
-    def __init__(self, session: SparkSession, plan: LogicalPlan):
-        super().__init__(session)
-        self._plan = plan
+    def __init__(self, plan: LogicalPlan, session: SparkSession):
+        super().__init__(plan, session)
 
     def __getattr__(self, name):
         """All attributes are resolved to columns, because none really exist 
in the
@@ -115,7 +114,7 @@ class PlanOnlyTestFixture(unittest.TestCase, 
PySparkErrorTestUtils):
 
     @classmethod
     def _df_mock(cls, plan: LogicalPlan) -> MockDF:
-        return MockDF(cls.connect, plan)
+        return MockDF(plan, cls.connect)
 
     @classmethod
     def _session_range(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to