This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c7f518fe09 Use base aws classes in AWS Datasync Operators (#36766)
c7f518fe09 is described below
commit c7f518fe0963e6957f0e57519177788217f9bc01
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Jan 15 03:05:54 2024 +0400
Use base aws classes in AWS Datasync Operators (#36766)
---
airflow/providers/amazon/aws/operators/datasync.py | 37 +++++++++++-----------
.../operators/datasync.rst | 5 +++
.../amazon/aws/operators/test_datasync.py | 36 +++++++++++++++++++++
3 files changed, 60 insertions(+), 18 deletions(-)
diff --git a/airflow/providers/amazon/aws/operators/datasync.py
b/airflow/providers/amazon/aws/operators/datasync.py
index f990ffbb70..c280b53102 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -19,20 +19,20 @@ from __future__ import annotations
import logging
import random
-from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from deprecated.classic import deprecated
from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowTaskTimeout
-from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.datasync import DataSyncHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
from airflow.utils.context import Context
-class DataSyncOperator(BaseOperator):
+class DataSyncOperator(AwsBaseOperator[DataSyncHook]):
"""Find, Create, Update, Execute and Delete AWS DataSync Tasks.
If ``do_xcom_push`` is True, then the DataSync TaskArn and TaskExecutionArn
@@ -46,7 +46,6 @@ class DataSyncOperator(BaseOperator):
environment. The default behavior is to create a new Task if there are
0, or
execute the Task if there was 1 Task, or fail if there were many Tasks.
- :param aws_conn_id: AWS connection to use.
:param wait_interval_seconds: Time to wait between two
consecutive calls to check TaskExecution status.
:param max_iterations: Maximum number of
@@ -91,6 +90,16 @@ class DataSyncOperator(BaseOperator):
``boto3.start_task_execution(TaskArn=task_arn,
**task_execution_kwargs)``
:param delete_task_after_execution: If True then the TaskArn which was
executed
will be deleted from AWS DataSync on successful completion.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
:raises AirflowException: If ``task_arn`` was not specified, or if
either ``source_location_uri`` or ``destination_location_uri`` were
not specified.
@@ -100,7 +109,8 @@ class DataSyncOperator(BaseOperator):
:raises AirflowException: If Task creation, update, execution or delete
fails.
"""
- template_fields: Sequence[str] = (
+ aws_hook_class = DataSyncHook
+ template_fields: Sequence[str] = aws_template_fields(
"task_arn",
"source_location_uri",
"destination_location_uri",
@@ -122,7 +132,6 @@ class DataSyncOperator(BaseOperator):
def __init__(
self,
*,
- aws_conn_id: str = "aws_default",
wait_interval_seconds: int = 30,
max_iterations: int = 60,
wait_for_completion: bool = True,
@@ -142,7 +151,6 @@ class DataSyncOperator(BaseOperator):
super().__init__(**kwargs)
# Assignments
- self.aws_conn_id = aws_conn_id
self.wait_interval_seconds = wait_interval_seconds
self.max_iterations = max_iterations
self.wait_for_completion = wait_for_completion
@@ -185,16 +193,9 @@ class DataSyncOperator(BaseOperator):
self.destination_location_arn: str | None = None
self.task_execution_arn: str | None = None
- @cached_property
- def hook(self) -> DataSyncHook:
- """Create and return DataSyncHook.
-
- :return DataSyncHook: An DataSyncHook instance.
- """
- return DataSyncHook(
- aws_conn_id=self.aws_conn_id,
- wait_interval_seconds=self.wait_interval_seconds,
- )
+ @property
+ def _hook_parameters(self) -> dict[str, Any]:
+ return {**super()._hook_parameters, "wait_interval_seconds":
self.wait_interval_seconds}
@deprecated(reason="use `hook` property instead.",
category=AirflowProviderDeprecationWarning)
def get_hook(self) -> DataSyncHook:
diff --git a/docs/apache-airflow-providers-amazon/operators/datasync.rst
b/docs/apache-airflow-providers-amazon/operators/datasync.rst
index aca6d5d755..16e65db42d 100644
--- a/docs/apache-airflow-providers-amazon/operators/datasync.rst
+++ b/docs/apache-airflow-providers-amazon/operators/datasync.rst
@@ -28,6 +28,11 @@ Prerequisite Tasks
.. include:: ../_partials/prerequisite_tasks.rst
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
Operators
---------
diff --git a/tests/providers/amazon/aws/operators/test_datasync.py
b/tests/providers/amazon/aws/operators/test_datasync.py
index 829dca7082..fa666dd476 100644
--- a/tests/providers/amazon/aws/operators/test_datasync.py
+++ b/tests/providers/amazon/aws/operators/test_datasync.py
@@ -108,6 +108,42 @@ class DataSyncTestCaseBase:
self.client = None
+def test_generic_params():
+ op = DataSyncOperator(
+ task_id="generic-task",
+ task_arn="arn:fake",
+ source_location_uri="fake://source",
+ destination_location_uri="fake://destination",
+ aws_conn_id="fake-conn-id",
+ region_name="cn-north-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ # Non-generic hook params
+ wait_interval_seconds=42,
+ )
+
+ assert op.hook.client_type == "datasync"
+ assert op.hook.resource_type is None
+ assert op.hook.aws_conn_id == "fake-conn-id"
+ assert op.hook._region_name == "cn-north-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+ assert op.hook.wait_interval_seconds == 42
+
+ op = DataSyncOperator(
+ task_id="generic-task",
+ task_arn="arn:fake",
+ source_location_uri="fake://source",
+ destination_location_uri="fake://destination",
+ )
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+ assert op.hook.wait_interval_seconds is not None
+
+
@mock_datasync
@mock.patch.object(DataSyncHook, "get_conn")
class TestDataSyncOperatorCreate(DataSyncTestCaseBase):