This is an automated email from the ASF dual-hosted git repository.

vavila pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new fa346099523 feat: Support OAuth2 single-use refresh tokens (#38364)
fa346099523 is described below

commit fa3460995239786ca96351ba5f03c964f894c698
Author: Vitor Avila <[email protected]>
AuthorDate: Tue Mar 3 16:07:15 2026 -0300

    feat: Support OAuth2 single-use refresh tokens (#38364)
---
 superset/db_engine_specs/base.py              |  6 ++
 superset/utils/oauth2.py                      |  4 ++
 tests/unit_tests/db_engine_specs/test_base.py | 91 +++++++++++++++++++++++++++
 tests/unit_tests/utils/oauth2_tests.py        | 56 +++++++++++++++++
 4 files changed, 157 insertions(+)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index fb0e26e77e7..965eec46b12 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -572,6 +572,10 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
     oauth2_token_request_uri: str | None = None
     oauth2_token_request_type = "data"  # noqa: S105
 
+    # Driver-specific query params to be included in 
`get_oauth2_authorization_uri`
+    oauth2_additional_auth_uri_query_params: dict[str, Any] = {}
+    # Driver-specific params to be included in the `get_oauth2_token` request 
body
+    oauth2_additional_token_request_params: dict[str, Any] = {}
     # Driver-specific exception that should be mapped to OAuth2RedirectError
     oauth2_exception = OAuth2RedirectError
 
@@ -754,6 +758,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             "state": encode_oauth2_state(state),
             "redirect_uri": config["redirect_uri"],
             "client_id": config["id"],
+            **cls.oauth2_additional_auth_uri_query_params,
         }
 
         # Add PKCE parameters (RFC 7636) if code_verifier is provided
@@ -784,6 +789,7 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             "client_secret": config["secret"],
             "redirect_uri": config["redirect_uri"],
             "grant_type": "authorization_code",
+            **cls.oauth2_additional_token_request_params,
         }
         # Add PKCE code_verifier if present (RFC 7636)
         if code_verifier:
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 57cc0a25ce9..4978c0af5c5 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -167,6 +167,10 @@ def refresh_oauth2_token(
         token.access_token_expiration = datetime.now() + timedelta(
             seconds=token_response["expires_in"]
         )
+        # Support single-use refresh tokens
+        if new_refresh_token := token_response.get("refresh_token"):
+            token.refresh_token = new_refresh_token
+
         db.session.add(token)
 
     return token.access_token
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index 6c6b98e0593..14fa82b55af 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -1052,6 +1052,97 @@ def test_get_oauth2_token_with_pkce(mocker: 
MockerFixture) -> None:
     assert request_body["code_verifier"] == code_verifier
 
 
+def test_get_oauth2_authorization_uri_additional_params(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that a subclass can inject additional query params into the 
authorization URI
+    via `oauth2_additional_auth_uri_query_params`.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    class CustomEngineSpec(BaseEngineSpec):
+        oauth2_additional_auth_uri_query_params = {
+            "prompt": "consent",
+            "access_type": "offline",
+        }
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    state: OAuth2State = {
+        "database_id": 1,
+        "user_id": 1,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "1234",
+    }
+
+    url = CustomEngineSpec.get_oauth2_authorization_uri(config, state)
+    parsed = urlparse(url)
+    query = parse_qs(parsed.query)
+
+    # Standard params still present
+    assert query["response_type"][0] == "code"
+    assert query["client_id"][0] == "client-id"
+
+    # Additional params included
+    assert query["prompt"][0] == "consent"
+    assert query["access_type"][0] == "offline"
+
+
+def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
+    """
+    Test that a subclass can inject additional params into the token request 
body
+    via `oauth2_additional_token_request_params`.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    class CustomEngineSpec(BaseEngineSpec):
+        oauth2_additional_token_request_params = {
+            "audience": "https://api.example.com";,
+        }
+
+    mocker.patch(
+        "flask.current_app.config",
+        {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 
30)},
+    )
+    mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+    mock_post.return_value.json.return_value = {
+        "access_token": "test-access-token",  # noqa: S105
+        "expires_in": 3600,
+    }
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    result = CustomEngineSpec.get_oauth2_token(config, "auth-code")
+
+    assert result["access_token"] == "test-access-token"  # noqa: S105
+    call_kwargs = mock_post.call_args
+    request_body = call_kwargs.kwargs.get("json") or 
call_kwargs.kwargs.get("data")
+
+    # Standard params still present
+    assert request_body["grant_type"] == "authorization_code"
+    assert request_body["client_id"] == "client-id"
+
+    # Additional param included
+    assert request_body["audience"] == "https://api.example.com";
+
+
 def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> 
None:
     """
     Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if 
set.
diff --git a/tests/unit_tests/utils/oauth2_tests.py 
b/tests/unit_tests/utils/oauth2_tests.py
index 08b7cc9c6e7..f04ae26e7c2 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -188,6 +188,62 @@ def test_refresh_oauth2_token_no_access_token_in_response(
     assert result is None
 
 
+def test_refresh_oauth2_token_updates_refresh_token(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that refresh_oauth2_token updates the refresh token when a new one is 
returned.
+
+    Some OAuth2 providers issue single-use refresh tokens, where each token 
refresh
+    response includes a new refresh token that replaces the previous one.
+    """
+    db = mocker.patch("superset.utils.oauth2.db")
+    mocker.patch("superset.utils.oauth2.DistributedLock")
+    db_engine_spec = mocker.MagicMock()
+    db_engine_spec.get_oauth2_fresh_token.return_value = {
+        "access_token": "new-access-token",
+        "expires_in": 3600,
+        "refresh_token": "new-refresh-token",
+    }
+    token = mocker.MagicMock()
+    token.refresh_token = "old-refresh-token"  # noqa: S105
+
+    with freeze_time("2024-01-01"):
+        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+
+    assert token.access_token == "new-access-token"  # noqa: S105
+    assert token.access_token_expiration == datetime(2024, 1, 1, 1)
+    assert token.refresh_token == "new-refresh-token"  # noqa: S105
+    db.session.add.assert_called_with(token)
+
+
+def test_refresh_oauth2_token_keeps_refresh_token(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that refresh_oauth2_token keeps the existing refresh token when none 
returned.
+
+    When the OAuth2 provider does not issue a new refresh token in the 
response,
+    the original refresh token should be preserved.
+    """
+    db = mocker.patch("superset.utils.oauth2.db")
+    mocker.patch("superset.utils.oauth2.DistributedLock")
+    db_engine_spec = mocker.MagicMock()
+    db_engine_spec.get_oauth2_fresh_token.return_value = {
+        "access_token": "new-access-token",
+        "expires_in": 3600,
+    }
+    token = mocker.MagicMock()
+    token.refresh_token = "original-refresh-token"  # noqa: S105
+
+    with freeze_time("2024-01-01"):
+        refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
+
+    assert token.access_token == "new-access-token"  # noqa: S105
+    assert token.refresh_token == "original-refresh-token"  # noqa: S105
+    db.session.add.assert_called_with(token)
+
+
 def test_generate_code_verifier_length() -> None:
     """
     Test that generate_code_verifier produces a string of valid length (RFC 
7636).

Reply via email to