This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c7f518fe09 Use base aws classes in AWS Datasync Operators (#36766)
c7f518fe09 is described below

commit c7f518fe0963e6957f0e57519177788217f9bc01
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Jan 15 03:05:54 2024 +0400

    Use base aws classes in AWS Datasync Operators (#36766)
---
 airflow/providers/amazon/aws/operators/datasync.py | 37 +++++++++++-----------
 .../operators/datasync.rst                         |  5 +++
 .../amazon/aws/operators/test_datasync.py          | 36 +++++++++++++++++++++
 3 files changed, 60 insertions(+), 18 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/datasync.py 
b/airflow/providers/amazon/aws/operators/datasync.py
index f990ffbb70..c280b53102 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -19,20 +19,20 @@ from __future__ import annotations
 
 import logging
 import random
-from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from deprecated.classic import deprecated
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowTaskTimeout
-from airflow.models import BaseOperator
 from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
-class DataSyncOperator(BaseOperator):
+class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
     """Find, Create, Update, Execute and Delete AWS DataSync Tasks.
 
     If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn
@@ -46,7 +46,6 @@ class DataSyncOperator(BaseOperator):
         environment. The default behavior is to create a new Task if there are 
0, or
         execute the Task if there was 1 Task, or fail if there were many Tasks.
 
-    :param aws_conn_id: AWS connection to use.
     :param wait_interval_seconds: Time to wait between two
         consecutive calls to check TaskExecution status.
     :param max_iterations: Maximum number of
@@ -91,6 +90,16 @@ class DataSyncOperator(BaseOperator):
         ``boto3.start_task_execution(TaskArn=task_arn, 
**task_execution_kwargs)``
     :param  delete_task_after_execution: If True then the TaskArn which was 
executed
         will be deleted from AWS DataSync on successful completion.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
     :raises AirflowException: If ``task_arn`` was not specified, or if
         either ``source_location_uri`` or ``destination_location_uri`` were
         not specified.
@@ -100,7 +109,8 @@ class DataSyncOperator(BaseOperator):
     :raises AirflowException: If Task creation, update, execution or delete 
fails.
     """
 
-    template_fields: Sequence[str] = (
+    aws_hook_class = DataSyncHook
+    template_fields: Sequence[str] = aws_template_fields(
         "task_arn",
         "source_location_uri",
         "destination_location_uri",
@@ -122,7 +132,6 @@ class DataSyncOperator(BaseOperator):
     def __init__(
         self,
         *,
-        aws_conn_id: str = "aws_default",
         wait_interval_seconds: int = 30,
         max_iterations: int = 60,
         wait_for_completion: bool = True,
@@ -142,7 +151,6 @@ class DataSyncOperator(BaseOperator):
         super().__init__(**kwargs)
 
         # Assignments
-        self.aws_conn_id = aws_conn_id
         self.wait_interval_seconds = wait_interval_seconds
         self.max_iterations = max_iterations
         self.wait_for_completion = wait_for_completion
@@ -185,16 +193,9 @@ class DataSyncOperator(BaseOperator):
         self.destination_location_arn: str | None = None
         self.task_execution_arn: str | None = None
 
-    @cached_property
-    def hook(self) -> DataSyncHook:
-        """Create and return DataSyncHook.
-
-        :return DataSyncHook: An DataSyncHook instance.
-        """
-        return DataSyncHook(
-            aws_conn_id=self.aws_conn_id,
-            wait_interval_seconds=self.wait_interval_seconds,
-        )
+    @property
+    def _hook_parameters(self) -> dict[str, Any]:
+        return {**super()._hook_parameters, "wait_interval_seconds": 
self.wait_interval_seconds}
 
     @deprecated(reason="use `hook` property instead.", 
category=AirflowProviderDeprecationWarning)
     def get_hook(self) -> DataSyncHook:
diff --git a/docs/apache-airflow-providers-amazon/operators/datasync.rst 
b/docs/apache-airflow-providers-amazon/operators/datasync.rst
index aca6d5d755..16e65db42d 100644
--- a/docs/apache-airflow-providers-amazon/operators/datasync.rst
+++ b/docs/apache-airflow-providers-amazon/operators/datasync.rst
@@ -28,6 +28,11 @@ Prerequisite Tasks
 
 .. include:: ../_partials/prerequisite_tasks.rst
 
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
 Operators
 ---------
 
diff --git a/tests/providers/amazon/aws/operators/test_datasync.py 
b/tests/providers/amazon/aws/operators/test_datasync.py
index 829dca7082..fa666dd476 100644
--- a/tests/providers/amazon/aws/operators/test_datasync.py
+++ b/tests/providers/amazon/aws/operators/test_datasync.py
@@ -108,6 +108,42 @@ class DataSyncTestCaseBase:
         self.client = None
 
 
+def test_generic_params():
+    op = DataSyncOperator(
+        task_id="generic-task",
+        task_arn="arn:fake",
+        source_location_uri="fake://source",
+        destination_location_uri="fake://destination",
+        aws_conn_id="fake-conn-id",
+        region_name="cn-north-1",
+        verify=False,
+        botocore_config={"read_timeout": 42},
+        # Non-generic hook params
+        wait_interval_seconds=42,
+    )
+
+    assert op.hook.client_type == "datasync"
+    assert op.hook.resource_type is None
+    assert op.hook.aws_conn_id == "fake-conn-id"
+    assert op.hook._region_name == "cn-north-1"
+    assert op.hook._verify is False
+    assert op.hook._config is not None
+    assert op.hook._config.read_timeout == 42
+    assert op.hook.wait_interval_seconds == 42
+
+    op = DataSyncOperator(
+        task_id="generic-task",
+        task_arn="arn:fake",
+        source_location_uri="fake://source",
+        destination_location_uri="fake://destination",
+    )
+    assert op.hook.aws_conn_id == "aws_default"
+    assert op.hook._region_name is None
+    assert op.hook._verify is None
+    assert op.hook._config is None
+    assert op.hook.wait_interval_seconds is not None
+
+
 @mock_datasync
 @mock.patch.object(DataSyncHook, "get_conn")
 class TestDataSyncOperatorCreate(DataSyncTestCaseBase):

Reply via email to