This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch SPARK-41916 in repository https://gitbox.apache.org/repos/asf/spark.git
commit df9805f9828707ef0b1d6c2347e0ff7f928d4249 Author: Weichen Xu <[email protected]> AuthorDate: Wed Dec 17 15:43:49 2025 +0800 init Signed-off-by: Weichen Xu <[email protected]> --- python/pyspark/ml/torch/distributor.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/pyspark/ml/torch/distributor.py b/python/pyspark/ml/torch/distributor.py index ef86f38b716b..62f9ab4b7497 100644 --- a/python/pyspark/ml/torch/distributor.py +++ b/python/pyspark/ml/torch/distributor.py @@ -212,8 +212,14 @@ class Distributor: task_gpu_amount = int(_get_conf(self.spark, key, "0")) if task_gpu_amount < 1: raise RuntimeError(f"'{key}' was unset, so gpu usage is unavailable.") - # TODO(SPARK-41916): Address situation when spark.task.resource.gpu.amount > 1 - return math.ceil(self.num_processes / task_gpu_amount) + + if task_gpu_amount > 1: + if not (self.num_processes % task_gpu_amount == 0): + raise RuntimeError( + f"TorchDistributor 'num_processes' value must be a multiple of " + "'spark.task.resource.gpu.amount' value" + ) + return self.num_processes // task_gpu_amount else: key = "spark.driver.resource.gpu.amount" if "gpu" not in _get_resources(self.spark): @@ -421,14 +427,19 @@ class TorchDistributor(Distributor): master_addr = os.environ["MASTER_ADDR"] master_port = os.environ["MASTER_PORT"] + + if cuda_visible_devices := os.environ.get("CUDA_VISIBLE_DEVICES"): + processes_per_node = len(cuda_visible_devices.split(",")) + else: + processes_per_node = 1 node_rank = os.environ["RANK"] + torchrun_args = [ - f"--nnodes={num_processes}", + f"--nnodes={num_processes // processes_per_node}", f"--node_rank={node_rank}", f"--rdzv_endpoint={master_addr}:{master_port}", "--rdzv_id=0", # TODO: setup random ID that is gleaned from env variables ] - processes_per_node = 1 return torchrun_args, processes_per_node @staticmethod @@ -1091,4 +1102,4 @@ def _get_spark_partition_data_loader( else: # if num_workers is zero, we cannot set `prefetch_factor` otherwise # torch will raise error. - return DataLoader(dataset, batch_size, num_workers=num_workers) + return DataLoader(dataset, batch_size, num_workers=num_workers) \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
