This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new e6d30092653 [SPARK-44264][ML][PYTHON] Support Distributed Training of
Functions Using Deepspeed
e6d30092653 is described below
commit e6d3009265374ff2e6f431ca372f3a6f0f9554a9
Author: Mathew Jacob <[email protected]>
AuthorDate: Wed Jul 19 17:29:29 2023 +0900
[SPARK-44264][ML][PYTHON] Support Distributed Training of Functions Using
Deepspeed
Made the DeepspeedTorchDistributor run() method use the _run() function as
the backbone.
It allows the user to run distributed training of a function with deepspeed
easily.
This adds the ability for the user to pass in a function as the
train_object when calling DeepspeedTorchDistributor.run(). The user must have
all necessary imports within the function itself, and the function must be
picklable. An example use case can be found in the python file linked in the
JIRA ticket.
Notebook/file linked in the JIRA ticket. Formal e2e tests will come in
future PR.
- [ ] Add more e2e tests for both running a regular pytorch file and
running a function for training
- [ ] Write more documentation
Closes #42067 from mathewjacob1002/add_func_deepspeed.
Authored-by: Mathew Jacob <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit 392f8d80c8cc4823ea513e78d452bba7f1a7d76c)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/ml/deepspeed/deepspeed_distributor.py | 18 ++----------------
1 file changed, 2 insertions(+), 16 deletions(-)
diff --git a/python/pyspark/ml/deepspeed/deepspeed_distributor.py
b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
index df1aac21e1f..d6ae98de5e3 100644
--- a/python/pyspark/ml/deepspeed/deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
@@ -15,7 +15,6 @@
# limitations under the License.
#
import json
-import os
import sys
import tempfile
from typing import (
@@ -135,19 +134,6 @@ class DeepspeedTorchDistributor(TorchDistributor):
def run(self, train_object: Union[Callable, str], *args: Any, **kwargs:
Any) -> Optional[Any]:
# If the "train_object" is a string, then we assume it's a filepath.
# Otherwise, we assume it's a function.
- if isinstance(train_object, str):
- if os.path.exists(train_object) is False:
- raise FileNotFoundError(f"The path to training file
{train_object} does not exist.")
- framework_wrapper_fn =
DeepspeedTorchDistributor._run_training_on_pytorch_file
- else:
- raise RuntimeError("Python training functions aren't supported as
inputs at this time")
-
- if self.local_mode:
- return self._run_local_training(framework_wrapper_fn,
train_object, *args, **kwargs)
- return self._run_distributed_training(
- framework_wrapper_fn,
- train_object,
- spark_dataframe=None,
- *args,
- **kwargs, # type:ignore[misc]
+ return self._run(
+ train_object,
DeepspeedTorchDistributor._run_training_on_pytorch_file, *args, **kwargs
)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]