pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3027386658


##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self, 
unauthorized_test_client):
     def test_should_respond_422(self, test_client):
         response = test_client.patch(self.ENDPOINT_URL, json={})
         assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+    DAG_ID = "example_task_group"
+    RUN_ID = "TEST_DAG_RUN_ID"
+    GROUP_ID = "section_1"
+    BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+    ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_success(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a task group sets state for all tasks in the 
group."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        response_data = response.json()
+        assert response_data["total_entries"] == mock_set_ti_state.call_count
+        assert mock_set_ti_state.call_count == 3
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == ["section_1.task_1", "section_1.task_2", 
"section_1.task_3"]
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "success"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_failed_state(self, mock_set_ti_state, 
test_client, session):
+        """Test that patching a task group with failed state works."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+                TaskInstance.task_id.in_(["section_1.task_1", 
"section_1.task_2", "section_1.task_3"]),
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "failed"},
+        )
+        assert response.status_code == 200
+        for call in mock_set_ti_state.call_args_list:
+            assert call.kwargs["state"] == "failed"
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_patch_task_group_nested(self, mock_set_ti_state, test_client, 
session):
+        """Test that patching a nested task group includes tasks from inner 
groups."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        tis = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.run_id == self.RUN_ID,
+            )
+        ).all()
+
+        ti_map = {ti.task_id: ti for ti in tis}
+        mock_set_ti_state.side_effect = lambda task_id, **kwargs: 
[ti_map[task_id]]
+
+        # section_2 contains task_1, and inner_section_2 which contains 
task_2, task_3, task_4
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 200
+        assert mock_set_ti_state.call_count == 4
+        called_task_ids = sorted(call.kwargs["task_id"] for call in 
mock_set_ti_state.call_args_list)
+        assert called_task_ids == [
+            "section_2.inner_section_2.task_2",
+            "section_2.inner_section_2.task_3",
+            "section_2.inner_section_2.task_4",
+            "section_2.task_1",
+        ]
+
+    def test_patch_task_group_not_found(self, test_client, session):
+        """Test that requesting a non-existent task group returns 404."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        url = 
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+        assert "nonexistent_group" in response.json()["detail"]
+
+    def test_patch_task_group_invalid_state(self, test_client, session):
+        """Test that an invalid new_state returns 422."""
+        self.create_task_instances(session, dag_id=self.DAG_ID)
+
+        response = test_client.patch(
+            self.ENDPOINT_URL,
+            json={"new_state": "invalid_state"},
+        )
+        assert response.status_code == 422
+
+    def test_patch_task_group_dag_not_found(self, test_client, session):
+        """Test that requesting a non-existent DAG returns 404."""
+        url = 
f"/dags/nonexistent_dag/dagRuns/{self.RUN_ID}/taskGroupInstances/{self.GROUP_ID}"
+        response = test_client.patch(
+            url,
+            json={"new_state": "success"},
+        )
+        assert response.status_code == 404
+
+    def test_should_respond_401(self, unauthenticated_test_client):
+        response = unauthenticated_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 401
+
+    def test_should_respond_403(self, unauthorized_test_client):
+        response = unauthorized_test_client.patch(self.ENDPOINT_URL, 
json={"new_state": "success"})
+        assert response.status_code == 403
+
+    
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+    def test_query_count_does_not_scale_with_task_group_size(self, 
mock_set_ti_state, test_client, session):
+        """Test that query count doesn't scale linearly with task group size - 
single bulk query."""

Review Comment:
   Don't mock anything to assert the number of db queries aren't scaly linearly.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -852,6 +856,120 @@ def _collect_relatives(run_id: str, direction: 
Literal["upstream", "downstream"]
     )
 
 
+@task_instances_router.patch(
+    "/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+    ),
+    dependencies=[
+        Depends(action_logging()),
+        Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE)),
+    ],
+    operation_id="patch_task_group_instances",
+)
+def patch_task_group_instances(
+    dag_id: str,
+    dag_run_id: str,
+    group_id: str,
+    dag_bag: DagBagDep,
+    body: PatchTaskGroupBody,
+    session: SessionDep,
+    user: GetUserDep,
+    update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+    """Update the state of all task instances in a task group."""
+    dag, tis, data = _patch_ti_group_validate_request(
+        dag_id, dag_run_id, group_id, dag_bag, body, session, update_mask
+    )
+    affected_tis_dict: dict[tuple[str, str, str, int], TI] = {}
+
+    for key, _ in data.items():
+        if key == "new_state":
+            for ti in tis:
+                bulk_ti_body = BulkTaskInstanceBody(
+                    task_id=ti.task_id,
+                    map_index=ti.map_index,
+                    new_state=body.new_state,
+                    note=body.note,
+                    include_upstream=body.include_upstream,
+                    include_downstream=body.include_downstream,
+                    include_future=body.include_future,
+                    include_past=body.include_past,
+                )
+
+                updated_tis = _patch_task_instance_state(
+                    task_id=ti.task_id,
+                    dag_run_id=dag_run_id,
+                    dag=dag,
+                    task_instance_body=bulk_ti_body,
+                    data=data,
+                    session=session,
+                )
+
+                _collect_unique_tis(affected_tis_dict, updated_tis)
+
+        elif key == "note":
+            _patch_task_instance_note(
+                task_instance_body=body,
+                tis=tis,
+                user=user,
+                update_mask=update_mask,
+            )
+            _collect_unique_tis(affected_tis_dict, tis)
+
+    return TaskInstanceCollectionResponse(
+        task_instances=[TaskInstanceResponse.model_validate(ti) for ti in 
affected_tis_dict.values()],
+        total_entries=len(affected_tis_dict),
+    )
+
+
+@task_instances_router.patch(
+    "/dagRuns/{dag_run_id}/taskGroupInstances/{group_id}/dry_run",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST],
+    ),
+    dependencies=[Depends(requires_access_dag(method="PUT", 
access_entity=DagAccessEntity.TASK_INSTANCE))],
+    operation_id="patch_task_group_instances_dry_run",
+)
+def patch_task_group_instances_dry_run(
+    dag_id: str,
+    dag_run_id: str,
+    group_id: str,
+    dag_bag: DagBagDep,
+    body: PatchTaskGroupBody,
+    session: SessionDep,
+) -> TaskInstanceCollectionResponse:
+    """Dry-run of updating the state of all task instances in a task group."""
+    dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+    tis = _get_task_group_task_instances(dag_id, dag_run_id, group_id, dag, 
session)
+
+    if body.new_state:
+        all_tis: list[TI] = []
+        for ti in tis:
+            affected_tis = (
+                dag.set_task_instance_state(
+                    task_id=ti.task_id,
+                    run_id=dag_run_id,
+                    map_indexes=[ti.map_index],
+                    state=body.new_state,
+                    upstream=body.include_upstream,
+                    downstream=body.include_downstream,
+                    future=body.include_future,
+                    past=body.include_past,
+                    commit=False,
+                    session=session,
+                )
+                or []
+            )
+            all_tis.extend(affected_tis)
+        tis = all_tis

Review Comment:
   Should we deduplicate here as well? I'm affraid that downstream TI could 
appear multiple times in the result array.



##########
airflow-core/src/airflow/api_fastapi/core_api/datamodels/task_instances.py:
##########
@@ -215,6 +215,14 @@ def validate_new_state(cls, ns: str | None) -> str:
         return ns
 
 
+class PatchTaskInstanceBody(PatchTaskInstanceBaseBody):
+    """Request body for Clear Task Instances endpoint."""
+
+
+class PatchTaskGroupBody(PatchTaskInstanceBaseBody):
+    """Request body for patching the state of all task instances in a task 
group."""
+

Review Comment:
   Just re-use PatchTaskInstanceBody, no need to create an abstract base class 
and two empty subclasses.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to