This is an automated email from the ASF dual-hosted git repository.

haejoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ee21e6b07a0d [SPARK-50113][CONNECT][PYTHON][TESTS] Add `@remote_only` 
to check the APIs that only supported with Spark Connect
ee21e6b07a0d is described below

commit ee21e6b07a0d30cbdf78a2dd6bfe43d8fc23d518
Author: Haejoon Lee <[email protected]>
AuthorDate: Thu Nov 21 17:29:18 2024 +0900

    [SPARK-50113][CONNECT][PYTHON][TESTS] Add `@remote_only` to check the APIs 
that only supported with Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to add `remote_only` to check the APIs that only supported 
with Spark Connect
    
    ### Why are the changes needed?
    
    The current compatibility check cannot capture the missing methods that 
only supported with Spark Connect
    
    ### Does this PR introduce _any_ user-facing change?
    
    No, it's test-only
    
    ### How was this patch tested?
    
    Updated the existing UT
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #48651 from itholic/SPARK-50113.
    
    Authored-by: Haejoon Lee <[email protected]>
    Signed-off-by: Haejoon Lee <[email protected]>
---
 python/pyspark/sql/session.py                      | 21 ++++++++++-
 .../sql/tests/test_connect_compatibility.py        | 43 ++++++++++++++++++----
 python/pyspark/sql/utils.py                        | 16 ++++++++
 python/pyspark/util.py                             | 19 +++++++---
 4 files changed, 84 insertions(+), 15 deletions(-)

diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index ef8750b6e72d..7231d6c10b0b 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -63,7 +63,12 @@ from pyspark.sql.types import (
     _from_numpy_type,
 )
 from pyspark.errors.exceptions.captured import install_exception_handler
-from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str, 
try_remote_session_classmethod
+from pyspark.sql.utils import (
+    is_timestamp_ntz_preferred,
+    to_str,
+    try_remote_session_classmethod,
+    remote_only,
+)
 from pyspark.errors import PySparkValueError, PySparkTypeError, 
PySparkRuntimeError
 
 if TYPE_CHECKING:
@@ -550,6 +555,7 @@ class SparkSession(SparkConversionMixin):
                 return session
 
         # Spark Connect-specific API
+        @remote_only
         def create(self) -> "SparkSession":
             """Creates a new SparkSession. Can only be used in the context of 
Spark Connect
             and will throw an exception otherwise.
@@ -2067,6 +2073,7 @@ class SparkSession(SparkConversionMixin):
 
     # SparkConnect-specific API
     @property
+    @remote_only
     def client(self) -> "SparkConnectClient":
         """
         Gives access to the Spark Connect client. In normal cases this is not 
necessary to be used
@@ -2090,6 +2097,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.client"},
         )
 
+    @remote_only
     def addArtifacts(
         self, *path: str, pyfile: bool = False, archive: bool = False, file: 
bool = False
     ) -> None:
@@ -2125,6 +2133,7 @@ class SparkSession(SparkConversionMixin):
 
     addArtifact = addArtifacts
 
+    @remote_only
     def registerProgressHandler(self, handler: "ProgressHandler") -> None:
         """
         Register a progress handler to be called when a progress update is 
received from the server.
@@ -2153,6 +2162,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": 
"SparkSession.registerProgressHandler"},
         )
 
+    @remote_only
     def removeProgressHandler(self, handler: "ProgressHandler") -> None:
         """
         Remove a progress handler that was previously registered.
@@ -2169,6 +2179,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": 
"SparkSession.removeProgressHandler"},
         )
 
+    @remote_only
     def clearProgressHandlers(self) -> None:
         """
         Clear all registered progress handlers.
@@ -2180,6 +2191,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": 
"SparkSession.clearProgressHandlers"},
         )
 
+    @remote_only
     def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
         """
         Copy file from local to cloud storage file system.
@@ -2208,6 +2220,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.copyFromLocalToFs"},
         )
 
+    @remote_only
     def interruptAll(self) -> List[str]:
         """
         Interrupt all operations of this session currently running on the 
connected server.
@@ -2228,6 +2241,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.interruptAll"},
         )
 
+    @remote_only
     def interruptTag(self, tag: str) -> List[str]:
         """
         Interrupt all operations of this session with the given operation tag.
@@ -2248,6 +2262,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.interruptTag"},
         )
 
+    @remote_only
     def interruptOperation(self, op_id: str) -> List[str]:
         """
         Interrupt an operation of this session with the given operationId.
@@ -2268,6 +2283,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.interruptOperation"},
         )
 
+    @remote_only
     def addTag(self, tag: str) -> None:
         """
         Add a tag to be assigned to all the operations started by this thread 
in this session.
@@ -2292,6 +2308,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.addTag"},
         )
 
+    @remote_only
     def removeTag(self, tag: str) -> None:
         """
         Remove a tag previously added to be assigned to all the operations 
started by this thread in
@@ -2309,6 +2326,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.removeTag"},
         )
 
+    @remote_only
     def getTags(self) -> Set[str]:
         """
         Get the tags that are currently set to be assigned to all the 
operations started by this
@@ -2326,6 +2344,7 @@ class SparkSession(SparkConversionMixin):
             messageParameters={"feature": "SparkSession.getTags"},
         )
 
