This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 677e53436ce Fix AwsBaseWaiterTrigger losing error details on deferred
task failure (#64085)
677e53436ce is described below
commit 677e53436cea919bd49a593547c214b77cb5a55b
Author: Shivam Rastogi <[email protected]>
AuthorDate: Fri Mar 27 03:28:54 2026 -0700
Fix AwsBaseWaiterTrigger losing error details on deferred task failure
(#64085)
When a deferred AWS task hits a terminal failure state, async_wait()
raises AirflowException with the error details. But
AwsBaseWaiterTrigger.run()
did not catch it — the exception propagated to the triggerer framework which
replaced it with a generic "Trigger failure" message. execute_complete() was
never called, so operators and on_failure_callbacks lost all error context.
---
.../airflow/providers/amazon/aws/operators/dms.py | 36 ++++++++++++---
.../airflow/providers/amazon/aws/operators/emr.py | 35 ++++++++-------
.../providers/amazon/aws/operators/neptune.py | 27 ++++++------
.../airflow/providers/amazon/aws/sensors/mwaa.py | 11 ++++-
.../airflow/providers/amazon/aws/triggers/base.py | 25 ++++++-----
.../airflow/providers/amazon/aws/triggers/dms.py | 2 +-
.../airflow/providers/amazon/aws/triggers/glue.py | 2 +-
.../tests/unit/amazon/aws/operators/test_dms.py | 51 ++++++++++++++++++++++
.../amazon/aws/operators/test_emr_serverless.py | 20 +++++++++
.../tests/unit/amazon/aws/sensors/test_mwaa.py | 22 ++++++++++
.../tests/unit/amazon/aws/triggers/test_base.py | 19 ++++++++
.../tests/unit/amazon/aws/triggers/test_glue.py | 7 +--
.../tests/unit/amazon/aws/triggers/test_neptune.py | 6 +--
13 files changed, 208 insertions(+), 55 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
index 49289fead60..aebccf64483 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/dms.py
@@ -30,6 +30,7 @@ from airflow.providers.amazon.aws.triggers.dms import (
DmsReplicationStoppedTrigger,
DmsReplicationTerminalStatusTrigger,
)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.common.compat.sdk import AirflowException, Context, conf
@@ -510,11 +511,21 @@ class
DmsDeleteReplicationConfigOperator(AwsBaseOperator[DmsHook]):
self.log.info("DMS replication config(%s) deleted.",
self.replication_config_arn)
def execute_complete(self, context, event=None):
- self.replication_config_arn = event.get("replication_config_arn")
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error deleting DMS replication config:
{validated_event}")
+
+ self.replication_config_arn =
validated_event.get("replication_config_arn")
self.log.info("DMS replication config(%s) deleted.",
self.replication_config_arn)
def retry_execution(self, context, event=None):
- self.replication_config_arn = event.get("replication_config_arn")
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error waiting for DMS replication config:
{validated_event}")
+
+ self.replication_config_arn =
validated_event.get("replication_config_arn")
self.log.info("Retrying replication config(%s) deletion.",
self.replication_config_arn)
self.execute(context)
@@ -703,11 +714,21 @@ class
DmsStartReplicationOperator(AwsBaseOperator[DmsHook]):
self.log.info("Status: %s Provision status: %s", current_status,
provision_status)
def execute_complete(self, context, event=None):
- self.replication_config_arn = event.get("replication_config_arn")
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error in DMS replication:
{validated_event}")
+
+ self.replication_config_arn =
validated_event.get("replication_config_arn")
self.log.info("Replication(%s) has completed.",
self.replication_config_arn)
def retry_execution(self, context, event=None):
- self.replication_config_arn = event.get("replication_config_arn")
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error waiting for DMS replication:
{validated_event}")
+
+ self.replication_config_arn =
validated_event.get("replication_config_arn")
self.log.info("Retrying replication %s.", self.replication_config_arn)
self.execute(context)
@@ -794,5 +815,10 @@ class DmsStopReplicationOperator(AwsBaseOperator[DmsHook]):
)
def execute_complete(self, context, event=None):
- self.replication_config_arn = event.get("replication_config_arn")
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error stopping DMS replication:
{validated_event}")
+
+ self.replication_config_arn =
validated_event.get("replication_config_arn")
self.log.info("Replication(%s) has stopped.",
self.replication_config_arn)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
index 3c9d0c75d48..313cc649489 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py
@@ -1620,24 +1620,26 @@ class
EmrServerlessStopApplicationOperator(AwsBaseOperator[EmrServerlessHook]):
if event is None:
self.log.error("Trigger error: event is None")
raise AirflowException("Trigger error: event is None")
- if event["status"] == "success":
- self.hook.conn.stop_application(applicationId=self.application_id)
- self.defer(
- trigger=EmrServerlessStopApplicationTrigger(
- application_id=self.application_id,
- aws_conn_id=self.aws_conn_id,
- waiter_delay=self.waiter_delay,
- waiter_max_attempts=self.waiter_max_attempts,
- ),
- timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
- method_name="execute_complete",
- )
+ if event["status"] != "success":
+ raise AirflowException(f"Error cancelling EMR Serverless jobs:
{event}")
+ self.hook.conn.stop_application(applicationId=self.application_id)
+ self.defer(
+ trigger=EmrServerlessStopApplicationTrigger(
+ application_id=self.application_id,
+ aws_conn_id=self.aws_conn_id,
+ waiter_delay=self.waiter_delay,
+ waiter_max_attempts=self.waiter_max_attempts,
+ ),
+ timeout=timedelta(seconds=self.waiter_max_attempts *
self.waiter_delay),
+ method_name="execute_complete",
+ )
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
validated_event = validate_execute_complete_event(event)
- if validated_event["status"] == "success":
- self.log.info("EMR serverless application %s stopped
successfully", self.application_id)
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error stopping EMR Serverless
application: {validated_event}")
+ self.log.info("EMR serverless application %s stopped successfully",
self.application_id)
class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperator):
@@ -1743,5 +1745,6 @@ class
EmrServerlessDeleteApplicationOperator(EmrServerlessStopApplicationOperato
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
validated_event = validate_execute_complete_event(event)
- if validated_event["status"] == "success":
- self.log.info("EMR serverless application %s deleted
successfully", self.application_id)
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error deleting EMR Serverless
application: {validated_event}")
+ self.log.info("EMR serverless application %s deleted successfully",
self.application_id)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py
index a09c2e27c19..9e3916da5eb 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/neptune.py
@@ -29,6 +29,7 @@ from airflow.providers.amazon.aws.triggers.neptune import (
NeptuneClusterInstancesAvailableTrigger,
NeptuneClusterStoppedTrigger,
)
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.common.compat.sdk import AirflowException, conf
@@ -187,14 +188,13 @@ class
NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
return {"db_cluster_id": self.cluster_id}
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, str]:
- status = ""
- cluster_id = ""
+ validated_event = validate_execute_complete_event(event)
- if event:
- status = event.get("status", "")
- cluster_id = event.get("cluster_id", "")
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error starting Neptune cluster:
{validated_event}")
- self.log.info("Neptune cluster %s available with status: %s",
cluster_id, status)
+ cluster_id = validated_event.get("db_cluster_id", "")
+ self.log.info("Neptune cluster %s available with status: %s",
cluster_id, validated_event["status"])
return {"db_cluster_id": cluster_id}
@@ -314,13 +314,12 @@ class
NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
return {"db_cluster_id": self.cluster_id}
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, str]:
- status = ""
- cluster_id = ""
- self.log.info(event)
- if event:
- status = event.get("status", "")
- cluster_id = event.get("cluster_id", "")
-
- self.log.info("Neptune cluster %s stopped with status: %s",
cluster_id, status)
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error stopping Neptune cluster:
{validated_event}")
+
+ cluster_id = validated_event.get("db_cluster_id", "")
+ self.log.info("Neptune cluster %s stopped with status: %s",
cluster_id, validated_event["status"])
return {"db_cluster_id": cluster_id}
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
index ece83acd052..d879a5e7922 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py
@@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Literal
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.triggers.mwaa import
MwaaDagRunCompletedTrigger, MwaaTaskCompletedTrigger
+from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.common.compat.sdk import AirflowException, conf
from airflow.utils.state import DagRunState, TaskInstanceState
@@ -143,7 +144,10 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
return state in self.success_states
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
- return None
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error in MWAA DAG run: {validated_event}")
def execute(self, context: Context):
if self.deferrable:
@@ -281,7 +285,10 @@ class MwaaTaskSensor(AwsBaseSensor[MwaaHook]):
return state in self.success_states
def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> None:
- return None
+ validated_event = validate_execute_complete_event(event)
+
+ if validated_event["status"] != "success":
+ raise AirflowException(f"Error in MWAA task: {validated_event}")
def execute(self, context: Context):
if self.external_dag_run_id is None:
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
index 4b7ddfd4054..999b5e5bfd1 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py
@@ -21,6 +21,7 @@ from abc import abstractmethod
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING, Any
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.utils.helpers import prune_dict
@@ -149,13 +150,17 @@ class AwsBaseWaiterTrigger(BaseTrigger):
client=client,
config_overrides=self.waiter_config_overrides,
)
- await async_wait(
- waiter,
- self.waiter_delay,
- self.attempts,
- self.waiter_args,
- self.failure_message,
- self.status_message,
- self.status_queries,
- )
- yield TriggerEvent({"status": "success", self.return_key:
self.return_value})
+ try:
+ await async_wait(
+ waiter,
+ self.waiter_delay,
+ self.attempts,
+ self.waiter_args,
+ self.failure_message,
+ self.status_message,
+ self.status_queries,
+ )
+ except AirflowException as e:
+ yield TriggerEvent({"status": "error", "message": str(e),
self.return_key: self.return_value})
+ else:
+ yield TriggerEvent({"status": "success", self.return_key:
self.return_value})
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/dms.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/dms.py
index 99addc96970..fb729ab6102 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/dms.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/dms.py
@@ -129,7 +129,7 @@ class DmsReplicationCompleteTrigger(AwsBaseWaiterTrigger):
waiter_max_attempts=waiter_max_attempts,
failure_message="Replication failed to complete.",
status_message="Status replication is",
- status_queries=["Replications[0].Status"],
+ status_queries=["Replications[0].Status",
"Replications[0].FailureMessages"],
return_key="replication_config_arn",
return_value=replication_config_arn,
aws_conn_id=aws_conn_id,
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py
index 6314bce5288..f54f761825e 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/glue.py
@@ -65,7 +65,7 @@ class GlueJobCompleteTrigger(AwsBaseWaiterTrigger):
waiter_args={"JobName": job_name, "RunId": run_id},
failure_message="AWS Glue job failed.",
status_message="Status of AWS Glue job is",
- status_queries=["JobRun.JobRunState"],
+ status_queries=["JobRun.JobRunState", "JobRun.ErrorMessage"],
return_key="run_id",
return_value=run_id,
waiter_delay=waiter_delay,
diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
index 178a2043f98..f7e07d15ab5 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_dms.py
@@ -862,6 +862,24 @@ class TestDmsDeleteReplicationConfigOperator:
assert isinstance(defer.value.trigger,
DmsReplicationTerminalStatusTrigger)
+ def test_execute_complete_error(self):
+ op = DmsDeleteReplicationConfigOperator(
+ task_id="delete_replication_config",
+ replication_config_arn="arn:test",
+ )
+ error_event = {"status": "error", "message": "Timeout",
"replication_config_arn": "arn:test"}
+ with pytest.raises(AirflowException, match="Error deleting DMS
replication config"):
+ op.execute_complete({}, error_event)
+
+ def test_retry_execution_error(self):
+ op = DmsDeleteReplicationConfigOperator(
+ task_id="delete_replication_config",
+ replication_config_arn="arn:test",
+ )
+ error_event = {"status": "error", "message": "Timeout",
"replication_config_arn": "arn:test"}
+ with pytest.raises(AirflowException, match="Error waiting for DMS
replication config"):
+ op.retry_execution({}, error_event)
+
class TestDmsDescribeReplicationsOperator:
FILTER = [{"Name": "replication-type", "Values": ["cdc"]}]
@@ -1008,6 +1026,30 @@ class TestDmsStartReplicationOperator:
op.execute({})
assert mock_conn.start_replication.call_count == 1
+ def test_execute_complete_error(self):
+ op = DmsStartReplicationOperator(
+ task_id="start_replication",
+ replication_config_arn="arn:test",
+ replication_start_type="reload",
+ )
+ error_event = {
+ "status": "error",
+ "message": "Replication failed",
+ "replication_config_arn": "arn:test",
+ }
+ with pytest.raises(AirflowException, match="Error in DMS replication"):
+ op.execute_complete({}, error_event)
+
+ def test_retry_execution_error(self):
+ op = DmsStartReplicationOperator(
+ task_id="start_replication",
+ replication_config_arn="arn:test",
+ replication_start_type="reload",
+ )
+ error_event = {"status": "error", "message": "Timeout",
"replication_config_arn": "arn:test"}
+ with pytest.raises(AirflowException, match="Error waiting for DMS
replication"):
+ op.retry_execution({}, error_event)
+
class TestDmsStopReplicationOperator:
def mock_describe_replication_response(self, status: str):
@@ -1066,3 +1108,12 @@ class TestDmsStopReplicationOperator:
op.execute({})
mock_get_waiter.assert_called_with("replication_stopped")
mock_get_waiter.assert_called_once()
+
+ def test_execute_complete_error(self):
+ op = DmsStopReplicationOperator(
+ task_id="stop_replication",
+ replication_config_arn="arn:test",
+ )
+ error_event = {"status": "error", "message": "Timeout",
"replication_config_arn": "arn:test"}
+ with pytest.raises(AirflowException, match="Error stopping DMS
replication"):
+ op.execute_complete({}, error_event)
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
index dcedaee3963..62cccc528d0 100644
--- a/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
+++ b/providers/amazon/tests/unit/amazon/aws/operators/test_emr_serverless.py
@@ -1406,6 +1406,14 @@ class TestEmrServerlessDeleteOperator:
with pytest.raises(TaskDeferred):
operator.execute(None)
+ def test_execute_complete_error(self):
+ operator = EmrServerlessDeleteApplicationOperator(
+ task_id=task_id, application_id=application_id_delete_operator
+ )
+ error_event = {"status": "error", "message": "Delete failed",
"application_id": "test"}
+ with pytest.raises(AirflowException, match="Error deleting EMR
Serverless application"):
+ operator.execute_complete({}, error_event)
+
def test_template_fields(self):
operator = EmrServerlessDeleteApplicationOperator(
task_id=task_id, application_id=application_id_delete_operator
@@ -1487,6 +1495,18 @@ class TestEmrServerlessStopOperator:
assert "no running jobs found with application ID test" in
caplog.messages
+ def test_execute_complete_error(self):
+ operator = EmrServerlessStopApplicationOperator(task_id=task_id,
application_id="test")
+ error_event = {"status": "error", "message": "Stop failed",
"application_id": "test"}
+ with pytest.raises(AirflowException, match="Error stopping EMR
Serverless application"):
+ operator.execute_complete({}, error_event)
+
+ def test_stop_application_error(self):
+ operator = EmrServerlessStopApplicationOperator(task_id=task_id,
application_id="test")
+ error_event = {"status": "error", "message": "Cancel jobs failed",
"application_id": "test"}
+ with pytest.raises(AirflowException, match="Error cancelling EMR
Serverless jobs"):
+ operator.stop_application({}, error_event)
+
def test_template_fields(self):
operator = EmrServerlessStopApplicationOperator(
task_id=task_id, application_id="test", deferrable=True,
force_stop=True
diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
index 11a1265003a..938d7228a70 100644
--- a/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
+++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_mwaa.py
@@ -98,6 +98,17 @@ class TestMwaaDagRunSuccessSensor:
with pytest.raises(AirflowException, match=f".*{state}.*"):
MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS,
**SENSOR_STATE_KWARGS).poke({})
+ def test_execute_complete_error(self):
+ sensor = MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS,
**SENSOR_STATE_KWARGS)
+ error_event = {"status": "error", "message": "DAG run failed",
"dag_run_id": "test_run"}
+ with pytest.raises(AirflowException, match="Error in MWAA DAG run"):
+ sensor.execute_complete({}, error_event)
+
+ def test_execute_complete_success(self):
+ sensor = MwaaDagRunSensor(**SENSOR_DAG_RUN_KWARGS,
**SENSOR_STATE_KWARGS)
+ success_event = {"status": "success", "dag_run_id": "test_run"}
+ sensor.execute_complete({}, success_event) # should not raise
+
class TestMwaaTaskSuccessSensor:
def test_init_success(self):
@@ -137,3 +148,14 @@ class TestMwaaTaskSuccessSensor:
mock_invoke_rest_api.return_value = {"RestApiResponse": {"state":
state}}
with pytest.raises(AirflowException, match=f".*{state}.*"):
MwaaTaskSensor(**SENSOR_TASK_KWARGS,
**SENSOR_STATE_KWARGS).poke({})
+
+ def test_execute_complete_error(self):
+ sensor = MwaaTaskSensor(**SENSOR_TASK_KWARGS, **SENSOR_STATE_KWARGS)
+ error_event = {"status": "error", "message": "Task failed", "task_id":
"test_task"}
+ with pytest.raises(AirflowException, match="Error in MWAA task"):
+ sensor.execute_complete({}, error_event)
+
+ def test_execute_complete_success(self):
+ sensor = MwaaTaskSensor(**SENSOR_TASK_KWARGS, **SENSOR_STATE_KWARGS)
+ success_event = {"status": "success", "task_id": "test_task"}
+ sensor.execute_complete({}, success_event) # should not raise
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
index 7e64cb38e1a..6423ff9c38c 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_base.py
@@ -22,6 +22,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
+from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
if TYPE_CHECKING:
@@ -125,3 +126,21 @@ class TestAwsBaseWaiterTrigger:
assert isinstance(res.payload, dict)
assert res.payload["status"] == "success"
assert res.payload["hello"] == "world"
+
+ @pytest.mark.asyncio
+ @mock.patch(
+ "airflow.providers.amazon.aws.triggers.base.async_wait",
+ side_effect=AirflowException("AWS Glue job failed.\nTerminal failure"),
+ )
+ async def test_run_error_yields_event(self, wait_mock: MagicMock):
+ self.trigger.return_key = "hello"
+ self.trigger.return_value = "world"
+
+ generator = self.trigger.run()
+ res: TriggerEvent = await generator.asend(None)
+
+ wait_mock.assert_called_once()
+ assert isinstance(res.payload, dict)
+ assert res.payload["status"] == "error"
+ assert "AWS Glue job failed." in res.payload["message"]
+ assert res.payload["hello"] == "world"
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py
index 1c6de7f096a..4339d36ce38 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_glue.py
@@ -30,7 +30,6 @@ from airflow.providers.amazon.aws.triggers.glue import (
GlueDataQualityRuleSetEvaluationRunCompleteTrigger,
GlueJobCompleteTrigger,
)
-from airflow.providers.common.compat.sdk import AirflowException
from airflow.triggers.base import TriggerEvent
from unit.amazon.aws.utils.test_waiter import assert_expected_waiter_type
@@ -85,10 +84,12 @@ class TestGlueJobTrigger:
waiter_delay=10,
)
generator = trigger.run()
+ event = await generator.asend(None)
- with pytest.raises(AirflowException):
- await generator.asend(None)
assert_expected_waiter_type(mock_get_waiter, "job_complete")
+ assert event.payload["status"] == "error"
+ assert "message" in event.payload
+ assert event.payload["run_id"] == "JobRunId"
def test_serialization(self):
trigger = GlueJobCompleteTrigger(
diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py
b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py
index 13b0c4b1dd8..4fa9d0ed38f 100644
--- a/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py
+++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_neptune.py
@@ -27,7 +27,6 @@ from airflow.providers.amazon.aws.triggers.neptune import (
NeptuneClusterInstancesAvailableTrigger,
NeptuneClusterStoppedTrigger,
)
-from airflow.providers.common.compat.sdk import AirflowException
from airflow.triggers.base import TriggerEvent
CLUSTER_ID = "test-cluster"
@@ -125,5 +124,6 @@ class TestNeptuneClusterInstancesAvailableTrigger:
db_cluster_id=CLUSTER_ID, waiter_delay=1, waiter_max_attempts=2
)
- with pytest.raises(AirflowException):
- await trigger.run().asend(None)
+ event = await trigger.run().asend(None)
+ assert event.payload["status"] == "error"
+ assert "message" in event.payload