This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new e6d30092653 [SPARK-44264][ML][PYTHON] Support Distributed Training of 
Functions Using Deepspeed
e6d30092653 is described below

commit e6d3009265374ff2e6f431ca372f3a6f0f9554a9
Author: Mathew Jacob <[email protected]>
AuthorDate: Wed Jul 19 17:29:29 2023 +0900

    [SPARK-44264][ML][PYTHON] Support Distributed Training of Functions Using 
Deepspeed
    
    Made the DeepspeedTorchDistributor run() method use the _run() function as 
the backbone.
    It allows the user to run distributed training of a function with deepspeed 
easily.
    
    This adds the ability for the user to pass in a function as the 
train_object when calling DeepspeedTorchDistributor.run(). The user must have 
all necessary imports within the function itself, and the function must be 
picklable. An example use case can be found in the python file linked in the 
JIRA ticket.
    
    Notebook/file linked in the JIRA ticket. Formal e2e tests will come in 
future PR.
    
    - [ ] Add more e2e tests for both running a regular pytorch file and 
running a function for training
    - [ ] Write more documentation
    
    Closes #42067 from mathewjacob1002/add_func_deepspeed.
    
    Authored-by: Mathew Jacob <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
    (cherry picked from commit 392f8d80c8cc4823ea513e78d452bba7f1a7d76c)
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/ml/deepspeed/deepspeed_distributor.py | 18 ++----------------
 1 file changed, 2 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/ml/deepspeed/deepspeed_distributor.py 
b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
index df1aac21e1f..d6ae98de5e3 100644
--- a/python/pyspark/ml/deepspeed/deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/deepspeed_distributor.py
@@ -15,7 +15,6 @@
 # limitations under the License.
 #
 import json
-import os
 import sys
 import tempfile
 from typing import (
@@ -135,19 +134,6 @@ class DeepspeedTorchDistributor(TorchDistributor):
     def run(self, train_object: Union[Callable, str], *args: Any, **kwargs: 
Any) -> Optional[Any]:
         # If the "train_object" is a string, then we assume it's a filepath.
         # Otherwise, we assume it's a function.
-        if isinstance(train_object, str):
-            if os.path.exists(train_object) is False:
-                raise FileNotFoundError(f"The path to training file 
{train_object} does not exist.")
-            framework_wrapper_fn = 
DeepspeedTorchDistributor._run_training_on_pytorch_file
-        else:
-            raise RuntimeError("Python training functions aren't supported as 
inputs at this time")
-
-        if self.local_mode:
-            return self._run_local_training(framework_wrapper_fn, 
train_object, *args, **kwargs)
-        return self._run_distributed_training(
-            framework_wrapper_fn,
-            train_object,
-            spark_dataframe=None,
-            *args,
-            **kwargs,  # type:ignore[misc]
+        return self._run(
+            train_object, 
DeepspeedTorchDistributor._run_training_on_pytorch_file, *args, **kwargs
         )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to