This is an automated email from the ASF dual-hosted git repository. pierrejeambrun pushed a commit to branch v2-5-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 816969dddeffce435260df2188a6396eb553f646 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Wed Mar 15 20:05:12 2023 +0000 Ensure that `dag.partial_subset` doesn't mutate task group properties (#30129) We had a few properties that we failed to copy that we should have. To fix that in such a way that we don't miss things in the future I've converted it to use deepcopy everything by default and exclude `children` and `parent_group`. This also made me notice that we were not correctly setting `parent_group` after partial_subset anymore -- it clearly hasn't mattered, but we were setting a now-unused `_parent_group` attribute. (cherry picked from commit 76a884c552a78bfb273fe8b65def58125fc7961a) --- airflow/models/dag.py | 23 +++++++++++++++++------ tests/models/test_dag.py | 35 +++++++++++++++++++++++++++++++---- 2 files changed, 48 insertions(+), 10 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index a2de4d4176..57f1b82415 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -2210,21 +2210,32 @@ class DAG(LoggingMixin): def filter_task_group(group, parent_group): """Exclude tasks not included in the subdag from the given TaskGroup.""" + # We want to deepcopy _most but not all_ attributes of the task group, so we create a shallow copy + # and then manually deep copy the instances. (memo argument to deepcopy only works for instances + # of classes, not "native" properties of an instance) copied = copy.copy(group) - copied.used_group_ids = set(copied.used_group_ids) - copied._parent_group = parent_group - copied.children = {} + memo[id(group.children)] = {} + if parent_group: + memo[id(group.parent_group)] = parent_group + for attr, value in copied.__dict__.items(): + if id(value) in memo: + value = memo[id(value)] + else: + value = copy.deepcopy(value, memo) + copied.__dict__[attr] = value + + proxy = weakref.proxy(copied) for child in group.children.values(): if isinstance(child, AbstractOperator): if child.task_id in dag.task_dict: task = copied.children[child.task_id] = dag.task_dict[child.task_id] - task.task_group = weakref.proxy(copied) + task.task_group = proxy else: copied.used_group_ids.discard(child.task_id) else: - filtered_child = filter_task_group(child, copied) + filtered_child = filter_task_group(child, proxy) # Only include this child TaskGroup if it is non-empty. if filtered_child.children: @@ -2232,7 +2243,7 @@ class DAG(LoggingMixin): return copied - dag._task_group = filter_task_group(self._task_group, None) + dag._task_group = filter_task_group(self.task_group, None) # Removing upstream/downstream references to tasks and TaskGroups that did not make # the cut. diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 4940d0d562..47f5f6b145 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -24,6 +24,7 @@ import os import pickle import re import sys +import weakref from contextlib import redirect_stdout from datetime import timedelta from pathlib import Path @@ -1340,7 +1341,7 @@ class TestDag: assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3} assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3} - def test_sub_dag_updates_all_references_while_deepcopy(self): + def test_partial_subset_updates_all_references_while_deepcopy(self): with DAG("test_dag", start_date=DEFAULT_DATE) as dag: op1 = EmptyOperator(task_id="t1") op2 = EmptyOperator(task_id="t2") @@ -1348,11 +1349,37 @@ class TestDag: op1 >> op2 op2 >> op3 - sub_dag = dag.partial_subset("t2", include_upstream=True, include_downstream=False) - assert id(sub_dag.task_dict["t1"].downstream_list[0].dag) == id(sub_dag) + partial = dag.partial_subset("t2", include_upstream=True, include_downstream=False) + assert id(partial.task_dict["t1"].downstream_list[0].dag) == id(partial) # Copied DAG should not include unused task IDs in used_group_ids - assert "t3" not in sub_dag._task_group.used_group_ids + assert "t3" not in partial.task_group.used_group_ids + + def test_partial_subset_taskgroup_join_ids(self): + with DAG("test_dag", start_date=DEFAULT_DATE) as dag: + start = EmptyOperator(task_id="start") + with TaskGroup(group_id="outer", prefix_group_id=False) as outer_group: + with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1: + EmptyOperator(task_id="t1") + with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2: + EmptyOperator(task_id="t2") + + start >> tg1 >> tg2 + + # Pre-condition checks + task = dag.get_task("t2") + assert task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(task.task_group.parent_group, weakref.ProxyType) + assert task.task_group.parent_group == outer_group + + partial = dag.partial_subset(["t2"], include_upstream=True, include_downstream=False) + copied_task = partial.get_task("t2") + assert copied_task.task_group.upstream_group_ids == {"tg1"} + assert isinstance(copied_task.task_group.parent_group, weakref.ProxyType) + assert copied_task.task_group.parent_group + + # Make sure we don't affect the original! + assert task.task_group.upstream_group_ids is not copied_task.task_group.upstream_group_ids def test_schedule_dag_no_previous_runs(self): """
