This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 0f2670e7ac Create DataprocStartClusterOperator and 
DataprocStopClusterOperator (#36996)
0f2670e7ac is described below

commit 0f2670e7acaabb7110dd800b42b491aac9a8a511
Author: M. Olcay Tercanlı <[email protected]>
AuthorDate: Fri Jan 26 18:12:13 2024 +0000

    Create DataprocStartClusterOperator and DataprocStopClusterOperator (#36996)
---
 airflow/providers/google/cloud/hooks/dataproc.py   |  88 +++++++++
 .../providers/google/cloud/operators/dataproc.py   | 197 +++++++++++++++++++++
 .../operators/cloud/dataproc.rst                   |  24 +++
 tests/always/test_project_structure.py             |   1 +
 .../providers/google/cloud/hooks/test_dataproc.py  |  42 +++++
 .../google/cloud/operators/test_dataproc.py        |  86 +++++++++
 ...proc_cluster_create_existing_stopped_cluster.py | 120 +++++++++++++
 .../example_dataproc_cluster_start_stop.py         | 114 ++++++++++++
 8 files changed, 672 insertions(+)

diff --git a/airflow/providers/google/cloud/hooks/dataproc.py 
b/airflow/providers/google/cloud/hooks/dataproc.py
index dae5535e40..4551b24384 100644
--- a/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/airflow/providers/google/cloud/hooks/dataproc.py
@@ -583,6 +583,94 @@ class DataprocHook(GoogleBaseHook):
         )
         return operation
 
+    @GoogleBaseHook.fallback_to_default_project_id
+    def start_cluster(
+        self,
+        region: str,
+        project_id: str,
+        cluster_name: str,
+        cluster_uuid: str | None = None,
+        request_id: str | None = None,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> Operation:
+        """Start a cluster in a project.
+
+        :param region: Cloud Dataproc region to handle the request.
+        :param project_id: Google Cloud project ID that the cluster belongs to.
+        :param cluster_name: The cluster name.
+        :param cluster_uuid: The cluster UUID
+        :param request_id: A unique id used to identify the request. If the
+            server receives two *UpdateClusterRequest* requests with the same
+            ID, the second request will be ignored, and an operation created
+            for the first one and stored in the backend is returned.
+        :param retry: A retry object used to retry requests. If *None*, 
requests
+            will not be retried.
+        :param timeout: The amount of time, in seconds, to wait for the request
+            to complete. If *retry* is specified, the timeout applies to each
+            individual attempt.
+        :param metadata: Additional metadata that is provided to the method.
+        :return: An instance of ``google.api_core.operation.Operation``
+        """
+        client = self.get_cluster_client(region=region)
+        return client.start_cluster(
+            request={
+                "project_id": project_id,
+                "region": region,
+                "cluster_name": cluster_name,
+                "cluster_uuid": cluster_uuid,
+                "request_id": request_id,
+            },
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    def stop_cluster(
+        self,
+        region: str,
+        project_id: str,
+        cluster_name: str,
+        cluster_uuid: str | None = None,
+        request_id: str | None = None,
+        retry: Retry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> Operation:
+        """Start a cluster in a project.
+
+        :param region: Cloud Dataproc region to handle the request.
+        :param project_id: Google Cloud project ID that the cluster belongs to.
+        :param cluster_name: The cluster name.
+        :param cluster_uuid: The cluster UUID
+        :param request_id: A unique id used to identify the request. If the
+            server receives two *UpdateClusterRequest* requests with the same
+            ID, the second request will be ignored, and an operation created
+            for the first one and stored in the backend is returned.
+        :param retry: A retry object used to retry requests. If *None*, 
requests
+            will not be retried.
+        :param timeout: The amount of time, in seconds, to wait for the request
+            to complete. If *retry* is specified, the timeout applies to each
+            individual attempt.
+        :param metadata: Additional metadata that is provided to the method.
+        :return: An instance of ``google.api_core.operation.Operation``
+        """
+        client = self.get_cluster_client(region=region)
+        return client.stop_cluster(
+            request={
+                "project_id": project_id,
+                "region": region,
+                "cluster_name": cluster_name,
+                "cluster_uuid": cluster_uuid,
+                "request_id": request_id,
+            },
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+
     @GoogleBaseHook.fallback_to_default_project_id
     def create_workflow_template(
         self,
diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index 7f3fcd5d01..aacc1adb24 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -724,6 +724,17 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
             cluster = self._get_cluster(hook)
         return cluster
 
+    def _start_cluster(self, hook: DataprocHook):
+        op: operation.Operation = hook.start_cluster(
+            region=self.region,
+            project_id=self.project_id,
+            cluster_name=self.cluster_name,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+        return hook.wait_for_operation(timeout=self.timeout, 
result_retry=self.retry, operation=op)
+
     def execute(self, context: Context) -> dict:
         self.log.info("Creating cluster: %s", self.cluster_name)
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
@@ -801,6 +812,9 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
             # Create new cluster
             cluster = self._create_cluster(hook)
             self._handle_error_state(hook, cluster)
+        elif cluster.status.state == cluster.status.State.STOPPED:
+            # if the cluster exists and already stopped, then start the cluster
+            self._start_cluster(hook)
 
         return Cluster.to_dict(cluster)
 
@@ -1082,6 +1096,189 @@ class 
DataprocDeleteClusterOperator(GoogleCloudBaseOperator):
         )
 
 
+class _DataprocStartStopClusterBaseOperator(GoogleCloudBaseOperator):
+    """Base class to start or stop a cluster in a project.
+
+    :param cluster_name: Required. Name of the cluster to create
+    :param region: Required. The specified region where the dataproc cluster 
is created.
+    :param project_id: Optional. The ID of the Google Cloud project the 
cluster belongs to.
+    :param cluster_uuid: Optional. Specifying the ``cluster_uuid`` means the 
RPC should fail
+        if cluster with specified UUID does not exist.
+    :param request_id: Optional. A unique id used to identify the request. If 
the server receives two
+        ``DeleteClusterRequest`` requests with the same id, then the second 
request will be ignored and the
+        first ``google.longrunning.Operation`` created and stored in the 
backend is returned.
+    :param retry: A retry object used to retry requests. If ``None`` is 
specified, requests will not be
+        retried.
+    :param timeout: The amount of time, in seconds, to wait for the request to 
complete. Note that if
+        ``retry`` is specified, the timeout applies to each individual attempt.
+    :param metadata: Additional metadata that is provided to the method.
+    :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+    :param impersonation_chain: Optional service account to impersonate using 
short-term
+        credentials, or chained list of accounts required to get the 
access_token
+        of the last account in the list, which will be impersonated in the 
request.
+        If set as a string, the account must grant the originating account
+        the Service Account Token Creator IAM role.
+        If set as a sequence, the identities from the list must grant
+        Service Account Token Creator IAM role to the directly preceding 
identity, with first
+        account from the list granting this role to the originating account 
(templated).
+    """
+
+    template_fields = (
+        "cluster_name",
+        "region",
+        "project_id",
+        "request_id",
+        "impersonation_chain",
+    )
+
+    def __init__(
+        self,
+        *,
+        cluster_name: str,
+        region: str,
+        project_id: str | None = None,
+        cluster_uuid: str | None = None,
+        request_id: str | None = None,
+        retry: AsyncRetry | _MethodDefault = DEFAULT,
+        timeout: float = 1 * 60 * 60,
+        metadata: Sequence[tuple[str, str]] = (),
+        gcp_conn_id: str = "google_cloud_default",
+        impersonation_chain: str | Sequence[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.project_id = project_id
+        self.region = region
+        self.cluster_name = cluster_name
+        self.cluster_uuid = cluster_uuid
+        self.request_id = request_id
+        self.retry = retry
+        self.timeout = timeout
+        self.metadata = metadata
+        self.gcp_conn_id = gcp_conn_id
+        self.impersonation_chain = impersonation_chain
+        self._hook: DataprocHook | None = None
+
+    @property
+    def hook(self):
+        if self._hook is None:
+            self._hook = DataprocHook(
+                gcp_conn_id=self.gcp_conn_id,
+                impersonation_chain=self.impersonation_chain,
+            )
+        return self._hook
+
+    def _get_project_id(self) -> str:
+        return self.project_id or self.hook.project_id
+
+    def _get_cluster(self) -> Cluster:
+        """Retrieve the cluster information.
+
+        :return: Instance of ``google.cloud.dataproc_v1.Cluster``` class
+        """
+        return self.hook.get_cluster(
+            project_id=self._get_project_id(),
+            region=self.region,
+            cluster_name=self.cluster_name,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+    def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, 
str | None]:
+        """Implement this method in child class to return whether the cluster 
is in desired state or not.
+
+        If the cluster is in desired stated you can return a log message 
content as a second value
+        for the return tuple.
+
+        :param cluster: Required. Instance of 
``google.cloud.dataproc_v1.Cluster``
+            class to interact with Dataproc API
+        :return: Tuple of (Boolean, Optional[str]) The first value of the 
tuple is whether the cluster is
+            in desired state or not. The second value of the tuple will use if 
you want to log something when
+            the cluster is in desired state already.
+        """
+        raise NotImplementedError
+
+    def _get_operation(self) -> operation.Operation:
+        """Implement this method in child class to call the related hook 
method and return its result.
+
+        :return: ``google.api_core.operation.Operation`` value whether the 
cluster is in desired state or not
+        """
+        raise NotImplementedError
+
+    def execute(self, context: Context) -> dict | None:
+        cluster: Cluster = self._get_cluster()
+        is_already_desired_state, log_str = 
self._check_desired_cluster_state(cluster)
+        if is_already_desired_state:
+            self.log.info(log_str)
+            return None
+
+        op: operation.Operation = self._get_operation()
+        result = self.hook.wait_for_operation(timeout=self.timeout, 
result_retry=self.retry, operation=op)
+        return Cluster.to_dict(result)
+
+
+class DataprocStartClusterOperator(_DataprocStartStopClusterBaseOperator):
+    """Start a cluster in a project."""
+
+    operator_extra_links = (DataprocClusterLink(),)
+
+    def execute(self, context: Context) -> dict | None:
+        self.log.info("Starting the cluster: %s", self.cluster_name)
+        cluster = super().execute(context)
+        DataprocClusterLink.persist(
+            context=context,
+            operator=self,
+            cluster_id=self.cluster_name,
+            project_id=self._get_project_id(),
+            region=self.region,
+        )
+        self.log.info("Cluster started")
+        return cluster
+
+    def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, 
str | None]:
+        if cluster.status.state == cluster.status.State.RUNNING:
+            return True, f'The cluster "{self.cluster_name}" already running!'
+        return False, None
+
+    def _get_operation(self) -> operation.Operation:
+        return self.hook.start_cluster(
+            region=self.region,
+            project_id=self._get_project_id(),
+            cluster_name=self.cluster_name,
+            cluster_uuid=self.cluster_uuid,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+
+class DataprocStopClusterOperator(_DataprocStartStopClusterBaseOperator):
+    """Stop a cluster in a project."""
+
+    def execute(self, context: Context) -> dict | None:
+        self.log.info("Stopping the cluster: %s", self.cluster_name)
+        cluster = super().execute(context)
+        self.log.info("Cluster stopped")
+        return cluster
+
+    def _check_desired_cluster_state(self, cluster: Cluster) -> tuple[bool, 
str | None]:
+        if cluster.status.state in [cluster.status.State.STOPPED, 
cluster.status.State.STOPPING]:
+            return True, f'The cluster "{self.cluster_name}" already stopped!'
+        return False, None
+
+    def _get_operation(self) -> operation.Operation:
+        return self.hook.stop_cluster(
+            region=self.region,
+            project_id=self._get_project_id(),
+            cluster_name=self.cluster_name,
+            cluster_uuid=self.cluster_uuid,
+            retry=self.retry,
+            timeout=self.timeout,
+            metadata=self.metadata,
+        )
+
+
 class DataprocJobBaseOperator(GoogleCloudBaseOperator):
     """Base class for operators that launch job on DataProc.
 
diff --git a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst 
b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
index 67c2831a1a..6277f94e05 100644
--- a/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
+++ b/docs/apache-airflow-providers-google/operators/cloud/dataproc.rst
@@ -201,6 +201,30 @@ You can use deferrable mode for this action in order to 
run the operator asynchr
     :start-after: [START how_to_cloud_dataproc_update_cluster_operator_async]
     :end-before: [END how_to_cloud_dataproc_update_cluster_operator_async]
 
+Starting a cluster
+---------------------------
+
+To start a cluster you can use the
+:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
+    :language: python
+    :dedent: 4
+    :start-after: [START how_to_cloud_dataproc_start_cluster_operator]
+    :end-before: [END how_to_cloud_dataproc_start_cluster_operator]
+
+Stopping a cluster
+---------------------------
+
+To stop a cluster you can use the
+:class:`~airflow.providers.google.cloud.operators.dataproc.DataprocStartClusterOperator`:
+
+.. exampleinclude:: 
/../../tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
+    :language: python
+    :dedent: 4
+    :start-after: [START how_to_cloud_dataproc_stop_cluster_operator]
+    :end-before: [END how_to_cloud_dataproc_stop_cluster_operator]
+
 Deleting a cluster
 ------------------
 
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index db026aa6bf..bab56abead 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -403,6 +403,7 @@ class 
TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
         
"airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryToSqlBaseOperator",
         
"airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator",
         
"airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator",
+        
"airflow.providers.google.cloud.operators.dataproc._DataprocStartStopClusterBaseOperator",
         
"airflow.providers.google.cloud.operators.vertex_ai.custom_job.CustomTrainingJobBaseOperator",
         
"airflow.providers.google.cloud.operators.cloud_base.GoogleCloudBaseOperator",
     }
diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py 
b/tests/providers/google/cloud/hooks/test_dataproc.py
index 1a82fc8a1c..131f5a342b 100644
--- a/tests/providers/google/cloud/hooks/test_dataproc.py
+++ b/tests/providers/google/cloud/hooks/test_dataproc.py
@@ -287,6 +287,48 @@ class TestDataprocHook:
                 update_mask="update-mask",
             )
 
+    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
+    def test_start_cluster(self, mock_client):
+        self.hook.start_cluster(
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+        )
+        mock_client.assert_called_once_with(region=GCP_LOCATION)
+        mock_client.return_value.start_cluster.assert_called_once_with(
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                cluster_name=CLUSTER_NAME,
+                cluster_uuid=None,
+                request_id=None,
+            ),
+            metadata=(),
+            retry=DEFAULT,
+            timeout=None,
+        )
+
+    @mock.patch(DATAPROC_STRING.format("DataprocHook.get_cluster_client"))
+    def test_stop_cluster(self, mock_client):
+        self.hook.stop_cluster(
+            region=GCP_LOCATION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+        )
+        mock_client.assert_called_once_with(region=GCP_LOCATION)
+        mock_client.return_value.stop_cluster.assert_called_once_with(
+            request=dict(
+                project_id=GCP_PROJECT,
+                region=GCP_LOCATION,
+                cluster_name=CLUSTER_NAME,
+                cluster_uuid=None,
+                request_id=None,
+            ),
+            metadata=(),
+            retry=DEFAULT,
+            timeout=None,
+        )
+
     @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client"))
     def test_create_workflow_template(self, mock_client):
         template = {"test": "test"}
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index d0b04a6fa9..44e20489a2 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -54,6 +54,8 @@ from airflow.providers.google.cloud.operators.dataproc import 
(
     DataprocLink,
     DataprocListBatchesOperator,
     DataprocScaleClusterOperator,
+    DataprocStartClusterOperator,
+    DataprocStopClusterOperator,
     DataprocSubmitHadoopJobOperator,
     DataprocSubmitHiveJobOperator,
     DataprocSubmitJobOperator,
@@ -1683,6 +1685,90 @@ def test_update_cluster_operator_extra_links(dag_maker, 
create_task_instance_of_
     assert ti.task.get_extra_links(ti, DataprocClusterLink.name) == 
DATAPROC_CLUSTER_LINK_EXPECTED
 
 
+class TestDataprocStartClusterOperator(DataprocClusterTestBase):
+    @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute(self, mock_hook, mock_to_dict):
+        cluster = MagicMock()
+        cluster.status.State.RUNNING = 3
+        cluster.status.state = 0
+        mock_hook.return_value.get_cluster.return_value = cluster
+
+        op = DataprocStartClusterOperator(
+            task_id=TASK_ID,
+            cluster_name=CLUSTER_NAME,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute(context=self.mock_context)
+
+        mock_hook.return_value.get_cluster.assert_called_with(
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.start_cluster.assert_called_once_with(
+            cluster_name=CLUSTER_NAME,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_uuid=None,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+
+class TestDataprocStopClusterOperator(DataprocClusterTestBase):
+    @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute(self, mock_hook, mock_to_dict):
+        cluster = MagicMock()
+        cluster.status.State.STOPPED = 4
+        cluster.status.state = 0
+        mock_hook.return_value.get_cluster.return_value = cluster
+
+        op = DataprocStopClusterOperator(
+            task_id=TASK_ID,
+            cluster_name=CLUSTER_NAME,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        op.execute(context=self.mock_context)
+
+        mock_hook.return_value.get_cluster.assert_called_with(
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.stop_cluster.assert_called_once_with(
+            cluster_name=CLUSTER_NAME,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_uuid=None,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+
 class TestDataprocInstantiateWorkflowTemplateOperator:
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     def test_execute(self, mock_hook):
diff --git 
a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py
 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py
new file mode 100644
index 0000000000..6a77a14684
--- /dev/null
+++ 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_create_existing_stopped_cluster.py
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocCreateClusterOperator in case of the cluster 
is already existing and stopped.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.dataproc import (
+    DataprocCreateClusterOperator,
+    DataprocDeleteClusterOperator,
+    DataprocStartClusterOperator,
+    DataprocStopClusterOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "example_dataproc_cluster_create_existing_stopped_cluster"
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEMS_TESTS_GCP_PROJECTS")
+
+CLUSTER_NAME = f"cluster-{ENV_ID}-{DAG_ID}".replace("_", "-")
+REGION = "europe-west1"
+
+# Cluster definition
+CLUSTER_CONFIG = {
+    "master_config": {
+        "num_instances": 1,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+    "worker_config": {
+        "num_instances": 2,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+}
+
+with DAG(
+    DAG_ID, schedule="@once", start_date=datetime(2024, 1, 1), catchup=False, 
tags=["dataproc", "example"]
+) as dag:
+    create_cluster = DataprocCreateClusterOperator(
+        task_id="create_cluster",
+        project_id=PROJECT_ID,
+        cluster_config=CLUSTER_CONFIG,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+        use_if_exists=True,
+    )
+
+    start_cluster = DataprocStartClusterOperator(
+        task_id="start_cluster",
+        project_id=PROJECT_ID,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+    )
+
+    stop_cluster = DataprocStopClusterOperator(
+        task_id="stop_cluster",
+        project_id=PROJECT_ID,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+    )
+
+    create_cluster_for_stopped_cluster = DataprocCreateClusterOperator(
+        task_id="create_cluster_for_stopped_cluster",
+        project_id=PROJECT_ID,
+        cluster_config=CLUSTER_CONFIG,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+        use_if_exists=True,
+    )
+
+    delete_cluster = DataprocDeleteClusterOperator(
+        task_id="delete_cluster",
+        project_id=PROJECT_ID,
+        cluster_name=CLUSTER_NAME,
+        region=REGION,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    (
+        # TEST SETUP
+        create_cluster
+        >> stop_cluster
+        >> start_cluster
+        # TEST BODY
+        >> create_cluster_for_stopped_cluster
+        # TEST TEARDOWN
+        >> delete_cluster
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "teardown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git 
a/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
new file mode 100644
index 0000000000..7dcb127cd6
--- /dev/null
+++ 
b/tests/system/providers/google/cloud/dataproc/example_dataproc_cluster_start_stop.py
@@ -0,0 +1,114 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example Airflow DAG for DataprocStartClusterOperator and 
DataprocStopClusterOperator.
+"""
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models.dag import DAG
+from airflow.providers.google.cloud.operators.dataproc import (
+    DataprocCreateClusterOperator,
+    DataprocDeleteClusterOperator,
+    DataprocStartClusterOperator,
+    DataprocStopClusterOperator,
+)
+from airflow.utils.trigger_rule import TriggerRule
+
+DAG_ID = "dataproc_cluster_start_stop"
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+PROJECT_ID = os.environ.get("SYSTEMS_TESTS_GCP_PROJECTS")
+
+CLUSTER_NAME = f"cluster-{ENV_ID}-{DAG_ID}".replace("_", "-")
+REGION = "europe-west1"
+
+# Cluster definition
+CLUSTER_CONFIG = {
+    "master_config": {
+        "num_instances": 1,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+    "worker_config": {
+        "num_instances": 2,
+        "machine_type_uri": "n1-standard-4",
+        "disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 
32},
+    },
+}
+
+with DAG(
+    DAG_ID, schedule="@once", start_date=datetime(2024, 1, 1), catchup=False, 
tags=["dataproc", "example"]
+) as dag:
+    create_cluster = DataprocCreateClusterOperator(
+        task_id="create_cluster",
+        project_id=PROJECT_ID,
+        cluster_config=CLUSTER_CONFIG,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+        use_if_exists=True,
+    )
+
+    # [START how_to_cloud_dataproc_start_cluster_operator]
+    start_cluster = DataprocStartClusterOperator(
+        task_id="start_cluster",
+        project_id=PROJECT_ID,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+    )
+    # [END how_to_cloud_dataproc_start_cluster_operator]
+
+    # [START how_to_cloud_dataproc_stop_cluster_operator]
+    stop_cluster = DataprocStopClusterOperator(
+        task_id="stop_cluster",
+        project_id=PROJECT_ID,
+        region=REGION,
+        cluster_name=CLUSTER_NAME,
+    )
+    # [END how_to_cloud_dataproc_stop_cluster_operator]
+
+    delete_cluster = DataprocDeleteClusterOperator(
+        task_id="delete_cluster",
+        project_id=PROJECT_ID,
+        cluster_name=CLUSTER_NAME,
+        region=REGION,
+        trigger_rule=TriggerRule.ALL_DONE,
+    )
+
+    (
+        # TEST SETUP
+        create_cluster
+        # TEST BODY
+        >> stop_cluster
+        >> start_cluster
+        # TEST TEARDOWN
+        >> delete_cluster
+    )
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "teardown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)

Reply via email to