This is an automated email from the ASF dual-hosted git repository.
weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 85ccf1accf2 Add overridable metadata engine creation hooks in
`settings.py` (#62184)
85ccf1accf2 is described below
commit 85ccf1accf2aca3f09020bafc2557958de698b53
Author: Arthur <[email protected]>
AuthorDate: Wed Apr 1 17:25:29 2026 +0200
Add overridable metadata engine creation hooks in `settings.py` (#62184)
* Add overridable metadata engine creation hooks in settings.py
* retrigger CI
* Add tests and docs for metadata engine creation hooks
Address review feedback: add unit tests proving configure_orm() and
_configure_async_session() delegate to the overridable hooks, verify
default implementations forward args correctly, and confirm overrides
via airflow_local_settings.py take effect. Document the new hooks in
cluster-policies.rst with function signatures and a JWT example.
* Fix ruff formatting and restore settings attrs in teardown
Ruff format required single-line function signatures. More importantly,
teardown_method now restores the original create_metadata_engine and
create_async_metadata_engine on the settings module to prevent the
override from test_override_via_local_settings from leaking into
subsequent test classes (e.g. test_sqlalchemy_config.py).
* Update PR with Lee's comments
---
.../cluster-policies.rst | 44 +++++++
airflow-core/src/airflow/settings.py | 46 ++++++-
airflow-core/tests/unit/core/test_settings.py | 142 +++++++++++++++++++++
3 files changed, 227 insertions(+), 5 deletions(-)
diff --git
a/airflow-core/docs/administration-and-deployment/cluster-policies.rst
b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
index 2d410b5a507..fd7d0b622f5 100644
--- a/airflow-core/docs/administration-and-deployment/cluster-policies.rst
+++ b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
@@ -185,3 +185,47 @@ Here's an example of re-routing tasks that are on their
second (or greater) retr
:end-before: [END example_task_mutation_hook]
Note that since priority weight is determined dynamically using weight rules,
you cannot alter the ``priority_weight`` of a task instance within the mutation
hook.
+
+
+Metadata Engine Hooks
+---------------------
+
+In addition to cluster policies, ``airflow_local_settings.py`` can override
how Airflow creates its metadata
+database engines. This is useful when you need per-connection logic that
cannot be expressed through static
+configuration — for example, injecting short-lived JWT tokens or IAM
credentials via a SQLAlchemy
+``do_connect`` event handler.
+
+Two functions can be overridden:
+
+* ``create_metadata_engine(sql_alchemy_conn, *, engine_args, connect_args) ->
Engine`` — called by
+ ``configure_orm()`` to create the synchronous metadata engine.
+* ``create_async_metadata_engine(sql_alchemy_conn_async, *, connect_args) ->
AsyncEngine`` — called by
+ ``_configure_async_session()`` to create the asynchronous metadata engine.
+
+The default implementations call ``sqlalchemy.create_engine`` /
``sqlalchemy.ext.asyncio.create_async_engine``
+with the same arguments Airflow has always used, so there is **no behavioral
change** unless you provide an
+override.
+
+Example: registering a ``do_connect`` handler that refreshes a JWT token
before every new physical connection:
+
+.. code-block:: python
+
+ # airflow_local_settings.py
+ from sqlalchemy import create_engine, event
+
+
+ def _refresh_jwt(dbapi_connection, connection_record):
+ """Called before every physical connection (including after pool
recycle)."""
+ token = my_token_provider.get_token()
+ dbapi_connection.execute(f"SET SESSION AUTHORIZATION '{token}'")
+
+
+ def create_metadata_engine(sql_alchemy_conn, *, engine_args, connect_args):
+ engine = create_engine(
+ sql_alchemy_conn,
+ connect_args=connect_args,
+ **engine_args,
+ future=True,
+ )
+ event.listen(engine, "do_connect", _refresh_jwt)
+ return engine
diff --git a/airflow-core/src/airflow/settings.py
b/airflow-core/src/airflow/settings.py
index 223ec291b4d..a00339603be 100644
--- a/airflow-core/src/airflow/settings.py
+++ b/airflow-core/src/airflow/settings.py
@@ -349,6 +349,44 @@ def _get_connect_args(mode: Literal["sync", "async"]) ->
Any:
return {}
+def create_metadata_engine(
+ sql_alchemy_conn: str,
+ *,
+ engine_args: dict[str, Any],
+ connect_args: dict[str, Any],
+) -> Engine:
+ """
+ Create the SQLAlchemy Engine for the Airflow metadata database.
+
+ Override in ``airflow_local_settings.py`` to customize engine creation,
+ e.g. to register ``do_connect`` event handlers for token-based
authentication.
+ """
+ return create_engine(
+ sql_alchemy_conn,
+ connect_args=connect_args,
+ **engine_args,
+ future=True,
+ )
+
+
+def create_async_metadata_engine(
+ sql_alchemy_conn_async: str,
+ *,
+ connect_args: dict[str, Any],
+) -> AsyncEngine:
+ """
+ Create the async SQLAlchemy Engine for the Airflow metadata database.
+
+ Override in ``airflow_local_settings.py`` to customize async engine
creation.
+ For ``do_connect`` handlers, register on ``engine.sync_engine``.
+ """
+ return create_async_engine(
+ sql_alchemy_conn_async,
+ connect_args=connect_args,
+ future=True,
+ )
+
+
def _configure_async_session() -> None:
"""
Configure async SQLAlchemy session.
@@ -364,10 +402,9 @@ def _configure_async_session() -> None:
AsyncSession = None
return
- async_engine = create_async_engine(
+ async_engine = create_async_metadata_engine(
SQL_ALCHEMY_CONN_ASYNC,
connect_args=_get_connect_args("async"),
- future=True,
)
AsyncSession = async_sessionmaker(
bind=async_engine,
@@ -408,11 +445,10 @@ def configure_orm(disable_connection_pool=False,
pool_class=None):
# to so the `test` thread and the tested endpoints can use common
objects.
connect_args["check_same_thread"] = False
- engine = create_engine(
+ engine = create_metadata_engine(
SQL_ALCHEMY_CONN,
+ engine_args=engine_args,
connect_args=connect_args,
- **engine_args,
- future=True,
)
_configure_async_session()
mask_secret(engine.url.password)
diff --git a/airflow-core/tests/unit/core/test_settings.py
b/airflow-core/tests/unit/core/test_settings.py
index 24a5485402d..a865094a574 100644
--- a/airflow-core/tests/unit/core/test_settings.py
+++ b/airflow-core/tests/unit/core/test_settings.py
@@ -28,6 +28,21 @@ import pytest
from airflow.exceptions import AirflowClusterPolicyViolation,
AirflowConfigException
+SETTINGS_FILE_CUSTOM_ENGINE = """
+from sqlalchemy import create_engine, event
+
+_engine_created = False
+
+def create_metadata_engine(sql_alchemy_conn, *, engine_args, connect_args):
+ global _engine_created
+ _engine_created = True
+ engine = create_engine(
+ sql_alchemy_conn, connect_args=connect_args, **engine_args,
future=True,
+ )
+ event.listen(engine, "do_connect", lambda *a, **kw: None)
+ return engine
+"""
+
SETTINGS_FILE_POLICY = """
def test_policy(task_instance):
task_instance.run_as_user = "myself"
@@ -196,6 +211,133 @@ class TestLocalSettings:
settings.task_must_have_owners(task_instance)
+class TestMetadataEngineHooks:
+ """Tests for the overridable create_metadata_engine /
create_async_metadata_engine hooks."""
+
+ def setup_method(self):
+ self.old_modules = dict(sys.modules)
+ from airflow import settings
+
+ self._orig_create_metadata_engine = settings.create_metadata_engine
+ self._orig_create_async_metadata_engine =
settings.create_async_metadata_engine
+
+ def teardown_method(self):
+ from airflow import settings
+
+ settings.create_metadata_engine = self._orig_create_metadata_engine
+ settings.create_async_metadata_engine =
self._orig_create_async_metadata_engine
+ for mod in [m for m in sys.modules if m not in self.old_modules]:
+ del sys.modules[mod]
+
+ @patch("airflow.settings.create_metadata_engine")
+ @patch("airflow.settings._configure_async_session")
+ def test_configure_orm_delegates_to_create_metadata_engine(self,
mock_async_session, mock_create_engine):
+ """configure_orm() must call create_metadata_engine, not create_engine
directly."""
+ from airflow import settings
+
+ mock_create_engine.return_value = MagicMock()
+
+ with (
+ patch("os.environ", {"_AIRFLOW_SKIP_DB_TESTS": "false"}),
+ patch("airflow.settings.SQL_ALCHEMY_CONN", "sqlite://"),
+ patch("airflow.settings.Session"),
+ patch("airflow.settings.engine"),
+ patch("airflow.settings.setup_event_handlers"),
+ patch("airflow.settings.mask_secret", create=True),
+ patch("airflow._shared.secrets_masker.mask_secret"),
+ ):
+ settings.configure_orm()
+
+ assert len(mock_create_engine.mock_calls) == 1
+ assert mock_async_session.mock_calls == [call()]
+ call_kwargs = mock_create_engine.call_args
+ assert call_kwargs[0][0] == "sqlite://"
+ assert "engine_args" in call_kwargs[1]
+ assert "connect_args" in call_kwargs[1]
+
+ @patch("airflow.settings.create_async_metadata_engine")
+ def test_configure_async_session_delegates_to_create_async_metadata_engine(
+ self, mock_create_async_engine
+ ):
+ """_configure_async_session() must call
create_async_metadata_engine."""
+ from airflow import settings
+
+ mock_create_async_engine.return_value = MagicMock()
+
+ with patch("airflow.settings.SQL_ALCHEMY_CONN_ASYNC",
"sqlite+aiosqlite://"):
+ settings._configure_async_session()
+
+ mock_create_async_engine.assert_called_once()
+ call_kwargs = mock_create_async_engine.call_args
+ assert call_kwargs[0][0] == "sqlite+aiosqlite://"
+ assert "connect_args" in call_kwargs[1]
+
+ @patch("airflow.settings.create_async_metadata_engine")
+ def test_configure_async_session_skips_when_no_async_conn(self,
mock_create_async_engine):
+ """_configure_async_session() must not call the hook when
SQL_ALCHEMY_CONN_ASYNC is empty."""
+ from airflow import settings
+
+ with patch("airflow.settings.SQL_ALCHEMY_CONN_ASYNC", ""):
+ settings._configure_async_session()
+
+ assert mock_create_async_engine.mock_calls == []
+
+ @patch("airflow.settings.create_engine")
+ def test_default_create_metadata_engine_forwards_args(self,
mock_sa_create_engine):
+ """Default create_metadata_engine must forward all args to
sqlalchemy.create_engine."""
+ from airflow import settings
+
+ mock_sa_create_engine.return_value = MagicMock()
+ engine_args = {"pool_size": 5, "pool_recycle": 1800}
+ connect_args = {"timeout": 30}
+
+ settings.create_metadata_engine("sqlite://", engine_args=engine_args,
connect_args=connect_args)
+
+ assert mock_sa_create_engine.mock_calls == [
+ call(
+ "sqlite://",
+ connect_args={"timeout": 30},
+ pool_size=5,
+ pool_recycle=1800,
+ future=True,
+ )
+ ]
+
+ @patch("airflow.settings.create_async_engine")
+ def test_default_create_async_metadata_engine_forwards_args(self,
mock_sa_create_async):
+ """Default create_async_metadata_engine must forward args to
sqlalchemy.create_async_engine."""
+ from airflow import settings
+
+ mock_sa_create_async.return_value = MagicMock()
+ connect_args = {"timeout": 30}
+
+ settings.create_async_metadata_engine("sqlite+aiosqlite://",
connect_args=connect_args)
+
+ mock_sa_create_async.assert_called_once_with(
+ "sqlite+aiosqlite://",
+ connect_args={"timeout": 30},
+ future=True,
+ )
+
+ def test_override_via_local_settings(self):
+ """An override in airflow_local_settings.py replaces the default
create_metadata_engine."""
+ with SettingsContext(SETTINGS_FILE_CUSTOM_ENGINE,
"airflow_local_settings"):
+ from airflow import settings
+
+ settings.import_local_settings()
+
+ import airflow_local_settings
+
+ # Verify the override is wired in
+ assert settings.create_metadata_engine is
airflow_local_settings.create_metadata_engine
+ assert not airflow_local_settings._engine_created
+
+ # Actually call the override and verify it runs the custom code
+ engine = settings.create_metadata_engine("sqlite://",
engine_args={}, connect_args={})
+ assert airflow_local_settings._engine_created
+ assert engine is not None
+
+
_local_db_path_error = pytest.raises(AirflowConfigException, match=r"Cannot
use relative path:")