This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9eab3e199e Use base aws classes in Amazon QuickSight Operators/Sensors
(#36776)
9eab3e199e is described below
commit 9eab3e199ecfcaca2c39cfcf66ff4d7fe83c69ef
Author: Andrey Anshin <[email protected]>
AuthorDate: Mon Jan 15 03:15:16 2024 +0400
Use base aws classes in Amazon QuickSight Operators/Sensors (#36776)
---
airflow/providers/amazon/aws/hooks/base_aws.py | 14 ++
airflow/providers/amazon/aws/hooks/quicksight.py | 51 +++--
.../providers/amazon/aws/operators/quicksight.py | 41 ++--
airflow/providers/amazon/aws/sensors/quicksight.py | 58 +++--
.../operators/quicksight.rst | 5 +
tests/providers/amazon/aws/hooks/test_base_aws.py | 4 +
.../providers/amazon/aws/hooks/test_quicksight.py | 245 +++++++++++++++------
.../amazon/aws/operators/test_quicksight.py | 42 +++-
.../amazon/aws/sensors/test_quicksight.py | 133 ++++++-----
9 files changed, 402 insertions(+), 191 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py
b/airflow/providers/amazon/aws/hooks/base_aws.py
index d6e0762a1a..635a874e26 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -629,6 +629,20 @@ class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
"""Verify or not SSL certificates boto3 client/resource read-only
property."""
return self.conn_config.verify
+ @cached_property
+ def account_id(self) -> str:
+ """Return associated AWS Account ID."""
+ return (
+ self.get_session(region_name=self.region_name)
+ .client(
+ service_name="sts",
+ endpoint_url=self.conn_config.get_service_endpoint_url("sts"),
+ config=self.config,
+ verify=self.verify,
+ )
+ .get_caller_identity()["Account"]
+ )
+
def get_session(self, region_name: str | None = None, deferrable: bool =
False) -> boto3.session.Session:
"""Get the underlying
boto3.session.Session(region_name=region_name)."""
return SessionFactory(
diff --git a/airflow/providers/amazon/aws/hooks/quicksight.py
b/airflow/providers/amazon/aws/hooks/quicksight.py
index 6ee7c5bfd4..1106a793c1 100644
--- a/airflow/providers/amazon/aws/hooks/quicksight.py
+++ b/airflow/providers/amazon/aws/hooks/quicksight.py
@@ -18,13 +18,13 @@
from __future__ import annotations
import time
+import warnings
from functools import cached_property
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.hooks.sts import StsHook
class QuickSightHook(AwsBaseHook):
@@ -46,10 +46,6 @@ class QuickSightHook(AwsBaseHook):
def __init__(self, *args, **kwargs):
super().__init__(client_type="quicksight", *args, **kwargs)
- @cached_property
- def sts_hook(self):
- return StsHook(aws_conn_id=self.aws_conn_id)
-
def create_ingestion(
self,
data_set_id: str,
@@ -57,6 +53,7 @@ class QuickSightHook(AwsBaseHook):
ingestion_type: str,
wait_for_completion: bool = True,
check_interval: int = 30,
+ aws_account_id: str | None = None,
) -> dict:
"""
Create and start a new SPICE ingestion for a dataset; refresh the
SPICE datasets.
@@ -66,18 +63,18 @@ class QuickSightHook(AwsBaseHook):
:param data_set_id: ID of the dataset used in the ingestion.
:param ingestion_id: ID for the ingestion.
- :param ingestion_type: Type of ingestion .
"INCREMENTAL_REFRESH"|"FULL_REFRESH"
+ :param ingestion_type: Type of ingestion:
"INCREMENTAL_REFRESH"|"FULL_REFRESH"
:param wait_for_completion: if the program should keep running until
job finishes
:param check_interval: the time interval in seconds which the operator
will check the status of QuickSight Ingestion
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use
associated AWS Account ID.
:return: Returns descriptive information about the created data
ingestion
having Ingestion ARN, HTTP status, ingestion ID and ingestion
status.
"""
+ aws_account_id = aws_account_id or self.account_id
self.log.info("Creating QuickSight Ingestion for data set id %s.",
data_set_id)
- quicksight_client = self.get_conn()
try:
- aws_account_id = self.sts_hook.get_account_number()
- create_ingestion_response = quicksight_client.create_ingestion(
+ create_ingestion_response = self.conn.create_ingestion(
DataSetId=data_set_id,
IngestionId=ingestion_id,
IngestionType=ingestion_type,
@@ -97,20 +94,21 @@ class QuickSightHook(AwsBaseHook):
self.log.error("Failed to run Amazon QuickSight create_ingestion
API, error: %s", general_error)
raise
- def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id:
str) -> str:
+ def get_status(self, aws_account_id: str | None, data_set_id: str,
ingestion_id: str) -> str:
"""
Get the current status of QuickSight Create Ingestion API.
.. seealso::
- :external+boto3:py:meth:`QuickSight.Client.describe_ingestion`
- :param aws_account_id: An AWS Account ID
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use
associated AWS Account ID.
:param data_set_id: QuickSight Data Set ID
:param ingestion_id: QuickSight Ingestion ID
:return: An QuickSight Ingestion Status
"""
+ aws_account_id = aws_account_id or self.account_id
try:
- describe_ingestion_response = self.get_conn().describe_ingestion(
+ describe_ingestion_response = self.conn.describe_ingestion(
AwsAccountId=aws_account_id, DataSetId=data_set_id,
IngestionId=ingestion_id
)
return describe_ingestion_response["Ingestion"]["IngestionStatus"]
@@ -119,17 +117,19 @@ class QuickSightHook(AwsBaseHook):
except ClientError as e:
raise AirflowException(f"AWS request failed: {e}")
- def get_error_info(self, aws_account_id: str, data_set_id: str,
ingestion_id: str) -> dict | None:
+ def get_error_info(self, aws_account_id: str | None, data_set_id: str,
ingestion_id: str) -> dict | None:
"""
Get info about the error if any.
- :param aws_account_id: An AWS Account ID
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use
associated AWS Account ID.
:param data_set_id: QuickSight Data Set ID
:param ingestion_id: QuickSight Ingestion ID
:return: Error info dict containing the error type (key 'Type') and
message (key 'Message')
if available. Else, returns None.
"""
- describe_ingestion_response = self.get_conn().describe_ingestion(
+ aws_account_id = aws_account_id or self.account_id
+
+ describe_ingestion_response = self.conn.describe_ingestion(
AwsAccountId=aws_account_id, DataSetId=data_set_id,
IngestionId=ingestion_id
)
# using .get() to get None if the key is not present, instead of an
exception.
@@ -137,7 +137,7 @@ class QuickSightHook(AwsBaseHook):
def wait_for_state(
self,
- aws_account_id: str,
+ aws_account_id: str | None,
data_set_id: str,
ingestion_id: str,
target_state: set,
@@ -146,7 +146,7 @@ class QuickSightHook(AwsBaseHook):
"""
Check status of a QuickSight Create Ingestion API.
- :param aws_account_id: An AWS Account ID
+ :param aws_account_id: An AWS Account ID, if set to ``None`` then use
associated AWS Account ID.
:param data_set_id: QuickSight Data Set ID
:param ingestion_id: QuickSight Ingestion ID
:param target_state: Describes the QuickSight Job's Target State
@@ -154,6 +154,8 @@ class QuickSightHook(AwsBaseHook):
will check the status of QuickSight Ingestion
:return: response of describe_ingestion call after Ingestion is done
"""
+ aws_account_id = aws_account_id or self.account_id
+
while True:
status = self.get_status(aws_account_id, data_set_id, ingestion_id)
self.log.info("Current status is %s", status)
@@ -168,3 +170,16 @@ class QuickSightHook(AwsBaseHook):
self.log.info("QuickSight Ingestion completed")
return status
+
+ @cached_property
+ def sts_hook(self):
+ warnings.warn(
+ f"`{type(self).__name__}.sts_hook` property is deprecated and will
be removed in the future. "
+ "This property used for obtain AWS Account ID, "
+ f"please consider to use `{type(self).__name__}.account_id`
instead",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ from airflow.providers.amazon.aws.hooks.sts import StsHook
+
+ return StsHook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/operators/quicksight.py
b/airflow/providers/amazon/aws/operators/quicksight.py
index 4268374117..9555e0d63a 100644
--- a/airflow/providers/amazon/aws/operators/quicksight.py
+++ b/airflow/providers/amazon/aws/operators/quicksight.py
@@ -18,16 +18,15 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
-from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
if TYPE_CHECKING:
from airflow.utils.context import Context
-DEFAULT_CONN_ID = "aws_default"
-
-class QuickSightCreateIngestionOperator(BaseOperator):
+class QuickSightCreateIngestionOperator(AwsBaseOperator[QuickSightHook]):
"""
Creates and starts a new SPICE ingestion for a dataset; also helps to
Refresh existing SPICE datasets.
@@ -43,23 +42,25 @@ class QuickSightCreateIngestionOperator(BaseOperator):
that the operation waits to check the status of the Amazon QuickSight
Ingestion.
:param check_interval: if wait is set to be true, this is the time interval
in seconds which the operator will check the status of the Amazon
QuickSight Ingestion
- :param aws_conn_id: The Airflow connection used for AWS credentials.
(templated)
- If this is None or empty then the default boto3 behaviour is used. If
- running Airflow in a distributed manner and aws_conn_id is None or
- empty, then the default boto3 configuration would be used (and must be
- maintained on each worker node).
- :param region: Which AWS region the connection should use. (templated)
- If this is None or empty then the default boto3 behaviour is used.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- template_fields: Sequence[str] = (
+ aws_hook_class = QuickSightHook
+ template_fields: Sequence[str] = aws_template_fields(
"data_set_id",
"ingestion_id",
"ingestion_type",
"wait_for_completion",
"check_interval",
- "aws_conn_id",
- "region",
)
ui_color = "#ffd700"
@@ -70,26 +71,18 @@ class QuickSightCreateIngestionOperator(BaseOperator):
ingestion_type: str = "FULL_REFRESH",
wait_for_completion: bool = True,
check_interval: int = 30,
- aws_conn_id: str = DEFAULT_CONN_ID,
- region: str | None = None,
**kwargs,
):
+ super().__init__(**kwargs)
self.data_set_id = data_set_id
self.ingestion_id = ingestion_id
self.ingestion_type = ingestion_type
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
- self.aws_conn_id = aws_conn_id
- self.region = region
- super().__init__(**kwargs)
def execute(self, context: Context):
- hook = QuickSightHook(
- aws_conn_id=self.aws_conn_id,
- region_name=self.region,
- )
self.log.info("Running the Amazon QuickSight SPICE Ingestion on
Dataset ID: %s", self.data_set_id)
- return hook.create_ingestion(
+ return self.hook.create_ingestion(
data_set_id=self.data_set_id,
ingestion_id=self.ingestion_id,
ingestion_type=self.ingestion_type,
diff --git a/airflow/providers/amazon/aws/sensors/quicksight.py
b/airflow/providers/amazon/aws/sensors/quicksight.py
index fc90ecbe45..ebd8310fe4 100644
--- a/airflow/providers/amazon/aws/sensors/quicksight.py
+++ b/airflow/providers/amazon/aws/sensors/quicksight.py
@@ -17,19 +17,19 @@
# under the License.
from __future__ import annotations
+import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Sequence
-from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
-from airflow.providers.amazon.aws.hooks.sts import StsHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
if TYPE_CHECKING:
from airflow.utils.context import Context
-class QuickSightSensor(BaseSensorOperator):
+class QuickSightSensor(AwsBaseSensor[QuickSightHook]):
"""
Watches for the status of an Amazon QuickSight Ingestion.
@@ -39,27 +39,25 @@ class QuickSightSensor(BaseSensorOperator):
:param data_set_id: ID of the dataset used in the ingestion.
:param ingestion_id: ID for the ingestion.
- :param aws_conn_id: The Airflow connection used for AWS credentials.
(templated)
- If this is None or empty then the default boto3 behaviour is used. If
- running Airflow in a distributed manner and aws_conn_id is None or
- empty, then the default boto3 configuration would be used (and must be
- maintained on each worker node).
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param verify: Whether or not to verify SSL certificates. See:
+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
+ aws_hook_class = QuickSightHook
template_fields: Sequence[str] = ("data_set_id", "ingestion_id",
"aws_conn_id")
- def __init__(
- self,
- *,
- data_set_id: str,
- ingestion_id: str,
- aws_conn_id: str = "aws_default",
- **kwargs,
- ) -> None:
+ def __init__(self, *, data_set_id: str, ingestion_id: str, **kwargs):
super().__init__(**kwargs)
self.data_set_id = data_set_id
self.ingestion_id = ingestion_id
- self.aws_conn_id = aws_conn_id
self.success_status = "COMPLETED"
self.errored_statuses = ("FAILED", "CANCELLED")
@@ -71,13 +69,10 @@ class QuickSightSensor(BaseSensorOperator):
:return: True if it COMPLETED and False if not.
"""
self.log.info("Poking for Amazon QuickSight Ingestion ID: %s",
self.ingestion_id)
- aws_account_id = self.sts_hook.get_account_number()
- quicksight_ingestion_state = self.quicksight_hook.get_status(
- aws_account_id, self.data_set_id, self.ingestion_id
- )
+ quicksight_ingestion_state = self.hook.get_status(None,
self.data_set_id, self.ingestion_id)
self.log.info("QuickSight Status: %s", quicksight_ingestion_state)
if quicksight_ingestion_state in self.errored_statuses:
- error = self.quicksight_hook.get_error_info(aws_account_id,
self.data_set_id, self.ingestion_id)
+ error = self.hook.get_error_info(None, self.data_set_id,
self.ingestion_id)
message = f"The QuickSight Ingestion failed. Error info: {error}"
if self.soft_fail:
raise AirflowSkipException(message)
@@ -86,8 +81,23 @@ class QuickSightSensor(BaseSensorOperator):
@cached_property
def quicksight_hook(self):
- return QuickSightHook(aws_conn_id=self.aws_conn_id)
+ warnings.warn(
+ f"`{type(self).__name__}.quicksight_hook` property is deprecated, "
+ f"please use `{type(self).__name__}.hook` property instead.",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.hook
@cached_property
def sts_hook(self):
+ warnings.warn(
+ f"`{type(self).__name__}.sts_hook` property is deprecated and will
be removed in the future. "
+ "This property used for obtain AWS Account ID, "
+ f"please consider to use `{type(self).__name__}.hook.account_id`
instead",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ from airflow.providers.amazon.aws.hooks.sts import StsHook
+
return StsHook(aws_conn_id=self.aws_conn_id)
diff --git a/docs/apache-airflow-providers-amazon/operators/quicksight.rst
b/docs/apache-airflow-providers-amazon/operators/quicksight.rst
index cbca98d7d5..9cc0abe337 100644
--- a/docs/apache-airflow-providers-amazon/operators/quicksight.rst
+++ b/docs/apache-airflow-providers-amazon/operators/quicksight.rst
@@ -30,6 +30,11 @@ Prerequisite Tasks
.. include:: ../_partials/prerequisite_tasks.rst
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
Operators
---------
diff --git a/tests/providers/amazon/aws/hooks/test_base_aws.py
b/tests/providers/amazon/aws/hooks/test_base_aws.py
index ba94048421..c87aaa98fd 100644
--- a/tests/providers/amazon/aws/hooks/test_base_aws.py
+++ b/tests/providers/amazon/aws/hooks/test_base_aws.py
@@ -1031,6 +1031,10 @@ class TestAwsBaseHook:
assert mock_mask_secret.mock_calls == expected_calls
assert credentials == expected_credentials
+ @mock_sts
+ def test_account_id(self):
+ assert AwsBaseHook(aws_conn_id=None).account_id == DEFAULT_ACCOUNT_ID
+
class ThrowErrorUntilCount:
"""Holds counter state for invoking a method several times in a row."""
diff --git a/tests/providers/amazon/aws/hooks/test_quicksight.py
b/tests/providers/amazon/aws/hooks/test_quicksight.py
index 9c8ef16ce8..6a7795843b 100644
--- a/tests/providers/amazon/aws/hooks/test_quicksight.py
+++ b/tests/providers/amazon/aws/hooks/test_quicksight.py
@@ -22,19 +22,15 @@ from unittest import mock
import pytest
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
-from airflow.providers.amazon.aws.hooks.sts import StsHook
-
-AWS_ACCOUNT_ID = "123456789012"
+DEFAULT_AWS_ACCOUNT_ID = "123456789012"
MOCK_DATA = {
"DataSetId": "DemoDataSet",
"IngestionId": "DemoDataSet_Ingestion",
"IngestionType": "INCREMENTAL_REFRESH",
- "AwsAccountId": AWS_ACCOUNT_ID,
}
-
MOCK_CREATE_INGESTION_RESPONSE = {
"Status": 201,
"Arn":
"arn:aws:quicksight:us-east-1:123456789012:dataset/DemoDataSet/ingestion/DemoDataSet3_Ingestion",
@@ -42,7 +38,6 @@ MOCK_CREATE_INGESTION_RESPONSE = {
"IngestionStatus": "INITIALIZED",
"RequestId": "fc1f7eea-1327-41d6-9af7-c12f097ed343",
}
-
MOCK_DESCRIBE_INGESTION_SUCCESS = {
"Status": 200,
"Ingestion": {
@@ -59,7 +54,6 @@ MOCK_DESCRIBE_INGESTION_SUCCESS = {
},
"RequestId": "DemoDataSet_Ingestion_Request_ID",
}
-
MOCK_DESCRIBE_INGESTION_FAILURE = {
"Status": 403,
"Ingestion": {
@@ -76,6 +70,23 @@ MOCK_DESCRIBE_INGESTION_FAILURE = {
},
"RequestId": "DemoDataSet_Ingestion_Request_ID",
}
+ACCOUNT_TEST_CASES = [
+ pytest.param(None, DEFAULT_AWS_ACCOUNT_ID, id="default-account-id"),
+ pytest.param("777777777777", "777777777777", id="custom-account-id"),
+]
+
+
[email protected]
+def mocked_account_id():
+ with mock.patch.object(QuickSightHook, "account_id",
new_callable=mock.PropertyMock) as m:
+ m.return_value = DEFAULT_AWS_ACCOUNT_ID
+ yield m
+
+
[email protected]
+def mocked_client():
+ with mock.patch.object(QuickSightHook, "conn") as m:
+ yield m
class TestQuicksight:
@@ -83,70 +94,174 @@ class TestQuicksight:
hook = QuickSightHook(aws_conn_id="aws_default",
region_name="us-east-1")
assert hook.conn is not None
- @mock.patch.object(QuickSightHook, "get_conn")
- @mock.patch.object(StsHook, "get_conn")
- @mock.patch.object(StsHook, "get_account_number")
- def test_create_ingestion(self, mock_get_account_number, sts_conn,
mock_conn):
- mock_conn.return_value.create_ingestion.return_value =
MOCK_CREATE_INGESTION_RESPONSE
- mock_get_account_number.return_value = AWS_ACCOUNT_ID
- quicksight_hook = QuickSightHook(aws_conn_id="aws_default",
region_name="us-east-1")
- result = quicksight_hook.create_ingestion(
- data_set_id="DemoDataSet",
- ingestion_id="DemoDataSet_Ingestion",
- ingestion_type="INCREMENTAL_REFRESH",
+ @pytest.mark.parametrize(
+ "response, expected_status",
+ [
+ pytest.param(MOCK_DESCRIBE_INGESTION_SUCCESS, "COMPLETED",
id="completed"),
+ pytest.param(MOCK_DESCRIBE_INGESTION_FAILURE, "Failed",
id="failed"),
+ ],
+ )
+ @pytest.mark.parametrize("aws_account_id, expected_account_id",
ACCOUNT_TEST_CASES)
+ def test_get_job_status(
+ self, response, expected_status, aws_account_id, expected_account_id,
mocked_account_id, mocked_client
+ ):
+ """Test get job status."""
+ mocked_client.describe_ingestion.return_value = response
+
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ assert (
+ hook.get_status(
+ data_set_id="DemoDataSet",
+ ingestion_id="DemoDataSet_Ingestion",
+ aws_account_id=aws_account_id,
+ )
+ == expected_status
+ )
+ mocked_client.describe_ingestion.assert_called_with(
+ AwsAccountId=expected_account_id,
+ DataSetId="DemoDataSet",
+ IngestionId="DemoDataSet_Ingestion",
+ )
+
+ @pytest.mark.parametrize(
+ "exception, error_match",
+ [
+ pytest.param(KeyError("Foo"), "Could not get status",
id="key-error"),
+ pytest.param(
+ ClientError(error_response={}, operation_name="fake"),
+ "AWS request failed",
+ id="botocore-client",
+ ),
+ ],
+ )
+ def test_get_job_status_exception(self, exception, error_match,
mocked_client, mocked_account_id):
+ mocked_client.describe_ingestion.side_effect = exception
+
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ with pytest.raises(AirflowException, match=error_match):
+ assert hook.get_status(
+ data_set_id="DemoDataSet",
+ ingestion_id="DemoDataSet_Ingestion",
+ aws_account_id=None,
+ )
+
+ @pytest.mark.parametrize(
+ "error_info",
+ [
+ pytest.param({"foo": "bar"}, id="error-info-exists"),
+ pytest.param(None, id="error-info-not-exists"),
+ ],
+ )
+ @pytest.mark.parametrize("aws_account_id, expected_account_id",
ACCOUNT_TEST_CASES)
+ def test_get_error_info(
+ self, error_info, aws_account_id, expected_account_id, mocked_client,
mocked_account_id
+ ):
+ mocked_response = {"Ingestion": {}}
+ if error_info:
+ mocked_response["Ingestion"]["ErrorInfo"] = error_info
+ mocked_client.describe_ingestion.return_value = mocked_response
+
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ assert (
+ hook.get_error_info(
+ data_set_id="DemoDataSet",
ingestion_id="DemoDataSet_Ingestion", aws_account_id=None
+ )
+ == error_info
)
- expected_call_params = MOCK_DATA
-
mock_conn.return_value.create_ingestion.assert_called_with(**expected_call_params)
- assert result == MOCK_CREATE_INGESTION_RESPONSE
- @mock.patch.object(QuickSightHook, "get_conn")
+ @mock.patch.object(QuickSightHook, "get_status", return_value="FAILED")
+ @mock.patch.object(QuickSightHook, "get_error_info")
+ @pytest.mark.parametrize("aws_account_id, expected_account_id",
ACCOUNT_TEST_CASES)
+ def test_wait_for_state_failure(
+ self,
+ mocked_get_error_info,
+ mocked_get_status,
+ aws_account_id,
+ expected_account_id,
+ mocked_client,
+ mocked_account_id,
+ ):
+ mocked_get_error_info.return_value = "Something Bad Happen"
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ with pytest.raises(AirflowException, match="Error info: Something Bad
Happen"):
+ hook.wait_for_state(
+ aws_account_id, "data_set_id", "ingestion_id",
target_state={"COMPLETED"}, check_interval=0
+ )
+ mocked_get_status.assert_called_with(expected_account_id,
"data_set_id", "ingestion_id")
+ mocked_get_error_info.assert_called_with(expected_account_id,
"data_set_id", "ingestion_id")
+
+ @mock.patch.object(QuickSightHook, "get_status", return_value="CANCELLED")
+ def test_wait_for_state_canceled(self, _):
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ with pytest.raises(AirflowException, match="The Amazon QuickSight
SPICE ingestion cancelled"):
+ hook.wait_for_state(
+ "aws_account_id", "data_set_id", "ingestion_id",
target_state={"COMPLETED"}, check_interval=0
+ )
+
@mock.patch.object(QuickSightHook, "get_status")
- def test_fast_failing_ingestion(self, mock_get_status, mock_conn):
- quicksight_hook = QuickSightHook(aws_conn_id="aws_default",
region_name="us-east-1")
- mock_get_status.return_value = "FAILED"
- with pytest.raises(AirflowException):
- quicksight_hook.wait_for_state(
- "account_id", "data_set_id", "ingestion_id",
target_state={"COMPLETED"}, check_interval=1
+ def test_wait_for_state_completed(self, mocked_get_status):
+ mocked_get_status.side_effect = ["INITIALIZED", "QUEUED", "RUNNING",
"COMPLETED"]
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ assert (
+ hook.wait_for_state(
+ "aws_account_id", "data_set_id", "ingestion_id",
target_state={"COMPLETED"}, check_interval=0
)
+ == "COMPLETED"
+ )
+ assert mocked_get_status.call_count == 4
+
+ @pytest.mark.parametrize(
+ "wait_for_completion", [pytest.param(True, id="wait"),
pytest.param(False, id="no-wait")]
+ )
+ @pytest.mark.parametrize("aws_account_id, expected_account_id",
ACCOUNT_TEST_CASES)
+ def test_create_ingestion(
+ self, wait_for_completion, aws_account_id, expected_account_id,
mocked_account_id, mocked_client
+ ):
+ mocked_client.create_ingestion.return_value =
MOCK_CREATE_INGESTION_RESPONSE
+
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ with mock.patch.object(QuickSightHook, "wait_for_state") as
mocked_wait_for_state:
+ assert (
+ hook.create_ingestion(
+ data_set_id="DemoDataSet",
+ ingestion_id="DemoDataSet_Ingestion",
+ ingestion_type="INCREMENTAL_REFRESH",
+ aws_account_id=aws_account_id,
+ wait_for_completion=wait_for_completion,
+ check_interval=0,
+ )
+ == MOCK_CREATE_INGESTION_RESPONSE
+ )
+ if wait_for_completion:
+ mocked_wait_for_state.assert_called_once_with(
+ aws_account_id=expected_account_id,
+ data_set_id="DemoDataSet",
+ ingestion_id="DemoDataSet_Ingestion",
+ target_state={"COMPLETED"},
+ check_interval=0,
+ )
+ else:
+ mocked_wait_for_state.assert_not_called()
- @mock.patch.object(StsHook, "get_conn")
- @mock.patch.object(StsHook, "get_account_number")
- def test_create_ingestion_exception(self, mock_get_account_number,
sts_conn):
- mock_get_account_number.return_value = AWS_ACCOUNT_ID
- hook = QuickSightHook(aws_conn_id="aws_default")
- with pytest.raises(ClientError) as raised_exception:
+
mocked_client.create_ingestion.assert_called_with(AwsAccountId=expected_account_id,
**MOCK_DATA)
+
+ def test_create_ingestion_exception(self, mocked_account_id,
mocked_client, caplog):
+ mocked_client.create_ingestion.side_effect = ValueError("Fake Error")
+ hook = QuickSightHook(aws_conn_id=None)
+ with pytest.raises(ValueError, match="Fake Error"):
hook.create_ingestion(
data_set_id="DemoDataSet",
ingestion_id="DemoDataSet_Ingestion",
ingestion_type="INCREMENTAL_REFRESH",
)
- ex = raised_exception.value
- assert ex.operation_name == "CreateIngestion"
-
- @mock.patch.object(QuickSightHook, "get_conn")
- def test_get_job_status(self, mock_conn):
- """
- Test get job status
- """
- mock_conn.return_value.describe_ingestion.return_value =
MOCK_DESCRIBE_INGESTION_SUCCESS
- quicksight_hook = QuickSightHook(aws_conn_id="aws_default",
region_name="us-east-1")
- result = quicksight_hook.get_status(
- data_set_id="DemoDataSet",
- ingestion_id="DemoDataSet_Ingestion",
- aws_account_id="123456789012",
- )
- assert result == "COMPLETED"
-
- @mock.patch.object(QuickSightHook, "get_conn")
- def test_get_job_status_failed(self, mock_conn):
- """
- Test get job status
- """
- mock_conn.return_value.describe_ingestion.return_value =
MOCK_DESCRIBE_INGESTION_FAILURE
- quicksight_hook = QuickSightHook(aws_conn_id="aws_default",
region_name="us-east-1")
- result = quicksight_hook.get_status(
- data_set_id="DemoDataSet",
- ingestion_id="DemoDataSet_Ingestion",
- aws_account_id="123456789012",
- )
- assert result == "Failed"
+ assert "create_ingestion API, error: Fake Error" in caplog.text
+
+ def test_deprecated_properties(self):
+ hook = QuickSightHook(aws_conn_id=None, region_name="us-east-1")
+ with mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook") as
mocked_class, pytest.warns(
+ AirflowProviderDeprecationWarning, match="consider to use
`.*account_id` instead"
+ ):
+ mocked_sts_hook = mock.MagicMock(name="FakeStsHook")
+ mocked_class.return_value = mocked_sts_hook
+ assert hook.sts_hook is mocked_sts_hook
+ mocked_class.assert_called_once_with(aws_conn_id=None)
diff --git a/tests/providers/amazon/aws/operators/test_quicksight.py
b/tests/providers/amazon/aws/operators/test_quicksight.py
index 2b7b0dc35f..fd30426293 100644
--- a/tests/providers/amazon/aws/operators/test_quicksight.py
+++ b/tests/providers/amazon/aws/operators/test_quicksight.py
@@ -38,17 +38,41 @@ MOCK_RESPONSE = {
class TestQuickSightCreateIngestionOperator:
def setup_method(self):
- self.quicksight = QuickSightCreateIngestionOperator(
- task_id="test_quicksight_operator",
- data_set_id=DATA_SET_ID,
- ingestion_id=INGESTION_ID,
+ self.default_op_kwargs = {
+ "task_id": "quicksight_create",
+ "aws_conn_id": None,
+ "data_set_id": DATA_SET_ID,
+ "ingestion_id": INGESTION_ID,
+ }
+
+ def test_init(self):
+ self.default_op_kwargs.pop("aws_conn_id", None)
+
+ op = QuickSightCreateIngestionOperator(
+ **self.default_op_kwargs,
+ # Generic hooks parameters
+ aws_conn_id="fake-conn-id",
+ region_name="cn-north-1",
+ verify=False,
+ botocore_config={"read_timeout": 42},
)
+ assert op.hook.client_type == "quicksight"
+ assert op.hook.resource_type is None
+ assert op.hook.aws_conn_id == "fake-conn-id"
+ assert op.hook._region_name == "cn-north-1"
+ assert op.hook._verify is False
+ assert op.hook._config is not None
+ assert op.hook._config.read_timeout == 42
+
+ op = QuickSightCreateIngestionOperator(**self.default_op_kwargs)
+ assert op.hook.aws_conn_id == "aws_default"
+ assert op.hook._region_name is None
+ assert op.hook._verify is None
+ assert op.hook._config is None
- @mock.patch.object(QuickSightHook, "get_conn")
- @mock.patch.object(QuickSightHook, "create_ingestion")
- def test_execute(self, mock_create_ingestion, mock_client):
- mock_create_ingestion.return_value = MOCK_RESPONSE
- self.quicksight.execute(None)
+ @mock.patch.object(QuickSightHook, "create_ingestion",
return_value=MOCK_RESPONSE)
+ def test_execute(self, mock_create_ingestion):
+ QuickSightCreateIngestionOperator(**self.default_op_kwargs).execute({})
mock_create_ingestion.assert_called_once_with(
data_set_id=DATA_SET_ID,
ingestion_id=INGESTION_ID,
diff --git a/tests/providers/amazon/aws/sensors/test_quicksight.py
b/tests/providers/amazon/aws/sensors/test_quicksight.py
index ba3ce83789..bef78d072d 100644
--- a/tests/providers/amazon/aws/sensors/test_quicksight.py
+++ b/tests/providers/amazon/aws/sensors/test_quicksight.py
@@ -20,10 +20,8 @@ from __future__ import annotations
from unittest import mock
import pytest
-from moto import mock_sts
-from moto.core import DEFAULT_ACCOUNT_ID
-from airflow.exceptions import AirflowException, AirflowSkipException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning, AirflowSkipException
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
from airflow.providers.amazon.aws.sensors.quicksight import QuickSightSensor
@@ -31,58 +29,91 @@ DATA_SET_ID = "DemoDataSet"
INGESTION_ID = "DemoDataSet_Ingestion"
[email protected]
+def mocked_get_status():
+ with mock.patch.object(QuickSightHook, "get_status") as m:
+ yield m
+
+
[email protected]
+def mocked_get_error_info():
+ with mock.patch.object(QuickSightHook, "get_error_info") as m:
+ yield m
+
+
class TestQuickSightSensor:
def setup_method(self):
- self.sensor = QuickSightSensor(
- task_id="test_quicksight_sensor",
- aws_conn_id="aws_default",
- data_set_id="DemoDataSet",
- ingestion_id="DemoDataSet_Ingestion",
+ self.default_op_kwargs = {
+ "task_id": "quicksight_sensor",
+ "aws_conn_id": None,
+ "data_set_id": DATA_SET_ID,
+ "ingestion_id": INGESTION_ID,
+ }
+
+ def test_init(self):
+ self.default_op_kwargs.pop("aws_conn_id", None)
+
+ sensor = QuickSightSensor(
+ **self.default_op_kwargs,
+ # Generic hooks parameters
+ aws_conn_id="fake-conn-id",
+ region_name="ca-west-1",
+ verify=True,
+ botocore_config={"read_timeout": 42},
)
+ assert sensor.hook.client_type == "quicksight"
+ assert sensor.hook.resource_type is None
+ assert sensor.hook.aws_conn_id == "fake-conn-id"
+ assert sensor.hook._region_name == "ca-west-1"
+ assert sensor.hook._verify is True
+ assert sensor.hook._config is not None
+ assert sensor.hook._config.read_timeout == 42
+
+ sensor = QuickSightSensor(**self.default_op_kwargs)
+ assert sensor.hook.aws_conn_id == "aws_default"
+ assert sensor.hook._region_name is None
+ assert sensor.hook._verify is None
+ assert sensor.hook._config is None
+
+ @pytest.mark.parametrize("status", ["COMPLETED"])
+ def test_poke_completed(self, status, mocked_get_status):
+ mocked_get_status.return_value = status
+ assert QuickSightSensor(**self.default_op_kwargs).poke({}) is True
+ mocked_get_status.assert_called_once_with(None, DATA_SET_ID,
INGESTION_ID)
- @mock_sts
- @mock.patch.object(QuickSightHook, "get_status")
- def test_poke_success(self, mock_get_status):
- mock_get_status.return_value = "COMPLETED"
- assert self.sensor.poke({}) is True
- mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID,
DATA_SET_ID, INGESTION_ID)
-
- @mock_sts
- @mock.patch.object(QuickSightHook, "get_status")
- @mock.patch.object(QuickSightHook, "get_error_info")
- def test_poke_cancelled(self, _, mock_get_status):
- mock_get_status.return_value = "CANCELLED"
- with pytest.raises(AirflowException):
- self.sensor.poke({})
- mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID,
DATA_SET_ID, INGESTION_ID)
-
- @mock_sts
- @mock.patch.object(QuickSightHook, "get_status")
- @mock.patch.object(QuickSightHook, "get_error_info")
- def test_poke_failed(self, _, mock_get_status):
- mock_get_status.return_value = "FAILED"
- with pytest.raises(AirflowException):
- self.sensor.poke({})
- mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID,
DATA_SET_ID, INGESTION_ID)
-
- @mock_sts
- @mock.patch.object(QuickSightHook, "get_status")
- def test_poke_initialized(self, mock_get_status):
- mock_get_status.return_value = "INITIALIZED"
- assert self.sensor.poke({}) is False
- mock_get_status.assert_called_once_with(DEFAULT_ACCOUNT_ID,
DATA_SET_ID, INGESTION_ID)
+ @pytest.mark.parametrize("status", ["INITIALIZED"])
+ def test_poke_not_completed(self, status, mocked_get_status):
+ mocked_get_status.return_value = status
+ assert QuickSightSensor(**self.default_op_kwargs).poke({}) is False
+ mocked_get_status.assert_called_once_with(None, DATA_SET_ID,
INGESTION_ID)
+ @pytest.mark.parametrize("status", ["FAILED", "CANCELLED"])
@pytest.mark.parametrize(
- "soft_fail, expected_exception", ((False, AirflowException), (True,
AirflowSkipException))
+ "soft_fail, expected_exception",
+ [
+ pytest.param(True, AirflowSkipException, id="soft-fail"),
+ pytest.param(False, AirflowException, id="non-soft-fail"),
+ ],
)
-
@mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook.get_account_number")
-
@mock.patch("airflow.providers.amazon.aws.hooks.quicksight.QuickSightHook.get_status")
-
@mock.patch("airflow.providers.amazon.aws.hooks.quicksight.QuickSightHook.get_error_info")
- def test_fail_poke(self, get_error_info, get_status, _, soft_fail,
expected_exception):
- self.sensor.soft_fail = soft_fail
- error = "expected error"
- message = f"The QuickSight Ingestion failed. Error info: {error}"
- with pytest.raises(expected_exception, match=message):
- get_status.return_value = "FAILED"
- get_error_info.return_value = message
- self.sensor.poke(context={})
+ def test_poke_terminated_status(
+ self, status, soft_fail, expected_exception, mocked_get_status,
mocked_get_error_info
+ ):
+ mocked_get_status.return_value = status
+ mocked_get_error_info.return_value = "something bad happen"
+ with pytest.raises(expected_exception, match="Error info: something
bad happen"):
+ QuickSightSensor(**self.default_op_kwargs,
soft_fail=soft_fail).poke({})
+ mocked_get_status.assert_called_once_with(None, DATA_SET_ID,
INGESTION_ID)
+ mocked_get_error_info.assert_called_once_with(None, DATA_SET_ID,
INGESTION_ID)
+
+ def test_deprecated_properties(self):
+ sensor = QuickSightSensor(**self.default_op_kwargs)
+ with pytest.warns(AirflowProviderDeprecationWarning, match="please use
`.*hook` property instead"):
+ assert sensor.quicksight_hook is sensor.hook
+
+ with mock.patch("airflow.providers.amazon.aws.hooks.sts.StsHook") as
mocked_class, pytest.warns(
+ AirflowProviderDeprecationWarning, match="consider to use
`.*hook\.account_id` instead"
+ ):
+ mocked_sts_hook = mock.MagicMock(name="FakeStsHook")
+ mocked_class.return_value = mocked_sts_hook
+ assert sensor.sts_hook is mocked_sts_hook
+ mocked_class.assert_called_once_with(aws_conn_id=None)