This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch v2-5-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 0be6263619985c6bf9047b140e1fd3d41f77a73f
Author: Hussein Awala <[email protected]>
AuthorDate: Sat Mar 18 22:03:41 2023 +0100

    fix update_mask in patch variable route (#29711)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    (cherry picked from commit de8e07dc6fea620541e0daa67131e8fe21dbd5fe)
---
 .../api_connexion/endpoints/connection_endpoint.py | 10 +---
 airflow/api_connexion/endpoints/update_mask.py     | 34 +++++++++++++
 .../api_connexion/endpoints/variable_endpoint.py   | 24 ++++++----
 tests/api_connexion/endpoints/test_update_mask.py  | 56 ++++++++++++++++++++++
 .../endpoints/test_variable_endpoint.py            | 16 +++++--
 5 files changed, 119 insertions(+), 21 deletions(-)

diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py 
b/airflow/api_connexion/endpoints/connection_endpoint.py
index 40cb474bda..64db0711fd 100644
--- a/airflow/api_connexion/endpoints/connection_endpoint.py
+++ b/airflow/api_connexion/endpoints/connection_endpoint.py
@@ -26,6 +26,7 @@ from sqlalchemy import func
 from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
+from airflow.api_connexion.endpoints.update_mask import 
extract_update_mask_data
 from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, 
NotFound
 from airflow.api_connexion.parameters import apply_sorting, check_limit, 
format_parameters
 from airflow.api_connexion.schemas.connection_schema import (
@@ -132,14 +133,7 @@ def patch_connection(
     if data.get("conn_id") and connection.conn_id != data["conn_id"]:
         raise BadRequest(detail="The connection_id cannot be updated.")
     if update_mask:
-        update_mask = [i.strip() for i in update_mask]
-        data_ = {}
-        for field in update_mask:
-            if field in data and field not in non_update_fields:
-                data_[field] = data[field]
-            else:
-                raise BadRequest(detail=f"'{field}' is unknown or cannot be 
updated.")
-        data = data_
+        data = extract_update_mask_data(update_mask, non_update_fields, data)
     for key in data:
         setattr(connection, key, data[key])
     session.add(connection)
diff --git a/airflow/api_connexion/endpoints/update_mask.py 
b/airflow/api_connexion/endpoints/update_mask.py
new file mode 100644
index 0000000000..38fd255f51
--- /dev/null
+++ b/airflow/api_connexion/endpoints/update_mask.py
@@ -0,0 +1,34 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any, Mapping, Sequence
+
+from airflow.api_connexion.exceptions import BadRequest
+
+
+def extract_update_mask_data(
+    update_mask: Sequence[str], non_update_fields: list[str], data: 
Mapping[str, Any]
+) -> Mapping[str, Any]:
+    extracted_data = {}
+    for field in update_mask:
+        field = field.strip()
+        if field in data and field not in non_update_fields:
+            extracted_data[field] = data[field]
+        else:
+            raise BadRequest(detail=f"'{field}' is unknown or cannot be 
updated.")
+    return extracted_data
diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py 
b/airflow/api_connexion/endpoints/variable_endpoint.py
index 3111ff18d4..da8f35fcb8 100644
--- a/airflow/api_connexion/endpoints/variable_endpoint.py
+++ b/airflow/api_connexion/endpoints/variable_endpoint.py
@@ -25,6 +25,7 @@ from sqlalchemy.orm import Session
 
 from airflow.api_connexion import security
 from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
+from airflow.api_connexion.endpoints.update_mask import 
extract_update_mask_data
 from airflow.api_connexion.exceptions import BadRequest, NotFound
 from airflow.api_connexion.parameters import apply_sorting, check_limit, 
format_parameters
 from airflow.api_connexion.schemas.variable_schema import 
variable_collection_schema, variable_schema
@@ -88,13 +89,19 @@ def get_variables(
 
 
 @security.requires_access([(permissions.ACTION_CAN_EDIT, 
permissions.RESOURCE_VARIABLE)])
+@provide_session
 @action_logging(
     event=action_event_from_permission(
         prefix=RESOURCE_EVENT_PREFIX,
         permission=permissions.ACTION_CAN_EDIT,
     ),
 )
-def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> 
Response:
+def patch_variable(
+    *,
+    variable_key: str,
+    update_mask: UpdateMask = None,
+    session: Session = NEW_SESSION,
+) -> Response:
     """Update a variable by key."""
     try:
         data = variable_schema.load(get_json_request_dict())
@@ -103,15 +110,14 @@ def patch_variable(*, variable_key: str, update_mask: 
UpdateMask = None) -> Resp
 
     if data["key"] != variable_key:
         raise BadRequest("Invalid post body", detail="key from request body 
doesn't match uri parameter")
-
+    non_update_fields = ["key"]
+    variable = session.query(Variable).filter_by(key=variable_key).first()
     if update_mask:
-        if "key" in update_mask:
-            raise BadRequest("key is a ready only field")
-        if "value" not in update_mask:
-            raise BadRequest("No field to update")
-
-    Variable.set(data["key"], data["val"])
-    return variable_schema.dump(data)
+        data = extract_update_mask_data(update_mask, non_update_fields, data)
+    for key, val in data.items():
+        setattr(variable, key, val)
+    session.add(variable)
+    return variable_schema.dump(variable)
 
 
 @security.requires_access([(permissions.ACTION_CAN_CREATE, 
permissions.RESOURCE_VARIABLE)])
diff --git a/tests/api_connexion/endpoints/test_update_mask.py 
b/tests/api_connexion/endpoints/test_update_mask.py
new file mode 100644
index 0000000000..4221f11a13
--- /dev/null
+++ b/tests/api_connexion/endpoints/test_update_mask.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pytest
+
+from airflow.api_connexion.endpoints.update_mask import 
extract_update_mask_data
+from airflow.api_connexion.exceptions import BadRequest
+
+
+class TestUpdateMask:
+    def test_should_extract_data(self):
+        non_update_fields = ["field_1"]
+        update_mask = ["field_2"]
+        data = {
+            "field_1": "value_1",
+            "field_2": "value_2",
+            "field_3": "value_3",
+        }
+        data = extract_update_mask_data(update_mask, non_update_fields, data)
+        assert data == {"field_2": "value_2"}
+
+    def test_update_forbid_field_should_raise_exception(self):
+        non_update_fields = ["field_1"]
+        update_mask = ["field_1", "field_2"]
+        data = {
+            "field_1": "value_1",
+            "field_2": "value_2",
+            "field_3": "value_3",
+        }
+        with pytest.raises(BadRequest):
+            extract_update_mask_data(update_mask, non_update_fields, data)
+
+    def test_update_unknown_field_should_raise_exception(self):
+        non_update_fields = ["field_1"]
+        update_mask = ["field_2", "field_3"]
+        data = {
+            "field_1": "value_1",
+            "field_2": "value_2",
+        }
+        with pytest.raises(BadRequest):
+            extract_update_mask_data(update_mask, non_update_fields, data)
diff --git a/tests/api_connexion/endpoints/test_variable_endpoint.py 
b/tests/api_connexion/endpoints/test_variable_endpoint.py
index 83d2fa57d4..6c1537bc89 100644
--- a/tests/api_connexion/endpoints/test_variable_endpoint.py
+++ b/tests/api_connexion/endpoints/test_variable_endpoint.py
@@ -229,10 +229,18 @@ class TestPatchVariable(TestVariableEndpoint):
             environ_overrides={"REMOTE_USER": "test"},
         )
         assert response.status_code == 200
-        assert response.json == {
-            "key": "var1",
-            "value": "updated",
-        }
+        assert response.json == {"key": "var1", "value": "updated", 
"description": None}
+        _check_last_log(session, dag_id=None, event="variable.edit", 
execution_date=None)
+
+    def test_should_update_variable_with_mask(self, session):
+        Variable.set("var1", "foo", description="before update")
+        response = self.client.patch(
+            "/api/v1/variables/var1?update_mask=description",
+            json={"key": "var1", "value": "updated", "description": 
"after_update"},
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response.status_code == 200
+        assert response.json == {"key": "var1", "value": "foo", "description": 
"after_update"}
         _check_last_log(session, dag_id=None, event="variable.edit", 
execution_date=None)
 
     def test_should_reject_invalid_update(self):

Reply via email to