This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6b800ba8461 [SPARK-41591][PYTHON][ML] Training PyTorch Files on Single
Node Multi GPU
6b800ba8461 is described below
commit 6b800ba8461935a205d8c15eba2ff11f141dea47
Author: Rithwik Ediga Lakhamsani <[email protected]>
AuthorDate: Thu Jan 12 08:42:01 2023 +0900
[SPARK-41591][PYTHON][ML] Training PyTorch Files on Single Node Multi GPU
### What changes were proposed in this pull request?
This is an addition to https://github.com/apache/spark/pull/39146 to add
support for single node training using PyTorch files. The users would follow
the second workflow in the [design
document](https://docs.google.com/document/d/1QPO1Ly8WteL6aIPvVcR7Xne9qVtJiB3fdrRn7NwBcpA/edit#heading=h.8yvw9xq428fh)
to run training. I added some new utility functions as well as built on top of
current functions.
### Why are the changes needed?
Look at the [main
ticket](https://issues.apache.org/jira/browse/SPARK-41589) for more details.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Some unit tests were added and integration tests will be added in a later
PR (https://issues.apache.org/jira/browse/SPARK-41777).
Closes #39188 from rithwik-db/pytorch-file-local-training.
Authored-by: Rithwik Ediga Lakhamsani <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/ml/torch/distributor.py | 186 ++++++++++++++++++++-
python/pyspark/ml/torch/tests/test_distributor.py | 147 +++++++++++++++-
.../pyspark/ml/torch/torch_run_process_wrapper.py | 83 +++++++++
3 files changed, 412 insertions(+), 4 deletions(-)
diff --git a/python/pyspark/ml/torch/distributor.py
b/python/pyspark/ml/torch/distributor.py
index 2a4027cbb25..80d5ad31c3c 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -15,7 +15,16 @@
# limitations under the License.
#
+import collections
+import ctypes
import math
+import os
+import random
+import re
+import signal
+import sys
+import subprocess
+import time
from typing import Union, Callable, Optional, Any
import warnings
@@ -34,8 +43,8 @@ def get_conf_boolean(sc: SparkContext, key: str,
default_value: str) -> bool:
Parameters
----------
- sc : SparkContext
- The SparkContext for the distributor.
+ sc : :class:`SparkContext`
+ The :class:`SparkContext` for the distributor.
key : str
string for conf name
default_value : str
@@ -64,6 +73,42 @@ def get_conf_boolean(sc: SparkContext, key: str,
default_value: str) -> bool:
)
+def get_gpus_owned(sc: SparkContext) -> list[str]:
+ """Gets the number of GPUs that Spark scheduled to the calling task.
+
+ Parameters
+ ----------
+ sc : :class:`SparkContext`
+ The :class:`SparkContext` that has GPUs available.
+
+ Returns
+ -------
+ list
+ The correct mapping of addresses to workers.
+
+ Raises
+ ------
+ ValueError
+ Raised if the input addresses were not found.
+ """
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+ pattern = re.compile("^[1-9][0-9]*|0$")
+ addresses = sc.resources["gpu"].addresses
+ if any(not pattern.match(address) for address in addresses):
+ raise ValueError(
+ f"Found GPU addresses {addresses} which "
+ "are not all in the correct format "
+ "for CUDA_VISIBLE_DEVICES, which requires "
+ "integers with no zero padding."
+ )
+ if CUDA_VISIBLE_DEVICES in os.environ:
+ gpu_indices = list(map(int, addresses))
+ gpu_list = os.environ[CUDA_VISIBLE_DEVICES].split(",")
+ gpu_owned = [gpu_list[i] for i in gpu_indices]
+ return gpu_owned
+ return addresses
+
+
class Distributor:
"""
The parent class for TorchDistributor. This class shouldn't be
instantiated directly.
@@ -85,6 +130,12 @@ class Distributor:
self.num_tasks = self._get_num_tasks()
self.ssl_conf = None
+ def _create_input_params(self) -> dict[str, Any]:
+ input_params = self.__dict__.copy()
+ for unneeded_param in ["spark", "sc", "ssl_conf"]:
+ del input_params[unneeded_param]
+ return input_params
+
def _get_num_tasks(self) -> int:
"""
Returns the number of Spark tasks to use for distributed training
@@ -261,6 +312,130 @@ class TorchDistributor(Distributor):
super().__init__(num_processes, local_mode, use_gpu)
self.ssl_conf = "pytorch.spark.distributor.ignoreSsl" # type: ignore
self._validate_input_params()
+ self.input_params = self._create_input_params()
+
+ @staticmethod
+ def _create_torchrun_command(
+ input_params: dict[str, Any], path_to_train_file: str, *args: Any
+ ) -> list[str]:
+ local_mode = input_params["local_mode"]
+ num_processes = input_params["num_processes"]
+
+ if local_mode:
+ torchrun_args = ["--standalone", "--nnodes=1"]
+ processes_per_node = num_processes
+ else:
+ pass
+ # TODO(SPARK-41592): Handle distributed training
+
+ args_string = list(map(str, args)) # converting all args to strings
+
+ return (
+ [sys.executable, "-m",
"pyspark.ml.torch.distributor.torch_run_process_wrapper"]
+ + torchrun_args
+ + [f"--nproc_per_node={processes_per_node}"]
+ + [path_to_train_file, *args_string]
+ )
+
+ @staticmethod
+ def _execute_command(
+ cmd: list[str], _prctl: bool = True, redirect_to_stdout: bool = True
+ ) -> None:
+ _TAIL_LINES_TO_KEEP = 100
+
+ def sigterm_on_parent_death() -> None:
+ """
+ Uses prctl to automatically send SIGTERM to the command process
when its parent is dead.
+ This handles the case when the parent is a PySpark worker process.
+ If a user cancels the PySpark job, the worker process gets killed,
regardless of
+ PySpark daemon and worker reuse settings.
+ """
+ if _prctl:
+ try:
+ libc = ctypes.CDLL("libc.so.6")
+ # Set the parent process death signal of the command
process to SIGTERM.
+ libc.prctl(1, signal.SIGTERM)
+ except OSError:
+ pass
+
+ task = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ stdin=subprocess.PIPE,
+ env=os.environ,
+ preexec_fn=sigterm_on_parent_death,
+ )
+ task.stdin.close() # type: ignore
+ tail: collections.deque = collections.deque(maxlen=_TAIL_LINES_TO_KEEP)
+ try:
+ for line in task.stdout: # type: ignore
+ decoded = line.decode()
+ tail.append(decoded)
+ if redirect_to_stdout:
+ sys.stdout.write(decoded)
+ task.wait()
+ finally:
+ if task.poll() is None:
+ try:
+ task.terminate() # SIGTERM
+ time.sleep(0.5)
+ if task.poll() is None:
+ task.kill() # SIGKILL
+ except OSError:
+ pass
+ if task.returncode != os.EX_OK:
+ if len(tail) == _TAIL_LINES_TO_KEEP:
+ last_n_msg = f"last {_TAIL_LINES_TO_KEEP} lines of the task
output are"
+ else:
+ last_n_msg = "task output is"
+ task_output = "".join(tail)
+ raise RuntimeError(
+ f"Command {cmd} failed with return code {task.returncode}."
+ f"The {last_n_msg} included below: {task_output}"
+ )
+
+ def _run_local_training(
+ self,
+ framework_wrapper_fn: Optional[Callable],
+ train_object: Union[Callable, str],
+ *args: Any,
+ ) -> Optional[Any]:
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+ cuda_state_was_set = CUDA_VISIBLE_DEVICES in os.environ
+ old_cuda_visible_devices = os.environ.get(CUDA_VISIBLE_DEVICES, "")
+ try:
+ if self.use_gpu:
+ gpus_owned = get_gpus_owned(self.sc)
+
+ if self.num_processes > len(gpus_owned):
+ raise ValueError(
+ f"""{self.num_processes} processes were requested
+ for local training with GPU training but only
+ {len(gpus_owned)} GPUs were available."""
+ )
+ random.seed(hash(train_object))
+ selected_gpus = [str(e) for e in random.sample(gpus_owned,
self.num_processes)]
+ os.environ[CUDA_VISIBLE_DEVICES] = ",".join(selected_gpus)
+
+ output = framework_wrapper_fn(self.input_params, train_object,
*args) # type: ignore
+ finally:
+ if cuda_state_was_set:
+ os.environ[CUDA_VISIBLE_DEVICES] = old_cuda_visible_devices
+ else:
+ if CUDA_VISIBLE_DEVICES in os.environ:
+ del os.environ[CUDA_VISIBLE_DEVICES]
+
+ return output
+
+ @staticmethod
+ def _run_training_on_pytorch_file(
+ input_params: dict[str, Any], train_path: str, *args: Any
+ ) -> None:
+ training_command = TorchDistributor._create_torchrun_command(
+ input_params, train_path, *args
+ )
+ TorchDistributor._execute_command(training_command)
def run(self, train_object: Union[Callable, str], *args: Any) ->
Optional[Any]:
"""Runs distributed training.
@@ -278,4 +453,9 @@ class TorchDistributor(Distributor):
Returns the output of train_object called with args if
train_object is a
Callable with an expected output.
"""
- pass
+ framework_wrapper_fn = None
+ if isinstance(train_object, str):
+ framework_wrapper_fn =
TorchDistributor._run_training_on_pytorch_file
+ if self.local_mode:
+ output = self._run_local_training(framework_wrapper_fn,
train_object, *args)
+ return output
diff --git a/python/pyspark/ml/torch/tests/test_distributor.py
b/python/pyspark/ml/torch/tests/test_distributor.py
index e84505f92fe..4b24eff8742 100644
--- a/python/pyspark/ml/torch/tests/test_distributor.py
+++ b/python/pyspark/ml/torch/tests/test_distributor.py
@@ -15,17 +15,38 @@
# limitations under the License.
#
+import contextlib
import os
+from six import StringIO # type: ignore
import stat
+import subprocess
+import sys
+import time
import tempfile
+import threading
+from typing import Callable
import unittest
+from unittest.mock import patch
from pyspark import SparkConf, SparkContext
-from pyspark.ml.torch.distributor import TorchDistributor
+from pyspark.ml.torch.distributor import TorchDistributor, get_gpus_owned
+from pyspark.ml.torch.torch_run_process_wrapper import clean_and_terminate,
check_parent_alive
from pyspark.sql import SparkSession
from pyspark.testing.utils import SPARK_HOME
[email protected]
+def patch_stdout() -> StringIO:
+ """patch stdout and give an output"""
+ sys_stdout = sys.stdout
+ io_out = StringIO()
+ sys.stdout = io_out
+ try:
+ yield io_out
+ finally:
+ sys.stdout = sys_stdout
+
+
class TorchDistributorBaselineUnitTests(unittest.TestCase):
def setUp(self) -> None:
conf = SparkConf()
@@ -35,6 +56,14 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
def tearDown(self) -> None:
self.spark.stop()
+ def setup_env_vars(self, input_map: dict[str, str]) -> None:
+ for key, value in input_map.items():
+ os.environ[key] = value
+
+ def delete_env_vars(self, input_map: dict[str, str]) -> None:
+ for key in input_map.keys():
+ del os.environ[key]
+
def test_validate_correct_inputs(self) -> None:
inputs = [
(1, True, False),
@@ -90,6 +119,55 @@ class TorchDistributorBaselineUnitTests(unittest.TestCase):
with self.assertRaisesRegex(RuntimeError, "unset"):
TorchDistributor(num_processes, False, True)
+ def test_execute_command(self) -> None:
+ """Test that run command runs the process and logs are written
correctly"""
+
+ with patch_stdout() as output:
+ stdout_command = ["echo", "hello_stdout"]
+ TorchDistributor._execute_command(stdout_command)
+ self.assertIn(
+ "hello_stdout", output.getvalue().strip(), "hello_stdout
should print to stdout"
+ )
+
+ with patch_stdout() as output:
+ stderr_command = ["bash", "-c", "echo hello_stderr >&2"]
+ TorchDistributor._execute_command(stderr_command)
+ self.assertIn(
+ "hello_stderr", output.getvalue().strip(), "hello_stderr
should print to stdout"
+ )
+
+ # include command in the exception message
+ with self.assertRaisesRegexp(RuntimeError, "exit 1"): # pylint:
disable=deprecated-method
+ error_command = ["bash", "-c", "exit 1"]
+ TorchDistributor._execute_command(error_command)
+
+ with self.assertRaisesRegexp(RuntimeError, "abcdef"): # pylint:
disable=deprecated-method
+ error_command = ["bash", "-c", "'abc''def'"]
+ TorchDistributor._execute_command(error_command)
+
+ def test_create_torchrun_command(self) -> None:
+ train_path = "train.py"
+ args_string = ["1", "3"]
+ local_mode_input_params = {"num_processes": 4, "local_mode": True}
+
+ expected_local_mode_output = [
+ sys.executable,
+ "-m",
+ "pyspark.ml.torch.distributor.torch_run_process_wrapper",
+ "--standalone",
+ "--nnodes=1",
+ "--nproc_per_node=4",
+ "train.py",
+ "1",
+ "3",
+ ]
+ self.assertEqual(
+ TorchDistributor._create_torchrun_command(
+ local_mode_input_params, train_path, *args_string
+ ),
+ expected_local_mode_output,
+ )
+
class TorchDistributorLocalUnitTests(unittest.TestCase):
def setUp(self) -> None:
@@ -118,6 +196,14 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
os.unlink(self.tempFile.name)
self.spark.stop()
+ def setup_env_vars(self, input_map: dict[str, str]) -> None:
+ for key, value in input_map.items():
+ os.environ[key] = value
+
+ def delete_env_vars(self, input_map: dict[str, str]) -> None:
+ for key in input_map.keys():
+ del os.environ[key]
+
def test_get_num_tasks_locally(self) -> None:
succeeds = [1, 2]
fails = [4, 8]
@@ -133,6 +219,42 @@ class TorchDistributorLocalUnitTests(unittest.TestCase):
distributor = TorchDistributor(num_processes, True, True)
distributor.num_processes = 3
+ def test_get_gpus_owned_local(self) -> None:
+ addresses = ["0", "1", "2"]
+ self.assertEqual(get_gpus_owned(self.sc), addresses)
+
+ env_vars = {"CUDA_VISIBLE_DEVICES": "3,4,5"}
+ self.setup_env_vars(env_vars)
+ self.assertEqual(get_gpus_owned(self.sc), ["3", "4", "5"])
+ self.delete_env_vars(env_vars)
+
+ def test_local_training_succeeds(self) -> None:
+ CUDA_VISIBLE_DEVICES = "CUDA_VISIBLE_DEVICES"
+ inputs = [
+ ("0,1,2", 1, True, "1"),
+ ("0,1,2", 3, True, "1,2,0"),
+ ("0,1,2", 2, False, "0,1,2"),
+ (None, 3, False, "NONE"),
+ ]
+
+ for i, (cuda_env_var, num_processes, use_gpu, expected) in
enumerate(inputs):
+ with self.subTest(f"subtest: {i + 1}"):
+ # setup
+ if cuda_env_var:
+ self.setup_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
+ dist = TorchDistributor(num_processes, True, use_gpu)
+ dist._run_training_on_pytorch_file = lambda *args:
os.environ.get( # type: ignore
+ CUDA_VISIBLE_DEVICES, "NONE"
+ )
+ self.assertEqual(
+ expected,
+
dist._run_local_training(dist._run_training_on_pytorch_file, "train.py"),
+ )
+ # cleanup
+ if cuda_env_var:
+ self.delete_env_vars({CUDA_VISIBLE_DEVICES: cuda_env_var})
+
class TorchDistributorDistributedUnitTests(unittest.TestCase):
def setUp(self) -> None:
@@ -178,6 +300,29 @@ class
TorchDistributorDistributedUnitTests(unittest.TestCase):
self.spark.sparkContext._conf.set("spark.task.resource.gpu.amount",
"1")
+class TorchWrapperUnitTests(unittest.TestCase):
+ def test_clean_and_terminate(self) -> None:
+ def kill_task(task: "subprocess.Popen") -> None:
+ time.sleep(1)
+ clean_and_terminate(task)
+
+ command = [sys.executable, "-c", '"import time; time.sleep(20)"']
+ task = subprocess.Popen(command)
+ t = threading.Thread(target=kill_task, args=(task,))
+ t.start()
+ time.sleep(2)
+ self.assertEqual(task.poll(), 0) # implies task ended
+
+ @patch("pyspark.ml.torch.torch_run_process_wrapper.clean_and_terminate")
+ def test_check_parent_alive(self, mock_clean_and_terminate: Callable) ->
None:
+ command = [sys.executable, "-c", '"import time; time.sleep(2)"']
+ task = subprocess.Popen(command)
+ t = threading.Thread(target=check_parent_alive, args=(task,),
daemon=True)
+ t.start()
+ time.sleep(2)
+ self.assertEqual(mock_clean_and_terminate.call_count, 0) # type:
ignore[attr-defined]
+
+
if __name__ == "__main__":
from pyspark.ml.torch.tests.test_distributor import * # noqa: F401,F403
type: ignore
diff --git a/python/pyspark/ml/torch/torch_run_process_wrapper.py
b/python/pyspark/ml/torch/torch_run_process_wrapper.py
new file mode 100644
index 00000000000..67ec492329d
--- /dev/null
+++ b/python/pyspark/ml/torch/torch_run_process_wrapper.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import signal
+import subprocess
+import sys
+import threading
+import time
+from typing import Any
+
+
+def clean_and_terminate(task: "subprocess.Popen") -> None:
+ task.terminate()
+ time.sleep(0.5)
+ if task.poll() is None:
+ task.kill()
+ # TODO(SPARK-41775): Cleanup temp files
+
+
+def check_parent_alive(task: "subprocess.Popen") -> None:
+ orig_parent_id = os.getppid()
+ while True:
+ if os.getppid() != orig_parent_id:
+ clean_and_terminate(task)
+ break
+ time.sleep(0.5)
+
+
+if __name__ == "__main__":
+ """
+ This is a wrapper around torch.distributed.run and it kills the child
process
+ if the parent process fails, crashes, or exits.
+ """
+
+ args = sys.argv[1:]
+
+ cmd = [sys.executable, "-m", "torch.distributed.run", *args]
+ task = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ stdin=subprocess.PIPE,
+ env=os.environ,
+ )
+ t = threading.Thread(target=check_parent_alive, args=(task,), daemon=True)
+
+ def sigterm_handler(*args: Any) -> None:
+ clean_and_terminate(task)
+ os._exit(0)
+
+ signal.signal(signal.SIGTERM, sigterm_handler)
+
+ t.start()
+ task.stdin.close() # type: ignore[union-attr]
+ try:
+ for line in task.stdout: # type: ignore[union-attr]
+ decoded = line.decode()
+ print(decoded.rstrip())
+ task.wait()
+ finally:
+ if task.poll() is None:
+ try:
+ task.terminate()
+ time.sleep(0.5)
+ if task.poll() is None:
+ task.kill()
+ except OSError:
+ pass
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]