This is an automated email from the ASF dual-hosted git repository.
haejoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new ee21e6b07a0d [SPARK-50113][CONNECT][PYTHON][TESTS] Add `@remote_only`
to check the APIs that only supported with Spark Connect
ee21e6b07a0d is described below
commit ee21e6b07a0d30cbdf78a2dd6bfe43d8fc23d518
Author: Haejoon Lee <[email protected]>
AuthorDate: Thu Nov 21 17:29:18 2024 +0900
[SPARK-50113][CONNECT][PYTHON][TESTS] Add `@remote_only` to check the APIs
that only supported with Spark Connect
### What changes were proposed in this pull request?
This PR proposes to add `remote_only` to check the APIs that only supported
with Spark Connect
### Why are the changes needed?
The current compatibility check cannot capture the missing methods that
only supported with Spark Connect
### Does this PR introduce _any_ user-facing change?
No, it's test-only
### How was this patch tested?
Updated the existing UT
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48651 from itholic/SPARK-50113.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Haejoon Lee <[email protected]>
---
python/pyspark/sql/session.py | 21 ++++++++++-
.../sql/tests/test_connect_compatibility.py | 43 ++++++++++++++++++----
python/pyspark/sql/utils.py | 16 ++++++++
python/pyspark/util.py | 19 +++++++---
4 files changed, 84 insertions(+), 15 deletions(-)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index ef8750b6e72d..7231d6c10b0b 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -63,7 +63,12 @@ from pyspark.sql.types import (
_from_numpy_type,
)
from pyspark.errors.exceptions.captured import install_exception_handler
-from pyspark.sql.utils import is_timestamp_ntz_preferred, to_str,
try_remote_session_classmethod
+from pyspark.sql.utils import (
+ is_timestamp_ntz_preferred,
+ to_str,
+ try_remote_session_classmethod,
+ remote_only,
+)
from pyspark.errors import PySparkValueError, PySparkTypeError,
PySparkRuntimeError
if TYPE_CHECKING:
@@ -550,6 +555,7 @@ class SparkSession(SparkConversionMixin):
return session
# Spark Connect-specific API
+ @remote_only
def create(self) -> "SparkSession":
"""Creates a new SparkSession. Can only be used in the context of
Spark Connect
and will throw an exception otherwise.
@@ -2067,6 +2073,7 @@ class SparkSession(SparkConversionMixin):
# SparkConnect-specific API
@property
+ @remote_only
def client(self) -> "SparkConnectClient":
"""
Gives access to the Spark Connect client. In normal cases this is not
necessary to be used
@@ -2090,6 +2097,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.client"},
)
+ @remote_only
def addArtifacts(
self, *path: str, pyfile: bool = False, archive: bool = False, file:
bool = False
) -> None:
@@ -2125,6 +2133,7 @@ class SparkSession(SparkConversionMixin):
addArtifact = addArtifacts
+ @remote_only
def registerProgressHandler(self, handler: "ProgressHandler") -> None:
"""
Register a progress handler to be called when a progress update is
received from the server.
@@ -2153,6 +2162,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature":
"SparkSession.registerProgressHandler"},
)
+ @remote_only
def removeProgressHandler(self, handler: "ProgressHandler") -> None:
"""
Remove a progress handler that was previously registered.
@@ -2169,6 +2179,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature":
"SparkSession.removeProgressHandler"},
)
+ @remote_only
def clearProgressHandlers(self) -> None:
"""
Clear all registered progress handlers.
@@ -2180,6 +2191,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature":
"SparkSession.clearProgressHandlers"},
)
+ @remote_only
def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
"""
Copy file from local to cloud storage file system.
@@ -2208,6 +2220,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.copyFromLocalToFs"},
)
+ @remote_only
def interruptAll(self) -> List[str]:
"""
Interrupt all operations of this session currently running on the
connected server.
@@ -2228,6 +2241,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.interruptAll"},
)
+ @remote_only
def interruptTag(self, tag: str) -> List[str]:
"""
Interrupt all operations of this session with the given operation tag.
@@ -2248,6 +2262,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.interruptTag"},
)
+ @remote_only
def interruptOperation(self, op_id: str) -> List[str]:
"""
Interrupt an operation of this session with the given operationId.
@@ -2268,6 +2283,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.interruptOperation"},
)
+ @remote_only
def addTag(self, tag: str) -> None:
"""
Add a tag to be assigned to all the operations started by this thread
in this session.
@@ -2292,6 +2308,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.addTag"},
)
+ @remote_only
def removeTag(self, tag: str) -> None:
"""
Remove a tag previously added to be assigned to all the operations
started by this thread in
@@ -2309,6 +2326,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.removeTag"},
)
+ @remote_only
def getTags(self) -> Set[str]:
"""
Get the tags that are currently set to be assigned to all the
operations started by this
@@ -2326,6 +2344,7 @@ class SparkSession(SparkConversionMixin):
messageParameters={"feature": "SparkSession.getTags"},
)
+ @remote_only
def clearTags(self) -> None:
"""
Clear the current thread's operation tags.
diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py
b/python/pyspark/sql/tests/test_connect_compatibility.py
index e20188e8da6f..3d74e796cd7a 100644
--- a/python/pyspark/sql/tests/test_connect_compatibility.py
+++ b/python/pyspark/sql/tests/test_connect_compatibility.py
@@ -64,12 +64,16 @@ if should_test_connect:
class ConnectCompatibilityTestsMixin:
def get_public_methods(self, cls):
"""Get public methods of a class."""
- return {
- name: method
- for name, method in inspect.getmembers(cls)
- if (inspect.isfunction(method) or isinstance(method,
functools._lru_cache_wrapper))
- and not name.startswith("_")
- }
+ methods = {}
+ for name, method in inspect.getmembers(cls):
+ if (
+ inspect.isfunction(method) or isinstance(method,
functools._lru_cache_wrapper)
+ ) and not name.startswith("_"):
+ if getattr(method, "_remote_only", False):
+ methods[name] = None
+ else:
+ methods[name] = method
+ return methods
def get_public_properties(self, cls):
"""Get public properties of a class."""
@@ -88,6 +92,10 @@ class ConnectCompatibilityTestsMixin:
common_methods = set(classic_methods.keys()) &
set(connect_methods.keys())
for method in common_methods:
+ # Skip non-callable, Spark Connect-specific methods
+ if classic_methods[method] is None or connect_methods[method] is
None:
+ continue
+
classic_signature = inspect.signature(classic_methods[method])
connect_signature = inspect.signature(connect_methods[method])
@@ -145,7 +153,11 @@ class ConnectCompatibilityTestsMixin:
connect_methods = self.get_public_methods(connect_cls)
# Identify missing methods
- classic_only_methods = set(classic_methods.keys()) -
set(connect_methods.keys())
+ classic_only_methods = {
+ name
+ for name, method in classic_methods.items()
+ if name not in connect_methods or method is None
+ }
connect_only_methods = set(connect_methods.keys()) -
set(classic_methods.keys())
# Compare the actual missing methods with the expected ones
@@ -249,7 +261,22 @@ class ConnectCompatibilityTestsMixin:
"""Test SparkSession compatibility between classic and connect."""
expected_missing_connect_properties = {"sparkContext"}
expected_missing_classic_properties = {"is_stopped", "session_id"}
- expected_missing_connect_methods = {"newSession"}
+ expected_missing_connect_methods = {
+ "addArtifact",
+ "addArtifacts",
+ "addTag",
+ "clearProgressHandlers",
+ "clearTags",
+ "copyFromLocalToFs",
+ "getTags",
+ "interruptAll",
+ "interruptOperation",
+ "interruptTag",
+ "newSession",
+ "registerProgressHandler",
+ "removeProgressHandler",
+ "removeTag",
+ }
expected_missing_classic_methods = set()
self.check_compatibility(
ClassicSparkSession,
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index cb59e0c7b439..3cacc5b9d021 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -458,3 +458,19 @@ class NumpyHelper:
return [float(start)]
step = (float(stop) - float(start)) / (num - 1)
return [start + step * i for i in range(num)]
+
+
+def remote_only(func: Union[Callable, property]) -> Union[Callable, property]:
+ """
+ Decorator to mark a function or method as only available in Spark Connect.
+
+ This decorator allows for easy identification of Spark Connect-specific
APIs.
+ """
+ if isinstance(func, property):
+ # If it's a property, we need to set the attribute on the getter
function
+ getter_func = func.fget
+ getter_func._remote_only = True # type: ignore[union-attr]
+ return property(getter_func)
+ else:
+ func._remote_only = True # type: ignore[attr-defined]
+ return func
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 86779de49a2e..3b38b8b72c61 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -382,19 +382,24 @@ def inheritable_thread_target(f: Optional[Union[Callable,
"SparkSession"]] = Non
assert session is not None, "Spark Connect session must be provided."
def outer(ff: Callable) -> Callable:
+ thread_local = session.client.thread_local # type:
ignore[union-attr, operator]
session_client_thread_local_attrs = [
(attr, copy.deepcopy(value))
for (
attr,
value,
- ) in session.client.thread_local.__dict__.items() # type:
ignore[union-attr]
+ ) in thread_local.__dict__.items()
]
@functools.wraps(ff)
def inner(*args: Any, **kwargs: Any) -> Any:
# Set thread locals in child thread.
for attr, value in session_client_thread_local_attrs:
- setattr(session.client.thread_local, attr, value) # type:
ignore[union-attr]
+ setattr(
+ session.client.thread_local, # type:
ignore[union-attr, operator]
+ attr,
+ value,
+ )
return ff(*args, **kwargs)
return inner
@@ -489,7 +494,8 @@ class InheritableThread(threading.Thread):
def copy_local_properties(*a: Any, **k: Any) -> Any:
# Set tags in child thread.
assert hasattr(self, "_tags")
- session.client.thread_local.tags = self._tags # type:
ignore[union-attr, has-type]
+ thread_local = session.client.thread_local # type:
ignore[union-attr, operator]
+ thread_local.tags = self._tags # type: ignore[has-type]
return target(*a, **k)
super(InheritableThread, self).__init__(
@@ -523,9 +529,10 @@ class InheritableThread(threading.Thread):
if is_remote():
# Spark Connect
assert hasattr(self, "_session")
- if not hasattr(self._session.client.thread_local, "tags"):
- self._session.client.thread_local.tags = set()
- self._tags = set(self._session.client.thread_local.tags)
+ thread_local = self._session.client.thread_local # type:
ignore[union-attr, operator]
+ if not hasattr(thread_local, "tags"):
+ thread_local.tags = set()
+ self._tags = set(thread_local.tags)
else:
# Non Spark Connect
from pyspark import SparkContext
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]