This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch backport-ef00040-v3-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 18b8eaefd0ca9c98f146a71e42f6af4c9bd4bbc5
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Mon Apr 13 08:18:58 2026 +0100

    Reduce per-DAG queries during DAG serialization with bulk prefetch (#64929)
    
    * Reduce per-DAG queries during DAG serialization with bulk prefetch
    
    Replaces 3 SELECTs per DAG in write_dag (update interval check, hash
    comparison, version fetch) with 2 bulk queries via a new
    _prefetch_dag_write_metadata classmethod. Also fixes 
DagCode.update_source_code
    to reuse the caller's session and eagerly loads dag_owner_links to prevent
    N+1 queries.
    
    * fixup! Reduce per-DAG queries during DAG serialization with bulk prefetch
    
    * fixup! fixup! Reduce per-DAG queries during DAG serialization with bulk 
prefetch
    
    (cherry picked from commit ef0004035edb27507c6899b11bd24166ce3a08c0)
---
 .../src/airflow/dag_processing/collection.py       |  21 ++++-
 airflow-core/src/airflow/models/serialized_dag.py  | 102 +++++++++++++++++----
 .../tests/unit/dag_processing/test_collection.py   |   1 +
 .../tests/unit/models/test_serialized_dag.py       |  42 +++++++++
 4 files changed, 146 insertions(+), 20 deletions(-)

diff --git a/airflow-core/src/airflow/dag_processing/collection.py 
b/airflow-core/src/airflow/dag_processing/collection.py
index 96f3c89f862..06e4900d816 100644
--- a/airflow-core/src/airflow/dag_processing/collection.py
+++ b/airflow-core/src/airflow/dag_processing/collection.py
@@ -53,6 +53,7 @@ from airflow.models.dag import DagModel, DagOwnerAttributes, 
DagTag
 from airflow.models.dagrun import DagRun
 from airflow.models.dagwarning import DagWarningType
 from airflow.models.errors import ParseImportError
+from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.trigger import Trigger
 from airflow.serialization.definitions.assets import (
     SerializedAsset,
@@ -75,6 +76,7 @@ if TYPE_CHECKING:
     from sqlalchemy.sql import Select
 
     from airflow.models.dagwarning import DagWarning
+    from airflow.models.serialized_dag import DagWriteMetadata
     from airflow.typing_compat import Self, Unpack
 
     AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias)
@@ -256,7 +258,11 @@ def _update_dag_owner_links(dag_owner_links: dict[str, 
str], dm: DagModel, *, se
 
 
 def _serialize_dag_capturing_errors(
-    dag: LazyDeserializedDAG, bundle_name, session: Session, bundle_version: 
str | None
+    dag: LazyDeserializedDAG,
+    bundle_name,
+    session: Session,
+    bundle_version: str | None,
+    _prefetched: DagWriteMetadata | None = None,
 ):
     """
     Try to serialize the dag to the DB, but make a note of any errors.
@@ -264,7 +270,6 @@ def _serialize_dag_capturing_errors(
     We can't place them directly in import_errors, as this may be retried, and 
work the next time
     """
     from airflow.models.dagcode import DagCode
-    from airflow.models.serialized_dag import SerializedDagModel
 
     # Updating serialized DAG can not be faster than a minimum interval to 
reduce database write rate.
     MIN_SERIALIZED_DAG_UPDATE_INTERVAL = conf.getint(
@@ -279,10 +284,11 @@ def _serialize_dag_capturing_errors(
             bundle_version=bundle_version,
             min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
             session=session,
+            _prefetched=_prefetched,
         )
         if not dag_was_updated:
             # Check and update DagCode
-            DagCode.update_source_code(dag.dag_id, dag.fileloc)
+            DagCode.update_source_code(dag.dag_id, dag.fileloc, 
session=session)
         if "FabAuthManager" in conf.get("core", "auth_manager"):
             _sync_dag_perms(dag, session=session)
 
@@ -473,6 +479,13 @@ def update_dag_parsing_results_in_db(
                 SerializedDAG.bulk_write_to_db(
                     bundle_name, bundle_version, dags, parse_duration, 
session=session
                 )
+                # Bulk prefetch metadata for all DAGs to avoid the standard 
per-DAG
+                # metadata lookups in write_dag. This replaces the 
update-interval,
+                # hash, and version queries with 2 bulk queries total; DAGs 
with
+                # deadlines may still do an additional lookup for deadline 
UUID reuse.
+                prefetched_metadata = 
SerializedDagModel._prefetch_dag_write_metadata(
+                    [dag.dag_id for dag in dags], session=session
+                )
                 # Write Serialized DAGs to DB, capturing errors
                 for dag in dags:
                     serialize_errors.extend(
@@ -481,6 +494,7 @@ def update_dag_parsing_results_in_db(
                             bundle_name=bundle_name,
                             bundle_version=bundle_version,
                             session=session,
+                            _prefetched=prefetched_metadata.get(dag.dag_id),
                         )
                     )
             except OperationalError:
@@ -526,6 +540,7 @@ class DagModelOperation(NamedTuple):
                 .options(joinedload(DagModel.schedule_asset_references))
                 .options(joinedload(DagModel.schedule_asset_alias_references))
                 .options(joinedload(DagModel.task_outlet_asset_references))
+                .options(joinedload(DagModel.dag_owner_links))
             ),
             of=DagModel,
             session=session,
diff --git a/airflow-core/src/airflow/models/serialized_dag.py 
b/airflow-core/src/airflow/models/serialized_dag.py
index 23b93f2dae5..0aaf295a58f 100644
--- a/airflow-core/src/airflow/models/serialized_dag.py
+++ b/airflow-core/src/airflow/models/serialized_dag.py
@@ -23,7 +23,7 @@ import logging
 import zlib
 from collections.abc import Callable, Iterable, Iterator, Sequence
 from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Any, Literal
+from typing import TYPE_CHECKING, Any, Literal, NamedTuple
 from uuid import UUID
 
 import uuid6
@@ -70,6 +70,14 @@ log = logging.getLogger(__name__)
 _COMPRESS_SERIALIZED_DAGS = conf.getboolean("core", 
"compress_serialized_dags", fallback=False)
 
 
+class DagWriteMetadata(NamedTuple):
+    """Pre-fetched metadata for write_dag to avoid per-DAG queries."""
+
+    last_updated: datetime | None
+    dag_hash: str | None
+    dag_version: DagVersion | None
+
+
 class _DagDependenciesResolver:
     """Resolver that resolves dag dependencies to include asset id and assets 
link to asset aliases."""
 
@@ -508,6 +516,70 @@ class SerializedDagModel(Base):
             )
             serialized_dag.deadline_alerts.append(alert)
 
+    @classmethod
+    def _prefetch_dag_write_metadata(
+        cls, dag_ids: Iterable[str], *, session: Session
+    ) -> dict[str, DagWriteMetadata]:
+        """
+        Bulk-fetch metadata needed by write_dag for multiple DAGs in two 
queries.
+
+        Instead of running 3 SELECTs per DAG in write_dag (update interval 
check,
+        hash comparison, version fetch), this fetches all needed data upfront.
+
+        :param dag_ids: DAG IDs to prefetch metadata for
+        :param session: ORM Session
+        :returns: dict mapping dag_id to DagWriteMetadata
+        """
+        dag_id_list = list(set(dag_ids))
+        if not dag_id_list:
+            return {}
+
+        # Fetch latest serialized_dag (last_updated, dag_hash) per dag_id
+        # using a window function to pick the most recent row.
+        sd_subq = (
+            select(
+                cls.dag_id.label("dag_id"),
+                cls.last_updated.label("last_updated"),
+                cls.dag_hash.label("dag_hash"),
+                func.row_number().over(partition_by=cls.dag_id, 
order_by=cls.created_at.desc()).label("rn"),
+            )
+            .where(cls.dag_id.in_(dag_id_list))
+            .subquery()
+        )
+        sd_rows = session.execute(
+            select(sd_subq.c.dag_id, sd_subq.c.last_updated, 
sd_subq.c.dag_hash).where(sd_subq.c.rn == 1)
+        ).all()
+        sd_by_dag_id: dict[str, tuple[datetime, str]] = {
+            row.dag_id: (row.last_updated, row.dag_hash) for row in sd_rows
+        }
+
+        # Fetch latest DagVersion per dag_id using a window function,
+        # matching the original write_dag ordering (ORDER BY created_at DESC).
+        dv_subq = (
+            select(
+                DagVersion.id.label("id"),
+                DagVersion.dag_id.label("dag_id"),
+                func.row_number()
+                .over(partition_by=DagVersion.dag_id, 
order_by=DagVersion.created_at.desc())
+                .label("rn"),
+            )
+            .where(DagVersion.dag_id.in_(dag_id_list))
+            .subquery()
+        )
+        dag_versions = session.scalars(
+            select(DagVersion).join(dv_subq, DagVersion.id == 
dv_subq.c.id).where(dv_subq.c.rn == 1)
+        ).all()
+        dv_by_dag_id: dict[str, DagVersion] = {dv.dag_id: dv for dv in 
dag_versions}
+
+        return {
+            dag_id: DagWriteMetadata(
+                last_updated=sd_by_dag_id[dag_id][0] if dag_id in sd_by_dag_id 
else None,
+                dag_hash=sd_by_dag_id[dag_id][1] if dag_id in sd_by_dag_id 
else None,
+                dag_version=dv_by_dag_id.get(dag_id),
+            )
+            for dag_id in dag_id_list
+        }
+
     @classmethod
     @provide_session
     def write_dag(
@@ -517,6 +589,7 @@ class SerializedDagModel(Base):
         bundle_version: str | None = None,
         min_update_interval: int | None = None,
         session: Session = NEW_SESSION,
+        _prefetched: DagWriteMetadata | None = None,
     ) -> bool:
         """
         Serialize a DAG and writes it into database.
@@ -529,33 +602,28 @@ class SerializedDagModel(Base):
         :param bundle_version: bundle version of the DAG
         :param min_update_interval: minimal interval in seconds to update 
serialized DAG
         :param session: ORM Session
+        :param _prefetched: Pre-fetched metadata to skip per-DAG queries; used 
by bulk callers
 
         :returns: Boolean indicating if the DAG was written to the DB
         """
+        if _prefetched is None:
+            _prefetched = cls._prefetch_dag_write_metadata([dag.dag_id], 
session=session).get(
+                dag.dag_id, DagWriteMetadata(last_updated=None, dag_hash=None, 
dag_version=None)
+            )
+
         # Checks if (Current Time - Time when the DAG was written to DB) < 
min_update_interval
         # If Yes, does nothing
         # If No or the DAG does not exists, updates / writes Serialized DAG to 
DB
         if min_update_interval is not None:
-            if session.scalar(
-                select(literal(True))
-                .where(
-                    cls.dag_id == dag.dag_id,
-                    (timezone.utcnow() - 
timedelta(seconds=min_update_interval)) < cls.last_updated,
-                )
-                .select_from(cls)
+            if (
+                _prefetched.last_updated is not None
+                and (timezone.utcnow() - 
timedelta(seconds=min_update_interval)) < _prefetched.last_updated
             ):
                 return False
 
         log.debug("Checking if DAG (%s) changed", dag.dag_id)
-        serialized_dag_hash = session.scalars(
-            select(cls.dag_hash).where(cls.dag_id == 
dag.dag_id).order_by(cls.created_at.desc())
-        ).first()
-        dag_version = session.scalar(
-            select(DagVersion)
-            .where(DagVersion.dag_id == dag.dag_id)
-            .order_by(DagVersion.created_at.desc())
-            .limit(1)
-        )
+        serialized_dag_hash = _prefetched.dag_hash
+        dag_version = _prefetched.dag_version
 
         if dag.data.get("dag", {}).get("deadline"):
             # Try to reuse existing deadline UUIDs if the deadline definitions 
haven't changed.
diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py 
b/airflow-core/tests/unit/dag_processing/test_collection.py
index 6a0aef00eaa..1792c9b38c0 100644
--- a/airflow-core/tests/unit/dag_processing/test_collection.py
+++ b/airflow-core/tests/unit/dag_processing/test_collection.py
@@ -537,6 +537,7 @@ class TestUpdateDagParsingResults:
                     bundle_version=None,
                     min_update_interval=mock.ANY,
                     session=mock_session,
+                    _prefetched=mock.ANY,
                 ),
             ]
         )
diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py 
b/airflow-core/tests/unit/models/test_serialized_dag.py
index de0464dea8c..2185635590e 100644
--- a/airflow-core/tests/unit/models/test_serialized_dag.py
+++ b/airflow-core/tests/unit/models/test_serialized_dag.py
@@ -524,6 +524,48 @@ class TestSerializedDagModel:
         )
         assert did_write is should_write
 
+    def test_prefetch_dag_write_metadata_multiple_dags(self, dag_maker, 
session):
+        """Test that _prefetch_dag_write_metadata returns correct metadata for 
multiple DAGs."""
+        with dag_maker("prefetch_multi_dag1"):
+            EmptyOperator(task_id="task1")
+        with dag_maker("prefetch_multi_dag2"):
+            EmptyOperator(task_id="task1")
+
+        result = SDM._prefetch_dag_write_metadata(
+            ["prefetch_multi_dag1", "prefetch_multi_dag2"], session=session
+        )
+
+        assert len(result) == 2
+        for dag_id in ("prefetch_multi_dag1", "prefetch_multi_dag2"):
+            metadata = result[dag_id]
+            assert metadata.last_updated is not None
+            assert metadata.dag_hash is not None
+            assert metadata.dag_version is not None
+            assert metadata.dag_version.dag_id == dag_id
+
+    def test_prefetch_dag_write_metadata_returns_latest_version(self, 
dag_maker, session):
+        """Test that _prefetch_dag_write_metadata returns the latest 
DagVersion."""
+        with dag_maker("prefetch_version_dag") as dag:
+            PythonOperator(task_id="task1", python_callable=lambda: None)
+        # Create a dagrun so that writing a changed DAG creates a new version
+        dag_maker.create_dagrun(run_id="run1", 
logical_date=pendulum.datetime(2025, 1, 1))
+
+        # Modify the DAG (add a task) and write again to create version 2
+        PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag)
+        SDM.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name="dag_maker")
+
+        assert (
+            session.scalar(
+                
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == 
dag.dag_id)
+            )
+            == 2
+        )
+
+        result = SDM._prefetch_dag_write_metadata([dag.dag_id], 
session=session)
+        metadata = result[dag.dag_id]
+        assert metadata.dag_version is not None
+        assert metadata.dag_version.version_number == 2
+
     def 
test_new_dag_version_created_when_bundle_name_changes_and_hash_unchanged(self, 
dag_maker, session):
         """Test that new dag_version is created if bundle_name changes but DAG 
is unchanged."""
         # Create and write initial DAG

Reply via email to