This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 392f8d80c8c [SPARK-44264][ML][PYTHON] Support Distributed Training of 
Functions Using Deepspeed
392f8d80c8c is described below

commit 392f8d80c8cc4823ea513e78d452bba7f1a7d76c
Author: Mathew Jacob <[email protected]>
AuthorDate: Wed Jul 19 17:29:29 2023 +0900

    [SPARK-44264][ML][PYTHON] Support Distributed Training of Functions Using 
Deepspeed
    
    ### What changes were proposed in this pull request?
    Made the DeepspeedTorchDistributor run() method use the _run() function as 
the backbone.
    ### Why are the changes needed?
    It allows the user to run distributed training of a function with deepspeed 
easily.
    
    ### Does this PR introduce _any_ user-facing change?
    This adds the ability for the user to pass in a function as the 
train_object when calling DeepspeedTorchDistributor.run(). The user must have 
all necessary imports within the function itself, and the function must be 
picklable. An example use case can be found in the python file linked in the 
JIRA ticket.
    
    ### How was this patch tested?
    Notebook/file linked in the JIRA ticket. Formal e2e tests will come in 
future PR.
    
    ### Next Steps/Timeline
    
    - [ ] Add more e2e tests for both running a regular pytorch file and 
running a function for training
    - [ ] Write more documentation
    
    Closes #42067 from mathewjacob1002/add_func_deepspeed.
    
    Authored-by: Mathew Jacob <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/ml/deepspeed/deepspeed_distributor.py | 18 ++----------------
 1 file changed, 2 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/ml/deepspeed/deepspeed_distributor.py 
b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
index df1aac21e1f..d6ae98de5e3 100644
--- a/python/pyspark/ml/deepspeed/deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 import json
-import os
 import sys
 import tempfile
 from typing import (
@@ -135,19 +134,6 @@ class DeepspeedTorchDistributor(TorchDistributor):
     def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: 
Any) -> Optional[Any]:
         # If the "train_object" is a string, then we assume it's a filepath.
         # Otherwise, we assume it's a function.
-        if isinstance(train_object, str):
-            if os.path.exists(train_object) is False:
-                raise FileNotFoundError(f"The path to training file 
{train_object} does not exist.")
-            framework_wrapper_fn = 
DeepspeedTorchDistributor._run_training_on_pytorch_file
-        else:
-            raise RuntimeError("Python training functions aren't supported as 
inputs at this time")
-
-        if self.local_mode:
-            return self._run_local_training(framework_wrapper_fn, 
train_object, *args, **kwargs)
-        return self._run_distributed_training(
-            framework_wrapper_fn,
-            train_object,
-            spark_dataframe=None,
-            *args,
-            **kwargs,  # type:ignore[misc]
+        return self._run(
+            train_object, 
DeepspeedTorchDistributor._run_training_on_pytorch_file, *args, **kwargs
         )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to