+    @remote_only
     def clearTags(self) -> None:
         """
         Clear the current thread's operation tags.
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py 
b/python/pyspark/sql/tests/test_connect_compatibility.py
index e20188e8da6f..3d74e796cd7a 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -64,12 +64,16 @@ if should_test_connect:
 class ConnectCompatibilityTestsMixin:
     def get_public_methods(self, cls):
         """Get public methods of a class."""
-        return {
-            name: method
-            for name, method in inspect.getmembers(cls)
-            if (inspect.isfunction(method) or isinstance(method, 
functools._lru_cache_wrapper))
-            and not name.startswith("_")
-        }
+        methods = {}
+        for name, method in inspect.getmembers(cls):
+            if (
+                inspect.isfunction(method) or isinstance(method, 
functools._lru_cache_wrapper)
+            ) and not name.startswith("_"):
+                if getattr(method, "_remote_only", False):
+                    methods[name] = None
+                else:
+                    methods[name] = method
+        return methods
 
     def get_public_properties(self, cls):
         """Get public properties of a class."""
@@ -88,6 +92,10 @@ class ConnectCompatibilityTestsMixin:
         common_methods = set(classic_methods.keys()) & 
set(connect_methods.keys())
 
         for method in common_methods:
+            # Skip non-callable, Spark Connect-specific methods
+            if classic_methods[method] is None or connect_methods[method] is 
None:
+                continue
+
             classic_signature = inspect.signature(classic_methods[method])
             connect_signature = inspect.signature(connect_methods[method])
 
@@ -145,7 +153,11 @@ class ConnectCompatibilityTestsMixin:
         connect_methods = self.get_public_methods(connect_cls)
 
         # Identify missing methods
-        classic_only_methods = set(classic_methods.keys()) - 
set(connect_methods.keys())
+        classic_only_methods = {
+            name
+            for name, method in classic_methods.items()
+            if name not in connect_methods or method is None
+        }
         connect_only_methods = set(connect_methods.keys()) - 
set(classic_methods.keys())
 
         # Compare the actual missing methods with the expected ones
@@ -249,7 +261,22 @@ class ConnectCompatibilityTestsMixin:
         """Test SparkSession compatibility between classic and connect."""
         expected_missing_connect_properties = {"sparkContext"}
         expected_missing_classic_properties = {"is_stopped", "session_id"}
-        expected_missing_connect_methods = {"newSession"}
+        expected_missing_connect_methods = {
+            "addArtifact",
+            "addArtifacts",
+            "addTag",
+            "clearProgressHandlers",
+            "clearTags",
+            "copyFromLocalToFs",
+            "getTags",
+            "interruptAll",
+            "interruptOperation",
+            "interruptTag",
+            "newSession",
+            "registerProgressHandler",
+            "removeProgressHandler",
+            "removeTag",
+        }
         expected_missing_classic_methods = set()
         self.check_compatibility(
             ClassicSparkSession,
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index cb59e0c7b439..3cacc5b9d021 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -458,3 +458,19 @@ class NumpyHelper:
             return [float(start)]
         step = (float(stop) - float(start)) / (num - 1)
         return [start + step * i for i in range(num)]
+
+
+def remote_only(func: Union[Callable, property]) -> Union[Callable, property]:
+    """
+    Decorator to mark a function or method as only available in Spark Connect.
+
+    This decorator allows for easy identification of Spark Connect-specific 
APIs.
+    """
+    if isinstance(func, property):
+        # If it's a property, we need to set the attribute on the getter 
function
+        getter_func = func.fget
+        getter_func._remote_only = True  # type: ignore[union-attr]
+        return property(getter_func)
+    else:
+        func._remote_only = True  # type: ignore[attr-defined]
+        return func
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 86779de49a2e..3b38b8b72c61 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -382,19 +382,24 @@ def inheritable_thread_target(f: Optional[Union[Callable, 
"SparkSession"]] = Non
         assert session is not None, "Spark Connect session must be provided."
 
         def outer(ff: Callable) -> Callable:
+            thread_local = session.client.thread_local  # type: 
ignore[union-attr, operator]
             session_client_thread_local_attrs = [
                 (attr, copy.deepcopy(value))
                 for (
                     attr,
                     value,
-                ) in session.client.thread_local.__dict__.items()  # type: 
ignore[union-attr]
+                ) in thread_local.__dict__.items()
             ]
 
             @functools.wraps(ff)
             def inner(*args: Any, **kwargs: Any) -> Any:
                 # Set thread locals in child thread.
                 for attr, value in session_client_thread_local_attrs:
-                    setattr(session.client.thread_local, attr, value)  # type: 
ignore[union-attr]
+                    setattr(
+                        session.client.thread_local,  # type: 
ignore[union-attr, operator]
+                        attr,
+                        value,
+                    )
                 return ff(*args, **kwargs)
 
             return inner
@@ -489,7 +494,8 @@ class InheritableThread(threading.Thread):
             def copy_local_properties(*a: Any, **k: Any) -> Any:
                 # Set tags in child thread.
                 assert hasattr(self, "_tags")
-                session.client.thread_local.tags = self._tags  # type: 
ignore[union-attr, has-type]
+                thread_local = session.client.thread_local  # type: 
ignore[union-attr, operator]
+                thread_local.tags = self._tags  # type: ignore[has-type]
                 return target(*a, **k)
 
             super(InheritableThread, self).__init__(
@@ -523,9 +529,10 @@ class InheritableThread(threading.Thread):
         if is_remote():
             # Spark Connect
             assert hasattr(self, "_session")
-            if not hasattr(self._session.client.thread_local, "tags"):
-                self._session.client.thread_local.tags = set()
-            self._tags = set(self._session.client.thread_local.tags)
+            thread_local = self._session.client.thread_local  # type: 
ignore[union-attr, operator]
+            if not hasattr(thread_local, "tags"):
+                thread_local.tags = set()
+            self._tags = set(thread_local.tags)
         else:
             # Non Spark Connect
             from pyspark import SparkContext


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to