This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch SPARK-41916
in repository https://gitbox.apache.org/repos/asf/spark.git

commit df9805f9828707ef0b1d6c2347e0ff7f928d4249
Author: Weichen Xu <[email protected]>
AuthorDate: Wed Dec 17 15:43:49 2025 +0800

    init
    
    Signed-off-by: Weichen Xu <[email protected]>
---
 python/pyspark/ml/torch/distributor.py | 21 ++++++++++++++++-----
 1 file changed, 16 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/ml/torch/distributor.py 
b/python/pyspark/ml/torch/distributor.py
index ef86f38b716b..62f9ab4b7497 100644
--- a/python/pyspark/ml/torch/distributor.py
+++ b/python/pyspark/ml/torch/distributor.py
@@ -212,8 +212,14 @@ class Distributor:
                 task_gpu_amount = int(_get_conf(self.spark, key, "0"))
                 if task_gpu_amount < 1:
                     raise RuntimeError(f"'{key}' was unset, so gpu usage is 
unavailable.")
-                # TODO(SPARK-41916): Address situation when 
spark.task.resource.gpu.amount > 1
-                return math.ceil(self.num_processes / task_gpu_amount)
+
+                if task_gpu_amount > 1:
+                    if not (self.num_processes % task_gpu_amount == 0):
+                        raise RuntimeError(
+                            f"TorchDistributor 'num_processes' value must be a 
multiple of "
+                            "'spark.task.resource.gpu.amount' value"
+                        )
+                return self.num_processes // task_gpu_amount
             else:
                 key = "spark.driver.resource.gpu.amount"
                 if "gpu" not in _get_resources(self.spark):
@@ -421,14 +427,19 @@ class TorchDistributor(Distributor):
 
         master_addr = os.environ["MASTER_ADDR"]
         master_port = os.environ["MASTER_PORT"]
+
+        if cuda_visible_devices := os.environ.get("CUDA_VISIBLE_DEVICES"):
+            processes_per_node = len(cuda_visible_devices.split(","))
+        else:
+            processes_per_node = 1
         node_rank = os.environ["RANK"]
+
         torchrun_args = [
-            f"--nnodes={num_processes}",
+            f"--nnodes={num_processes // processes_per_node}",
             f"--node_rank={node_rank}",
             f"--rdzv_endpoint={master_addr}:{master_port}",
             "--rdzv_id=0",  # TODO: setup random ID that is gleaned from env 
variables
         ]
-        processes_per_node = 1
         return torchrun_args, processes_per_node
 
     @staticmethod
@@ -1091,4 +1102,4 @@ def _get_spark_partition_data_loader(
     else:
         # if num_workers is zero, we cannot set `prefetch_factor` otherwise
         # torch will raise error.
-        return DataLoader(dataset, batch_size, num_workers=num_workers)
+        return DataLoader(dataset, batch_size, num_workers=num_workers)
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to