pierrejeambrun commented on code in PR #62812:
URL: https://github.com/apache/airflow/pull/62812#discussion_r3063193582
##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py:
##########
@@ -5919,3 +5919,350 @@ def test_should_respond_403(self,
unauthorized_test_client):
def test_should_respond_422(self, test_client):
response = test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 422
+
+
+class TestPatchTaskGroup(TestTaskInstanceEndpoint):
+ DAG_ID = "example_task_group"
+ RUN_ID = "TEST_DAG_RUN_ID"
+ GROUP_ID = "section_1"
+ BASE_URL = f"/dags/{DAG_ID}/dagRuns/{RUN_ID}/taskGroupInstances"
+ ENDPOINT_URL = f"{BASE_URL}/{GROUP_ID}"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_patch_task_group_success(self, mock_set_ti_state, test_client,
session):
+ """Test that patching a task group sets state for all tasks in the
group."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ tis = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+ )
+ ).all()
+
+ ti_map = {ti.task_id: ti for ti in tis}
+ mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+ response = test_client.patch(
+ self.ENDPOINT_URL,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 200
+ response_data = response.json()
+ assert response_data["total_entries"] == mock_set_ti_state.call_count
+ assert mock_set_ti_state.call_count == 3
+ called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+ assert called_task_ids == ["section_1.task_1", "section_1.task_2",
"section_1.task_3"]
+ for call in mock_set_ti_state.call_args_list:
+ assert call.kwargs["state"] == "success"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_patch_task_group_failed_state(self, mock_set_ti_state,
test_client, session):
+ """Test that patching a task group with failed state works."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ tis = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ TaskInstance.task_id.in_(["section_1.task_1",
"section_1.task_2", "section_1.task_3"]),
+ )
+ ).all()
+
+ ti_map = {ti.task_id: ti for ti in tis}
+ mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+ response = test_client.patch(
+ self.ENDPOINT_URL,
+ json={"new_state": "failed"},
+ )
+ assert response.status_code == 200
+ for call in mock_set_ti_state.call_args_list:
+ assert call.kwargs["state"] == "failed"
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_patch_task_group_nested(self, mock_set_ti_state, test_client,
session):
+ """Test that patching a nested task group includes tasks from inner
groups."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ tis = session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ )
+ ).all()
+
+ ti_map = {ti.task_id: ti for ti in tis}
+ mock_set_ti_state.side_effect = lambda task_id, **kwargs:
[ti_map[task_id]]
+
+ # section_2 contains task_1, and inner_section_2 which contains
task_2, task_3, task_4
+ url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/section_2"
+ response = test_client.patch(
+ url,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 200
+ assert mock_set_ti_state.call_count == 4
+ called_task_ids = sorted(call.kwargs["task_id"] for call in
mock_set_ti_state.call_args_list)
+ assert called_task_ids == [
+ "section_2.inner_section_2.task_2",
+ "section_2.inner_section_2.task_3",
+ "section_2.inner_section_2.task_4",
+ "section_2.task_1",
+ ]
+
+ def test_patch_task_group_not_found(self, test_client, session):
+ """Test that requesting a non-existent task group returns 404."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ url =
f"/dags/{self.DAG_ID}/dagRuns/{self.RUN_ID}/taskGroupInstances/nonexistent_group"
+ response = test_client.patch(
+ url,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 404
+ assert "nonexistent_group" in response.json()["detail"]
+
+ def test_patch_task_group_invalid_state(self, test_client, session):
+ """Test that an invalid new_state returns 422."""
+ self.create_task_instances(session, dag_id=self.DAG_ID)
+
+ response = test_client.patch(
+ self.ENDPOINT_URL,
+ json={"new_state": "invalid_state"},
+ )
+ assert response.status_code == 422
+
+ def test_patch_task_group_dag_not_found(self, test_client, session):
+ """Test that requesting a non-existent DAG returns 404."""
+ url =
f"/dags/nonexistent_dag/dagRuns/{self.RUN_ID}/taskGroupInstances/{self.GROUP_ID}"
+ response = test_client.patch(
+ url,
+ json={"new_state": "success"},
+ )
+ assert response.status_code == 404
+
+ def test_should_respond_401(self, unauthenticated_test_client):
+ response = unauthenticated_test_client.patch(self.ENDPOINT_URL,
json={"new_state": "success"})
+ assert response.status_code == 401
+
+ def test_should_respond_403(self, unauthorized_test_client):
+ response = unauthorized_test_client.patch(self.ENDPOINT_URL,
json={"new_state": "success"})
+ assert response.status_code == 403
+
+
@mock.patch("airflow.serialization.definitions.dag.SerializedDAG.set_task_instance_state")
+ def test_query_count_does_not_scale_with_task_group_size(self,
mock_set_ti_state, test_client, session):
+ """Test that query count doesn't scale linearly with task group size -
single bulk query."""
Review Comment:
My bad @OscarLigthart, I think we have the same problem with patching all
map_indexes of a TI. This is already accepted for that endpoint and I don't
think we can remove easily the multiple calls to `set_state`.
Remove the query guard for the patch tests and revert to your original draft
that basically replicated the patterns we already have at other places.
I was confused because I didn't realize we had such N+1 query problem for
patching a TI mapped indexes.
So we're fine.
We just need a single call to `set_state` I believe.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]