Copilot commented on code in PR #62176:
URL: https://github.com/apache/airflow/pull/62176#discussion_r3066481404
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -884,9 +904,17 @@ def get_task_instances(
Keep this method because it is widely used across the code.
"""
task_ids = DagRun._get_partial_task_ids(self.dag)
- return DagRun.fetch_task_instances(
- dag_id=self.dag_id, run_id=self.run_id, task_ids=task_ids,
state=state, session=session
+ tis = DagRun._fetch_task_instances(
+ dag_id=self.dag_id,
+ run_id=self.run_id,
+ task_ids=task_ids,
+ state=state,
+ session=session,
+ load_dag_run=False,
)
+ for ti in tis:
+ ti.dag_run = self
+ return tis
Review Comment:
Setting `ti.dag_run = self` mutates ORM relationship state and can cause
cross-session attachment errors when `self` and the returned `TaskInstance`
objects belong to different sessions; it can also mark instances dirty/trigger
autoflush in what is intended to be a read-only accessor. To backfill the
relationship without affecting session state, use SQLAlchemy’s committed-value
helpers (there’s existing precedent in `TaskInstance` for setting `dag_run`
this way).
##########
airflow-core/tests/unit/models/test_dagrun.py:
##########
@@ -847,6 +847,37 @@ def test_get_task_instance_on_empty_dagrun(self,
dag_maker, session):
ti = dag_run.get_task_instance("test_short_circuit_false")
assert ti is None
+ def test_get_task_instances_optimization(self, dag_maker, session):
+ """
+ Verify that get_task_instances (load_dag_run=False) avoids joinedload
on DagRun,
+ while fetch_task_instances (load_dag_run=True) uses it.
+ """
+ with dag_maker(dag_id="test_get_task_instances_optimization",
session=session):
+ EmptyOperator(task_id="t1")
+
+ dr = dag_maker.create_dagrun()
+
+
+
Review Comment:
There are two blank/whitespace-only lines here that should be removed to
avoid lint/style issues.
```suggestion
```
##########
airflow-core/tests/unit/models/test_dagrun.py:
##########
@@ -847,6 +847,37 @@ def test_get_task_instance_on_empty_dagrun(self,
dag_maker, session):
ti = dag_run.get_task_instance("test_short_circuit_false")
assert ti is None
+ def test_get_task_instances_optimization(self, dag_maker, session):
+ """
+ Verify that get_task_instances (load_dag_run=False) avoids joinedload
on DagRun,
+ while fetch_task_instances (load_dag_run=True) uses it.
+ """
+ with dag_maker(dag_id="test_get_task_instances_optimization",
session=session):
+ EmptyOperator(task_id="t1")
+
+ dr = dag_maker.create_dagrun()
+
+
+
+ # We verifying that it doesn't trigger extra queries when accessing
dag_run because it is manually set.
Review Comment:
Comment grammar: "We verifying" should be "We verify" (and consider wrapping
the line to match surrounding style).
```suggestion
# We verify that accessing dag_run doesn't trigger extra queries
because it is manually set.
```
##########
airflow-core/tests/unit/models/test_dagrun.py:
##########
@@ -847,6 +847,37 @@ def test_get_task_instance_on_empty_dagrun(self,
dag_maker, session):
ti = dag_run.get_task_instance("test_short_circuit_false")
assert ti is None
+ def test_get_task_instances_optimization(self, dag_maker, session):
+ """
+ Verify that get_task_instances (load_dag_run=False) avoids joinedload
on DagRun,
+ while fetch_task_instances (load_dag_run=True) uses it.
+ """
+ with dag_maker(dag_id="test_get_task_instances_optimization",
session=session):
+ EmptyOperator(task_id="t1")
+
+ dr = dag_maker.create_dagrun()
+
+
+
+ # We verifying that it doesn't trigger extra queries when accessing
dag_run because it is manually set.
+ with assert_queries_count(1):
+ tis = dr.get_task_instances(session=session)
+
+ assert len(tis) == 1
+ # ti.dag_run should be set (manually) to the exact same object
+ assert tis[0].dag_run is dr
+
+ # 2. fetch_task_instances (default: load_dag_run=True)
+ # Should issue 1 query to fetch TIs (with joinedload).
+ # We verify that accessing dag_run doesn't trigger extra queries.
+ session.expire_all()
+ with assert_queries_count(1):
+ tis_loaded = DagRun.fetch_task_instances(dag_id=dr.dag_id,
run_id=dr.run_id, session=session)
+
Review Comment:
For consistency and to avoid counting unrelated queries from other
sessions/fixtures, consider passing the `session=` argument to
`assert_queries_count` here as is done in the other new tests.
##########
airflow-core/tests/unit/models/test_dagrun.py:
##########
@@ -3203,3 +3234,84 @@ def on_failure(context):
assert context_received["ti"].task_id == "test_task"
assert context_received["ti"].dag_id == "test_dag"
assert context_received["ti"].run_id == dr.run_id
+
+ def
test_fetch_task_instances_does_not_load_dag_run_when_flag_is_false(self,
dag_maker, session):
+ """When load_dag_run=False, no dag_run rows are fetched from the DB.
+
+ Verified by:
+ 1. Exactly 1 SQL query (the TI SELECT only).
Review Comment:
These assertions describe "Exactly 1 SQL query", but `assert_queries_count`
only fails when the count exceeds the expected value (it does not enforce an
exact count). Consider adjusting the wording to "no more than 1" (or use a
stricter assertion helper if you need exactness).
```suggestion
1. No more than 1 SQL query (the TI SELECT only).
```
##########
airflow-core/src/airflow/models/dagrun.py:
##########
@@ -798,16 +798,36 @@ def fetch_task_instances(
session: Session = NEW_SESSION,
) -> list[TI]:
"""Return the task instances for this dag run."""
- tis = (
- select(TI)
- .options(joinedload(TI.dag_run))
- .where(
- TI.dag_id == dag_id,
- TI.run_id == run_id,
- )
- .order_by(TI.task_id, TI.map_index)
+ return DagRun._fetch_task_instances(
+ dag_id=dag_id,
+ run_id=run_id,
+ task_ids=task_ids,
+ state=state,
+ session=session,
+ load_dag_run=True,
)
+ @staticmethod
+ @provide_session
+ def _fetch_task_instances(
+ dag_id: str | None = None,
+ run_id: str | None = None,
+ task_ids: list[str] | None = None,
+ state: Iterable[TaskInstanceState | None] | None = None,
+ session: Session = NEW_SESSION,
+ load_dag_run: bool = True,
+ ) -> list[TI]:
+ """Return the task instances for this dag run."""
+ tis = select(TI).where(
+ TI.dag_id == dag_id,
+ TI.run_id == run_id,
+ )
+
+ if load_dag_run:
+ tis = tis.options(joinedload(TI.dag_run))
+
Review Comment:
`load_dag_run=False` currently does not actually prevent loading `DagRun`
because `TaskInstance.dag_run` is configured with `lazy="joined"` (see
`airflow/models/taskinstance.py`), so the relationship will still be
joined/eager-loaded by default. This means the performance optimization won’t
take effect and the new tests that expect `DetachedInstanceError` after
detaching will not be valid. Consider explicitly overriding the loader strategy
when `load_dag_run` is false (e.g. apply an option that disables joined eager
loading for `TI.dag_run`).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]