This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fd03dc2933 Fix reraise outside of try block in 
`AthenaHook.get_output_location` (#36008)
fd03dc2933 is described below

commit fd03dc29336e1331d20de0113993dd5a35353ee0
Author: Andrey Anshin <[email protected]>
AuthorDate: Fri Dec 1 20:53:04 2023 +0400

    Fix reraise outside of try block in `AthenaHook.get_output_location` 
(#36008)
---
 airflow/providers/amazon/aws/hooks/athena.py    | 25 +++++++++++--------------
 tests/providers/amazon/aws/hooks/test_athena.py | 25 +++++++++++++++++++++++++
 2 files changed, 36 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/athena.py 
b/airflow/providers/amazon/aws/hooks/athena.py
index 3715b4a1bc..04853621f1 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -292,20 +292,17 @@ class AthenaHook(AwsBaseHook):
 
         :param query_execution_id: Id of submitted athena query
         """
-        if query_execution_id:
-            response = 
self.get_query_info(query_execution_id=query_execution_id, use_cache=True)
-
-            if response:
-                try:
-                    return 
response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
-                except KeyError:
-                    self.log.error(
-                        "Error retrieving OutputLocation. Query execution id: 
%s", query_execution_id
-                    )
-                    raise
-            else:
-                raise
-        raise ValueError("Invalid Query execution id. Query execution id: %s", 
query_execution_id)
+        if not query_execution_id:
+            raise ValueError(f"Invalid Query execution id. Query execution id: 
{query_execution_id}")
+
+        if not (response := 
self.get_query_info(query_execution_id=query_execution_id, use_cache=True)):
+            raise ValueError(f"Unable to get query information for execution 
id: {query_execution_id}")
+
+        try:
+            return 
response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
+        except KeyError:
+            self.log.error("Error retrieving OutputLocation. Query execution 
id: %s", query_execution_id)
+            raise
 
     def stop_query(self, query_execution_id: str) -> dict:
         """Cancel the submitted query.
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py 
b/tests/providers/amazon/aws/hooks/test_athena.py
index 8f224f0b2d..a61663a8fb 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -18,6 +18,8 @@ from __future__ import annotations
 
 from unittest import mock
 
+import pytest
+
 from airflow.providers.amazon.aws.hooks.athena import AthenaHook
 
 MOCK_DATA = {
@@ -197,6 +199,29 @@ class TestAthenaHook:
         result = 
self.athena.get_output_location(query_execution_id=MOCK_DATA["query_execution_id"])
         assert result == "s3://test_bucket/test.csv"
 
+    @pytest.mark.parametrize(
+        "query_execution_id", [pytest.param("", id="empty-string"), 
pytest.param(None, id="none")]
+    )
+    def test_hook_get_output_location_empty_execution_id(self, 
query_execution_id):
+        with pytest.raises(ValueError, match="Invalid Query execution id"):
+            
self.athena.get_output_location(query_execution_id=query_execution_id)
+
+    @pytest.mark.parametrize("response", [pytest.param({}, id="empty-dict"), 
pytest.param(None, id="none")])
+    def test_hook_get_output_location_no_response(self, response):
+        with mock.patch.object(AthenaHook, "get_query_info", 
return_value=response) as m:
+            with pytest.raises(ValueError, match="Unable to get query 
information"):
+                
self.athena.get_output_location(query_execution_id="PLACEHOLDER")
+            m.assert_called_once_with(query_execution_id="PLACEHOLDER", 
use_cache=True)
+
+    def test_hook_get_output_location_invalid_response(self, caplog):
+        with mock.patch.object(AthenaHook, "get_query_info") as m:
+            m.return_value = {"foo": "bar"}
+            caplog.clear()
+            caplog.set_level("ERROR")
+            with pytest.raises(KeyError):
+                
self.athena.get_output_location(query_execution_id="PLACEHOLDER")
+            assert "Error retrieving OutputLocation" in caplog.text
+
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_get_query_info_caching(self, mock_conn):
         mock_conn.return_value.get_query_execution.return_value = 
MOCK_QUERY_EXECUTION_OUTPUT

Reply via email to