This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 103df61bde6 Fix the way to get STS endpoint in EKS hook (#45520)
103df61bde6 is described below
commit 103df61bde6e98de2466f42e76b4ac3bcc4ab8b5
Author: Vincent <[email protected]>
AuthorDate: Thu Jan 9 13:20:12 2025 -0500
Fix the way to get STS endpoint in EKS hook (#45520)
---
providers/src/airflow/providers/amazon/aws/hooks/eks.py | 4 ++--
providers/tests/amazon/aws/hooks/test_eks.py | 8 +++-----
2 files changed, 5 insertions(+), 7 deletions(-)
diff --git a/providers/src/airflow/providers/amazon/aws/hooks/eks.py
b/providers/src/airflow/providers/amazon/aws/hooks/eks.py
index d48c103505a..4e8c0ca7ad6 100644
--- a/providers/src/airflow/providers/amazon/aws/hooks/eks.py
+++ b/providers/src/airflow/providers/amazon/aws/hooks/eks.py
@@ -32,6 +32,7 @@ from botocore.exceptions import ClientError
from botocore.signers import RequestSigner
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.hooks.sts import StsHook
from airflow.utils import yaml
from airflow.utils.json import AirflowJsonEncoder
@@ -612,8 +613,7 @@ class EksHook(AwsBaseHook):
def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str:
session = self.get_session()
service_id = self.conn.meta.service_model.service_id
- sts_client = session.client("sts")
- sts_url =
f"{sts_client.meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15"
+ sts_url =
f"{StsHook().conn_client_meta.endpoint_url}/?Action=GetCallerIdentity&Version=2011-06-15"
signer = RequestSigner(
service_id=service_id,
diff --git a/providers/tests/amazon/aws/hooks/test_eks.py
b/providers/tests/amazon/aws/hooks/test_eks.py
index 06cc7ddab53..10a93790ac6 100644
--- a/providers/tests/amazon/aws/hooks/test_eks.py
+++ b/providers/tests/amazon/aws/hooks/test_eks.py
@@ -22,7 +22,6 @@ from copy import deepcopy
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
-from unittest.mock import Mock
from urllib.parse import urlsplit
import pytest
@@ -1284,14 +1283,13 @@ class TestEksHook:
}
@mock.patch("airflow.providers.amazon.aws.hooks.eks.RequestSigner")
+ @mock.patch("airflow.providers.amazon.aws.hooks.eks.StsHook")
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.conn")
@mock.patch("airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook.get_session")
- def test_fetch_access_token_for_cluster(self, mock_get_session, mock_conn,
mock_signer):
+ def test_fetch_access_token_for_cluster(self, mock_get_session, mock_conn,
mock_sts_hook, mock_signer):
mock_signer.return_value.generate_presigned_url.return_value =
"http://example.com"
mock_get_session.return_value.region_name = "us-east-1"
- client = Mock()
- client.meta.endpoint_url = "https://sts.us-east-1.amazonaws.com"
- mock_get_session.return_value.client.return_value = client
+ mock_sts_hook.return_value.conn_client_meta.endpoint_url =
"https://sts.us-east-1.amazonaws.com"
hook = EksHook()
token =
hook.fetch_access_token_for_cluster(eks_cluster_name="test-cluster")
mock_signer.assert_called_once_with(