This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 770a96f4c5 Update GCP Dataproc ClusterGenerator to support GPU params 
(#37036)
770a96f4c5 is described below

commit 770a96f4c577c3af2bcb6c03dcfaac4a5ad051b6
Author: dmedora <[email protected]>
AuthorDate: Fri Jan 26 16:23:06 2024 -0500

    Update GCP Dataproc ClusterGenerator to support GPU params (#37036)
---
 .../providers/google/cloud/operators/dataproc.py   | 38 +++++++++
 .../google/cloud/operators/test_dataproc.py        | 95 ++++++++++++++++++++++
 2 files changed, 133 insertions(+)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index aacc1adb24..61ab4079ba 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -158,12 +158,18 @@ class ClusterGenerator:
         Valid values: ``pd-ssd`` (Persistent Disk Solid State Drive) or
         ``pd-standard`` (Persistent Disk Hard Disk Drive).
     :param master_disk_size: Disk size for the primary node
+    :param master_accelerator_type: Type of the accelerator card (GPU) to 
attach to the primary node,
+        see 
https://cloud.google.com/dataproc/docs/reference/rest/v1/InstanceGroupConfig#acceleratorconfig
+    :param master_accelerator_count: Number of accelerator cards (GPUs) to 
attach to the primary node
     :param worker_machine_type: Compute engine machine type to use for the 
worker nodes
     :param worker_disk_type: Type of the boot disk for the worker node
         (default is ``pd-standard``).
         Valid values: ``pd-ssd`` (Persistent Disk Solid State Drive) or
         ``pd-standard`` (Persistent Disk Hard Disk Drive).
     :param worker_disk_size: Disk size for the worker nodes
+    :param worker_accelerator_type: Type of the accelerator card (GPU) to 
attach to the worker nodes,
+        see 
https://cloud.google.com/dataproc/docs/reference/rest/v1/InstanceGroupConfig#acceleratorconfig
+    :param worker_accelerator_count: Number of accelerator cards (GPUs) to 
attach to the worker nodes
     :param num_preemptible_workers: The # of VM instances in the instance 
group as secondary workers
         inside the cluster with Preemptibility enabled by default.
         Note, that it is not possible to mix non-preemptible and preemptible 
secondary workers in
@@ -200,6 +206,9 @@ class ClusterGenerator:
         identify the driver group in future operations, such as resizing the 
node group.
     :param secondary_worker_instance_flexibility_policy: Instance flexibility 
Policy allowing a mixture of VM
         shapes and provisioning models.
+    :param secondary_worker_accelerator_type: Type of the accelerator card 
(GPU) to attach to the secondary workers,
+        see 
https://cloud.google.com/dataproc/docs/reference/rest/v1/InstanceGroupConfig#acceleratorconfig
+    :param secondary_worker_accelerator_count: Number of accelerator cards 
(GPUs) to attach to the secondary workers
     """
 
     def __init__(
@@ -227,9 +236,13 @@ class ClusterGenerator:
         master_machine_type: str = "n1-standard-4",
         master_disk_type: str = "pd-standard",
         master_disk_size: int = 1024,
+        master_accelerator_type: str | None = None,
+        master_accelerator_count: int | None = None,
         worker_machine_type: str = "n1-standard-4",
         worker_disk_type: str = "pd-standard",
         worker_disk_size: int = 1024,
+        worker_accelerator_type: str | None = None,
+        worker_accelerator_count: int | None = None,
         num_preemptible_workers: int = 0,
         preemptibility: str = PreemptibilityType.PREEMPTIBLE.value,
         service_account: str | None = None,
@@ -242,6 +255,8 @@ class ClusterGenerator:
         driver_pool_size: int = 0,
         driver_pool_id: str | None = None,
         secondary_worker_instance_flexibility_policy: 
InstanceFlexibilityPolicy | None = None,
+        secondary_worker_accelerator_type: str | None = None,
+        secondary_worker_accelerator_count: int | None = None,
         **kwargs,
     ) -> None:
         self.project_id = project_id
@@ -263,10 +278,14 @@ class ClusterGenerator:
         self.master_machine_type = master_machine_type
         self.master_disk_type = master_disk_type
         self.master_disk_size = master_disk_size
+        self.master_accelerator_type = master_accelerator_type
+        self.master_accelerator_count = master_accelerator_count
         self.autoscaling_policy = autoscaling_policy
         self.worker_machine_type = worker_machine_type
         self.worker_disk_type = worker_disk_type
         self.worker_disk_size = worker_disk_size
+        self.worker_accelerator_type = worker_accelerator_type
+        self.worker_accelerator_count = worker_accelerator_count
         self.zone = zone
         self.network_uri = network_uri
         self.subnetwork_uri = subnetwork_uri
@@ -283,6 +302,8 @@ class ClusterGenerator:
         self.driver_pool_size = driver_pool_size
         self.driver_pool_id = driver_pool_id
         self.secondary_worker_instance_flexibility_policy = 
secondary_worker_instance_flexibility_policy
+        self.secondary_worker_accelerator_type = 
secondary_worker_accelerator_type
+        self.secondary_worker_accelerator_count = 
secondary_worker_accelerator_count
 
         if self.custom_image and self.image_version:
             raise ValueError("The custom_image and image_version can't be both 
set")
@@ -423,6 +444,18 @@ class ClusterGenerator:
         if self.min_num_workers:
             cluster_data["worker_config"]["min_num_instances"] = 
self.min_num_workers
 
+        if self.master_accelerator_type:
+            cluster_data["master_config"]["accelerators"] = {
+                "accelerator_type_uri": self.master_accelerator_type,
+                "accelerator_count": self.master_accelerator_count,
+            }
+
+        if self.worker_accelerator_type:
+            cluster_data["worker_config"]["accelerators"] = {
+                "accelerator_type_uri": self.worker_accelerator_type,
+                "accelerator_count": self.worker_accelerator_count,
+            }
+
         if self.num_preemptible_workers > 0:
             cluster_data["secondary_worker_config"] = {
                 "num_instances": self.num_preemptible_workers,
@@ -434,6 +467,11 @@ class ClusterGenerator:
                 "is_preemptible": True,
                 "preemptibility": self.preemptibility.value,
             }
+            if self.worker_accelerator_type:
+                cluster_data["secondary_worker_config"]["accelerators"] = {
+                    "accelerator_type_uri": 
self.secondary_worker_accelerator_type,
+                    "accelerator_count": 
self.secondary_worker_accelerator_count,
+                }
             if self.secondary_worker_instance_flexibility_policy:
                 
cluster_data["secondary_worker_config"]["instance_flexibility_policy"] = {
                     "instance_selection_list": [
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index 44e20489a2..686c2f10b2 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -273,6 +273,56 @@ CONFIG_WITH_FLEX_MIG = {
     "endpoint_config": {},
 }
 
+CONFIG_WITH_GPU_ACCELERATOR = {
+    "gce_cluster_config": {
+        "zone_uri": 
"https://www.googleapis.com/compute/v1/projects/project_id/zones/zone";,
+        "metadata": {"metadata": "data"},
+        "network_uri": "network_uri",
+        "subnetwork_uri": "subnetwork_uri",
+        "internal_ip_only": True,
+        "tags": ["tags"],
+        "service_account": "service_account",
+        "service_account_scopes": ["service_account_scopes"],
+    },
+    "master_config": {
+        "num_instances": 2,
+        "machine_type_uri": 
"projects/project_id/zones/zone/machineTypes/master_machine_type",
+        "disk_config": {"boot_disk_type": "master_disk_type", 
"boot_disk_size_gb": 128},
+        "image_uri": "https://www.googleapis.com/compute/beta/projects/";
+        "custom_image_project_id/global/images/custom_image",
+        "accelerators": {"accelerator_type_uri": "master_accelerator_type", 
"accelerator_count": 1},
+    },
+    "worker_config": {
+        "num_instances": 2,
+        "machine_type_uri": 
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+        "disk_config": {"boot_disk_type": "worker_disk_type", 
"boot_disk_size_gb": 256},
+        "image_uri": "https://www.googleapis.com/compute/beta/projects/";
+        "custom_image_project_id/global/images/custom_image",
+        "min_num_instances": 1,
+        "accelerators": {"accelerator_type_uri": "worker_accelerator_type", 
"accelerator_count": 1},
+    },
+    "secondary_worker_config": {
+        "num_instances": 4,
+        "machine_type_uri": 
"projects/project_id/zones/zone/machineTypes/worker_machine_type",
+        "disk_config": {"boot_disk_type": "worker_disk_type", 
"boot_disk_size_gb": 256},
+        "is_preemptible": True,
+        "preemptibility": "PREEMPTIBLE",
+        "accelerators": {"accelerator_type_uri": 
"secondary_worker_accelerator_type", "accelerator_count": 1},
+    },
+    "software_config": {"properties": {"properties": "data"}, 
"optional_components": ["optional_components"]},
+    "lifecycle_config": {
+        "idle_delete_ttl": {"seconds": 60},
+        "auto_delete_time": "2019-09-12T00:00:00.000000Z",
+    },
+    "encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
+    "autoscaling_config": {"policy_uri": "autoscaling_policy"},
+    "config_bucket": "storage_bucket",
+    "initialization_actions": [
+        {"executable_file": "init_actions_uris", "execution_timeout": 
{"seconds": 600}}
+    ],
+    "endpoint_config": {},
+}
+
 LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION}
 
 LABELS.update({"airflow-version": "v" + airflow_version.replace(".", 
"-").replace("+", "-")})
@@ -582,6 +632,51 @@ class TestsClusterGenerator:
         cluster = generator.make()
         assert CONFIG_WITH_FLEX_MIG == cluster
 
+    def test_build_with_gpu_accelerator(self):
+        generator = ClusterGenerator(
+            project_id="project_id",
+            num_workers=2,
+            min_num_workers=1,
+            zone="zone",
+            network_uri="network_uri",
+            subnetwork_uri="subnetwork_uri",
+            internal_ip_only=True,
+            tags=["tags"],
+            storage_bucket="storage_bucket",
+            init_actions_uris=["init_actions_uris"],
+            init_action_timeout="10m",
+            metadata={"metadata": "data"},
+            custom_image="custom_image",
+            custom_image_project_id="custom_image_project_id",
+            autoscaling_policy="autoscaling_policy",
+            properties={"properties": "data"},
+            optional_components=["optional_components"],
+            num_masters=2,
+            master_machine_type="master_machine_type",
+            master_disk_type="master_disk_type",
+            master_disk_size=128,
+            master_accelerator_type="master_accelerator_type",
+            master_accelerator_count=1,
+            worker_machine_type="worker_machine_type",
+            worker_disk_type="worker_disk_type",
+            worker_disk_size=256,
+            worker_accelerator_type="worker_accelerator_type",
+            worker_accelerator_count=1,
+            num_preemptible_workers=4,
+            
secondary_worker_accelerator_type="secondary_worker_accelerator_type",
+            secondary_worker_accelerator_count=1,
+            preemptibility="preemptible",
+            region="region",
+            service_account="service_account",
+            service_account_scopes=["service_account_scopes"],
+            idle_delete_ttl=60,
+            auto_delete_time=datetime(2019, 9, 12),
+            auto_delete_ttl=250,
+            customer_managed_key="customer_managed_key",
+        )
+        cluster = generator.make()
+        assert CONFIG_WITH_GPU_ACCELERATOR == cluster
+
 
 class TestDataprocCreateClusterOperator(DataprocClusterTestBase):
     def test_deprecation_warning(self):

Reply via email to