This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b83fcf9a31c AIP-86 - Add async support for Notifiers (#53831)
b83fcf9a31c is described below
commit b83fcf9a31c84b25feb4bd7b816bb64f226c52d3
Author: Ramit Kataria <[email protected]>
AuthorDate: Tue Aug 26 09:23:53 2025 -0700
AIP-86 - Add async support for Notifiers (#53831)
This makes the `BaseNotifier` awaitable and implements the necessary
changes required for the notifiers to work in non-blocking mode
(required for DeadlineAlerts). Since notifiers use hooks which may need
to use the TaskSDK API if they will fetching a `Connection`, I've added
async counterparts to relevant TaskSDK functions as well, while avoiding
as much code duplication as I could.
Includes changes needed to `SlackWebhookNotifier` as an example.
---------
Co-authored-by: ferruzzi <[email protected]>
---
.../logging-monitoring/callbacks.rst | 81 ++++++++----
airflow-core/docs/howto/notifications.rst | 47 ++-----
airflow-core/src/airflow/triggers/deadline.py | 6 +-
airflow-core/tests/unit/models/test_deadline.py | 119 ++++++++---------
airflow-core/tests/unit/triggers/test_deadline.py | 40 +++++-
devel-common/src/tests_common/pytest_plugin.py | 5 +-
.../airflow/providers/slack/hooks/slack_webhook.py | 147 +++++++++++++++++----
.../providers/slack/notifications/slack_webhook.py | 21 ++-
.../tests/unit/slack/hooks/test_slack_webhook.py | 133 ++++++++++++++++++-
.../unit/slack/notifications/test_slack_webhook.py | 70 ++++++++++
task-sdk/src/airflow/sdk/bases/hook.py | 14 ++
task-sdk/src/airflow/sdk/bases/notifier.py | 68 ++++++++--
task-sdk/src/airflow/sdk/definitions/connection.py | 20 ++-
task-sdk/src/airflow/sdk/execution_time/comms.py | 4 +
task-sdk/src/airflow/sdk/execution_time/context.py | 31 +++--
task-sdk/tests/task_sdk/bases/test_hook.py | 32 +++++
.../tests/task_sdk/execution_time/test_context.py | 4 +-
17 files changed, 658 insertions(+), 184 deletions(-)
diff --git
a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
index c2201921cd1..00eca8bd0e8 100644
---
a/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
+++
b/airflow-core/docs/administration-and-deployment/logging-monitoring/callbacks.rst
@@ -20,19 +20,19 @@
Callbacks
=========
-A valuable component of logging and monitoring is the use of task callbacks to
act upon changes in state of a given DAG or task, or across all tasks in a
given DAG.
-For example, you may wish to alert when certain tasks have failed, or invoke a
callback when your DAG succeeds.
+A valuable component of logging and monitoring is the use of task callbacks to
act upon changes in state of a given Dag or task, or across all tasks in a
given Dag.
+For example, you may wish to alert when certain tasks have failed, or invoke a
callback when your Dag succeeds.
There are three different places where callbacks can be defined.
-- Callbacks set in the DAG definition will be applied at the DAG level.
-- Using ``default_args``, callbacks can be set for each task in a DAG.
+- Callbacks set in the Dag definition will be applied at the Dag level.
+- Using ``default_args``, callbacks can be set for each task in a Dag.
- Individual callbacks can be set for a task by setting that callback within
the task definition itself.
.. note::
- Callback functions are only invoked when the DAG or task state changes due
to execution by a worker.
- As such, DAG and task changes set by the command line interface (:doc:`CLI
<../../howto/usage-cli>`) or user interface (:doc:`UI <../../ui>`) do not
+ Callback functions are only invoked when the Dag or task state changes due
to execution by a worker.
+ As such, Dag and task changes set by the command line interface (:doc:`CLI
<../../howto/usage-cli>`) or user interface (:doc:`UI <../../ui>`) do not
execute callback functions.
.. warning::
@@ -42,6 +42,12 @@ There are three different places where callbacks can be
defined.
By default, scheduler logs do not show up in the UI and instead can be
found in
``$AIRFLOW_HOME/logs/scheduler/latest/DAG_FILE.py.log``
+.. note::
+ As of Airflow 2.6.0, callbacks now supports a list of callback functions,
allowing users to specify multiple functions
+ to be executed in the desired event. Simply pass a list of callback
functions to the callback args when defining your Dag/task
+ callbacks: e.g ``on_failure_callback=[callback_func_1, callback_func_2]``
+
+
Callback Types
--------------
@@ -50,33 +56,33 @@ There are six types of events that can trigger a callback:
===========================================
================================================================
Name Description
===========================================
================================================================
-``on_success_callback`` Invoked when the :ref:`DAG
succeeds <dag-run:dag-run-status>` or :ref:`task succeeds
<concepts:task-instances>`.
- Available at the DAG or task level.
+``on_success_callback`` Invoked when the :ref:`Dag
succeeds <dag-run:dag-run-status>` or :ref:`task succeeds
<concepts:task-instances>`.
+ Available at the Dag or task level.
``on_failure_callback`` Invoked when the task :ref:`fails
<concepts:task-instances>`.
- Available at the DAG or task level.
+ Available at the Dag or task level.
``on_retry_callback`` Invoked when the task is :ref:`up
for retry <concepts:task-instances>`.
Available only at the task level.
``on_execute_callback`` Invoked right before the task
begins executing.
Available only at the task level.
``on_skipped_callback`` Invoked when the task is
:ref:`running <concepts:task-instances>` and AirflowSkipException raised.
Explicitly it is NOT called if a
task is not started to be executed because of a preceding branching
- decision in the DAG or a trigger
rule which causes execution to skip so that the task execution
+ decision in the Dag or a trigger
rule which causes execution to skip so that the task execution
is never scheduled.
Available only at the task level.
===========================================
================================================================
-Example
--------
+Examples
+--------
+
+Using Custom Callback Methods
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-In the following example, failures in ``task1`` call the
``task_failure_alert`` function, and success at DAG level calls the
``dag_success_alert`` function.
+In the following example, failures in ``task1`` call the
``task_failure_alert`` function, and success at Dag level calls the
``dag_success_alert`` function.
Before each task begins to execute, the ``task_execute_callback`` function
will be called:
.. code-block:: python
- import datetime
- import pendulum
-
from airflow.sdk import DAG
from airflow.providers.standard.operators.empty import EmptyOperator
@@ -90,27 +96,48 @@ Before each task begins to execute, the
``task_execute_callback`` function will
def dag_success_alert(context):
- print(f"DAG has succeeded, run_id: {context['run_id']}")
+ print(f"Dag has succeeded, run_id: {context['run_id']}")
with DAG(
dag_id="example_callback",
- schedule=None,
- start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
- dagrun_timeout=datetime.timedelta(minutes=60),
- catchup=False,
on_success_callback=dag_success_alert,
default_args={"on_execute_callback": task_execute_callback},
- tags=["example"],
):
task1 = EmptyOperator(task_id="task1",
on_failure_callback=[task_failure_alert])
task2 = EmptyOperator(task_id="task2")
task3 = EmptyOperator(task_id="task3")
task1 >> task2 >> task3
-.. note::
- As of Airflow 2.6.0, callbacks now supports a list of callback functions,
allowing users to specify multiple functions
- to be executed in the desired event. Simply pass a list of callback
functions to the callback args when defining your DAG/task
- callbacks: e.g ``on_failure_callback=[callback_func_1, callback_func_2]``
-
Full list of variables available in ``context`` in :doc:`docs
<../../templates-ref>` and `code
<https://github.com/apache/airflow/blob/main/task-sdk/src/airflow/sdk/definitions/context.py>`_.
+
+
+Using Notifiers
+^^^^^^^^^^^^^^^
+
+You can use Notifiers in your Dag definition by passing it as an argument to
the ``on_*_callbacks``.
+For example, you can use it with ``on_success_callback`` or
``on_failure_callback`` to send notifications based
+on the status of a task or a Dag run.
+
+Here's an example of using a custom notifier:
+
+.. code-block:: python
+
+ from airflow.sdk import DAG
+ from airflow.providers.standard.operators.bash import BashOperator
+
+ from myprovider.notifier import MyNotifier
+
+ with DAG(
+ dag_id="example_notifier",
+ on_success_callback=MyNotifier(message="Success!"),
+ on_failure_callback=MyNotifier(message="Failure!"),
+ ):
+ task = BashOperator(
+ task_id="example_task",
+ bash_command="exit 1",
+ on_success_callback=MyNotifier(message="Task Succeeded!"),
+ )
+
+For a list of community-managed Notifiers, see
:doc:`apache-airflow-providers:core-extensions/notifications`.
+For more information on writing a custom Notifier, see the :doc:`Notifiers
<../../howto/notifications>` how-to page.
diff --git a/airflow-core/docs/howto/notifications.rst
b/airflow-core/docs/howto/notifications.rst
index 38a545bb4d8..1705ccf25dc 100644
--- a/airflow-core/docs/howto/notifications.rst
+++ b/airflow-core/docs/howto/notifications.rst
@@ -15,8 +15,9 @@
specific language governing permissions and limitations
under the License.
-Creating a notifier
+Creating a Notifier
===================
+
The :class:`~airflow.sdk.definitions.notifier.BaseNotifier` is an abstract
class that provides a basic
structure for sending notifications in Airflow using the various
``on_*__callback``.
It is intended for providers to extend and customize for their specific needs.
@@ -32,49 +33,29 @@ Here's an example of how you can create a Notifier class:
.. code-block:: python
from airflow.sdk import BaseNotifier
- from my_provider import send_message
+ from my_provider import async_send_message, send_message
class MyNotifier(BaseNotifier):
template_fields = ("message",)
- def __init__(self, message):
+ def __init__(self, message: str):
self.message = message
- def notify(self, context):
- # Send notification here, below is an example
+ def notify(self, context: Context) -> None:
+ # Send notification here. For example:
title = f"Task {context['task_instance'].task_id} failed"
send_message(title, self.message)
-Using a notifier
-----------------
-Once you have a notifier implementation, you can use it in your ``DAG``
definition by passing it as an argument to
-the ``on_*_callbacks``. For example, you can use it with
``on_success_callback`` or ``on_failure_callback`` to send
-notifications based on the status of a task or a DAG run.
-
-Here's an example of using the above notifier:
-
-.. code-block:: python
-
- from datetime import datetime
+ async def async_notify(self, context: Context) -> None:
+ # Only required if your Notifier is going to support asynchronous
code. For example:
+ title = f"Task {context['task_instance'].task_id} failed"
+ await async_send_message(title, self.message)
- from airflow.sdk import DAG
- from airflow.providers.standard.operators.bash import BashOperator
- from myprovider.notifier import MyNotifier
+For a list of community-managed notifiers, see
:doc:`apache-airflow-providers:core-extensions/notifications`.
- with DAG(
- dag_id="example_notifier",
- start_date=datetime(2022, 1, 1),
- schedule=None,
- on_success_callback=MyNotifier(message="Success!"),
- on_failure_callback=MyNotifier(message="Failure!"),
- ):
- task = BashOperator(
- task_id="example_task",
- bash_command="exit 1",
- on_success_callback=MyNotifier(message="Task Succeeded!"),
- )
+Using Notifiers
+===============
-For a list of community-managed notifiers, see
-:doc:`apache-airflow-providers:core-extensions/notifications`.
+For using Notifiers in event-based DAG callbacks, see
:doc:`../administration-and-deployment/logging-monitoring/callbacks`.
diff --git a/airflow-core/src/airflow/triggers/deadline.py
b/airflow-core/src/airflow/triggers/deadline.py
index 8b70015c76a..bcff27fd1b2 100644
--- a/airflow-core/src/airflow/triggers/deadline.py
+++ b/airflow-core/src/airflow/triggers/deadline.py
@@ -51,7 +51,11 @@ class DeadlineCallbackTrigger(BaseTrigger):
try:
callback = import_string(self.callback_path)
yield TriggerEvent({PAYLOAD_STATUS_KEY:
DeadlineCallbackState.RUNNING})
- result = await callback(**self.callback_kwargs)
+
+ # TODO: get airflow context
+ context: dict = {}
+
+ result = await callback(**self.callback_kwargs, context=context)
log.info("Deadline callback completed with return value: %s",
result)
yield TriggerEvent({PAYLOAD_STATUS_KEY:
DeadlineCallbackState.SUCCESS, PAYLOAD_BODY_KEY: result})
except Exception as e:
diff --git a/airflow-core/tests/unit/models/test_deadline.py
b/airflow-core/tests/unit/models/test_deadline.py
index 981adf196ba..73b68d76d0c 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -21,7 +21,6 @@ from unittest import mock
import pytest
import time_machine
-from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from airflow.models import DagRun, Trigger
@@ -36,9 +35,8 @@ from tests_common.test_utils import db
from unit.models import DEFAULT_DATE
DAG_ID = "dag_id_1"
-RUN_ID = 1
INVALID_DAG_ID = "invalid_dag_id"
-INVALID_RUN_ID = 2
+INVALID_RUN_ID = -1
REFERENCE_TYPES = [
pytest.param(DeadlineReference.DAGRUN_LOGICAL_DATE, id="logical_date"),
@@ -77,6 +75,18 @@ def dagrun(session, dag_maker):
return session.query(DagRun).one()
[email protected]
+def deadline_orm(dagrun, session):
+ deadline = Deadline(
+ deadline_time=DEFAULT_DATE,
+ callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ dagrun_id=dagrun.id,
+ )
+ session.add(deadline)
+ session.flush()
+ return deadline
+
+
@pytest.mark.db_test
class TestDeadline:
@staticmethod
@@ -87,42 +97,32 @@ class TestDeadline:
def teardown_method():
_clean_db()
- def test_add_deadline(self, dagrun, session):
- assert session.query(Deadline).count() == 0
- deadline_orm = Deadline(
- deadline_time=DEFAULT_DATE,
- callback=TEST_ASYNC_CALLBACK,
- dagrun_id=dagrun.id,
- )
-
- session.add(deadline_orm)
- session.flush()
-
- assert session.query(Deadline).count() == 1
-
- result = session.scalars(select(Deadline)).first()
- assert result.dagrun_id == deadline_orm.dagrun_id
- assert result.deadline_time == deadline_orm.deadline_time
- assert result.callback == deadline_orm.callback
-
@pytest.mark.parametrize(
"conditions",
[
pytest.param({}, id="empty_conditions"),
- pytest.param({Deadline.dagrun_id: INVALID_RUN_ID},
id="no_matches"),
- pytest.param({Deadline.dagrun_id: RUN_ID}, id="single_condition"),
+ pytest.param({Deadline.dagrun_id: -1}, id="no_matches"),
+ pytest.param({Deadline.dagrun_id: "valid_placeholder"},
id="single_condition"),
pytest.param(
- {Deadline.dagrun_id: RUN_ID, Deadline.deadline_time:
datetime.now() + timedelta(days=365)},
+ {
+ Deadline.dagrun_id: "valid_placeholder",
+ Deadline.deadline_time: datetime.now() +
timedelta(days=365),
+ },
id="multiple_conditions",
),
pytest.param(
- {Deadline.dagrun_id: RUN_ID, Deadline.callback_state:
"invalid"}, id="mixed_conditions"
+ {Deadline.dagrun_id: "valid_placeholder",
Deadline.callback_state: "invalid"},
+ id="mixed_conditions",
),
],
)
@mock.patch("sqlalchemy.orm.Session")
- def test_prune_deadlines(self, mock_session, conditions):
+ def test_prune_deadlines(self, mock_session, conditions, dagrun):
"""Test deadline resolution with various conditions."""
+ if Deadline.dagrun_id in conditions:
+ if conditions[Deadline.dagrun_id] == "valid_placeholder":
+ conditions[Deadline.dagrun_id] = dagrun.id
+
expected_result = 1 if conditions else 0
# Set up the query chain to return a list of (Deadline, DagRun) pairs
mock_dagrun = mock.Mock(spec=DagRun, end_date=datetime.now())
@@ -142,32 +142,13 @@ class TestDeadline:
else:
mock_session.query.assert_not_called()
- def test_orm(self):
- deadline_orm = Deadline(
- deadline_time=DEFAULT_DATE,
- callback=TEST_ASYNC_CALLBACK,
- dagrun_id=RUN_ID,
- )
-
- assert deadline_orm.deadline_time == DEFAULT_DATE
- assert deadline_orm.callback == TEST_ASYNC_CALLBACK
- assert deadline_orm.dagrun_id == RUN_ID
-
- def test_repr_with_callback_kwargs(self, dagrun, session):
- deadline_orm = Deadline(
- deadline_time=DEFAULT_DATE,
- callback=TEST_ASYNC_CALLBACK,
- dagrun_id=dagrun.id,
- )
- session.add(deadline_orm)
- session.flush()
-
+ def test_repr_with_callback_kwargs(self, deadline_orm, dagrun):
assert (
repr(deadline_orm) == f"[DagRun Deadline] Dag: {DAG_ID} Run:
{dagrun.id} needed by "
- f"{deadline_orm.deadline_time} or run:
{TEST_CALLBACK_PATH}({TEST_CALLBACK_KWARGS})"
+ f"{DEFAULT_DATE} or run:
{TEST_CALLBACK_PATH}({TEST_CALLBACK_KWARGS})"
)
- def test_repr_without_callback_kwargs(self, dagrun, session):
+ def test_repr_without_callback_kwargs(self, deadline_orm, dagrun, session):
deadline_orm = Deadline(
deadline_time=DEFAULT_DATE,
callback=AsyncCallback(TEST_CALLBACK_PATH),
@@ -179,19 +160,11 @@ class TestDeadline:
assert deadline_orm.callback.kwargs is None
assert (
repr(deadline_orm) == f"[DagRun Deadline] Dag: {DAG_ID} Run:
{dagrun.id} needed by "
- f"{deadline_orm.deadline_time} or run: {TEST_CALLBACK_PATH}()"
+ f"{DEFAULT_DATE} or run: {TEST_CALLBACK_PATH}()"
)
@pytest.mark.db_test
- def test_handle_miss_async_callback(self, dagrun, session):
- deadline_orm = Deadline(
- deadline_time=DEFAULT_DATE,
- callback=TEST_ASYNC_CALLBACK,
- dagrun_id=dagrun.id,
- )
- session.add(deadline_orm)
- session.flush()
-
+ def test_handle_miss_async_callback(self, dagrun, deadline_orm, session):
deadline_orm.handle_miss(session=session)
session.flush()
@@ -248,15 +221,7 @@ class TestDeadline:
pytest.param(TriggerEvent({PAYLOAD_STATUS_KEY: "unknown_state"}),
False, id="unknown_event"),
],
)
- def test_handle_callback_event(self, dagrun, session, event,
none_trigger_expected):
- deadline_orm = Deadline(
- deadline_time=DEFAULT_DATE,
- callback=TEST_ASYNC_CALLBACK,
- dagrun_id=dagrun.id,
- )
- session.add(deadline_orm)
- session.flush()
-
+ def test_handle_callback_event(self, dagrun, deadline_orm, session, event,
none_trigger_expected):
deadline_orm.handle_miss(session=session)
session.flush()
@@ -271,6 +236,26 @@ class TestDeadline:
else:
assert deadline_orm.callback_state == DeadlineCallbackState.QUEUED
+ def test_handle_miss_creates_trigger(self, dagrun, deadline_orm, session):
+ """Test that handle_miss creates a trigger with correct parameters."""
+ deadline_orm.handle_miss(session)
+ session.flush()
+
+ # Check trigger was created
+ trigger = session.query(Trigger).first()
+ assert trigger is not None
+ assert deadline_orm.trigger_id == trigger.id
+
+ # Check trigger has correct kwargs
+ assert trigger.kwargs["callback_path"] == TEST_CALLBACK_PATH
+ assert trigger.kwargs["callback_kwargs"] == TEST_CALLBACK_KWARGS
+
+ def test_handle_miss_sets_callback_state(self, dagrun, deadline_orm,
session):
+ """Test that handle_miss sets the callback state to QUEUED."""
+ deadline_orm.handle_miss(session)
+
+ assert deadline_orm.callback_state == DeadlineCallbackState.QUEUED
+
@pytest.mark.db_test
class TestCalculatedDeadlineDatabaseCalls:
diff --git a/airflow-core/tests/unit/triggers/test_deadline.py
b/airflow-core/tests/unit/triggers/test_deadline.py
index 137f40e63e7..72bea33f188 100644
--- a/airflow-core/tests/unit/triggers/test_deadline.py
+++ b/airflow-core/tests/unit/triggers/test_deadline.py
@@ -22,13 +22,29 @@ from unittest import mock
import pytest
from airflow.models.deadline import DeadlineCallbackState
+from airflow.sdk import BaseNotifier
from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY,
DeadlineCallbackTrigger
+TEST_MESSAGE = "test_message"
TEST_CALLBACK_PATH = "classpath.test_callback_for_deadline"
-TEST_CALLBACK_KWARGS = {"arg1": "value1"}
+TEST_CALLBACK_KWARGS = {"message": TEST_MESSAGE}
TEST_TRIGGER = DeadlineCallbackTrigger(callback_path=TEST_CALLBACK_PATH,
callback_kwargs=TEST_CALLBACK_KWARGS)
+class ExampleAsyncNotifier(BaseNotifier):
+ """Example of a properly implemented async notifier."""
+
+ def __init__(self, message, **kwargs):
+ super().__init__(**kwargs)
+ self.message = message
+
+ async def async_notify(self, context):
+ return f"Async notification: {self.message}, context: {context}"
+
+ def notify(self, context):
+ return f"Sync notification: {self.message}, context: {context}"
+
+
class TestDeadlineCallbackTrigger:
@pytest.fixture
def mock_import_string(self):
@@ -56,7 +72,8 @@ class TestDeadlineCallbackTrigger:
}
@pytest.mark.asyncio
- async def test_run_success(self, mock_import_string):
+ async def test_run_success_with_async_function(self, mock_import_string):
+ """Test trigger handles async functions correctly."""
callback_return_value = "some value"
mock_callback = mock.AsyncMock(return_value=callback_return_value)
mock_import_string.return_value = mock_callback
@@ -68,10 +85,25 @@ class TestDeadlineCallbackTrigger:
success_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
- mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
+ mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS,
context=mock.ANY)
assert success_event.payload[PAYLOAD_STATUS_KEY] ==
DeadlineCallbackState.SUCCESS
assert success_event.payload[PAYLOAD_BODY_KEY] == callback_return_value
+ @pytest.mark.asyncio
+ async def test_run_success_with_notifier(self, mock_import_string):
+ """Test trigger handles async notifier classes correctly."""
+ mock_import_string.return_value = ExampleAsyncNotifier
+
+ trigger_gen = TEST_TRIGGER.run()
+
+ running_event = await anext(trigger_gen)
+ assert running_event.payload[PAYLOAD_STATUS_KEY] ==
DeadlineCallbackState.RUNNING
+
+ success_event = await anext(trigger_gen)
+ mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
+ assert success_event.payload[PAYLOAD_STATUS_KEY] ==
DeadlineCallbackState.SUCCESS
+ assert success_event.payload[PAYLOAD_BODY_KEY] == f"Async
notification: {TEST_MESSAGE}, context: {{}}"
+
@pytest.mark.asyncio
async def test_run_failure(self, mock_import_string):
exc_msg = "Something went wrong"
@@ -85,6 +117,6 @@ class TestDeadlineCallbackTrigger:
failure_event = await anext(trigger_gen)
mock_import_string.assert_called_once_with(TEST_CALLBACK_PATH)
- mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS)
+ mock_callback.assert_called_once_with(**TEST_CALLBACK_KWARGS,
context=mock.ANY)
assert failure_event.payload[PAYLOAD_STATUS_KEY] ==
DeadlineCallbackState.FAILED
assert all(s in failure_event.payload[PAYLOAD_BODY_KEY] for s in
["raise", "RuntimeError", exc_msg])
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index 848a633403e..dc98efe31e7 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2127,7 +2127,10 @@ def sdk_connection_not_found(mock_supervisor_comms):
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
- mock_supervisor_comms.send.return_value =
ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
+ error_response = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
+ mock_supervisor_comms.send.return_value = error_response
+ if hasattr(mock_supervisor_comms, "asend"):
+ mock_supervisor_comms.asend.return_value = error_response
yield mock_supervisor_comms
diff --git a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py
b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py
index 2e6a0f74ae1..9c2260b2936 100644
--- a/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py
+++ b/providers/slack/src/airflow/providers/slack/hooks/slack_webhook.py
@@ -24,6 +24,7 @@ from functools import cached_property, wraps
from typing import TYPE_CHECKING, Any
from slack_sdk import WebhookClient
+from slack_sdk.webhook.async_client import AsyncWebhookClient
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.slack.utils import ConnectionExtraConfig
@@ -35,17 +36,34 @@ if TYPE_CHECKING:
LEGACY_INTEGRATION_PARAMS = ("channel", "username", "icon_emoji", "icon_url")
+def _validate_response(resp):
+ """Validate webhook response and raise error if status code != 200."""
+ if resp.status_code != 200:
+ raise AirflowException(
+ f"Response body: {resp.body!r}, Status Code: {resp.status_code}. "
+ "See: https://api.slack.com/messaging/webhooks#handling_errors"
+ )
+
+
def check_webhook_response(func: Callable) -> Callable:
"""Check WebhookResponse and raise an error if status code != 200."""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
resp = func(*args, **kwargs)
- if resp.status_code != 200:
- raise AirflowException(
- f"Response body: {resp.body!r}, Status Code:
{resp.status_code}. "
- "See: https://api.slack.com/messaging/webhooks#handling_errors"
- )
+ _validate_response(resp)
+ return resp
+
+ return wrapper
+
+
+def async_check_webhook_response(func: Callable) -> Callable:
+ """Check WebhookResponse and raise an error if status code != 200
(async)."""
+
+ @wraps(func)
+ async def wrapper(*args, **kwargs) -> Callable:
+ resp = await func(*args, **kwargs)
+ _validate_response(resp)
return resp
return wrapper
@@ -134,13 +152,27 @@ class SlackWebhookHook(BaseHook):
"""Get the underlying slack_sdk.webhook.WebhookClient (cached)."""
return WebhookClient(**self._get_conn_params())
+ @cached_property
+ async def async_client(self) -> AsyncWebhookClient:
+ """Get the underlying
`slack_sdk.webhook.async_client.AsyncWebhookClient` (cached)."""
+ return AsyncWebhookClient(**await self._async_get_conn_params())
+
def get_conn(self) -> WebhookClient:
- """Get the underlying slack_sdk.webhook.WebhookClient (cached)."""
+ """Get the underlying `slack_sdk.webhook.WebhookClient` (cached)."""
return self.client
def _get_conn_params(self) -> dict[str, Any]:
"""Fetch connection params as a dict and merge it with hook
parameters."""
conn = self.get_connection(self.slack_webhook_conn_id)
+ return self._build_conn_params(conn)
+
+ async def _async_get_conn_params(self) -> dict[str, Any]:
+ """Fetch connection params as a dict and merge it with hook parameters
(async)."""
+ conn = await self.aget_connection(self.slack_webhook_conn_id)
+ return self._build_conn_params(conn)
+
+ def _build_conn_params(self, conn) -> dict[str, Any]:
+ """Build connection parameters from connection object."""
if not conn.password or not conn.password.strip():
raise AirflowNotFoundException(
f"Connection ID {self.slack_webhook_conn_id!r} does not
contain password "
@@ -173,14 +205,8 @@ class SlackWebhookHook(BaseHook):
conn_params.update(self.extra_client_args)
return {k: v for k, v in conn_params.items() if v is not None}
- @check_webhook_response
- def send_dict(self, body: dict[str, Any] | str, *, headers: dict[str, str]
| None = None):
- """
- Perform a Slack Incoming Webhook request with given JSON data block.
-
- :param body: JSON data structure, expected dict or JSON-string.
- :param headers: Request headers for this request.
- """
+ def _process_body(self, body: dict[str, Any] | str) -> dict[str, Any]:
+ """Validate and process the request body."""
if isinstance(body, str):
try:
body = json.loads(body)
@@ -203,9 +229,31 @@ class SlackWebhookHook(BaseHook):
UserWarning,
stacklevel=2,
)
+ return body
+ @check_webhook_response
+ def send_dict(self, body: dict[str, Any] | str, *, headers: dict[str, str]
| None = None):
+ """
+ Perform a Slack Incoming Webhook request with given JSON data block.
+
+ :param body: JSON data structure, expected dict or JSON-string.
+ :param headers: Request headers for this request.
+ """
+ body = self._process_body(body)
return self.client.send_dict(body, headers=headers)
+ @async_check_webhook_response
+ async def async_send_dict(self, body: dict[str, Any] | str, *, headers:
dict[str, str] | None = None):
+ """
+ Perform a Slack Incoming Webhook request with given JSON data block
(async).
+
+ :param body: JSON data structure, expected dict or JSON-string.
+ :param headers: Request headers for this request.
+ """
+ body = self._process_body(body)
+ async_client = await self.async_client
+ return await async_client.send_dict(body, headers=headers)
+
def send(
self,
*,
@@ -235,20 +283,69 @@ class SlackWebhookHook(BaseHook):
:param attachments: (legacy) A collection of attachments.
"""
body = {
- "text": text,
- "attachments": attachments,
- "blocks": blocks,
- "response_type": response_type,
- "replace_original": replace_original,
- "delete_original": delete_original,
- "unfurl_links": unfurl_links,
- "unfurl_media": unfurl_media,
- # Legacy Integration Parameters
- **kwargs,
+ k: v
+ for k, v in {
+ "text": text,
+ "attachments": attachments,
+ "blocks": blocks,
+ "response_type": response_type,
+ "replace_original": replace_original,
+ "delete_original": delete_original,
+ "unfurl_links": unfurl_links,
+ "unfurl_media": unfurl_media,
+ # Legacy Integration Parameters
+ **kwargs,
+ }.items()
+ if v is not None
}
- body = {k: v for k, v in body.items() if v is not None}
return self.send_dict(body=body, headers=headers)
+ async def async_send(
+ self,
+ *,
+ text: str | None = None,
+ blocks: list[dict[str, Any]] | None = None,
+ response_type: str | None = None,
+ replace_original: bool | None = None,
+ delete_original: bool | None = None,
+ unfurl_links: bool | None = None,
+ unfurl_media: bool | None = None,
+ headers: dict[str, str] | None = None,
+ attachments: list[dict[str, Any]] | None = None,
+ **kwargs,
+ ):
+ """
+ Perform a Slack Incoming Webhook request with given arguments (async).
+
+ :param text: The text message
+ (even when having blocks, setting this as well is recommended as
it works as fallback).
+ :param blocks: A collection of Block Kit UI components.
+ :param response_type: The type of message (either 'in_channel' or
'ephemeral').
+ :param replace_original: True if you use this option for response_url
requests.
+ :param delete_original: True if you use this option for response_url
requests.
+ :param unfurl_links: Option to indicate whether text url should unfurl.
+ :param unfurl_media: Option to indicate whether media url should
unfurl.
+ :param headers: Request headers for this request.
+ :param attachments: (legacy) A collection of attachments.
+ """
+ body = {
+ k: v
+ for k, v in {
+ "text": text,
+ "attachments": attachments,
+ "blocks": blocks,
+ "response_type": response_type,
+ "replace_original": replace_original,
+ "delete_original": delete_original,
+ "unfurl_links": unfurl_links,
+ "unfurl_media": unfurl_media,
+ # Legacy Integration Parameters
+ **kwargs,
+ }.items()
+ if v is not None
+ }
+ return await self.async_send_dict(body=body, headers=headers)
+
def send_text(
self,
text: str,
diff --git
a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py
b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py
index 36b7ccd851d..d1125224ecf 100644
--- a/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py
+++ b/providers/slack/src/airflow/providers/slack/notifications/slack_webhook.py
@@ -62,8 +62,9 @@ class SlackWebhookNotifier(BaseNotifier):
timeout: int | None = None,
attachments: list | None = None,
retry_handlers: list[RetryHandler] | None = None,
+ **kwargs,
):
- super().__init__()
+ super().__init__(**kwargs)
self.slack_webhook_conn_id = slack_webhook_conn_id
self.text = text
self.attachments = attachments
@@ -86,13 +87,29 @@ class SlackWebhookNotifier(BaseNotifier):
def notify(self, context):
"""Send a message to a Slack Incoming Webhook."""
- self.hook.send(
+ resp = self.hook.send(
text=self.text,
blocks=self.blocks,
unfurl_links=self.unfurl_links,
unfurl_media=self.unfurl_media,
attachments=self.attachments,
)
+ self.log.debug(
+ "Slack webhook notification sent using notify(): %s %s",
resp.status_code, resp.api_url
+ )
+
+ async def async_notify(self, context):
+ """Send a message to a Slack Incoming Webhook (async)."""
+ resp = await self.hook.async_send(
+ text=self.text,
+ blocks=self.blocks,
+ unfurl_links=self.unfurl_links,
+ unfurl_media=self.unfurl_media,
+ attachments=self.attachments,
+ )
+ self.log.debug(
+ "Slack webhook notification sent using notify_async(): %s %s",
resp.status_code, resp.api_url
+ )
send_slack_webhook_notification = SlackWebhookNotifier
diff --git a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py
b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py
index 8dfa021b911..43c97045b40 100644
--- a/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py
+++ b/providers/slack/tests/unit/slack/hooks/test_slack_webhook.py
@@ -27,11 +27,16 @@ from unittest.mock import patch
import pytest
from slack_sdk.http_retry.builtin_handlers import ConnectionErrorRetryHandler,
RateLimitErrorRetryHandler
+from slack_sdk.webhook.async_client import AsyncWebhookClient
from slack_sdk.webhook.webhook_response import WebhookResponse
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.connection import Connection
-from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook,
check_webhook_response
+from airflow.providers.slack.hooks.slack_webhook import (
+ SlackWebhookHook,
+ async_check_webhook_response,
+ check_webhook_response,
+)
TEST_TOKEN = "T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX"
TEST_WEBHOOK_URL = f"https://hooks.slack.com/services/{TEST_TOKEN}"
@@ -172,6 +177,42 @@ class TestCheckWebhookResponseDecorator:
assert decorated()
+class TestAsyncCheckWebhookResponseDecorator:
+ @pytest.mark.asyncio
+ async def test_ok_response(self):
+ """Test async decorator with OK response."""
+
+ @async_check_webhook_response
+ async def decorated():
+ return MOCK_WEBHOOK_RESPONSE
+
+ assert await decorated() is MOCK_WEBHOOK_RESPONSE
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "status_code,body",
+ [
+ (400, "invalid_payload"),
+ (403, "action_prohibited"),
+ (404, "channel_not_found"),
+ (410, "channel_is_archived"),
+ (500, "rollup_error"),
+ (418, "i_am_teapot"),
+ ],
+ )
+ async def test_error_response(self, status_code, body):
+ """Test async decorator with error response."""
+ test_response = WebhookResponse(url="foo://bar",
status_code=status_code, body=body, headers={})
+
+ @async_check_webhook_response
+ async def decorated():
+ return test_response
+
+ error_message = rf"Response body: '{body}', Status Code:
{status_code}\."
+ with pytest.raises(AirflowException, match=error_message):
+ await decorated()
+
+
class TestSlackWebhookHook:
@pytest.mark.parametrize(
"conn_id",
@@ -432,6 +473,7 @@ class TestSlackWebhookHook:
{"text": "Test Text"},
{"text": "Fallback Text", "blocks": ["Dummy Block"]},
{"text": "Fallback Text", "blocks": ["Dummy Block"],
"unfurl_media": True, "unfurl_links": True},
+ {"legacy": "value"},
],
)
@mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook.send_dict")
@@ -503,3 +545,92 @@ class TestSlackWebhookHook:
hook = SlackWebhookHook(slack_webhook_conn_id="my_conn")
params = hook._get_conn_params()
assert "proxy" not in params
+
+
+class TestSlackWebhookHookAsync:
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook._async_get_conn_params")
+ async def test_async_client(self, mock_async_get_conn_params):
+ """Test async_client property creates AsyncWebhookClient with correct
params."""
+ mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL}
+
+ hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID)
+ client = await hook.async_client
+
+ assert isinstance(client, AsyncWebhookClient)
+ assert client.url == TEST_WEBHOOK_URL
+ mock_async_get_conn_params.assert_called_once()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("headers", [None, {"User-Agent": "Airflow"}])
+ @pytest.mark.parametrize(
+ "send_body",
+ [
+ {"text": "Test Text"},
+ {"text": "Fallback Text", "blocks": ["Dummy Block"]},
+ {"text": "Fallback Text", "blocks": ["Dummy Block"],
"unfurl_media": True, "unfurl_links": True},
+ ],
+ )
+
@mock.patch("airflow.providers.slack.hooks.slack_webhook.AsyncWebhookClient")
+
@mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook._async_get_conn_params")
+ async def test_async_send_dict(
+ self, mock_async_get_conn_params, mock_async_webhook_client_cls,
send_body, headers
+ ):
+ """Test async_send_dict method with dict input."""
+ mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL}
+ mock_async_client = mock_async_webhook_client_cls.return_value
+ mock_async_client.send_dict =
mock.AsyncMock(return_value=MOCK_WEBHOOK_RESPONSE)
+
+ hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID)
+ resp = await hook.async_send_dict(body=send_body, headers=headers)
+
+ assert resp == MOCK_WEBHOOK_RESPONSE
+ mock_async_client.send_dict.assert_called_once_with(send_body,
headers=headers)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("headers", [None, {"User-Agent": "Airflow"}])
+ @pytest.mark.parametrize(
+ "send_body",
+ [
+ {"text": "Test Text"},
+ {"text": "Fallback Text", "blocks": ["Dummy Block"]},
+ {"text": "Fallback Text", "blocks": ["Dummy Block"],
"unfurl_media": True, "unfurl_links": True},
+ ],
+ )
+
@mock.patch("airflow.providers.slack.hooks.slack_webhook.AsyncWebhookClient")
+
@mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook._async_get_conn_params")
+ async def test_async_send_dict_json_string(
+ self, mock_async_get_conn_params, mock_async_webhook_client_cls,
send_body, headers
+ ):
+ """Test async_send_dict method with JSON string input."""
+ mock_async_get_conn_params.return_value = {"url": TEST_WEBHOOK_URL}
+ mock_async_client = mock_async_webhook_client_cls.return_value
+ mock_async_client.send_dict =
mock.AsyncMock(return_value=MOCK_WEBHOOK_RESPONSE)
+
+ hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID)
+ resp = await hook.async_send_dict(body=json.dumps(send_body),
headers=headers)
+
+ assert resp == MOCK_WEBHOOK_RESPONSE
+ mock_async_client.send_dict.assert_called_once_with(send_body,
headers=headers)
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("headers", [None, {"User-Agent": "Airflow"}])
+ @pytest.mark.parametrize(
+ "send_params",
+ [
+ {"text": "Test Text"},
+ {"text": "Fallback Text", "blocks": ["Dummy Block"]},
+ {"text": "Fallback Text", "blocks": ["Dummy Block"],
"unfurl_media": True, "unfurl_links": True},
+ {"legacy": "value"},
+ ],
+ )
+
@mock.patch("airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook.async_send_dict")
+ async def test_async_send(self, mock_async_send_dict, send_params,
headers):
+ """Test at async_send method."""
+ mock_async_send_dict.return_value = MOCK_WEBHOOK_RESPONSE
+
+ hook = SlackWebhookHook(slack_webhook_conn_id=TEST_CONN_ID)
+ resp = await hook.async_send(**send_params, headers=headers)
+
+ assert resp == MOCK_WEBHOOK_RESPONSE
+ mock_async_send_dict.assert_called_once_with(body=send_params,
headers=headers)
diff --git
a/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py
b/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py
index 7897ef2674b..ed12eb5794b 100644
--- a/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py
+++ b/providers/slack/tests/unit/slack/notifications/test_slack_webhook.py
@@ -65,6 +65,31 @@ class TestSlackNotifier:
)
mock_slack_hook.assert_called_once_with(slack_webhook_conn_id="test_conn_id",
**hook_extra_kwargs)
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook")
+ async def test_async_slack_webhook_notifier(self, mock_slack_hook):
+ mock_hook = mock_slack_hook.return_value
+ mock_hook.async_send = mock.AsyncMock()
+
+ notifier = send_slack_webhook_notification(
+ slack_webhook_conn_id="test_conn_id",
+ text="foo-bar",
+ blocks="spam-egg",
+ attachments="baz-qux",
+ unfurl_links=True,
+ unfurl_media=False,
+ )
+
+ await notifier.async_notify({})
+
+ mock_hook.async_send.assert_called_once_with(
+ text="foo-bar",
+ blocks="spam-egg",
+ unfurl_links=True,
+ unfurl_media=False,
+ attachments="baz-qux",
+ )
+
@mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook")
def test_slack_webhook_templated(self, mock_slack_hook,
create_dag_without_db):
notifier = send_slack_webhook_notification(
@@ -90,3 +115,48 @@ class TestSlackNotifier:
unfurl_links=None,
unfurl_media=None,
)
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.slack.notifications.slack_webhook.SlackWebhookHook")
+ async def test_async_slack_webhook_templated(self, mock_slack_hook,
create_dag_without_db):
+ """Test async notification with template rendering."""
+ mock_hook = mock_slack_hook.return_value
+ mock_hook.async_send = mock.AsyncMock()
+
+ notifier = send_slack_webhook_notification(
+ text="Who am I? {{ username }}",
+ blocks=[{"type": "header", "text": {"type": "plain_text", "text":
"{{ dag.dag_id }}"}}],
+ attachments=[{"image_url": "{{ dag.dag_id }}.png"}],
+ )
+
+ # Call notifier first to handle template rendering
+ notifier(
+ {
+ "dag":
create_dag_without_db("test_async_send_slack_webhook_notification_templated"),
+ "username": "not-a-root",
+ }
+ )
+
+ # Then call async_notify with rendered templates
+ await notifier.async_notify(
+ {
+ "dag":
create_dag_without_db("test_async_send_slack_webhook_notification_templated"),
+ "username": "not-a-root",
+ }
+ )
+
+ mock_hook.async_send.assert_called_once_with(
+ text="Who am I? not-a-root",
+ blocks=[
+ {
+ "type": "header",
+ "text": {
+ "type": "plain_text",
+ "text":
"test_async_send_slack_webhook_notification_templated",
+ },
+ }
+ ],
+ attachments=[{"image_url":
"test_async_send_slack_webhook_notification_templated.png"}],
+ unfurl_links=None,
+ unfurl_media=None,
+ )
diff --git a/task-sdk/src/airflow/sdk/bases/hook.py
b/task-sdk/src/airflow/sdk/bases/hook.py
index 2de50c443c5..8aaba17ca78 100644
--- a/task-sdk/src/airflow/sdk/bases/hook.py
+++ b/task-sdk/src/airflow/sdk/bases/hook.py
@@ -62,6 +62,20 @@ class BaseHook(LoggingMixin):
log.debug("Connection Retrieved '%s' (via task-sdk)", conn.conn_id)
return conn
+ @classmethod
+ async def aget_connection(cls, conn_id: str) -> Connection:
+ """
+ Get connection (async), given connection id.
+
+ :param conn_id: connection id
+ :return: connection
+ """
+ from airflow.sdk.definitions.connection import Connection
+
+ conn = await Connection.async_get(conn_id)
+ log.debug("Connection Retrieved '%s' (via task-sdk)", conn.conn_id)
+ return conn
+
@classmethod
def get_hook(cls, conn_id: str, hook_params: dict | None = None):
"""
diff --git a/task-sdk/src/airflow/sdk/bases/notifier.py
b/task-sdk/src/airflow/sdk/bases/notifier.py
index df4023d043a..6772e406f0b 100644
--- a/task-sdk/src/airflow/sdk/bases/notifier.py
+++ b/task-sdk/src/airflow/sdk/bases/notifier.py
@@ -17,8 +17,7 @@
from __future__ import annotations
-from abc import abstractmethod
-from collections.abc import Sequence
+from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING
from airflow.sdk.definitions._internal.templater import Templater
@@ -33,13 +32,32 @@ if TYPE_CHECKING:
class BaseNotifier(LoggingMixin, Templater):
- """BaseNotifier class for sending notifications."""
+ """
+ BaseNotifier class for sending notifications.
+
+ It can be used asynchronously (preferred) if `async_notify`is implemented
and/or
+ synchronously if `notify` is implemented.
+
+ Currently, the DAG/Task state change callbacks run on the DAG Processor
and only support sync usage.
+
+ Usage::
+ # Asynchronous usage
+ await Notifier(context)
+
+ # Synchronous usage
+ notifier = Notifier()
+ notifier(context)
+ """
template_fields: Sequence[str] = ()
template_ext: Sequence[str] = ()
- def __init__(self):
+ # Context stored as attribute here because parameters can't be passed to
__await__
+ context: Context
+
+ def __init__(self, context: Context | None = None):
super().__init__()
+ self.context = context or {}
self.resolve_template_files()
def _update_context(self, context: Context) -> Context:
@@ -53,7 +71,7 @@ class BaseNotifier(LoggingMixin, Templater):
return context
def _render(self, template, context, dag: DAG | None = None):
- dag = dag or context["dag"]
+ dag = dag or context.get("dag")
return super()._render(template, context, dag)
def render_template_fields(
@@ -69,19 +87,34 @@ class BaseNotifier(LoggingMixin, Templater):
:param context: Context dict with values to apply on content.
:param jinja_env: Jinja environment to use for rendering.
"""
- dag = context["dag"]
+ dag = context.get("dag")
if not jinja_env:
jinja_env = self.get_template_env(dag=dag)
self._do_render_template_fields(self, self.template_fields, context,
jinja_env, set())
- @abstractmethod
+ async def async_notify(self, context: Context) -> None:
+ """
+ Send a notification (async).
+
+ Implementing this is a requirement for running this notifier in the
triggerer, which is the
+ recommended approach for using Deadline Alerts.
+
+ :param context: The airflow context
+
+ Note: the context is not available in the current version.
+ """
+ raise NotImplementedError
+
def notify(self, context: Context) -> None:
"""
- Send a notification.
+ Send a notification (sync).
+
+ Implementing this is a requirement for running this notifier in the
DAG processor, which is where the
+ `on_success_callback` and `on_failure_callback` run.
:param context: The airflow context
"""
- ...
+ raise NotImplementedError
def __call__(self, *args) -> None:
"""
@@ -104,4 +137,19 @@ class BaseNotifier(LoggingMixin, Templater):
try:
self.notify(context)
except Exception as e:
- self.log.exception("Failed to send notification: %s", e)
+ self.log.error("Failed to send notification (sync): %s", e)
+ raise
+
+ def __await__(self) -> Generator:
+ """
+ Make the notifier awaitable.
+
+ Context must be provided as an attribute.
+ """
+ self._update_context(self.context)
+ self.render_template_fields(self.context)
+ try:
+ return self.async_notify(self.context).__await__()
+ except Exception as e:
+ self.log.error("Failed to send notification (async): %s", e)
+ raise
diff --git a/task-sdk/src/airflow/sdk/definitions/connection.py
b/task-sdk/src/airflow/sdk/definitions/connection.py
index 39ea645395a..cc2e92a41aa 100644
--- a/task-sdk/src/airflow/sdk/definitions/connection.py
+++ b/task-sdk/src/airflow/sdk/definitions/connection.py
@@ -188,6 +188,13 @@ class Connection:
hook_params = {}
return hook_class(**{hook.connection_id_attribute_name: self.conn_id},
**hook_params)
+ @classmethod
+ def _handle_connection_error(cls, e: AirflowRuntimeError, conn_id: str) ->
None:
+ """Handle connection retrieval errors."""
+ if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
+ raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't
defined") from None
+ raise
+
@classmethod
def get(cls, conn_id: str) -> Any:
from airflow.sdk.execution_time.context import _get_connection
@@ -195,9 +202,16 @@ class Connection:
try:
return _get_connection(conn_id)
except AirflowRuntimeError as e:
- if e.error.error == ErrorType.CONNECTION_NOT_FOUND:
- raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't
defined") from None
- raise
+ cls._handle_connection_error(e, conn_id)
+
+ @classmethod
+ async def async_get(cls, conn_id: str) -> Any:
+ from airflow.sdk.execution_time.context import _async_get_connection
+
+ try:
+ return await _async_get_connection(conn_id)
+ except AirflowRuntimeError as e:
+ cls._handle_connection_error(e, conn_id)
@property
def extra_dejson(self) -> dict:
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index f2ee0672299..27e4b0a754e 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -206,6 +206,10 @@ class CommsDecoder(Generic[ReceiveMsgType, SendMsgType]):
return self._get_response()
+ async def asend(self, msg: SendMsgType) -> ReceiveMsgType | None:
+ """Send a request to the parent without blocking."""
+ raise NotImplementedError
+
@overload
def _read_frame(self, maxfds: None = None) -> _ResponseFrame: ...
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index caf586cb018..570cd25d9a3 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
ConnectionResult,
OKResponse,
PrevSuccessfulDagRunResponse,
+ ReceiveMsgType,
VariableResult,
)
from airflow.sdk.types import OutletEventAccessorsProtocol
@@ -101,8 +102,15 @@ log = structlog.get_logger(logger_name="task")
T = TypeVar("T")
-def _convert_connection_result_conn(conn_result: ConnectionResult) ->
Connection:
+def _process_connection_result_conn(conn_result: ReceiveMsgType | None) ->
Connection:
from airflow.sdk.definitions.connection import Connection
+ from airflow.sdk.execution_time.comms import ErrorResponse
+
+ if isinstance(conn_result, ErrorResponse):
+ raise AirflowRuntimeError(conn_result)
+
+ if TYPE_CHECKING:
+ assert isinstance(conn_result, ConnectionResult)
# `by_alias=True` is used to convert the `schema` field to `schema_` in
the Connection model
return Connection(**conn_result.model_dump(exclude={"type"},
by_alias=True))
@@ -121,7 +129,7 @@ def _convert_variable_result_to_variable(var_result:
VariableResult, deserialize
def _get_connection(conn_id: str) -> Connection:
from airflow.sdk.execution_time.supervisor import
ensure_secrets_backend_loaded
- # TODO: check cache first
+ # TODO: check cache first (also in _async_get_connection)
# enabled only if SecretCache.init() has been called first
# iterate over configured backends if not in cache (or expired)
@@ -154,17 +162,24 @@ def _get_connection(conn_id: str) -> Connection:
# A reason to not move it to `airflow.sdk.execution_time.comms` is that
it
# will make that module depend on Task SDK, which is not ideal because
we intend to
# keep Task SDK as a separate package than execution time mods.
- from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
+ # Also applies to _async_get_connection.
+ from airflow.sdk.execution_time.comms import GetConnection
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
msg = SUPERVISOR_COMMS.send(GetConnection(conn_id=conn_id))
- if isinstance(msg, ErrorResponse):
- raise AirflowRuntimeError(msg)
+ return _process_connection_result_conn(msg)
- if TYPE_CHECKING:
- assert isinstance(msg, ConnectionResult)
- return _convert_connection_result_conn(msg)
+
+async def _async_get_connection(conn_id: str) -> Connection:
+ # TODO: add async support for secrets backends
+
+ from airflow.sdk.execution_time.comms import GetConnection
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ msg = await SUPERVISOR_COMMS.asend(GetConnection(conn_id=conn_id))
+
+ return _process_connection_result_conn(msg)
def _get_variable(key: str, deserialize_json: bool) -> Any:
diff --git a/task-sdk/tests/task_sdk/bases/test_hook.py
b/task-sdk/tests/task_sdk/bases/test_hook.py
index 4b15ab5e013..4c691266f82 100644
--- a/task-sdk/tests/task_sdk/bases/test_hook.py
+++ b/task-sdk/tests/task_sdk/bases/test_hook.py
@@ -59,6 +59,28 @@ class TestBaseHook:
msg=GetConnection(conn_id="test_conn"),
)
+ @pytest.mark.asyncio
+ async def test_aget_connection(self, mock_supervisor_comms):
+ """Test async connection retrieval in task sdk context."""
+ conn = ConnectionResult(
+ conn_id="test_conn",
+ conn_type="mysql",
+ host="mysql",
+ schema="airflow",
+ login="login",
+ password="password",
+ port=1234,
+ extra='{"extra_key": "extra_value"}',
+ )
+
+ mock_supervisor_comms.asend.return_value = conn
+
+ hook = BaseHook(logger_name="")
+ await hook.aget_connection(conn_id="test_conn")
+ mock_supervisor_comms.asend.assert_called_once_with(
+ msg=GetConnection(conn_id="test_conn"),
+ )
+
def test_get_connection_not_found(self, sdk_connection_not_found):
conn_id = "test_conn"
hook = BaseHook()
@@ -67,6 +89,16 @@ class TestBaseHook:
with pytest.raises(AirflowNotFoundException, match="The conn_id
`test_conn` isn't defined"):
hook.get_connection(conn_id=conn_id)
+ @pytest.mark.asyncio
+ async def test_aget_connection_not_found(self, sdk_connection_not_found):
+ """Test async connection not found error."""
+ conn_id = "test_conn"
+ hook = BaseHook()
+ sdk_connection_not_found
+
+ with pytest.raises(AirflowNotFoundException, match="The conn_id
`test_conn` isn't defined"):
+ await hook.aget_connection(conn_id=conn_id)
+
def test_get_connection_secrets_backend_configured(self,
mock_supervisor_comms, tmp_path):
path = tmp_path / "conn.env"
path.write_text("CONN_A=mysql://host_a")
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 1def08086cc..54e2c66bee8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -58,8 +58,8 @@ from airflow.sdk.execution_time.context import (
TriggeringAssetEventsAccessor,
VariableAccessor,
_AssetRefResolutionMixin,
- _convert_connection_result_conn,
_convert_variable_result_to_variable,
+ _process_connection_result_conn,
context_to_airflow_vars,
set_current_context,
)
@@ -77,7 +77,7 @@ def test_convert_connection_result_conn():
port=1234,
extra='{"extra_key": "extra_value"}',
)
- conn = _convert_connection_result_conn(conn)
+ conn = _process_connection_result_conn(conn)
assert conn == Connection(
conn_id="test_conn",
conn_type="mysql",