This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e1619653895 [SPARK-41593][FOLLOW-UP] Fix the case torch distributor
logging server not shut down
e1619653895 is described below
commit e1619653895b4d5e11d7121bdb7906355d8c17bf
Author: Weichen Xu <[email protected]>
AuthorDate: Tue May 30 19:13:20 2023 +0800
[SPARK-41593][FOLLOW-UP] Fix the case torch distributor logging server not
shut down
### What changes were proposed in this pull request?
Fix the case torch distributor logging server not shut down.
The `_get_spark_task_function` and `_check_encryption` might raise
exception, in this case, the logging server must be shut down but it is not
shut down. This PR fixes the case.
### Why are the changes needed?
Fix the case torch distributor logging server not shut down
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Unit tests.
Closes #41375 from
WeichenXu123/improve-torch-distributor-log-server-exception-handling.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
python/pyspark/ml/torch/distributor.py | 26 +++++++++++++-------------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index ad8b4d8cc25..0249e6b4b2c 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -665,20 +665,20 @@ class TorchDistributor(Distributor):
time.sleep(1) # wait for the server to start
self.log_streaming_server_port = log_streaming_server.port
- spark_task_function = self._get_spark_task_function(
- framework_wrapper_fn, train_object, spark_dataframe, *args,
**kwargs
- )
- self._check_encryption()
- self.logger.info(
- f"Started distributed training with {self.num_processes} executor
processes"
- )
- if spark_dataframe is not None:
- input_df = spark_dataframe
- else:
- input_df = self.spark.range(
- start=0, end=self.num_tasks, step=1,
numPartitions=self.num_tasks
- )
try:
+ spark_task_function = self._get_spark_task_function(
+ framework_wrapper_fn, train_object, spark_dataframe, *args,
**kwargs
+ )
+ self._check_encryption()
+ self.logger.info(
+ f"Started distributed training with {self.num_processes}
executor processes"
+ )
+ if spark_dataframe is not None:
+ input_df = spark_dataframe
+ else:
+ input_df = self.spark.range(
+ start=0, end=self.num_tasks, step=1,
numPartitions=self.num_tasks
+ )
rows = input_df.mapInArrow(
func=spark_task_function, schema="chunk binary", barrier=True
).collect()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]