pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3027370557


##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
     def test_should_respond_422(self, test_client):
         response = test_client.patch(self.ENDPOINT_URL, json={})
         assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+    DAG_ID = "example_task_group"
+    RUN_ID = "TEST_DAG_RUN_ID"
+    GROUP_ID = "section_1"
+    BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+    ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a task group sets state for all tasks in the 
group."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+        assert response_data["total_entries"] == mock_set_ti_state.call_count
+        assert mock_set_ti_state.call_count == 3
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "success"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+        """Test that patching a task group with failed state works."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "failed"},
+        )
+        assert response.status_code == 200
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "failed"

Review Comment:
   Same here assert the API response.



##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
     def test_should_respond_422(self, test_client):
         response = test_client.patch(self.ENDPOINT_URL, json={})
         assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+    DAG_ID = "example_task_group"
+    RUN_ID = "TEST_DAG_RUN_ID"
+    GROUP_ID = "section_1"
+    BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+    ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a task group sets state for all tasks in the 
group."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+        assert response_data["total_entries"] == mock_set_ti_state.call_count
+        assert mock_set_ti_state.call_count == 3
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "success"
+

Review Comment:
   Can you also assert the API response? 



##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
     def test_should_respond_422(self, test_client):
         response = test_client.patch(self.ENDPOINT_URL, json={})
         assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+    DAG_ID = "example_task_group"
+    RUN_ID = "TEST_DAG_RUN_ID"
+    GROUP_ID = "section_1"
+    BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+    ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a task group sets state for all tasks in the 
group."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+        assert response_data["total_entries"] == mock_set_ti_state.call_count
+        assert mock_set_ti_state.call_count == 3
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "success"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+        """Test that patching a task group with failed state works."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "failed"},
+        )
+        assert response.status_code == 200
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "failed"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a nested task group includes tasks from inner 
groups."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        # section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        assert mock_set_ti_state.call_count == 4
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == [
+            "section_2.inner_section_2.task_2",
+            "section_2.inner_section_2.task_3",
+            "section_2.inner_section_2.task_4",
+            "section_2.task_1",
+        ]
+
+    def test_patch_task_group_not_found(self, test_client, session):
+        """Test that requesting a non-existent task group returns 404."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+        assert "nonexistent_group" in response.json()["detail"]
+
+    def test_patch_task_group_invalid_state(self, test_client, session):
+        """Test that an invalid new_state returns 422."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "invalid_state"},
+        )
+        assert response.status_code == 422
+
+    def test_patch_task_group_dag_not_found(self, test_client, session):
+        """Test that requesting a non-existent DAG returns 404."""
+        url = 
f"/dags/nonexistent_dag/dagRuns/{self.RUN_ID}/taskGroupInstances/{self.GROUP_ID}"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 403
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_query_count_does_not_scale_with_task_group_size(self, 
mock_set_ti_state, test_client, session):
+        """Test that query count doesn't scale linearly with task group size - 
single bulk query."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        url_section_1 = f"{self.BASE_URL}/section_1"
+        url_section_2 = f"{self.BASE_URL}/section_2"
+
+        tis_section_1 = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        tis_section_2 = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(
+                    [
+                        "section_2.task_1",
+                        "section_2.inner_section_2.task_2",
+                        "section_2.inner_section_2.task_3",
+                        "section_2.inner_section_2.task_4",
+                    ]
+                ),
+            )
+        ).all()
+
+        # Warm-up call to populate caches (DAG deserialization, etc.)
+        mock_set_ti_state.return_value = tis_section_1[:1]
+        test_client.patch(url_section_1, json={"new_state": "success"})
+        mock_set_ti_state.reset_mock()

Review Comment:
   No need, remove warm up 



##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
     def test_should_respond_422(self, test_client):
         response = test_client.patch(self.ENDPOINT_URL, json={})
         assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+    DAG_ID = "example_task_group"
+    RUN_ID = "TEST_DAG_RUN_ID"
+    GROUP_ID = "section_1"
+    BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+    ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a task group sets state for all tasks in the 
group."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+        assert response_data["total_entries"] == mock_set_ti_state.call_count
+        assert mock_set_ti_state.call_count == 3
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "success"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+        """Test that patching a task group with failed state works."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "failed"},
+        )
+        assert response.status_code == 200
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "failed"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a nested task group includes tasks from inner 
groups."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        # section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        assert mock_set_ti_state.call_count == 4
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == [
+            "section_2.inner_section_2.task_2",
+            "section_2.inner_section_2.task_3",
+            "section_2.inner_section_2.task_4",
+            "section_2.task_1",
+        ]
+
+    def test_patch_task_group_not_found(self, test_client, session):
+        """Test that requesting a non-existent task group returns 404."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+        assert "nonexistent_group" in response.json()["detail"]
+
+    def test_patch_task_group_invalid_state(self, test_client, session):
+        """Test that an invalid new_state returns 422."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "invalid_state"},
+        )
+        assert response.status_code == 422
+
+    def test_patch_task_group_dag_not_found(self, test_client, session):
+        """Test that requesting a non-existent DAG returns 404."""
+        url = 
f"/dags/nonexistent_dag/dagRuns/{self.RUN_ID}/taskGroupInstances/{self.GROUP_ID}"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 403
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_query_count_does_not_scale_with_task_group_size(self, 
mock_set_ti_state, test_client, session):
+        """Test that query count doesn't scale linearly with task group size - 
single bulk query."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        url_section_1 = f"{self.BASE_URL}/section_1"
+        url_section_2 = f"{self.BASE_URL}/section_2"
+
+        tis_section_1 = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        tis_section_2 = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(
+                    [
+                        "section_2.task_1",
+                        "section_2.inner_section_2.task_2",
+                        "section_2.inner_section_2.task_3",
+                        "section_2.inner_section_2.task_4",
+                    ]
+                ),
+            )
+        ).all()
+
+        # Warm-up call to populate caches (DAG deserialization, etc.)
+        mock_set_ti_state.return_value = tis_section_1[:1]
+        test_client.patch(url_section_1, json={"new_state": "success"})
+        mock_set_ti_state.reset_mock()
+
+        # --- section_1 (3 tasks) ---
+        mock_set_ti_state.return_value = tis_section_1[:1]
+
+        with count_queries() as result_section_1:
+            response = test_client.patch(url_section_1, json={"new_state": 
"success"})
+        assert response.status_code == 200
+        assert mock_set_ti_state.call_count == 3
+        mock_set_ti_state.reset_mock()
+
+        # --- section_2 (4 tasks including nested inner_section_2) ---
+        mock_set_ti_state.return_value = tis_section_2[:1]
+
+        with count_queries() as result_section_2:
+            response = test_client.patch(url_section_2, json={"new_state": 
"success"})
+        assert response.status_code == 200
+        assert mock_set_ti_state.call_count == 4
+
+        # Query count should be identical regardless of task group size (3 vs 
4 tasks)
+        count_section_1 = sum(result_section_1.values())
+        count_section_2 = sum(result_section_2.values())
+        assert count_section_1 == count_section_2, (
+            f"Query counts should be equal across differently-sized task 
groups: "
+            f"section_1={count_section_1}, section_2={count_section_2}"
+        )
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_includes_upstream_downstream_parameters(self, mock_set_ti_state, 
test_client, session):
+        """Test that include_upstream and include_downstream parameters are 
passed through."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+        mock_set_ti_state.return_value = tis[:1]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={
+                "new_state": "success",
+                "include_upstream": True,
+                "include_downstream": True,
+                "include_future": True,
+                "include_past": True,
+            },
+        )
+        assert response.status_code == 200
+
+        # Verify the parameters were passed to set_task_instance_state
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["upstream"] is True
+            assert call.kwargs["downstream"] is True
+            assert call.kwargs["future"] is True
+            assert call.kwargs["past"] is True
+
+
+class TestPatchTaskGroupDryRun(TestTaskInstanceEndpoint):
+    DAG_ID = "example_task_group"
+    RUN_ID = "TEST_DAG_RUN_ID"
+    GROUP_ID = "section_1"
+    BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+    ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}/dry_run"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_dry_run_returns_affected_tis_without_committing(self, 
mock_set_ti_state, test_client, session):
+        """Test that dry run returns TIs that would be affected without 
committing."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+        mock_set_ti_state.return_value = tis[:1]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        assert mock_set_ti_state.call_count == 3
+        # Verify commit=False was passed for dry run
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["commit"] is False
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_dry_run_query_count_does_not_scale(self, mock_set_ti_state, 
test_client, session):
+        """Test that dry_run query count doesn't scale with task group size - 
uses bulk query."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        url_section_1 = f"{self.BASE_URL}/section_1/dry_run"
+        url_section_2 = f"{self.BASE_URL}/section_2/dry_run"
+
+        tis_section_1 = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        tis_section_2 = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(
+                    [
+                        "section_2.task_1",
+                        "section_2.inner_section_2.task_2",
+                        "section_2.inner_section_2.task_3",
+                        "section_2.inner_section_2.task_4",
+                    ]
+                ),
+            )
+        ).all()
+
+        # Warm-up call to populate caches (DAG deserialization, etc.)
+        mock_set_ti_state.return_value = tis_section_1[:1]
+        test_client.patch(url_section_1, json={"new_state": "success"})
+        mock_set_ti_state.reset_mock()

Review Comment:
   Remove.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to