This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 392f8d80c8c [SPARK-44264][ML][PYTHON] Support Distributed Training of
Functions Using Deepspeed
392f8d80c8c is described below
commit 392f8d80c8cc4823ea513e78d452bba7f1a7d76c
Author: Mathew Jacob <[email protected]>
AuthorDate: Wed Jul 19 17:29:29 2023 +0900
[SPARK-44264][ML][PYTHON] Support Distributed Training of Functions Using
Deepspeed
### What changes were proposed in this pull request?
Made the DeepspeedTorchDistributor run() method use the _run() function as
the backbone.
### Why are the changes needed?
It allows the user to run distributed training of a function with deepspeed
easily.
### Does this PR introduce _any_ user-facing change?
This adds the ability for the user to pass in a function as the
train_object when calling DeepspeedTorchDistributor.run(). The user must have
all necessary imports within the function itself, and the function must be
picklable. An example use case can be found in the python file linked in the
JIRA ticket.
### How was this patch tested?
Notebook/file linked in the JIRA ticket. Formal e2e tests will come in
future PR.
### Next Steps/Timeline
- [ ] Add more e2e tests for both running a regular pytorch file and
running a function for training
- [ ] Write more documentation
Closes #42067 from mathewjacob1002/add_func_deepspeed.
Authored-by: Mathew Jacob <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/ml/deepspeed/deepspeed_distributor.py | 18 ++----------------
1 file changed, 2 insertions(+), 16 deletions(-)
diff --git a/python/pyspark/ml/deepspeed/deepspeed_distributor.py
b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
index df1aac21e1f..d6ae98de5e3 100644
--- a/python/pyspark/ml/deepspeed/deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
import json
-import os
import sys
import tempfile
from typing import (
@@ -135,19 +134,6 @@ class DeepspeedTorchDistributor(TorchDistributor):
def run(self, train_object: Union[Callable, str], *args: Any, **kwargs:
Any) -> Optional[Any]:
# If the "train_object" is a string, then we assume it's a filepath.
# Otherwise, we assume it's a function.
- if isinstance(train_object, str):
- if os.path.exists(train_object) is False:
- raise FileNotFoundError(f"The path to training file
{train_object} does not exist.")
- framework_wrapper_fn =
DeepspeedTorchDistributor._run_training_on_pytorch_file
- else:
- raise RuntimeError("Python training functions aren't supported as
inputs at this time")
-
- if self.local_mode:
- return self._run_local_training(framework_wrapper_fn,
train_object, *args, **kwargs)
- return self._run_distributed_training(
- framework_wrapper_fn,
- train_object,
- spark_dataframe=None,
- *args,
- **kwargs, # type:ignore[misc]
+ return self._run(
+ train_object,
DeepspeedTorchDistributor._run_training_on_pytorch_file, *args, **kwargs
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]