This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3721c9a441 Use base aws classes in Amazon S3 Glacier Operators/Sensors
(#35108)
3721c9a441 is described below
commit 3721c9a4413d3f5002b46589beeff490827cd9cb
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Oct 24 18:54:21 2023 +0400
Use base aws classes in Amazon S3 Glacier Operators/Sensors (#35108)
---
airflow/providers/amazon/aws/hooks/glacier.py | 6 +-
airflow/providers/amazon/aws/operators/glacier.py | 30 ++++-----
airflow/providers/amazon/aws/sensors/glacier.py | 15 ++---
.../operators/s3/glacier.rst | 5 ++
.../providers/amazon/aws/operators/test_glacier.py | 72 +++++++++++++++++-----
tests/providers/amazon/aws/sensors/test_glacier.py | 65 ++++++++++++-------
6 files changed, 124 insertions(+), 69 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/glacier.py
b/airflow/providers/amazon/aws/hooks/glacier.py
index bd260000e7..4655b28e30 100644
--- a/airflow/providers/amazon/aws/hooks/glacier.py
+++ b/airflow/providers/amazon/aws/hooks/glacier.py
@@ -35,9 +35,9 @@ class GlacierHook(AwsBaseHook):
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
- def __init__(self, aws_conn_id: str = "aws_default") -> None:
- super().__init__(client_type="glacier")
- self.aws_conn_id = aws_conn_id
+ def __init__(self, *args, **kwargs) -> None:
+ kwargs.update({"client_type": "glacier", "resource_type": None})
+ super().__init__(*args, **kwargs)
def retrieve_inventory(self, vault_name: str) -> dict[str, Any]:
"""Initiate an Amazon Glacier inventory-retrieval job.
diff --git a/airflow/providers/amazon/aws/operators/glacier.py
b/airflow/providers/amazon/aws/operators/glacier.py
index 54123e586d..3164004181 100644
--- a/airflow/providers/amazon/aws/operators/glacier.py
+++ b/airflow/providers/amazon/aws/operators/glacier.py
@@ -19,14 +19,15 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
-from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
from airflow.utils.context import Context
-class GlacierCreateJobOperator(BaseOperator):
+class GlacierCreateJobOperator(AwsBaseOperator[GlacierHook]):
"""
Initiate an Amazon Glacier inventory-retrieval job.
@@ -38,25 +39,18 @@ class GlacierCreateJobOperator(BaseOperator):
:param vault_name: the Glacier vault on which job is executed
"""
- template_fields: Sequence[str] = ("vault_name",)
+ aws_hook_class = GlacierHook
+ template_fields: Sequence[str] = aws_template_fields("vault_name")
- def __init__(
- self,
- *,
- aws_conn_id="aws_default",
- vault_name: str,
- **kwargs,
- ):
+ def __init__(self, *, vault_name: str, **kwargs):
super().__init__(**kwargs)
- self.aws_conn_id = aws_conn_id
self.vault_name = vault_name
def execute(self, context: Context):
- hook = GlacierHook(aws_conn_id=self.aws_conn_id)
- return hook.retrieve_inventory(vault_name=self.vault_name)
+ return self.hook.retrieve_inventory(vault_name=self.vault_name)
-class GlacierUploadArchiveOperator(BaseOperator):
+class GlacierUploadArchiveOperator(AwsBaseOperator[GlacierHook]):
"""
This operator add an archive to an Amazon S3 Glacier vault.
@@ -74,7 +68,8 @@ class GlacierUploadArchiveOperator(BaseOperator):
:param aws_conn_id: The reference to the AWS connection details
"""
- template_fields: Sequence[str] = ("vault_name",)
+ aws_hook_class = GlacierHook
+ template_fields: Sequence[str] = aws_template_fields("vault_name")
def __init__(
self,
@@ -84,11 +79,9 @@ class GlacierUploadArchiveOperator(BaseOperator):
checksum: str | None = None,
archive_description: str | None = None,
account_id: str | None = None,
- aws_conn_id="aws_default",
**kwargs,
):
super().__init__(**kwargs)
- self.aws_conn_id = aws_conn_id
self.account_id = account_id
self.vault_name = vault_name
self.body = body
@@ -96,8 +89,7 @@ class GlacierUploadArchiveOperator(BaseOperator):
self.archive_description = archive_description
def execute(self, context: Context):
- hook = GlacierHook(aws_conn_id=self.aws_conn_id)
- return hook.get_conn().upload_archive(
+ return self.hook.conn.upload_archive(
accountId=self.account_id,
vaultName=self.vault_name,
archiveDescription=self.archive_description,
diff --git a/airflow/providers/amazon/aws/sensors/glacier.py
b/airflow/providers/amazon/aws/sensors/glacier.py
index e9cc8fc4b7..7a65fc6fc3 100644
--- a/airflow/providers/amazon/aws/sensors/glacier.py
+++ b/airflow/providers/amazon/aws/sensors/glacier.py
@@ -18,12 +18,12 @@
from __future__ import annotations
from enum import Enum
-from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence
from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -36,7 +36,7 @@ class JobStatus(Enum):
SUCCEEDED = "Succeeded"
-class GlacierJobOperationSensor(BaseSensorOperator):
+class GlacierJobOperationSensor(AwsBaseSensor[GlacierHook]):
"""
Glacier sensor for checking job state. This operator runs only in
reschedule mode.
@@ -63,12 +63,12 @@ class GlacierJobOperationSensor(BaseSensorOperator):
prevent too much load on the scheduler.
"""
- template_fields: Sequence[str] = ("vault_name", "job_id")
+ aws_hook_class = GlacierHook
+ template_fields: Sequence[str] = aws_template_fields("vault_name",
"job_id")
def __init__(
self,
*,
- aws_conn_id: str = "aws_default",
vault_name: str,
job_id: str,
poke_interval: int = 60 * 20,
@@ -76,16 +76,11 @@ class GlacierJobOperationSensor(BaseSensorOperator):
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
- self.aws_conn_id = aws_conn_id
self.vault_name = vault_name
self.job_id = job_id
self.poke_interval = poke_interval
self.mode = mode
- @cached_property
- def hook(self):
- return GlacierHook(aws_conn_id=self.aws_conn_id)
-
def poke(self, context: Context) -> bool:
response = self.hook.describe_job(vault_name=self.vault_name,
job_id=self.job_id)
diff --git a/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
b/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
index 9dca7a776c..c85e7ac294 100644
--- a/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
+++ b/docs/apache-airflow-providers-amazon/operators/s3/glacier.rst
@@ -27,6 +27,11 @@ Prerequisite Tasks
.. include:: ../../_partials/prerequisite_tasks.rst
+Generic Parameters
+------------------
+
+.. include:: ../../_partials/generic_parameters.rst
+
Operators
---------
diff --git a/tests/providers/amazon/aws/operators/test_glacier.py
b/tests/providers/amazon/aws/operators/test_glacier.py
index d9afe50511..4dbd8f2f5a 100644
--- a/tests/providers/amazon/aws/operators/test_glacier.py
+++ b/tests/providers/amazon/aws/operators/test_glacier.py
@@ -17,13 +17,19 @@
# under the License.
from __future__ import annotations
+from typing import TYPE_CHECKING, Any
from unittest import mock
+import pytest
+
from airflow.providers.amazon.aws.operators.glacier import (
GlacierCreateJobOperator,
GlacierUploadArchiveOperator,
)
+if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+
AWS_CONN_ID = "aws_default"
BUCKET_NAME = "airflow_bucket"
FILENAME = "path/to/file/"
@@ -34,22 +40,60 @@ TASK_ID = "glacier_job"
VAULT_NAME = "airflow"
-class TestGlacierCreateJobOperator:
- @mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook")
+class BaseGlacierOperatorsTests:
+ op_class: type[AwsBaseOperator]
+ default_op_kwargs: dict[str, Any]
+
+ def test_base_aws_op_attributes(self):
+ op = self.op_class(**self.default_op_kwargs)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+
+ op = self.op_class(
+ **self.default_op_kwargs,
+ aws_conn_id="aws-test-custom-conn",
+ region_name="eu-west-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == "aws-test-custom-conn"
+ assert op.hook._region_name == "eu-west-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+
+class TestGlacierCreateJobOperator(BaseGlacierOperatorsTests):
+ op_class = GlacierCreateJobOperator
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self):
+ self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id": TASK_ID}
+
+ @mock.patch.object(GlacierCreateJobOperator, "hook",
new_callable=mock.PropertyMock)
def test_execute(self, hook_mock):
- op = GlacierCreateJobOperator(aws_conn_id=AWS_CONN_ID,
vault_name=VAULT_NAME, task_id=TASK_ID)
+ op = self.op_class(aws_conn_id=None, **self.default_op_kwargs)
op.execute(mock.MagicMock())
- hook_mock.assert_called_once_with(aws_conn_id=AWS_CONN_ID)
hook_mock.return_value.retrieve_inventory.assert_called_once_with(vault_name=VAULT_NAME)
-class TestGlacierUploadArchiveOperator:
-
@mock.patch("airflow.providers.amazon.aws.operators.glacier.GlacierHook.get_conn")
- def test_execute(self, hook_mock):
- op = GlacierUploadArchiveOperator(
- aws_conn_id=AWS_CONN_ID, vault_name=VAULT_NAME, body=b"Test Data",
task_id=TASK_ID
- )
- op.execute(mock.MagicMock())
- hook_mock.return_value.upload_archive.assert_called_once_with(
- accountId=None, vaultName=VAULT_NAME, archiveDescription=None,
body=b"Test Data", checksum=None
- )
+class TestGlacierUploadArchiveOperator(BaseGlacierOperatorsTests):
+ op_class = GlacierUploadArchiveOperator
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self):
+ self.default_op_kwargs = {"vault_name": VAULT_NAME, "task_id":
TASK_ID, "body": b"Test Data"}
+
+ def test_execute(self):
+ with mock.patch.object(self.op_class.aws_hook_class, "conn",
new_callable=mock.PropertyMock) as m:
+ op = self.op_class(aws_conn_id=None, **self.default_op_kwargs)
+ op.execute(mock.MagicMock())
+ m.return_value.upload_archive.assert_called_once_with(
+ accountId=None,
+ vaultName=VAULT_NAME,
+ archiveDescription=None,
+ body=b"Test Data",
+ checksum=None,
+ )
diff --git a/tests/providers/amazon/aws/sensors/test_glacier.py
b/tests/providers/amazon/aws/sensors/test_glacier.py
index 4213eed9d0..5019a4dd0c 100644
--- a/tests/providers/amazon/aws/sensors/test_glacier.py
+++ b/tests/providers/amazon/aws/sensors/test_glacier.py
@@ -28,49 +28,68 @@ SUCCEEDED = "Succeeded"
IN_PROGRESS = "InProgress"
[email protected]
+def mocked_describe_job():
+ with
mock.patch("airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job")
as m:
+ yield m
+
+
class TestAmazonGlacierSensor:
def setup_method(self):
- self.op = GlacierJobOperationSensor(
+ self.default_op_kwargs = dict(
task_id="test_athena_sensor",
- aws_conn_id="aws_default",
vault_name="airflow",
job_id="1a2b3c4d",
poke_interval=60 * 20,
)
+ self.op = GlacierJobOperationSensor(**self.default_op_kwargs,
aws_conn_id=None)
- @mock.patch(
-
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
- side_effect=[{"Action": "", "StatusCode": JobStatus.SUCCEEDED.value}],
- )
- def test_poke_succeeded(self, _):
+ def test_base_aws_op_attributes(self):
+ op = GlacierJobOperationSensor(**self.default_op_kwargs)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
+
+ op = GlacierJobOperationSensor(
+ **self.default_op_kwargs,
+ aws_conn_id="aws-test-custom-conn",
+ region_name="eu-west-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
+ )
+ assert op.hook.aws_conn_id == "aws-test-custom-conn"
+ assert op.hook._region_name == "eu-west-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+ def test_poke_succeeded(self, mocked_describe_job):
+ mocked_describe_job.side_effect = [{"Action": "", "StatusCode":
JobStatus.SUCCEEDED.value}]
assert self.op.poke(None)
- @mock.patch(
-
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
- side_effect=[{"Action": "", "StatusCode":
JobStatus.IN_PROGRESS.value}],
- )
- def test_poke_in_progress(self, _):
+ def test_poke_in_progress(self, mocked_describe_job):
+ mocked_describe_job.side_effect = [{"Action": "", "StatusCode":
JobStatus.IN_PROGRESS.value}]
assert not self.op.poke(None)
- @mock.patch(
-
"airflow.providers.amazon.aws.sensors.glacier.GlacierHook.describe_job",
- side_effect=[{"Action": "", "StatusCode": ""}],
- )
- def test_poke_fail(self, _):
- with pytest.raises(AirflowException) as ctx:
+ def test_poke_fail(self, mocked_describe_job):
+ mocked_describe_job.side_effect = [{"Action": "", "StatusCode": ""}]
+ with pytest.raises(AirflowException, match="Sensor failed"):
self.op.poke(None)
- assert "Sensor failed" in str(ctx.value)
@pytest.mark.parametrize(
- "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ "soft_fail, expected_exception",
+ [
+ pytest.param(False, AirflowException, id="not-soft-fail"),
+ pytest.param(True, AirflowSkipException, id="soft-fail"),
+ ],
)
-
@mock.patch("airflow.providers.amazon.aws.hooks.glacier.GlacierHook.describe_job")
- def test_fail_poke(self, describe_job, soft_fail, expected_exception):
+ def test_fail_poke(self, soft_fail, expected_exception,
mocked_describe_job):
self.op.soft_fail = soft_fail
response = {"Action": "some action", "StatusCode": "Failed"}
message = f'Sensor failed. Job status: {response["Action"]}, code
status: {response["StatusCode"]}'
with pytest.raises(expected_exception, match=message):
- describe_job.return_value = response
+ mocked_describe_job.return_value = response
self.op.poke(context={})