This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 31a81983fb2d [SPARK-45780][CONNECT] Propagate all Spark Connect client
threadlocals in InheritableThread
31a81983fb2d is described below
commit 31a81983fb2d19e05fadccdf49c37dd4f5c50465
Author: Juliusz Sompolski <[email protected]>
AuthorDate: Fri Nov 3 17:53:55 2023 -0700
[SPARK-45780][CONNECT] Propagate all Spark Connect client threadlocals in
InheritableThread
### What changes were proposed in this pull request?
Currently pyspark InheritableThread propagates Spark Connect
session.client.thread_local.tags to child threads. Generalize this to propagate
all thread locals, and also make a deep copy, just like the scala equivalent
does a clone.
### Why are the changes needed?
Generalize the mechanism of SparkConnectClient.thread_local
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing test for propagating SparkSession tags should cover this.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43649 from juliuszsompolski/SPARK-45780.
Authored-by: Juliusz Sompolski <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/util.py | 16 +++++++++++-----
1 file changed, 11 insertions(+), 5 deletions(-)
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 9c70bac2a3d9..4a828d6bfc94 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -16,6 +16,7 @@
# limitations under the License.
#
+import copy
import functools
import itertools
import os
@@ -343,14 +344,19 @@ def inheritable_thread_target(f: Optional[Union[Callable,
"SparkSession"]] = Non
assert session is not None, "Spark Connect session must be provided."
def outer(ff: Callable) -> Callable:
- if not hasattr(session.client.thread_local, "tags"): # type:
ignore[union-attr]
- session.client.thread_local.tags = set() # type:
ignore[union-attr]
- tags = set(session.client.thread_local.tags) # type:
ignore[union-attr]
+ session_client_thread_local_attrs = [
+ (attr, copy.deepcopy(value))
+ for (
+ attr,
+ value,
+ ) in session.client.thread_local.__dict__.items() # type:
ignore[union-attr]
+ ]
@functools.wraps(ff)
def inner(*args: Any, **kwargs: Any) -> Any:
- # Set tags in child thread.
- session.client.thread_local.tags = tags # type:
ignore[union-attr]
+ # Set thread locals in child thread.
+ for attr, value in session_client_thread_local_attrs:
+ setattr(session.client.thread_local, attr, value) # type:
ignore[union-attr]
return ff(*args, **kwargs)
return inner
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]