This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new faf094a4a21a [SPARK-55904][PYTHON][CONNECT] Utilize 
_check_same_session to narrow down types
faf094a4a21a is described below

commit faf094a4a21af79c4fd9dcbf3ca70aef2d50b4f3
Author: Tian Gao <[email protected]>
AuthorDate: Tue Mar 10 10:42:21 2026 +0800

    [SPARK-55904][PYTHON][CONNECT] Utilize _check_same_session to narrow down 
types
    
    ### What changes were proposed in this pull request?
    
    * Make `_check_same_session` return the input argument if it's on the same 
session
    * Use that value to narrow down types so we can throw away some type: ignore
    
    ### Why are the changes needed?
    
    When we can narrow down types, we should do it so type hint works better 
(less exceptions).
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    mypy passed.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #54707 from gaogaotiantian/check-same-session.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py | 66 ++++++++++++++-------------------
 1 file changed, 28 insertions(+), 38 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 4c40f9512ce3..f07a14403698 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -315,20 +315,22 @@ class DataFrame(ParentDataFrame):
         return table[0][0].as_py()
 
     def crossJoin(self, other: ParentDataFrame) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         return DataFrame(
-            plan.Join(
-                left=self._plan, right=other._plan, on=None, how="cross"  # 
type: ignore[arg-type]
-            ),
+            plan.Join(left=self._plan, right=other._plan, on=None, 
how="cross"),
             session=self._session,
         )
 
-    def _check_same_session(self, other: ParentDataFrame) -> None:
-        if self._session.session_id != other._session.session_id:  # type: 
ignore[attr-defined]
+    def _check_same_session(self, other: ParentDataFrame) -> "DataFrame":
+        if (
+            not isinstance(other, DataFrame)
+            or self._session.session_id != other._session.session_id
+        ):
             raise SessionNotSameException(
                 errorClass="SESSION_NOT_SAME",
                 messageParameters={},
             )
+        return other
 
     def coalesce(self, numPartitions: int) -> ParentDataFrame:
         if not numPartitions > 0:
@@ -724,11 +726,11 @@ class DataFrame(ParentDataFrame):
         on: Optional[Union[str, List[str], Column, List[Column]]] = None,
         how: Optional[str] = None,
     ) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         if how is not None and isinstance(how, str):
             how = how.lower().replace("_", "")
         return DataFrame(
-            plan.Join(left=self._plan, right=other._plan, on=on, how=how),  # 
type: ignore[arg-type]
+            plan.Join(left=self._plan, right=other._plan, on=on, how=how),
             session=self._session,
         )
 
@@ -738,13 +740,11 @@ class DataFrame(ParentDataFrame):
         on: Optional[Column] = None,
         how: Optional[str] = None,
     ) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         if how is not None and isinstance(how, str):
             how = how.lower().replace("_", "")
         return DataFrame(
-            plan.LateralJoin(
-                left=self._plan, right=cast(plan.LogicalPlan, other._plan), 
on=on, how=how
-            ),
+            plan.LateralJoin(left=self._plan, right=other._plan, on=on, 
how=how),
             session=self._session,
         )
 
@@ -760,7 +760,7 @@ class DataFrame(ParentDataFrame):
         allowExactMatches: bool = True,
         direction: str = "backward",
     ) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         if how is None:
             how = "inner"
         assert isinstance(how, str), "how should be a string"
@@ -777,7 +777,7 @@ class DataFrame(ParentDataFrame):
         return DataFrame(
             plan.AsOfJoin(
                 left=self._plan,
-                right=other._plan,  # type: ignore[arg-type]
+                right=other._plan,
                 left_as_of=_convert_col(self, leftAsOfColumn),
                 right_as_of=_convert_col(other, rightAsOfColumn),
                 on=on,
@@ -1159,15 +1159,13 @@ class DataFrame(ParentDataFrame):
         return None
 
     def union(self, other: ParentDataFrame) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         return self.unionAll(other)
 
     def unionAll(self, other: ParentDataFrame) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         res = DataFrame(
-            plan.SetOperation(
-                self._plan, other._plan, "union", is_all=True  # type: 
ignore[arg-type]
-            ),
+            plan.SetOperation(self._plan, other._plan, "union", is_all=True),
             session=self._session,
         )
         res._cached_schema = self._merge_cached_schema(other)
@@ -1176,11 +1174,11 @@ class DataFrame(ParentDataFrame):
     def unionByName(
         self, other: ParentDataFrame, allowMissingColumns: bool = False
     ) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         res = DataFrame(
             plan.SetOperation(
                 self._plan,
-                other._plan,  # type: ignore[arg-type]
+                other._plan,
                 "union",
                 by_name=True,
                 allow_missing_columns=allowMissingColumns,
@@ -1191,22 +1189,18 @@ class DataFrame(ParentDataFrame):
         return res
 
     def subtract(self, other: ParentDataFrame) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         res = DataFrame(
-            plan.SetOperation(
-                self._plan, other._plan, "except", is_all=False  # type: 
ignore[arg-type]
-            ),
+            plan.SetOperation(self._plan, other._plan, "except", is_all=False),
             session=self._session,
         )
         res._cached_schema = self._merge_cached_schema(other)
         return res
 
     def exceptAll(self, other: ParentDataFrame) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         res = DataFrame(
-            plan.SetOperation(
-                self._plan, other._plan, "except", is_all=True  # type: 
ignore[arg-type]
-            ),
+            plan.SetOperation(self._plan, other._plan, "except", is_all=True),
             session=self._session,
         )
         res._cached_schema = self._merge_cached_schema(other)
@@ -1218,22 +1212,18 @@ class DataFrame(ParentDataFrame):
         )
 
     def intersect(self, other: ParentDataFrame) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         res = DataFrame(
-            plan.SetOperation(
-                self._plan, other._plan, "intersect", is_all=False  # type: 
ignore[arg-type]
-            ),
+            plan.SetOperation(self._plan, other._plan, "intersect", 
is_all=False),
             session=self._session,
         )
         res._cached_schema = self._merge_cached_schema(other)
         return res
 
     def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         res = DataFrame(
-            plan.SetOperation(
-                self._plan, other._plan, "intersect", is_all=True  # type: 
ignore[arg-type]
-            ),
+            plan.SetOperation(self._plan, other._plan, "intersect", 
is_all=True),
             session=self._session,
         )
         res._cached_schema = self._merge_cached_schema(other)
@@ -2226,7 +2216,7 @@ class DataFrame(ParentDataFrame):
                 errorClass="NOT_DATAFRAME",
                 messageParameters={"arg_name": "other", "arg_type": 
type(other).__name__},
             )
-        self._check_same_session(other)
+        other = self._check_same_session(other)
         return self._session.client.same_semantics(
             plan=self._plan.to_proto(self._session.client),
             other=other._plan.to_proto(other._session.client),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to