This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 816969dddeffce435260df2188a6396eb553f646
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Wed Mar 15 20:05:12 2023 +0000

    Ensure that `dag.partial_subset` doesn't mutate task group properties 
(#30129)
    
    We had a few properties that we failed to copy that we should have. To
    fix that in such a way that we don't miss things in the future I've
    converted it to use deepcopy everything by default and exclude
    `children` and `parent_group`.
    
    This also made me notice that we were not correctly setting
    `parent_group` after partial_subset anymore -- it clearly hasn't mattered, 
but
    we were setting a now-unused `_parent_group` attribute.
    
    (cherry picked from commit 76a884c552a78bfb273fe8b65def58125fc7961a)
---
 airflow/models/dag.py    | 23 +++++++++++++++++------
 tests/models/test_dag.py | 35 +++++++++++++++++++++++++++++++----
 2 files changed, 48 insertions(+), 10 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index a2de4d4176..57f1b82415 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -2210,21 +2210,32 @@ class DAG(LoggingMixin):
 
         def filter_task_group(group, parent_group):
             """Exclude tasks not included in the subdag from the given 
TaskGroup."""
+            # We want to deepcopy _most but not all_ attributes of the task 
group, so we create a shallow copy
+            # and then manually deep copy the instances. (memo argument to 
deepcopy only works for instances
+            # of classes, not "native" properties of an instance)
             copied = copy.copy(group)
-            copied.used_group_ids = set(copied.used_group_ids)
-            copied._parent_group = parent_group
 
-            copied.children = {}
+            memo[id(group.children)] = {}
+            if parent_group:
+                memo[id(group.parent_group)] = parent_group
+            for attr, value in copied.__dict__.items():
+                if id(value) in memo:
+                    value = memo[id(value)]
+                else:
+                    value = copy.deepcopy(value, memo)
+                copied.__dict__[attr] = value
+
+            proxy = weakref.proxy(copied)
 
             for child in group.children.values():
                 if isinstance(child, AbstractOperator):
                     if child.task_id in dag.task_dict:
                         task = copied.children[child.task_id] = 
dag.task_dict[child.task_id]
-                        task.task_group = weakref.proxy(copied)
+                        task.task_group = proxy
                     else:
                         copied.used_group_ids.discard(child.task_id)
                 else:
-                    filtered_child = filter_task_group(child, copied)
+                    filtered_child = filter_task_group(child, proxy)
 
                     # Only include this child TaskGroup if it is non-empty.
                     if filtered_child.children:
@@ -2232,7 +2243,7 @@ class DAG(LoggingMixin):
 
             return copied
 
-        dag._task_group = filter_task_group(self._task_group, None)
+        dag._task_group = filter_task_group(self.task_group, None)
 
         # Removing upstream/downstream references to tasks and TaskGroups that 
did not make
         # the cut.
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 4940d0d562..47f5f6b145 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -24,6 +24,7 @@ import os
 import pickle
 import re
 import sys
+import weakref
 from contextlib import redirect_stdout
 from datetime import timedelta
 from pathlib import Path
@@ -1340,7 +1341,7 @@ class TestDag:
         assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
         assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
 
-    def test_sub_dag_updates_all_references_while_deepcopy(self):
+    def test_partial_subset_updates_all_references_while_deepcopy(self):
         with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
             op1 = EmptyOperator(task_id="t1")
             op2 = EmptyOperator(task_id="t2")
@@ -1348,11 +1349,37 @@ class TestDag:
             op1 >> op2
             op2 >> op3
 
-        sub_dag = dag.partial_subset("t2", include_upstream=True, 
include_downstream=False)
-        assert id(sub_dag.task_dict["t1"].downstream_list[0].dag) == 
id(sub_dag)
+        partial = dag.partial_subset("t2", include_upstream=True, 
include_downstream=False)
+        assert id(partial.task_dict["t1"].downstream_list[0].dag) == 
id(partial)
 
         # Copied DAG should not include unused task IDs in used_group_ids
-        assert "t3" not in sub_dag._task_group.used_group_ids
+        assert "t3" not in partial.task_group.used_group_ids
+
+    def test_partial_subset_taskgroup_join_ids(self):
+        with DAG("test_dag", start_date=DEFAULT_DATE) as dag:
+            start = EmptyOperator(task_id="start")
+            with TaskGroup(group_id="outer", prefix_group_id=False) as 
outer_group:
+                with TaskGroup(group_id="tg1", prefix_group_id=False) as tg1:
+                    EmptyOperator(task_id="t1")
+                with TaskGroup(group_id="tg2", prefix_group_id=False) as tg2:
+                    EmptyOperator(task_id="t2")
+
+                start >> tg1 >> tg2
+
+        # Pre-condition checks
+        task = dag.get_task("t2")
+        assert task.task_group.upstream_group_ids == {"tg1"}
+        assert isinstance(task.task_group.parent_group, weakref.ProxyType)
+        assert task.task_group.parent_group == outer_group
+
+        partial = dag.partial_subset(["t2"], include_upstream=True, 
include_downstream=False)
+        copied_task = partial.get_task("t2")
+        assert copied_task.task_group.upstream_group_ids == {"tg1"}
+        assert isinstance(copied_task.task_group.parent_group, 
weakref.ProxyType)
+        assert copied_task.task_group.parent_group
+
+        # Make sure we don't affect the original!
+        assert task.task_group.upstream_group_ids is not 
copied_task.task_group.upstream_group_ids
 
     def test_schedule_dag_no_previous_runs(self):
         """

Reply via email to