This is an automated email from the ASF dual-hosted git repository.

weilee pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 85ccf1accf2 Add overridable metadata engine creation hooks in 
`settings.py` (#62184)
85ccf1accf2 is described below

commit 85ccf1accf2aca3f09020bafc2557958de698b53
Author: Arthur <[email protected]>
AuthorDate: Wed Apr 1 17:25:29 2026 +0200

    Add overridable metadata engine creation hooks in `settings.py` (#62184)
    
    * Add overridable metadata engine creation hooks in settings.py
    
    * retrigger CI
    
    * Add tests and docs for metadata engine creation hooks
    
    Address review feedback: add unit tests proving configure_orm() and
    _configure_async_session() delegate to the overridable hooks, verify
    default implementations forward args correctly, and confirm overrides
    via airflow_local_settings.py take effect. Document the new hooks in
    cluster-policies.rst with function signatures and a JWT example.
    
    * Fix ruff formatting and restore settings attrs in teardown
    
    Ruff format required single-line function signatures. More importantly,
    teardown_method now restores the original create_metadata_engine and
    create_async_metadata_engine on the settings module to prevent the
    override from test_override_via_local_settings from leaking into
    subsequent test classes (e.g. test_sqlalchemy_config.py).
    
    * Update PR with Lee's comments
---
 .../cluster-policies.rst                           |  44 +++++++
 airflow-core/src/airflow/settings.py               |  46 ++++++-
 airflow-core/tests/unit/core/test_settings.py      | 142 +++++++++++++++++++++
 3 files changed, 227 insertions(+), 5 deletions(-)

diff --git 
a/airflow-core/docs/administration-and-deployment/cluster-policies.rst 
b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
index 2d410b5a507..fd7d0b622f5 100644
--- a/airflow-core/docs/administration-and-deployment/cluster-policies.rst
+++ b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
@@ -185,3 +185,47 @@ Here's an example of re-routing tasks that are on their 
second (or greater) retr
         :end-before: [END example_task_mutation_hook]
 
 Note that since priority weight is determined dynamically using weight rules, 
you cannot alter the ``priority_weight`` of a task instance within the mutation 
hook.
+
+
+Metadata Engine Hooks
+---------------------
+
+In addition to cluster policies, ``airflow_local_settings.py`` can override 
how Airflow creates its metadata
+database engines. This is useful when you need per-connection logic that 
cannot be expressed through static
+configuration — for example, injecting short-lived JWT tokens or IAM 
credentials via a SQLAlchemy
+``do_connect`` event handler.
+
+Two functions can be overridden:
+
+* ``create_metadata_engine(sql_alchemy_conn, *, engine_args, connect_args) -> 
Engine`` — called by
+  ``configure_orm()`` to create the synchronous metadata engine.
+* ``create_async_metadata_engine(sql_alchemy_conn_async, *, connect_args) -> 
AsyncEngine`` — called by
+  ``_configure_async_session()`` to create the asynchronous metadata engine.
+
+The default implementations call ``sqlalchemy.create_engine`` / 
``sqlalchemy.ext.asyncio.create_async_engine``
+with the same arguments Airflow has always used, so there is **no behavioral 
change** unless you provide an
+override.
+
+Example: registering a ``do_connect`` handler that refreshes a JWT token 
before every new physical connection:
+
+.. code-block:: python
+
+    # airflow_local_settings.py
+    from sqlalchemy import create_engine, event
+
+
+    def _refresh_jwt(dbapi_connection, connection_record):
+        """Called before every physical connection (including after pool 
recycle)."""
+        token = my_token_provider.get_token()
+        dbapi_connection.execute(f"SET SESSION AUTHORIZATION '{token}'")
+
+
+    def create_metadata_engine(sql_alchemy_conn, *, engine_args, connect_args):
+        engine = create_engine(
+            sql_alchemy_conn,
+            connect_args=connect_args,
+            **engine_args,
+            future=True,
+        )
+        event.listen(engine, "do_connect", _refresh_jwt)
+        return engine
diff --git a/airflow-core/src/airflow/settings.py 
b/airflow-core/src/airflow/settings.py
index 223ec291b4d..a00339603be 100644
--- a/airflow-core/src/airflow/settings.py
+++ b/airflow-core/src/airflow/settings.py
@@ -349,6 +349,44 @@ def _get_connect_args(mode: Literal["sync", "async"]) -> 
Any:
     return {}
 
 
+def create_metadata_engine(
+    sql_alchemy_conn: str,
+    *,
+    engine_args: dict[str, Any],
+    connect_args: dict[str, Any],
+) -> Engine:
+    """
+    Create the SQLAlchemy Engine for the Airflow metadata database.
+
+    Override in ``airflow_local_settings.py`` to customize engine creation,
+    e.g. to register ``do_connect`` event handlers for token-based 
authentication.
+    """
+    return create_engine(
+        sql_alchemy_conn,
+        connect_args=connect_args,
+        **engine_args,
+        future=True,
+    )
+
+
+def create_async_metadata_engine(
+    sql_alchemy_conn_async: str,
+    *,
+    connect_args: dict[str, Any],
+) -> AsyncEngine:
+    """
+    Create the async SQLAlchemy Engine for the Airflow metadata database.
+
+    Override in ``airflow_local_settings.py`` to customize async engine 
creation.
+    For ``do_connect`` handlers, register on ``engine.sync_engine``.
+    """
+    return create_async_engine(
+        sql_alchemy_conn_async,
+        connect_args=connect_args,
+        future=True,
+    )
+
+
 def _configure_async_session() -> None:
     """
     Configure async SQLAlchemy session.
@@ -364,10 +402,9 @@ def _configure_async_session() -> None:
         AsyncSession = None
         return
 
-    async_engine = create_async_engine(
+    async_engine = create_async_metadata_engine(
         SQL_ALCHEMY_CONN_ASYNC,
         connect_args=_get_connect_args("async"),
-        future=True,
     )
     AsyncSession = async_sessionmaker(
         bind=async_engine,
@@ -408,11 +445,10 @@ def configure_orm(disable_connection_pool=False, 
pool_class=None):
         # to so the `test` thread and the tested endpoints can use common 
objects.
         connect_args["check_same_thread"] = False
 
-    engine = create_engine(
+    engine = create_metadata_engine(
         SQL_ALCHEMY_CONN,
+        engine_args=engine_args,
         connect_args=connect_args,
-        **engine_args,
-        future=True,
     )
     _configure_async_session()
     mask_secret(engine.url.password)
diff --git a/airflow-core/tests/unit/core/test_settings.py 
b/airflow-core/tests/unit/core/test_settings.py
index 24a5485402d..a865094a574 100644
--- a/airflow-core/tests/unit/core/test_settings.py
+++ b/airflow-core/tests/unit/core/test_settings.py
@@ -28,6 +28,21 @@ import pytest
 
 from airflow.exceptions import AirflowClusterPolicyViolation, 
AirflowConfigException
 
+SETTINGS_FILE_CUSTOM_ENGINE = """
+from sqlalchemy import create_engine, event
+
+_engine_created = False
+
+def create_metadata_engine(sql_alchemy_conn, *, engine_args, connect_args):
+    global _engine_created
+    _engine_created = True
+    engine = create_engine(
+        sql_alchemy_conn, connect_args=connect_args, **engine_args, 
future=True,
+    )
+    event.listen(engine, "do_connect", lambda *a, **kw: None)
+    return engine
+"""
+
 SETTINGS_FILE_POLICY = """
 def test_policy(task_instance):
     task_instance.run_as_user = "myself"
@@ -196,6 +211,133 @@ class TestLocalSettings:
                 settings.task_must_have_owners(task_instance)
 
 
+class TestMetadataEngineHooks:
+    """Tests for the overridable create_metadata_engine / 
create_async_metadata_engine hooks."""
+
+    def setup_method(self):
+        self.old_modules = dict(sys.modules)
+        from airflow import settings
+
+        self._orig_create_metadata_engine = settings.create_metadata_engine
+        self._orig_create_async_metadata_engine = 
settings.create_async_metadata_engine
+
+    def teardown_method(self):
+        from airflow import settings
+
+        settings.create_metadata_engine = self._orig_create_metadata_engine
+        settings.create_async_metadata_engine = 
self._orig_create_async_metadata_engine
+        for mod in [m for m in sys.modules if m not in self.old_modules]:
+            del sys.modules[mod]
+
+    @patch("airflow.settings.create_metadata_engine")
+    @patch("airflow.settings._configure_async_session")
+    def test_configure_orm_delegates_to_create_metadata_engine(self, 
mock_async_session, mock_create_engine):
+        """configure_orm() must call create_metadata_engine, not create_engine 
directly."""
+        from airflow import settings
+
+        mock_create_engine.return_value = MagicMock()
+
+        with (
+            patch("os.environ", {"_AIRFLOW_SKIP_DB_TESTS": "false"}),
+            patch("airflow.settings.SQL_ALCHEMY_CONN", "sqlite://"),
+            patch("airflow.settings.Session"),
+            patch("airflow.settings.engine"),
+            patch("airflow.settings.setup_event_handlers"),
+            patch("airflow.settings.mask_secret", create=True),
+            patch("airflow._shared.secrets_masker.mask_secret"),
+        ):
+            settings.configure_orm()
+
+        assert len(mock_create_engine.mock_calls) == 1
+        assert mock_async_session.mock_calls == [call()]
+        call_kwargs = mock_create_engine.call_args
+        assert call_kwargs[0][0] == "sqlite://"
+        assert "engine_args" in call_kwargs[1]
+        assert "connect_args" in call_kwargs[1]
+
+    @patch("airflow.settings.create_async_metadata_engine")
+    def test_configure_async_session_delegates_to_create_async_metadata_engine(
+        self, mock_create_async_engine
+    ):
+        """_configure_async_session() must call 
create_async_metadata_engine."""
+        from airflow import settings
+
+        mock_create_async_engine.return_value = MagicMock()
+
+        with patch("airflow.settings.SQL_ALCHEMY_CONN_ASYNC", 
"sqlite+aiosqlite://"):
+            settings._configure_async_session()
+
+        mock_create_async_engine.assert_called_once()
+        call_kwargs = mock_create_async_engine.call_args
+        assert call_kwargs[0][0] == "sqlite+aiosqlite://"
+        assert "connect_args" in call_kwargs[1]
+
+    @patch("airflow.settings.create_async_metadata_engine")
+    def test_configure_async_session_skips_when_no_async_conn(self, 
mock_create_async_engine):
+        """_configure_async_session() must not call the hook when 
SQL_ALCHEMY_CONN_ASYNC is empty."""
+        from airflow import settings
+
+        with patch("airflow.settings.SQL_ALCHEMY_CONN_ASYNC", ""):
+            settings._configure_async_session()
+
+        assert mock_create_async_engine.mock_calls == []
+
+    @patch("airflow.settings.create_engine")
+    def test_default_create_metadata_engine_forwards_args(self, 
mock_sa_create_engine):
+        """Default create_metadata_engine must forward all args to 
sqlalchemy.create_engine."""
+        from airflow import settings
+
+        mock_sa_create_engine.return_value = MagicMock()
+        engine_args = {"pool_size": 5, "pool_recycle": 1800}
+        connect_args = {"timeout": 30}
+
+        settings.create_metadata_engine("sqlite://", engine_args=engine_args, 
connect_args=connect_args)
+
+        assert mock_sa_create_engine.mock_calls == [
+            call(
+                "sqlite://",
+                connect_args={"timeout": 30},
+                pool_size=5,
+                pool_recycle=1800,
+                future=True,
+            )
+        ]
+
+    @patch("airflow.settings.create_async_engine")
+    def test_default_create_async_metadata_engine_forwards_args(self, 
mock_sa_create_async):
+        """Default create_async_metadata_engine must forward args to 
sqlalchemy.create_async_engine."""
+        from airflow import settings
+
+        mock_sa_create_async.return_value = MagicMock()
+        connect_args = {"timeout": 30}
+
+        settings.create_async_metadata_engine("sqlite+aiosqlite://", 
connect_args=connect_args)
+
+        mock_sa_create_async.assert_called_once_with(
+            "sqlite+aiosqlite://",
+            connect_args={"timeout": 30},
+            future=True,
+        )
+
+    def test_override_via_local_settings(self):
+        """An override in airflow_local_settings.py replaces the default 
create_metadata_engine."""
+        with SettingsContext(SETTINGS_FILE_CUSTOM_ENGINE, 
"airflow_local_settings"):
+            from airflow import settings
+
+            settings.import_local_settings()
+
+            import airflow_local_settings
+
+            # Verify the override is wired in
+            assert settings.create_metadata_engine is 
airflow_local_settings.create_metadata_engine
+            assert not airflow_local_settings._engine_created
+
+            # Actually call the override and verify it runs the custom code
+            engine = settings.create_metadata_engine("sqlite://", 
engine_args={}, connect_args={})
+            assert airflow_local_settings._engine_created
+            assert engine is not None
+
+
 _local_db_path_error = pytest.raises(AirflowConfigException, match=r"Cannot 
use relative path:")
 
 

Reply via email to