OscarLigthart commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3034320062


##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
     def test_should_respond_422(self, test_client):
         response = test_client.patch(self.ENDPOINT_URL, json={})
         assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+    DAG_ID = "example_task_group"
+    RUN_ID = "TEST_DAG_RUN_ID"
+    GROUP_ID = "section_1"
+    BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+    ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a task group sets state for all tasks in the 
group."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+        assert response_data["total_entries"] == mock_set_ti_state.call_count
+        assert mock_set_ti_state.call_count == 3
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "success"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+        """Test that patching a task group with failed state works."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "failed"},
+        )
+        assert response.status_code == 200
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "failed"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a nested task group includes tasks from inner 
groups."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        # section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        assert mock_set_ti_state.call_count == 4
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == [
+            "section_2.inner_section_2.task_2",
+            "section_2.inner_section_2.task_3",
+            "section_2.inner_section_2.task_4",
+            "section_2.task_1",
+        ]
+
+    def test_patch_task_group_not_found(self, test_client, session):
+        """Test that requesting a non-existent task group returns 404."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+        assert "nonexistent_group" in response.json()["detail"]
+
+    def test_patch_task_group_invalid_state(self, test_client, session):
+        """Test that an invalid new_state returns 422."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "invalid_state"},
+        )
+        assert response.status_code == 422
+
+    def test_patch_task_group_dag_not_found(self, test_client, session):
+        """Test that requesting a non-existent DAG returns 404."""
+        url = 
f"/dags/nonexistent_dag/dagRuns/{self.RUN_ID}/taskGroupInstances/{self.GROUP_ID}"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 403
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_query_count_does_not_scale_with_task_group_size(self, 
mock_set_ti_state, test_client, session):
+        """Test that query count doesn't scale linearly with task group size - 
single bulk query."""

Review Comment:
   This was a bit tricky. After removing the mock I was again stopped by the 
N+1 query problem. I consulted Claude and the result is [this 
commit](https://github.com/apache/airflow/pull/62812/commits/bcf31b62fb72bd2f6b2f46e5e57cc7468487ca28).
 What I'm seeing there makes sense, but, frankly, I am a bit out of my depth 
here. Would be happy to discuss this with you and get some guidelines on how to 
properly eliminate the problem.
   
   Claude tells me this:
   
   ># N+1 Query Fix for Task Group Endpoints
   >
   >## Before (N+1)
   >
   >For N tasks in a group, the route called `set_task_instance_state()` N 
times.
   >Each call ran the full `mark_tasks.set_state()` chain: run_id lookup, task 
query,
   >per-TI update, downstream clear. ~15+ queries × N = O(N) full round trips.
   >
   >## After (batch)
   >
   >One call to `set_multiple_task_instances_state()` with all tasks bundled.
   >`get_run_ids()`, `get_all_dag_task_query()`, and `partial_subset+clear()` 
each
   >run once. Only the per-TI UPDATE+FLUSH remains inside the loop.
   >
   >Total: ~15 base queries + ~2 per additional task.
   >
   >## Remaining per-TI overhead
   >
   >`TaskInstance.set_state()` calls `session.flush()` after each TI update.
   >We could inline this and flush once, but it duplicates model internals that
   >would diverge silently on changes. 
   >
   >The clean fix would be adding a `flush=False` parameter to
   >`TaskInstance.set_state()`, but that's a change to core model code used
   >across the entire codebase — out of scope here. The gain is also marginal:
   >just N-1 small in-transaction UPDATEs on already-locked rows, not full
   >query chains.
   >
   >The test asserts ≤5 queries overhead per additional task, confirming the
   >expensive O(N) operations are eliminated.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to