This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 7ff3364bf feat(csharp/test/Drivers/Databricks): Add mandatory token
exchange (#3192)
7ff3364bf is described below
commit 7ff3364bf449e7496e4b32a0a6f7aa499c75a4fc
Author: Alex Guo <[email protected]>
AuthorDate: Thu Jul 31 13:08:38 2025 -0700
feat(csharp/test/Drivers/Databricks): Add mandatory token exchange (#3192)
## Motivation
Databricks will eventually require that all non-inhouse OAuth tokens be
exchanged for Databricks OAuth tokens before accessing resources. This
change implements mandatory token exchange before sending Thrift
requests. This check and exchange is performed in the background for now
to reduce latency, but it will eventually need to be blocking if
non-inhouse OAuth tokens will fail to access Databricks resources in the
future.
## Key Components
1. JWT Token Decoder - Decodes JWT tokens to inspect the issuer claim
and determine if token exchange is necessary
2. MandatoryTokenExchangeDelegatingHandler - HTTP handler that
intercepts requests and performs token exchange when required
3. TokenExchangeClient - Handles the token exchange logic with the same
/oidc/v1/token endpoint as token refresh, with slightly different
parameters
## Changes
- Added new connection string parameter: IdentityFederationClientId for
service principal workload identity federation scenarios
- Implemented token exchange logic that checks JWT issuer against
workspace host
- Introduced fallback behavior to maintain backward compatibility if
token exchange fails
## Testing
`dotnet test --filter
"FullyQualifiedName~MandatoryTokenExchangeDelegatingHandlerTests"`
```
[xUnit.net 00:00:00.00] xUnit.net VSTest Adapter v3.1.1+bf6400fd51 (64-bit
.NET 8.0.7)
[xUnit.net 00:00:00.06] Discovering:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.15] Discovered:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.16] Starting:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:01.77] Finished:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (2.6s)
Test summary: total: 11, failed: 0, succeeded: 11, skipped: 0, duration:
2.6s
```
`dotnet test --filter "FullyQualifiedName~TokenExchangeClientTests"`
```
[xUnit.net 00:00:00.00] xUnit.net VSTest Adapter v3.1.1+bf6400fd51 (64-bit
.NET 8.0.7)
[xUnit.net 00:00:00.06] Discovering:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.14] Discovered:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.15] Starting:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.23] Finished:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (0.8s)
Test summary: total: 19, failed: 0, succeeded: 19, skipped: 0, duration:
0.8s
```
`dotnet test --filter "FullyQualifiedName~JwtTokenDecoderTests"`
```
[xUnit.net 00:00:00.00] xUnit.net VSTest Adapter v3.1.1+bf6400fd51 (64-bit
.NET 8.0.7)
[xUnit.net 00:00:00.06] Discovering:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.14] Discovered:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.15] Starting:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:00:00.19] Finished:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (0.8s)
Test summary: total: 10, failed: 0, succeeded: 10, skipped: 0, duration:
0.8s
```
Also tested E2E manually with AAD tokens for Azure Databricks
workspaces, AAD tokens for AWS Databricks workspaces, and service
principal workload identity federation tokens
---
.../src/Drivers/Databricks/Auth/JwtTokenDecoder.cs | 35 ++
... => MandatoryTokenExchangeDelegatingHandler.cs} | 132 ++---
.../Drivers/Databricks/Auth/TokenExchangeClient.cs | 70 ++-
...Handler.cs => TokenRefreshDelegatingHandler.cs} | 8 +-
.../src/Drivers/Databricks/DatabricksConnection.cs | 111 ++--
.../src/Drivers/Databricks/DatabricksParameters.cs | 6 +
.../Databricks/E2E/Auth/TokenExchangeTests.cs | 6 +-
.../Databricks/Unit/Auth/JwtTokenDecoderTests.cs | 105 +++-
...MandatoryTokenExchangeDelegatingHandlerTests.cs | 578 +++++++++++++++++++++
.../Unit/Auth/TokenExchangeClientTests.cs | 285 ++++++++--
...ts.cs => TokenRefreshDelegatingHandlerTests.cs} | 84 +--
11 files changed, 1210 insertions(+), 210 deletions(-)
diff --git a/csharp/src/Drivers/Databricks/Auth/JwtTokenDecoder.cs
b/csharp/src/Drivers/Databricks/Auth/JwtTokenDecoder.cs
index ed68aeeb8..acb21b3ea 100644
--- a/csharp/src/Drivers/Databricks/Auth/JwtTokenDecoder.cs
+++ b/csharp/src/Drivers/Databricks/Auth/JwtTokenDecoder.cs
@@ -69,6 +69,41 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
}
}
+ /// <summary>
+ /// Tries to extract the issuer (iss) claim from a JWT token.
+ /// </summary>
+ /// <param name="token">The JWT token to parse.</param>
+ /// <param name="issuer">The extracted issuer, if successful.</param>
+ /// <returns>True if the issuer was successfully extracted, false
otherwise.</returns>
+ public static bool TryGetIssuer(string token, out string issuer)
+ {
+ issuer = string.Empty;
+
+ try
+ {
+ string[] parts = token.Split('.');
+ if (parts.Length != 3)
+ {
+ return false;
+ }
+
+ string payload = DecodeBase64Url(parts[1]);
+ using JsonDocument jsonDoc = JsonDocument.Parse(payload);
+
+ if (!jsonDoc.RootElement.TryGetProperty("iss", out JsonElement
issElement))
+ {
+ return false;
+ }
+
+ issuer = issElement.GetString() ?? string.Empty;
+ return !string.IsNullOrEmpty(issuer);
+ }
+ catch
+ {
+ return false;
+ }
+ }
+
/// <summary>
/// Decodes a base64url encoded string to a regular string.
/// </summary>
diff --git
a/csharp/src/Drivers/Databricks/Auth/TokenExchangeDelegatingHandler.cs
b/csharp/src/Drivers/Databricks/Auth/MandatoryTokenExchangeDelegatingHandler.cs
similarity index 51%
copy from csharp/src/Drivers/Databricks/Auth/TokenExchangeDelegatingHandler.cs
copy to
csharp/src/Drivers/Databricks/Auth/MandatoryTokenExchangeDelegatingHandler.cs
index 2748dc951..edda03f53 100644
--- a/csharp/src/Drivers/Databricks/Auth/TokenExchangeDelegatingHandler.cs
+++
b/csharp/src/Drivers/Databricks/Auth/MandatoryTokenExchangeDelegatingHandler.cs
@@ -24,109 +24,118 @@ using System.Threading.Tasks;
namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
{
/// <summary>
- /// HTTP message handler that automatically refreshes OAuth tokens before
they expire.
- /// Uses a non-blocking approach to refresh tokens in the background.
+ /// HTTP message handler that performs mandatory token exchange for
non-Databricks tokens.
+ /// Uses a non-blocking approach to exchange tokens in the background.
/// </summary>
- internal class TokenExchangeDelegatingHandler : DelegatingHandler
+ internal class MandatoryTokenExchangeDelegatingHandler : DelegatingHandler
{
- private readonly string _initialToken;
- private readonly int _tokenRenewLimitMinutes;
+ private readonly string? _identityFederationClientId;
private readonly object _tokenLock = new object();
private readonly ITokenExchangeClient _tokenExchangeClient;
+ private string? _currentToken;
+ private string? _lastSeenToken;
- private string _currentToken;
- private DateTime _tokenExpiryTime;
- private bool _tokenExchangeAttempted = false;
- private Task? _pendingTokenTask = null;
+ protected Task? _pendingTokenTask = null;
/// <summary>
- /// Initializes a new instance of the <see
cref="TokenExchangeDelegatingHandler"/> class.
+ /// Initializes a new instance of the <see
cref="MandatoryTokenExchangeDelegatingHandler"/> class.
/// </summary>
/// <param name="innerHandler">The inner handler to delegate
to.</param>
/// <param name="tokenExchangeClient">The client for token exchange
operations.</param>
- /// <param name="initialToken">The initial token from the connection
string.</param>
- /// <param name="tokenExpiryTime">The expiry time of the initial
token.</param>
- /// <param name="tokenRenewLimitMinutes">The minutes before token
expiration when we should start renewing the token.</param>
- public TokenExchangeDelegatingHandler(
+ /// <param name="identityFederationClientId">Optional identity
federation client ID.</param>
+ public MandatoryTokenExchangeDelegatingHandler(
HttpMessageHandler innerHandler,
ITokenExchangeClient tokenExchangeClient,
- string initialToken,
- DateTime tokenExpiryTime,
- int tokenRenewLimitMinutes)
+ string? identityFederationClientId = null)
: base(innerHandler)
{
_tokenExchangeClient = tokenExchangeClient ?? throw new
ArgumentNullException(nameof(tokenExchangeClient));
- _initialToken = initialToken ?? throw new
ArgumentNullException(nameof(initialToken));
- _tokenExpiryTime = tokenExpiryTime;
- _tokenRenewLimitMinutes = tokenRenewLimitMinutes;
- _currentToken = initialToken;
+ _identityFederationClientId = identityFederationClientId;
}
/// <summary>
- /// Checks if the token needs to be renewed.
+ /// Determines if token exchange is needed by checking if the token is
a Databricks token.
/// </summary>
- /// <returns>True if the token needs to be renewed, false
otherwise.</returns>
- private bool NeedsTokenRenewal()
+ /// <returns>True if token exchange is needed, false
otherwise.</returns>
+ private bool NeedsTokenExchange(string bearerToken)
{
- // Only renew if:
- // 1. We haven't already attempted token exchange (a token can
only be renewed once)
- // 2. The token will expire within the renewal limit
- // 3. We don't already have a pending refresh task
- return !_tokenExchangeAttempted &&
- DateTime.UtcNow.AddMinutes(_tokenRenewLimitMinutes) >=
_tokenExpiryTime &&
- _pendingTokenTask == null;
+ // If we already started exchange for this token, no need to check
again
+ if (_lastSeenToken == bearerToken)
+ {
+ return false;
+ }
+
+ // If we already have a pending token task, don't start another
exchange
+ if (_pendingTokenTask != null)
+ {
+ return false;
+ }
+
+ // If we can't parse the token as JWT, default to use existing
token
+ if (!JwtTokenDecoder.TryGetIssuer(bearerToken, out string issuer))
+ {
+ return false;
+ }
+
+ // Check if the issuer matches the current workspace host
+ // If the issuer is from the same host, it's already a Databricks
token
+ string normalizedHost =
_tokenExchangeClient.TokenExchangeEndpoint.Replace("/v1/token",
"").ToLowerInvariant();
+ string normalizedIssuer = issuer.TrimEnd('/').ToLowerInvariant();
+
+ return normalizedIssuer != normalizedHost;
}
/// <summary>
- /// Starts token renewal in the background if needed.
+ /// Starts token exchange in the background if needed.
/// </summary>
+ /// <param name="bearerToken">The bearer token to potentially
exchange.</param>
/// <param name="cancellationToken">A cancellation token.</param>
- private void StartTokenRenewalIfNeeded(CancellationToken
cancellationToken)
+ private void StartTokenExchangeIfNeeded(string bearerToken,
CancellationToken cancellationToken)
{
- if (!NeedsTokenRenewal())
+ if (_lastSeenToken == bearerToken)
{
return;
}
- bool needsRenewal;
+ bool needsExchange;
lock (_tokenLock)
{
- // Double-check pattern in case another thread renewed while
we were waiting
- needsRenewal = NeedsTokenRenewal();
- if (needsRenewal)
- {
- // Mark that we've attempted token exchange to prevent
multiple attempts
- // Specifically, NeedsTokenRenewal checks this flag
- _tokenExchangeAttempted = true;
- }
+ needsExchange = NeedsTokenExchange(bearerToken);
+
+ _lastSeenToken = bearerToken;
}
- if (!needsRenewal)
+ if (!needsExchange)
{
return;
}
- // Start token refresh in the background
+ // Start token exchange in the background
_pendingTokenTask = Task.Run(async () =>
{
try
{
- TokenExchangeResponse response = await
_tokenExchangeClient.ExchangeTokenAsync(_initialToken, cancellationToken);
+ TokenExchangeResponse response = await
_tokenExchangeClient.ExchangeTokenAsync(
+ bearerToken,
+ _identityFederationClientId,
+ cancellationToken);
- // Update the token atomically when ready
lock (_tokenLock)
{
_currentToken = response.AccessToken;
- _tokenExpiryTime = response.ExpiryTime;
}
}
catch (Exception ex)
{
- // Log the error but continue with the current token
- // This is to avoid interrupting the operation if token
exchange fails
- System.Diagnostics.Debug.WriteLine($"Token exchange
failed: {ex.Message}");
+ System.Diagnostics.Debug.WriteLine($"Mandatory token
exchange failed: {ex.Message}");
+ }
+ }, cancellationToken).ContinueWith(_ =>
+ {
+ lock (_tokenLock)
+ {
+ _pendingTokenTask = null;
}
- }, cancellationToken);
+ }, TaskScheduler.Default);
}
/// <summary>
@@ -137,16 +146,20 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
/// <returns>The HTTP response message.</returns>
protected override async Task<HttpResponseMessage>
SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
- StartTokenRenewalIfNeeded(cancellationToken);
-
- // Use the current token (which might be the old one while refresh
is in progress)
- string tokenToUse;
- lock (_tokenLock)
+ string? bearerToken = request.Headers.Authorization?.Parameter;
+ if (!string.IsNullOrEmpty(bearerToken))
{
- tokenToUse = _currentToken;
+ StartTokenExchangeIfNeeded(bearerToken!, cancellationToken);
+
+ string tokenToUse;
+ lock (_tokenLock)
+ {
+ tokenToUse = _currentToken ?? bearerToken!;
+ }
+
+ request.Headers.Authorization = new
AuthenticationHeaderValue("Bearer", tokenToUse);
}
- request.Headers.Authorization = new
AuthenticationHeaderValue("Bearer", tokenToUse);
return await base.SendAsync(request, cancellationToken);
}
@@ -164,6 +177,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
}
catch (Exception ex)
{
+ // Log any exceptions during disposal
System.Diagnostics.Debug.WriteLine($"Exception during
token task cleanup: {ex.Message}");
}
}
diff --git a/csharp/src/Drivers/Databricks/Auth/TokenExchangeClient.cs
b/csharp/src/Drivers/Databricks/Auth/TokenExchangeClient.cs
index 5246c16fa..0f28ce224 100644
--- a/csharp/src/Drivers/Databricks/Auth/TokenExchangeClient.cs
+++ b/csharp/src/Drivers/Databricks/Auth/TokenExchangeClient.cs
@@ -56,12 +56,26 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
internal interface ITokenExchangeClient
{
/// <summary>
- /// Exchanges the provided token for a new token.
+ /// Gets the token exchange endpoint URL.
+ /// </summary>
+ string TokenExchangeEndpoint { get; }
+
+ /// <summary>
+ /// Refreshes the provided token to extend the lifetime.
+ /// </summary>
+ /// <param name="token">The token to refresh.</param>
+ /// <param name="cancellationToken">A cancellation token.</param>
+ /// <returns>The response from the token exchange API.</returns>
+ Task<TokenExchangeResponse> RefreshTokenAsync(string token,
CancellationToken cancellationToken);
+
+ /// <summary>
+ /// Exchanges the provided token for a Databricks OAuth token.
/// </summary>
/// <param name="token">The token to exchange.</param>
+ /// <param name="identityFederationClientId">Optional identity
federation client ID.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>The response from the token exchange API.</returns>
- Task<TokenExchangeResponse> ExchangeTokenAsync(string token,
CancellationToken cancellationToken);
+ Task<TokenExchangeResponse> ExchangeTokenAsync(string token, string?
identityFederationClientId, CancellationToken cancellationToken);
}
/// <summary>
@@ -72,6 +86,8 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
private readonly HttpClient _httpClient;
private readonly string _tokenExchangeEndpoint;
+ public string TokenExchangeEndpoint => _tokenExchangeEndpoint;
+
/// <summary>
/// Initializes a new instance of the <see
cref="TokenExchangeClient"/> class.
/// </summary>
@@ -93,12 +109,12 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
}
/// <summary>
- /// Exchanges the provided token for a new token.
+ /// Refreshes the provided token to extend the lifetime.
/// </summary>
- /// <param name="token">The token to exchange.</param>
+ /// <param name="token">The token to refresh.</param>
/// <param name="cancellationToken">A cancellation token.</param>
/// <returns>The response from the token exchange API.</returns>
- public async Task<TokenExchangeResponse> ExchangeTokenAsync(string
token, CancellationToken cancellationToken)
+ public async Task<TokenExchangeResponse> RefreshTokenAsync(string
token, CancellationToken cancellationToken)
{
var content = new FormUrlEncodedContent(new[]
{
@@ -120,6 +136,50 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
return ParseTokenResponse(responseContent);
}
+ /// <summary>
+ /// Exchanges the provided token for a Databricks OAuth token.
+ /// </summary>
+ /// <param name="token">The token to exchange.</param>
+ /// <param name="identityFederationClientId">Optional identity
federation client ID.</param>
+ /// <param name="cancellationToken">A cancellation token.</param>
+ /// <returns>The response from the token exchange API.</returns>
+ public async Task<TokenExchangeResponse> ExchangeTokenAsync(
+ string token,
+ string? identityFederationClientId,
+ CancellationToken cancellationToken)
+ {
+ var formData = new List<KeyValuePair<string, string>>
+ {
+ new KeyValuePair<string, string>("grant_type",
"urn:ietf:params:oauth:grant-type:jwt-bearer"),
+ new KeyValuePair<string, string>("assertion", token),
+ new KeyValuePair<string, string>("scope", "sql")
+ };
+
+ if (!string.IsNullOrEmpty(identityFederationClientId))
+ {
+ formData.Add(new KeyValuePair<string,
string>("identity_federation_client_id", identityFederationClientId!));
+ }
+ else
+ {
+ formData.Add(new KeyValuePair<string,
string>("return_original_token_if_authenticated", "true"));
+ }
+
+ var content = new FormUrlEncodedContent(formData);
+
+ var request = new HttpRequestMessage(HttpMethod.Post,
_tokenExchangeEndpoint)
+ {
+ Content = content
+ };
+ request.Headers.Accept.Add(new
System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("*/*"));
+
+ HttpResponseMessage response = await
_httpClient.SendAsync(request, cancellationToken);
+
+ response.EnsureSuccessStatusCode();
+
+ string responseContent = await
response.Content.ReadAsStringAsync();
+ return ParseTokenResponse(responseContent);
+ }
+
/// <summary>
/// Parses the token exchange API response.
/// </summary>
diff --git
a/csharp/src/Drivers/Databricks/Auth/TokenExchangeDelegatingHandler.cs
b/csharp/src/Drivers/Databricks/Auth/TokenRefreshDelegatingHandler.cs
similarity index 96%
rename from csharp/src/Drivers/Databricks/Auth/TokenExchangeDelegatingHandler.cs
rename to csharp/src/Drivers/Databricks/Auth/TokenRefreshDelegatingHandler.cs
index 2748dc951..7608478a0 100644
--- a/csharp/src/Drivers/Databricks/Auth/TokenExchangeDelegatingHandler.cs
+++ b/csharp/src/Drivers/Databricks/Auth/TokenRefreshDelegatingHandler.cs
@@ -27,7 +27,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
/// HTTP message handler that automatically refreshes OAuth tokens before
they expire.
/// Uses a non-blocking approach to refresh tokens in the background.
/// </summary>
- internal class TokenExchangeDelegatingHandler : DelegatingHandler
+ internal class TokenRefreshDelegatingHandler : DelegatingHandler
{
private readonly string _initialToken;
private readonly int _tokenRenewLimitMinutes;
@@ -40,14 +40,14 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
private Task? _pendingTokenTask = null;
/// <summary>
- /// Initializes a new instance of the <see
cref="TokenExchangeDelegatingHandler"/> class.
+ /// Initializes a new instance of the <see
cref="TokenRefreshDelegatingHandler"/> class.
/// </summary>
/// <param name="innerHandler">The inner handler to delegate
to.</param>
/// <param name="tokenExchangeClient">The client for token exchange
operations.</param>
/// <param name="initialToken">The initial token from the connection
string.</param>
/// <param name="tokenExpiryTime">The expiry time of the initial
token.</param>
/// <param name="tokenRenewLimitMinutes">The minutes before token
expiration when we should start renewing the token.</param>
- public TokenExchangeDelegatingHandler(
+ public TokenRefreshDelegatingHandler(
HttpMessageHandler innerHandler,
ITokenExchangeClient tokenExchangeClient,
string initialToken,
@@ -111,7 +111,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Auth
{
try
{
- TokenExchangeResponse response = await
_tokenExchangeClient.ExchangeTokenAsync(_initialToken, cancellationToken);
+ TokenExchangeResponse response = await
_tokenExchangeClient.RefreshTokenAsync(_initialToken, cancellationToken);
// Update the token atomically when ready
lock (_tokenLock)
diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
index b5c11588f..14a0eef08 100644
--- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
@@ -66,6 +66,9 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
private string _traceParentHeaderName = "traceparent";
private bool _traceStateEnabled = false;
+ // Identity federation client ID for token exchange
+ private string? _identityFederationClientId;
+
// Default namespace
private TNamespace? _defaultNamespace;
@@ -273,6 +276,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
// Default QueryTimeSeconds in Hive2Connection is only 60s,
which is too small for lots of long running query
QueryTimeoutSeconds = DefaultQueryTimeSeconds;
}
+
+ if
(Properties.TryGetValue(DatabricksParameters.IdentityFederationClientId, out
string? identityFederationClientId))
+ {
+ _identityFederationClientId = identityFederationClientId;
+ }
}
/// <summary>
@@ -354,61 +362,64 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
baseAuthHandler = new RetryHttpHandler(baseAuthHandler,
TemporarilyUnavailableRetryTimeout);
}
- Debug.Assert(_authHttpClient == null, "Auth HttpClient should not
be initialized yet.");
- _authHttpClient = new HttpClient(baseAuthHandler);
-
- // Add OAuth client credentials handler if OAuth M2M
authentication is being used
if (Properties.TryGetValue(SparkParameters.AuthType, out string?
authType) &&
SparkAuthTypeParser.TryParse(authType, out SparkAuthType
authTypeValue) &&
- authTypeValue == SparkAuthType.OAuth &&
- Properties.TryGetValue(DatabricksParameters.OAuthGrantType,
out string? grantTypeStr) &&
- DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out
DatabricksOAuthGrantType grantType) &&
- grantType == DatabricksOAuthGrantType.ClientCredentials)
+ authTypeValue == SparkAuthType.OAuth)
{
- string host = GetHost();
-
- Properties.TryGetValue(DatabricksParameters.OAuthClientId, out
string? clientId);
- Properties.TryGetValue(DatabricksParameters.OAuthClientSecret,
out string? clientSecret);
- Properties.TryGetValue(DatabricksParameters.OAuthScope, out
string? scope);
-
- var tokenProvider = new OAuthClientCredentialsProvider(
- _authHttpClient,
- clientId!,
- clientSecret!,
- host!,
- scope: scope ?? "sql",
- timeoutMinutes: 1
- );
+ Debug.Assert(_authHttpClient == null, "Auth HttpClient should
not be initialized yet.");
+ _authHttpClient = new HttpClient(baseAuthHandler);
- baseHandler = new OAuthDelegatingHandler(baseHandler,
tokenProvider);
- }
- // Add token exchange handler if token renewal is enabled and the
auth type is OAuth access token
- else if
(Properties.TryGetValue(DatabricksParameters.TokenRenewLimit, out string?
tokenRenewLimitStr) &&
- int.TryParse(tokenRenewLimitStr, out int tokenRenewLimit) &&
- tokenRenewLimit > 0 &&
- Properties.TryGetValue(SparkParameters.AuthType, out string?
authTypeForToken) &&
- SparkAuthTypeParser.TryParse(authTypeForToken, out
SparkAuthType authTypeValueForToken) &&
- authTypeValueForToken == SparkAuthType.OAuth &&
- Properties.TryGetValue(SparkParameters.AccessToken, out
string? accessToken))
- {
- if (string.IsNullOrEmpty(accessToken))
- {
- throw new ArgumentException("Access token is required for
OAuth authentication with token renewal.");
- }
-
- // Check if token is a JWT token by trying to decode it
- if (JwtTokenDecoder.TryGetExpirationTime(accessToken, out
DateTime expiryTime))
+ string host = GetHost();
+ ITokenExchangeClient tokenExchangeClient = new
TokenExchangeClient(_authHttpClient, host);
+
+ // Mandatory token exchange should be the inner handler so
that it happens
+ // AFTER the OAuth handlers (e.g. after M2M sets the access
token)
+ baseHandler = new MandatoryTokenExchangeDelegatingHandler(
+ baseHandler,
+ tokenExchangeClient,
+ _identityFederationClientId);
+
+ // Add OAuth client credentials handler if OAuth M2M
authentication is being used
+ if
(Properties.TryGetValue(DatabricksParameters.OAuthGrantType, out string?
grantTypeStr) &&
+ DatabricksOAuthGrantTypeParser.TryParse(grantTypeStr, out
DatabricksOAuthGrantType grantType) &&
+ grantType == DatabricksOAuthGrantType.ClientCredentials)
{
- string host = GetHost();
-
- var tokenExchangeClient = new
TokenExchangeClient(_authHttpClient, host);
-
- baseHandler = new TokenExchangeDelegatingHandler(
- baseHandler,
- tokenExchangeClient,
- accessToken,
- expiryTime,
- tokenRenewLimit);
+ Properties.TryGetValue(DatabricksParameters.OAuthClientId,
out string? clientId);
+
Properties.TryGetValue(DatabricksParameters.OAuthClientSecret, out string?
clientSecret);
+ Properties.TryGetValue(DatabricksParameters.OAuthScope,
out string? scope);
+
+ var tokenProvider = new OAuthClientCredentialsProvider(
+ _authHttpClient,
+ clientId!,
+ clientSecret!,
+ host!,
+ scope: scope ?? "sql",
+ timeoutMinutes: 1
+ );
+
+ baseHandler = new OAuthDelegatingHandler(baseHandler,
tokenProvider);
+ }
+ // Add token renewal handler for OAuth access token
+ else if
(Properties.TryGetValue(DatabricksParameters.TokenRenewLimit, out string?
tokenRenewLimitStr) &&
+ int.TryParse(tokenRenewLimitStr, out int tokenRenewLimit)
&&
+ tokenRenewLimit > 0 &&
+ Properties.TryGetValue(SparkParameters.AccessToken, out
string? accessToken))
+ {
+ if (string.IsNullOrEmpty(accessToken))
+ {
+ throw new ArgumentException("Access token is required
for OAuth authentication with token renewal.");
+ }
+
+ // Check if token is a JWT token by trying to decode it
+ if (JwtTokenDecoder.TryGetExpirationTime(accessToken, out
DateTime expiryTime))
+ {
+ baseHandler = new TokenRefreshDelegatingHandler(
+ baseHandler,
+ tokenExchangeClient,
+ accessToken,
+ expiryTime,
+ tokenRenewLimit);
+ }
}
}
diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs
b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
index 56030b4e3..5db14314d 100644
--- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
@@ -212,6 +212,12 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
/// Default value is 0 (disabled) if not specified.
/// </summary>
public const string TokenRenewLimit =
"adbc.databricks.token_renew_limit";
+
+ /// <summary>
+ /// The client ID of the service principal when using workload
identity federation.
+ /// Default value is empty if not specified.
+ /// </summary>
+ public const string IdentityFederationClientId =
"adbc.databricks.identity_federation_client_id";
}
/// <summary>
diff --git a/csharp/test/Drivers/Databricks/E2E/Auth/TokenExchangeTests.cs
b/csharp/test/Drivers/Databricks/E2E/Auth/TokenExchangeTests.cs
index c574abe0c..12208712f 100644
--- a/csharp/test/Drivers/Databricks/E2E/Auth/TokenExchangeTests.cs
+++ b/csharp/test/Drivers/Databricks/E2E/Auth/TokenExchangeTests.cs
@@ -92,7 +92,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth
string host = GetHost();
var tokenExchangeClient = new TokenExchangeClient(_httpClient,
host);
- var response = await
tokenExchangeClient.ExchangeTokenAsync(TestConfiguration.AccessToken,
CancellationToken.None);
+ var response = await
tokenExchangeClient.RefreshTokenAsync(TestConfiguration.AccessToken,
CancellationToken.None);
Assert.NotNull(response);
Assert.NotEmpty(response.AccessToken);
@@ -119,7 +119,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth
// Create a token capturing handler to intercept the actual tokens
being sent
var tokenCapturingHandler = new TokenCapturingHandler(new
HttpClientHandler());
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
tokenCapturingHandler,
tokenExchangeClient,
TestConfiguration.AccessToken,
@@ -180,7 +180,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Auth
var tokenCapturingHandler = new TokenCapturingHandler(new
HttpClientHandler());
// Create a handler that should not refresh the token (token not
near expiry)
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
tokenCapturingHandler,
tokenExchangeClient,
TestConfiguration.AccessToken,
diff --git a/csharp/test/Drivers/Databricks/Unit/Auth/JwtTokenDecoderTests.cs
b/csharp/test/Drivers/Databricks/Unit/Auth/JwtTokenDecoderTests.cs
index bd630ce49..8c7100e6b 100644
--- a/csharp/test/Drivers/Databricks/Unit/Auth/JwtTokenDecoderTests.cs
+++ b/csharp/test/Drivers/Databricks/Unit/Auth/JwtTokenDecoderTests.cs
@@ -16,6 +16,7 @@
*/
using System;
+using System.Collections.Generic;
using System.Text;
using System.Text.Json;
using Apache.Arrow.Adbc.Drivers.Databricks.Auth;
@@ -28,7 +29,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit.Auth
[Fact]
public void TryGetExpirationTime_ValidToken_ReturnsTrue()
{
- string token = CreateTestToken(DateTime.UtcNow.AddMinutes(30));
+ string token = CreateTestToken(expiryTime:
DateTime.UtcNow.AddMinutes(30));
bool result = JwtTokenDecoder.TryGetExpirationTime(token, out
DateTime expiryTime);
@@ -40,7 +41,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit.Auth
[Fact]
public void TryGetExpirationTime_ExpiredToken_ReturnsTrue()
{
- string token = CreateTestToken(DateTime.UtcNow.AddMinutes(-30));
+ string token = CreateTestToken(expiryTime:
DateTime.UtcNow.AddMinutes(-30));
bool result = JwtTokenDecoder.TryGetExpirationTime(token, out
DateTime expiryTime);
@@ -63,7 +64,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit.Auth
[Fact]
public void TryGetExpirationTime_MissingExpClaim_ReturnsFalse()
{
- string token = CreateTestTokenWithoutExpClaim();
+ string token = CreateTestToken(expiryTime: null);
bool result = JwtTokenDecoder.TryGetExpirationTime(token, out
DateTime expiryTime);
@@ -71,39 +72,95 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit.Auth
Assert.Equal(DateTime.MinValue, expiryTime);
}
- private string CreateTestToken(DateTime expiryTime)
+ [Fact]
+ public void TryGetIssuer_ValidToken_ReturnsTrue()
{
- // Create a simple JWT token with expiration claim
- var header = new { alg = "HS256", typ = "JWT" };
- var payload = new { exp =
((DateTimeOffset)expiryTime).ToUnixTimeSeconds() };
+ string expectedIssuer = "https://test.databricks.com/oidc";
+ string token = CreateTestToken(issuer: expectedIssuer);
- string headerJson = JsonSerializer.Serialize(header);
- string payloadJson = JsonSerializer.Serialize(payload);
+ bool result = JwtTokenDecoder.TryGetIssuer(token, out string
issuer);
- string headerBase64 =
Convert.ToBase64String(Encoding.UTF8.GetBytes(headerJson))
- .Replace('+', '-')
- .Replace('/', '_')
- .TrimEnd('=');
+ Assert.True(result);
+ Assert.Equal(expectedIssuer, issuer);
+ }
- string payloadBase64 =
Convert.ToBase64String(Encoding.UTF8.GetBytes(payloadJson))
- .Replace('+', '-')
- .Replace('/', '_')
- .TrimEnd('=');
+ [Fact]
+ public void TryGetIssuer_InvalidToken_ReturnsFalse()
+ {
+ string token = "invalid.token.format";
- // For testing purposes, we don't need a valid signature
- string signature = "signature";
+ bool result = JwtTokenDecoder.TryGetIssuer(token, out string
issuer);
- return $"{headerBase64}.{payloadBase64}.{signature}";
+ Assert.False(result);
+ Assert.Empty(issuer);
}
- private string CreateTestTokenWithoutExpClaim()
+ [Fact]
+ public void TryGetIssuer_MissingIssClaim_ReturnsFalse()
{
- // Create a simple JWT token without expiration claim
+ string token = CreateTestToken(issuer: null);
+
+ bool result = JwtTokenDecoder.TryGetIssuer(token, out string
issuer);
+
+ Assert.False(result);
+ Assert.Empty(issuer);
+ }
+
+ [Fact]
+ public void TryGetIssuer_EmptyIssuer_ReturnsFalse()
+ {
+ string token = CreateTestToken(issuer: "");
+
+ bool result = JwtTokenDecoder.TryGetIssuer(token, out string
issuer);
+
+ Assert.False(result);
+ Assert.Empty(issuer);
+ }
+
+ [Fact]
+ public void TryGetIssuer_TokenWithOnlyTwoParts_ReturnsFalse()
+ {
+ string token = "header.payload"; // Missing signature part
+
+ bool result = JwtTokenDecoder.TryGetIssuer(token, out string
issuer);
+
+ Assert.False(result);
+ Assert.Empty(issuer);
+ }
+
+ [Fact]
+ public void TryGetIssuer_MalformedBase64_ReturnsFalse()
+ {
+ string token = "invalid!base64.invalid!base64.signature";
+
+ bool result = JwtTokenDecoder.TryGetIssuer(token, out string
issuer);
+
+ Assert.False(result);
+ Assert.Empty(issuer);
+ }
+
+ private string CreateTestToken(
+ string? issuer = null,
+ DateTime? expiryTime = null)
+ {
+ // Create a simple JWT token with optional claims
var header = new { alg = "HS256", typ = "JWT" };
- var payload = new { sub = "test" };
+
+ // Build payload dynamically based on parameters
+ var payloadDict = new Dictionary<string, object> { { "sub", "test"
} };
+
+ if (issuer != null)
+ {
+ payloadDict["iss"] = issuer;
+ }
+
+ if (expiryTime.HasValue)
+ {
+ payloadDict["exp"] =
((DateTimeOffset)expiryTime.Value).ToUnixTimeSeconds();
+ }
string headerJson = JsonSerializer.Serialize(header);
- string payloadJson = JsonSerializer.Serialize(payload);
+ string payloadJson = JsonSerializer.Serialize(payloadDict);
string headerBase64 =
Convert.ToBase64String(Encoding.UTF8.GetBytes(headerJson))
.Replace('+', '-')
diff --git
a/csharp/test/Drivers/Databricks/Unit/Auth/MandatoryTokenExchangeDelegatingHandlerTests.cs
b/csharp/test/Drivers/Databricks/Unit/Auth/MandatoryTokenExchangeDelegatingHandlerTests.cs
new file mode 100644
index 000000000..13fe196e8
--- /dev/null
+++
b/csharp/test/Drivers/Databricks/Unit/Auth/MandatoryTokenExchangeDelegatingHandlerTests.cs
@@ -0,0 +1,578 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Net;
+using System.Net.Http;
+using System.Text;
+using System.Text.Json;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Databricks.Auth;
+using Moq;
+using Moq.Protected;
+using Xunit;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit.Auth
+{
+ public class MandatoryTokenExchangeDelegatingHandlerTests : IDisposable
+ {
+ private readonly Mock<HttpMessageHandler> _mockInnerHandler;
+ private readonly Mock<ITokenExchangeClient> _mockTokenExchangeClient;
+ private readonly string _identityFederationClientId = "test-client-id";
+
+ // Real JWT tokens for testing (these are valid JWT structure but not
real credentials)
+ private readonly string _databricksToken;
+ private readonly string _externalToken;
+ private readonly string _exchangedToken = "exchanged-databricks-token";
+
+ public MandatoryTokenExchangeDelegatingHandlerTests()
+ {
+ _mockInnerHandler = new Mock<HttpMessageHandler>();
+ _mockTokenExchangeClient = new Mock<ITokenExchangeClient>();
+
+ // Setup token exchange endpoint for host comparison
+ _mockTokenExchangeClient.Setup(x => x.TokenExchangeEndpoint)
+
.Returns("https://databricks-workspace.cloud.databricks.com/v1/token");
+
+ // Create real JWT tokens with proper issuers
+ _databricksToken =
CreateJwtToken("https://databricks-workspace.cloud.databricks.com",
DateTime.UtcNow.AddHours(1));
+ _externalToken = CreateJwtToken("https://external-provider.com",
DateTime.UtcNow.AddHours(1));
+ }
+
+ [Fact]
+ public void Constructor_WithValidParameters_InitializesCorrectly()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ Assert.NotNull(handler);
+ }
+
+ [Fact]
+ public void
Constructor_WithNullTokenExchangeClient_ThrowsArgumentNullException()
+ {
+ Assert.Throws<ArgumentNullException>(() => new
MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ null!,
+ _identityFederationClientId));
+ }
+
+ [Fact]
+ public void
Constructor_WithoutIdentityFederationClientId_InitializesCorrectly()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object);
+
+ Assert.NotNull(handler);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithDatabricksToken_UsesTokenDirectlyWithoutExchange()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var request = new HttpRequestMessage(HttpMethod.Get,
"https://example.com");
+ request.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _databricksToken);
+ var expectedResponse = new HttpResponseMessage(HttpStatusCode.OK);
+
+ HttpRequestMessage? capturedRequest = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest = req)
+ .ReturnsAsync(expectedResponse);
+
+ var httpClient = new HttpClient(handler);
+
+ var response = await httpClient.SendAsync(request);
+
+ Assert.Equal(expectedResponse, response);
+ Assert.NotNull(capturedRequest);
+ Assert.Equal("Bearer",
capturedRequest.Headers.Authorization?.Scheme);
+ Assert.Equal(_databricksToken,
capturedRequest.Headers.Authorization?.Parameter);
+
+ // Wait for any background tasks
+ await Task.Delay(1000);
+
+ // Verify no token exchange was attempted
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(It.IsAny<string>(),
It.IsAny<string>(), It.IsAny<CancellationToken>()),
+ Times.Never);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithExternalToken_StartsTokenExchangeInBackground()
+ {
+ var tokenExchangeDelay = TimeSpan.FromMilliseconds(500);
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var request = new HttpRequestMessage(HttpMethod.Get,
"https://example.com");
+ request.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _externalToken);
+ var expectedResponse = new HttpResponseMessage(HttpStatusCode.OK);
+
+ var tokenExchangeResponse = new TokenExchangeResponse
+ {
+ AccessToken = _exchangedToken,
+ TokenType = "Bearer",
+ ExpiresIn = 3600,
+ ExpiryTime = DateTime.UtcNow.AddHours(1)
+ };
+
+ _mockTokenExchangeClient
+ .Setup(x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()))
+ .Returns(async (string token, string clientId,
CancellationToken ct) =>
+ {
+ await Task.Delay(tokenExchangeDelay, ct);
+ return tokenExchangeResponse;
+ });
+
+ HttpRequestMessage? capturedRequest = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest = req)
+ .ReturnsAsync(expectedResponse);
+
+ var httpClient = new HttpClient(handler);
+
+ // First request should use original token and start background
exchange
+ var startTime = DateTime.UtcNow;
+ var response = await httpClient.SendAsync(request);
+ var requestDuration = DateTime.UtcNow - startTime;
+
+ Assert.Equal(expectedResponse, response);
+ Assert.True(requestDuration < tokenExchangeDelay,
+ $"Request took {requestDuration.TotalMilliseconds}ms, which is
longer than the token exchange delay of
{tokenExchangeDelay.TotalMilliseconds}ms");
+
+ Assert.NotNull(capturedRequest);
+ Assert.Equal("Bearer",
capturedRequest.Headers.Authorization?.Scheme);
+ Assert.Equal(_externalToken,
capturedRequest.Headers.Authorization?.Parameter); // First request uses
original token
+
+ // Wait for background task to complete
+ await Task.Delay(tokenExchangeDelay +
TimeSpan.FromMilliseconds(1000));
+
+ // Make a second request - this should use the exchanged token
+ var request2 = new HttpRequestMessage(HttpMethod.Get,
"https://example.com/2");
+ request2.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _externalToken);
+ HttpRequestMessage? capturedRequest2 = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.Is<HttpRequestMessage>(r =>
r.RequestUri!.PathAndQuery == "/2"),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest2 = req)
+ .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK));
+
+ await httpClient.SendAsync(request2);
+
+ Assert.NotNull(capturedRequest2);
+ Assert.Equal("Bearer",
capturedRequest2.Headers.Authorization?.Scheme);
+ Assert.Equal(_exchangedToken,
capturedRequest2.Headers.Authorization?.Parameter); // Second request uses
exchanged token
+
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()),
+ Times.Once);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithTokenExchangeFailure_ContinuesWithOriginalToken()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var request = new HttpRequestMessage(HttpMethod.Get,
"https://example.com");
+ request.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _externalToken);
+ var expectedResponse = new HttpResponseMessage(HttpStatusCode.OK);
+
+ // Setup token exchange to fail
+ _mockTokenExchangeClient
+ .Setup(x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()))
+ .ThrowsAsync(new Exception("Token exchange failed"));
+
+ HttpRequestMessage? capturedRequest = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest = req)
+ .ReturnsAsync(expectedResponse);
+
+ var httpClient = new HttpClient(handler);
+ var response = await httpClient.SendAsync(request);
+
+ Assert.Equal(expectedResponse, response);
+ Assert.NotNull(capturedRequest);
+ Assert.Equal("Bearer",
capturedRequest.Headers.Authorization?.Scheme);
+ Assert.Equal(_externalToken,
capturedRequest.Headers.Authorization?.Parameter); // Should still use original
token
+
+ // Wait for background task to complete
+ await Task.Delay(1000);
+
+ var request2 = new HttpRequestMessage(HttpMethod.Get,
"https://example.com/2");
+ request2.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _externalToken);
+ HttpRequestMessage? capturedRequest2 = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.Is<HttpRequestMessage>(r =>
r.RequestUri!.PathAndQuery == "/2"),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest2 = req)
+ .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK));
+
+ await httpClient.SendAsync(request2);
+
+ Assert.NotNull(capturedRequest2);
+ Assert.Equal("Bearer",
capturedRequest2.Headers.Authorization?.Scheme);
+ Assert.Equal(_externalToken,
capturedRequest2.Headers.Authorization?.Parameter); // Should still use
original token
+
+ // Verify token exchange was attempted
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()),
+ Times.Once);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithSameExternalTokenMultipleTimes_ExchangesTokenOnlyOnce()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var tokenExchangeResponse = new TokenExchangeResponse
+ {
+ AccessToken = _exchangedToken,
+ TokenType = "Bearer",
+ ExpiresIn = 3600,
+ ExpiryTime = DateTime.UtcNow.AddHours(1)
+ };
+
+ _mockTokenExchangeClient
+ .Setup(x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()))
+ .ReturnsAsync(tokenExchangeResponse);
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK));
+
+ var httpClient = new HttpClient(handler);
+
+ // Make multiple requests with the same token
+ for (int i = 0; i < 3; i++)
+ {
+ var request = new HttpRequestMessage(HttpMethod.Get,
$"https://example.com/{i}");
+ request.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _externalToken);
+ await httpClient.SendAsync(request);
+ }
+
+ // Wait for background exchange to complete
+ await Task.Delay(1000);
+
+ // Token exchange should only be called once
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()),
+ Times.Once);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithDifferentExternalTokens_ExchangesEachTokenOnce()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var externalToken1 =
CreateJwtToken("https://external-provider.com", DateTime.UtcNow.AddHours(1));
+ var externalToken2 =
CreateJwtToken("https://another-provider.com", DateTime.UtcNow.AddHours(1));
+ var exchangedToken1 = "exchanged-token-1";
+ var exchangedToken2 = "exchanged-token-2";
+
+ _mockTokenExchangeClient
+ .Setup(x => x.ExchangeTokenAsync(externalToken1,
_identityFederationClientId, It.IsAny<CancellationToken>()))
+ .ReturnsAsync(new TokenExchangeResponse
+ {
+ AccessToken = exchangedToken1,
+ TokenType = "Bearer",
+ ExpiresIn = 3600,
+ ExpiryTime = DateTime.UtcNow.AddHours(1)
+ });
+
+ _mockTokenExchangeClient
+ .Setup(x => x.ExchangeTokenAsync(externalToken2,
_identityFederationClientId, It.IsAny<CancellationToken>()))
+ .ReturnsAsync(new TokenExchangeResponse
+ {
+ AccessToken = exchangedToken2,
+ TokenType = "Bearer",
+ ExpiresIn = 3600,
+ ExpiryTime = DateTime.UtcNow.AddHours(1)
+ });
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK));
+
+ var httpClient = new HttpClient(handler);
+
+ // Make request with first token
+ var request1 = new HttpRequestMessage(HttpMethod.Get,
"https://example.com/1");
+ request1.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", externalToken1);
+ await httpClient.SendAsync(request1);
+
+ // Wait for first exchange to complete
+ await Task.Delay(1000);
+
+ // Make request with second token
+ var request2 = new HttpRequestMessage(HttpMethod.Get,
"https://example.com/2");
+ request2.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", externalToken2);
+ await httpClient.SendAsync(request2);
+
+ // Wait for second exchange to complete
+ await Task.Delay(1000);
+
+ // Verify both tokens were exchanged
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(externalToken1,
_identityFederationClientId, It.IsAny<CancellationToken>()),
+ Times.Once);
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(externalToken2,
_identityFederationClientId, It.IsAny<CancellationToken>()),
+ Times.Once);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithConcurrentRequestsSameToken_ExchangesTokenOnlyOnce()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var tokenExchangeResponse = new TokenExchangeResponse
+ {
+ AccessToken = _exchangedToken,
+ TokenType = "Bearer",
+ ExpiresIn = 3600,
+ ExpiryTime = DateTime.UtcNow.AddHours(1)
+ };
+
+ // Add a small delay to token exchange to simulate concurrent
access
+ _mockTokenExchangeClient
+ .Setup(x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()))
+ .Returns(async () =>
+ {
+ await Task.Delay(200);
+ return tokenExchangeResponse;
+ });
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK));
+
+ var httpClient = new HttpClient(handler);
+
+ // Make concurrent requests with the same token
+ var tasks = new[]
+ {
+ CreateAndSendRequest(httpClient, _externalToken,
"https://example.com/1"),
+ CreateAndSendRequest(httpClient, _externalToken,
"https://example.com/2"),
+ CreateAndSendRequest(httpClient, _externalToken,
"https://example.com/3")
+ };
+
+ await Task.WhenAll(tasks);
+
+ // Wait for any background token exchange to complete
+ await Task.Delay(1000);
+
+ // Token exchange should only be called once despite concurrent
requests
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(_externalToken,
_identityFederationClientId, It.IsAny<CancellationToken>()),
+ Times.Once);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithInvalidJwtToken_UsesTokenDirectlyWithoutExchange()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var invalidToken = "invalid-jwt-token";
+ var request = new HttpRequestMessage(HttpMethod.Get,
"https://example.com");
+ request.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", invalidToken);
+ var expectedResponse = new HttpResponseMessage(HttpStatusCode.OK);
+
+ HttpRequestMessage? capturedRequest = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest = req)
+ .ReturnsAsync(expectedResponse);
+
+ var httpClient = new HttpClient(handler);
+ var response = await httpClient.SendAsync(request);
+
+ Assert.Equal(expectedResponse, response);
+ Assert.NotNull(capturedRequest);
+ Assert.Equal("Bearer",
capturedRequest.Headers.Authorization?.Scheme);
+ Assert.Equal(invalidToken,
capturedRequest.Headers.Authorization?.Parameter);
+
+ // Wait for any background tasks
+ await Task.Delay(1000);
+
+ // Verify no token exchange was attempted
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(It.IsAny<string>(),
It.IsAny<string>(), It.IsAny<CancellationToken>()),
+ Times.Never);
+ }
+
+ [Fact]
+ public async Task
SendAsync_WithNoAuthorizationHeader_PassesThroughWithoutModification()
+ {
+ var handler = new MandatoryTokenExchangeDelegatingHandler(
+ _mockInnerHandler.Object,
+ _mockTokenExchangeClient.Object,
+ _identityFederationClientId);
+
+ var request = new HttpRequestMessage(HttpMethod.Get,
"https://example.com");
+ var expectedResponse = new HttpResponseMessage(HttpStatusCode.OK);
+
+ HttpRequestMessage? capturedRequest = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest = req)
+ .ReturnsAsync(expectedResponse);
+
+ var httpClient = new HttpClient(handler);
+ var response = await httpClient.SendAsync(request);
+
+ Assert.Equal(expectedResponse, response);
+ Assert.NotNull(capturedRequest);
+ Assert.Null(capturedRequest.Headers.Authorization);
+
+ // Verify no token exchange was attempted
+ _mockTokenExchangeClient.Verify(
+ x => x.ExchangeTokenAsync(It.IsAny<string>(),
It.IsAny<string>(), It.IsAny<CancellationToken>()),
+ Times.Never);
+ }
+
+ private async Task<HttpResponseMessage>
CreateAndSendRequest(HttpClient httpClient, string token, string url)
+ {
+ var request = new HttpRequestMessage(HttpMethod.Get, url);
+ request.Headers.Authorization = new
System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", token);
+ return await httpClient.SendAsync(request);
+ }
+
+ /// <summary>
+ /// Creates a valid JWT token with the specified issuer and expiration
time.
+ /// This is for testing purposes only and creates a properly formatted
JWT.
+ /// </summary>
+ private static string CreateJwtToken(string issuer, DateTime
expiryTime)
+ {
+ // Create header
+ var header = new
+ {
+ alg = "HS256",
+ typ = "JWT"
+ };
+
+ // Create payload
+ var payload = new
+ {
+ iss = issuer,
+ exp = new DateTimeOffset(expiryTime).ToUnixTimeSeconds(),
+ iat = new DateTimeOffset(DateTime.UtcNow).ToUnixTimeSeconds(),
+ sub = "test-subject"
+ };
+
+ // Encode header and payload
+ string encodedHeader =
EncodeBase64Url(JsonSerializer.Serialize(header));
+ string encodedPayload =
EncodeBase64Url(JsonSerializer.Serialize(payload));
+
+ // For testing, we'll use a dummy signature
+ string signature = EncodeBase64Url("dummy-signature");
+
+ return $"{encodedHeader}.{encodedPayload}.{signature}";
+ }
+
+ /// <summary>
+ /// Encodes a string to base64url format.
+ /// </summary>
+ private static string EncodeBase64Url(string input)
+ {
+ byte[] bytes = Encoding.UTF8.GetBytes(input);
+ string base64 = Convert.ToBase64String(bytes);
+
+ // Convert base64 to base64url
+ return base64
+ .Replace('+', '-')
+ .Replace('/', '_')
+ .TrimEnd('=');
+ }
+
+ protected virtual void Dispose(bool disposing)
+ {
+ if (disposing)
+ {
+ _mockInnerHandler?.Object?.Dispose();
+ }
+ }
+
+ public void Dispose()
+ {
+ Dispose(true);
+ GC.SuppressFinalize(this);
+ }
+ }
+}
diff --git
a/csharp/test/Drivers/Databricks/Unit/Auth/TokenExchangeClientTests.cs
b/csharp/test/Drivers/Databricks/Unit/Auth/TokenExchangeClientTests.cs
index 59d99bead..b7306256b 100644
--- a/csharp/test/Drivers/Databricks/Unit/Auth/TokenExchangeClientTests.cs
+++ b/csharp/test/Drivers/Databricks/Unit/Auth/TokenExchangeClientTests.cs
@@ -22,11 +22,12 @@ using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Databricks.Auth;
+using Apache.Arrow.Adbc.Drivers.Databricks;
using Moq;
using Moq.Protected;
using Xunit;
-namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
+namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit.Auth
{
public class TokenExchangeClientTests : IDisposable
{
@@ -45,6 +46,15 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
{
var client = new TokenExchangeClient(_httpClient, _testHost);
Assert.NotNull(client);
+ Assert.Equal($"https://{_testHost}/oidc/v1/token",
client.TokenExchangeEndpoint);
+ }
+
+ [Fact]
+ public void Constructor_WithHostTrailingSlash_RemovesTrailingSlash()
+ {
+ var hostWithSlash = "test.databricks.com/";
+ var client = new TokenExchangeClient(_httpClient, hostWithSlash);
+ Assert.Equal("https://test.databricks.com/oidc/v1/token",
client.TokenExchangeEndpoint);
}
[Fact]
@@ -54,7 +64,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
}
[Fact]
- public async Task
ExchangeTokenAsync_WithValidResponse_ReturnsTokenExchangeResponse()
+ public async Task
RefreshTokenAsync_WithValidResponse_ReturnsTokenExchangeResponse()
{
var testToken = "test-jwt-token";
var expectedAccessToken = "new-access-token";
@@ -85,7 +95,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
- var result = await client.ExchangeTokenAsync(testToken,
CancellationToken.None);
+ var result = await client.RefreshTokenAsync(testToken,
CancellationToken.None);
Assert.NotNull(result);
Assert.Equal(expectedAccessToken, result.AccessToken);
@@ -96,7 +106,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
}
[Fact]
- public async Task ExchangeTokenAsync_SendsCorrectRequestFormat()
+ public async Task RefreshTokenAsync_SendsCorrectRequestFormat()
{
var testToken = "test-jwt-token";
var responseJson = JsonSerializer.Serialize(new
@@ -112,34 +122,39 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
};
HttpRequestMessage? capturedRequest = null;
+ string? capturedFormContent = null;
_mockHttpMessageHandler.Protected()
.Setup<Task<HttpResponseMessage>>(
"SendAsync",
ItExpr.IsAny<HttpRequestMessage>(),
ItExpr.IsAny<CancellationToken>())
- .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest = req)
- .ReturnsAsync(httpResponseMessage);
+ .Returns<HttpRequestMessage, CancellationToken>(async (req,
ct) =>
+ {
+ capturedRequest = req;
+ if (req.Content != null)
+ {
+ capturedFormContent = await
req.Content.ReadAsStringAsync();
+ }
+ return httpResponseMessage;
+ });
var client = new TokenExchangeClient(_httpClient, _testHost);
- await client.ExchangeTokenAsync(testToken, CancellationToken.None);
+ await client.RefreshTokenAsync(testToken, CancellationToken.None);
Assert.NotNull(capturedRequest);
Assert.Equal(HttpMethod.Post, capturedRequest.Method);
Assert.Equal($"https://{_testHost}/oidc/v1/token",
capturedRequest?.RequestUri?.ToString());
Assert.True(capturedRequest?.Headers.Accept.Contains(new
System.Net.Http.Headers.MediaTypeWithQualityHeaderValue("*/*")));
- var content = capturedRequest?.Content as FormUrlEncodedContent;
- Assert.NotNull(content);
-
- var formContent = await content.ReadAsStringAsync();
-
Assert.Contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
formContent);
- Assert.Contains($"assertion={testToken}", formContent);
+ Assert.NotNull(capturedFormContent);
+
Assert.Contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
capturedFormContent);
+ Assert.Contains($"assertion={testToken}", capturedFormContent);
}
[Fact]
- public async Task
ExchangeTokenAsync_WithHttpError_ThrowsHttpRequestException()
+ public async Task
RefreshTokenAsync_WithHttpError_ThrowsHttpRequestException()
{
var testToken = "test-jwt-token";
var httpResponseMessage = new
HttpResponseMessage(HttpStatusCode.Unauthorized)
@@ -157,11 +172,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
await Assert.ThrowsAsync<HttpRequestException>(() =>
- client.ExchangeTokenAsync(testToken, CancellationToken.None));
+ client.RefreshTokenAsync(testToken, CancellationToken.None));
}
[Fact]
- public async Task
ExchangeTokenAsync_WithMissingAccessToken_ThrowsDatabricksException()
+ public async Task
RefreshTokenAsync_WithMissingAccessToken_ThrowsDatabricksException()
{
var testToken = "test-jwt-token";
var responseJson = JsonSerializer.Serialize(new
@@ -185,13 +200,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
var exception = await Assert.ThrowsAsync<DatabricksException>(() =>
- client.ExchangeTokenAsync(testToken, CancellationToken.None));
+ client.RefreshTokenAsync(testToken, CancellationToken.None));
Assert.Contains("access_token", exception.Message);
}
[Fact]
- public async Task
ExchangeTokenAsync_WithEmptyAccessToken_ThrowsDatabricksException()
+ public async Task
RefreshTokenAsync_WithEmptyAccessToken_ThrowsDatabricksException()
{
var testToken = "test-jwt-token";
var responseJson = JsonSerializer.Serialize(new
@@ -216,13 +231,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
var exception = await Assert.ThrowsAsync<DatabricksException>(() =>
- client.ExchangeTokenAsync(testToken, CancellationToken.None));
+ client.RefreshTokenAsync(testToken, CancellationToken.None));
Assert.Contains("access_token was null or empty",
exception.Message);
}
[Fact]
- public async Task
ExchangeTokenAsync_WithMissingTokenType_ThrowsDatabricksException()
+ public async Task
RefreshTokenAsync_WithMissingTokenType_ThrowsDatabricksException()
{
var testToken = "test-jwt-token";
var responseJson = JsonSerializer.Serialize(new
@@ -246,13 +261,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
var exception = await Assert.ThrowsAsync<DatabricksException>(() =>
- client.ExchangeTokenAsync(testToken, CancellationToken.None));
+ client.RefreshTokenAsync(testToken, CancellationToken.None));
Assert.Contains("token_type", exception.Message);
}
[Fact]
- public async Task
ExchangeTokenAsync_WithMissingExpiresIn_ThrowsDatabricksException()
+ public async Task
RefreshTokenAsync_WithMissingExpiresIn_ThrowsDatabricksException()
{
var testToken = "test-jwt-token";
var responseJson = JsonSerializer.Serialize(new
@@ -276,13 +291,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
var exception = await Assert.ThrowsAsync<DatabricksException>(() =>
- client.ExchangeTokenAsync(testToken, CancellationToken.None));
+ client.RefreshTokenAsync(testToken, CancellationToken.None));
Assert.Contains("expires_in", exception.Message);
}
[Fact]
- public async Task
ExchangeTokenAsync_WithNegativeExpiresIn_ThrowsDatabricksException()
+ public async Task
RefreshTokenAsync_WithNegativeExpiresIn_ThrowsDatabricksException()
{
var testToken = "test-jwt-token";
var responseJson = JsonSerializer.Serialize(new
@@ -307,13 +322,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
var exception = await Assert.ThrowsAsync<DatabricksException>(() =>
- client.ExchangeTokenAsync(testToken, CancellationToken.None));
+ client.RefreshTokenAsync(testToken, CancellationToken.None));
Assert.Contains("expires_in value must be positive",
exception.Message);
}
[Fact]
- public async Task
ExchangeTokenAsync_WithInvalidJson_ThrowsJsonException()
+ public async Task
RefreshTokenAsync_WithInvalidJson_ThrowsJsonException()
{
var testToken = "test-jwt-token";
var invalidJson = "{ invalid json }";
@@ -333,11 +348,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
await Assert.ThrowsAnyAsync<JsonException>(() =>
- client.ExchangeTokenAsync(testToken, CancellationToken.None));
+ client.RefreshTokenAsync(testToken, CancellationToken.None));
}
[Fact]
- public async Task
ExchangeTokenAsync_WithCancellationToken_PropagatesCancellation()
+ public async Task
RefreshTokenAsync_WithCancellationToken_PropagatesCancellation()
{
var testToken = "test-jwt-token";
var cts = new CancellationTokenSource();
@@ -353,11 +368,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
await Assert.ThrowsAsync<TaskCanceledException>(() =>
- client.ExchangeTokenAsync(testToken, cts.Token));
+ client.RefreshTokenAsync(testToken, cts.Token));
}
[Fact]
- public async Task ExchangeTokenAsync_CalculatesExpiryTimeCorrectly()
+ public async Task RefreshTokenAsync_CalculatesExpiryTimeCorrectly()
{
var testToken = "test-jwt-token";
var expiresIn = 1800; // 30 minutes
@@ -384,7 +399,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var client = new TokenExchangeClient(_httpClient, _testHost);
- var result = await client.ExchangeTokenAsync(testToken,
CancellationToken.None);
+ var result = await client.RefreshTokenAsync(testToken,
CancellationToken.None);
var afterCall = DateTime.UtcNow;
var expectedMinExpiry = beforeCall.AddSeconds(expiresIn);
@@ -394,6 +409,214 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
Assert.True(result.ExpiryTime <= expectedMaxExpiry);
}
+ [Fact]
+ public async Task
ExchangeTokenAsync_WithoutIdentityFederationClientId_SendsCorrectRequest()
+ {
+ var testToken = "test-jwt-token";
+ var responseJson = JsonSerializer.Serialize(new
+ {
+ access_token = "exchanged-token",
+ token_type = "Bearer",
+ expires_in = 3600
+ });
+
+ var httpResponseMessage = new
HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent(responseJson)
+ };
+
+ HttpRequestMessage? capturedRequest = null;
+ string? capturedFormContent = null;
+
+ _mockHttpMessageHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Returns<HttpRequestMessage, CancellationToken>(async (req,
ct) =>
+ {
+ capturedRequest = req;
+ if (req.Content != null)
+ {
+ capturedFormContent = await
req.Content.ReadAsStringAsync();
+ }
+ return httpResponseMessage;
+ });
+
+ var client = new TokenExchangeClient(_httpClient, _testHost);
+
+ var result = await client.ExchangeTokenAsync(testToken, null,
CancellationToken.None);
+
+ Assert.NotNull(result);
+ Assert.Equal("exchanged-token", result.AccessToken);
+
+ Assert.NotNull(capturedRequest);
+ Assert.NotNull(capturedFormContent);
+
Assert.Contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
capturedFormContent);
+ Assert.Contains($"assertion={testToken}", capturedFormContent);
+ Assert.Contains("scope=sql", capturedFormContent);
+ Assert.Contains("return_original_token_if_authenticated=true",
capturedFormContent);
+ Assert.DoesNotContain("identity_federation_client_id",
capturedFormContent);
+ }
+
+ [Fact]
+ public async Task
ExchangeTokenAsync_WithIdentityFederationClientId_SendsCorrectRequest()
+ {
+ var testToken = "test-jwt-token";
+ var clientId = "test-client-id";
+ var responseJson = JsonSerializer.Serialize(new
+ {
+ access_token = "exchanged-token",
+ token_type = "Bearer",
+ expires_in = 3600
+ });
+
+ var httpResponseMessage = new
HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent(responseJson)
+ };
+
+ HttpRequestMessage? capturedRequest = null;
+ string? capturedFormContent = null;
+
+ _mockHttpMessageHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Returns<HttpRequestMessage, CancellationToken>(async (req,
ct) =>
+ {
+ capturedRequest = req;
+ if (req.Content != null)
+ {
+ capturedFormContent = await
req.Content.ReadAsStringAsync();
+ }
+ return httpResponseMessage;
+ });
+
+ var client = new TokenExchangeClient(_httpClient, _testHost);
+
+ var result = await client.ExchangeTokenAsync(testToken, clientId,
CancellationToken.None);
+
+ Assert.NotNull(result);
+ Assert.Equal("exchanged-token", result.AccessToken);
+
+ Assert.NotNull(capturedRequest);
+ Assert.NotNull(capturedFormContent);
+
Assert.Contains("grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
capturedFormContent);
+ Assert.Contains($"assertion={testToken}", capturedFormContent);
+ Assert.Contains("scope=sql", capturedFormContent);
+ Assert.Contains($"identity_federation_client_id={clientId}",
capturedFormContent);
+ Assert.DoesNotContain("return_original_token_if_authenticated",
capturedFormContent);
+ }
+
+ [Fact]
+ public async Task
ExchangeTokenAsync_WithEmptyIdentityFederationClientId_SendsReturnOriginalToken()
+ {
+ var testToken = "test-jwt-token";
+ var clientId = "";
+ var responseJson = JsonSerializer.Serialize(new
+ {
+ access_token = "exchanged-token",
+ token_type = "Bearer",
+ expires_in = 3600
+ });
+
+ var httpResponseMessage = new
HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent(responseJson)
+ };
+
+ HttpRequestMessage? capturedRequest = null;
+ string? capturedFormContent = null;
+
+ _mockHttpMessageHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Returns<HttpRequestMessage, CancellationToken>(async (req,
ct) =>
+ {
+ capturedRequest = req;
+ if (req.Content != null)
+ {
+ capturedFormContent = await
req.Content.ReadAsStringAsync();
+ }
+ return httpResponseMessage;
+ });
+
+ var client = new TokenExchangeClient(_httpClient, _testHost);
+
+ var result = await client.ExchangeTokenAsync(testToken, clientId,
CancellationToken.None);
+
+ Assert.NotNull(result);
+ Assert.NotNull(capturedRequest);
+ Assert.NotNull(capturedFormContent);
+ Assert.Contains("return_original_token_if_authenticated=true",
capturedFormContent);
+ Assert.DoesNotContain("identity_federation_client_id",
capturedFormContent);
+ }
+
+ [Fact]
+ public async Task
ExchangeTokenAsync_WithHttpError_ThrowsHttpRequestException()
+ {
+ var testToken = "test-jwt-token";
+ var httpResponseMessage = new
HttpResponseMessage(HttpStatusCode.Unauthorized)
+ {
+ Content = new StringContent("Unauthorized")
+ };
+
+ _mockHttpMessageHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .ReturnsAsync(httpResponseMessage);
+
+ var client = new TokenExchangeClient(_httpClient, _testHost);
+
+ await Assert.ThrowsAsync<HttpRequestException>(() =>
+ client.ExchangeTokenAsync(testToken, null,
CancellationToken.None));
+ }
+
+ [Fact]
+ public async Task ExchangeTokenAsync_UsesCorrectEndpoint()
+ {
+ var testToken = "test-jwt-token";
+ var responseJson = JsonSerializer.Serialize(new
+ {
+ access_token = "token",
+ token_type = "Bearer",
+ expires_in = 3600
+ });
+
+ var httpResponseMessage = new
HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent(responseJson)
+ };
+
+ HttpRequestMessage? capturedRequest = null;
+
+ _mockHttpMessageHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Returns<HttpRequestMessage, CancellationToken>((req, ct) =>
+ {
+ capturedRequest = req;
+ return Task.FromResult(httpResponseMessage);
+ });
+
+
+ var client = new TokenExchangeClient(_httpClient, _testHost);
+
+ await client.ExchangeTokenAsync(testToken, null,
CancellationToken.None);
+
+ Assert.NotNull(capturedRequest);
+ Assert.Equal(HttpMethod.Post, capturedRequest.Method);
+ Assert.Equal($"https://{_testHost}/oidc/v1/token",
capturedRequest?.RequestUri?.ToString());
+ }
+
protected virtual void Dispose(bool disposing)
{
if (disposing)
diff --git
a/csharp/test/Drivers/Databricks/Unit/Auth/TokenExchangeDelegatingHandlerTests.cs
b/csharp/test/Drivers/Databricks/Unit/Auth/TokenRefreshDelegatingHandlerTests.cs
similarity index 86%
rename from
csharp/test/Drivers/Databricks/Unit/Auth/TokenExchangeDelegatingHandlerTests.cs
rename to
csharp/test/Drivers/Databricks/Unit/Auth/TokenRefreshDelegatingHandlerTests.cs
index 40b83c156..99d4db9b8 100644
---
a/csharp/test/Drivers/Databricks/Unit/Auth/TokenExchangeDelegatingHandlerTests.cs
+++
b/csharp/test/Drivers/Databricks/Unit/Auth/TokenRefreshDelegatingHandlerTests.cs
@@ -25,9 +25,9 @@ using Moq;
using Moq.Protected;
using Xunit;
-namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
+namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit.Auth
{
- public class TokenExchangeDelegatingHandlerTests : IDisposable
+ public class TokenRefreshDelegatingHandlerTests : IDisposable
{
private readonly Mock<HttpMessageHandler> _mockInnerHandler;
private readonly Mock<ITokenExchangeClient> _mockTokenExchangeClient;
@@ -35,7 +35,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
private readonly int _tokenRenewLimitMinutes = 10;
private readonly DateTime _initialTokenExpiry =
DateTime.UtcNow.AddHours(1);
- public TokenExchangeDelegatingHandlerTests()
+ public TokenRefreshDelegatingHandlerTests()
{
_mockInnerHandler = new Mock<HttpMessageHandler>();
_mockTokenExchangeClient = new Mock<ITokenExchangeClient>();
@@ -44,7 +44,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
[Fact]
public void Constructor_WithValidParameters_InitializesCorrectly()
{
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -57,7 +57,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
[Fact]
public void
Constructor_WithNullTokenExchangeClient_ThrowsArgumentNullException()
{
- Assert.Throws<ArgumentNullException>(() => new
TokenExchangeDelegatingHandler(
+ Assert.Throws<ArgumentNullException>(() => new
TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
null!,
_initialToken,
@@ -68,7 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
[Fact]
public void
Constructor_WithNullInitialToken_ThrowsArgumentNullException()
{
- Assert.Throws<ArgumentNullException>(() => new
TokenExchangeDelegatingHandler(
+ Assert.Throws<ArgumentNullException>(() => new
TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
null!,
@@ -80,7 +80,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
public async Task
SendAsync_WithValidTokenNotNearExpiry_UsesInitialTokenWithoutRenewal()
{
var futureExpiry = DateTime.UtcNow.AddHours(2); // Well beyond
renewal limit
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -109,10 +109,10 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
Assert.Equal(_initialToken,
capturedRequest.Headers.Authorization?.Parameter);
// Wait for background task to complete
- await Task.Delay(100);
+ await Task.Delay(1000);
_mockTokenExchangeClient.Verify(
- x => x.ExchangeTokenAsync(It.IsAny<string>(),
It.IsAny<CancellationToken>()),
+ x => x.RefreshTokenAsync(It.IsAny<string>(),
It.IsAny<CancellationToken>()),
Times.Never);
}
@@ -125,7 +125,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var newExpiry = DateTime.UtcNow.AddHours(1);
var tokenExchangeDelay = TimeSpan.FromMilliseconds(500);
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -144,7 +144,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
};
_mockTokenExchangeClient
- .Setup(x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
+ .Setup(x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
.Returns(async (string token, CancellationToken ct) =>
{
await Task.Delay(tokenExchangeDelay, ct);
@@ -177,7 +177,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
Assert.Equal(_initialToken,
capturedRequest.Headers.Authorization?.Parameter); // First request uses
original token
// Wait a bit for the background task to complete
- await Task.Delay(tokenExchangeDelay +
TimeSpan.FromMilliseconds(100));
+ await Task.Delay(tokenExchangeDelay +
TimeSpan.FromMilliseconds(1000));
// Make a second request - this should use the new token
var request2 = new HttpRequestMessage(HttpMethod.Get,
"https://example.com/2");
@@ -198,7 +198,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
Assert.Equal(newToken,
capturedRequest2.Headers.Authorization?.Parameter); // Second request uses new
token
_mockTokenExchangeClient.Verify(
- x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
+ x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
Times.Once);
}
@@ -207,7 +207,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
{
var nearExpiryTime = DateTime.UtcNow.AddMinutes(5); // Within
renewal limit
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -219,7 +219,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
// Setup token exchange to fail
_mockTokenExchangeClient
- .Setup(x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
+ .Setup(x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
.ThrowsAsync(new Exception("Token exchange failed"));
HttpRequestMessage? capturedRequest = null;
@@ -241,11 +241,27 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
Assert.Equal(_initialToken,
capturedRequest.Headers.Authorization?.Parameter); // Should still use original
token
// Wait for background task to complete
- await Task.Delay(100);
+ await Task.Delay(1000);
+
+ var request2 = new HttpRequestMessage(HttpMethod.Get,
"https://example.com/2");
+ HttpRequestMessage? capturedRequest2 = null;
+
+ _mockInnerHandler.Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.Is<HttpRequestMessage>(r =>
r.RequestUri!.PathAndQuery == "/2"),
+ ItExpr.IsAny<CancellationToken>())
+ .Callback<HttpRequestMessage, CancellationToken>((req, ct) =>
capturedRequest2 = req)
+ .ReturnsAsync(new HttpResponseMessage(HttpStatusCode.OK));
+
+ await httpClient.SendAsync(request2);
+
+ Assert.NotNull(capturedRequest2);
+ Assert.Equal("Bearer",
capturedRequest2.Headers.Authorization?.Scheme);
+ Assert.Equal(_initialToken,
capturedRequest2.Headers.Authorization?.Parameter);
- // Verify token exchange was attempted
_mockTokenExchangeClient.Verify(
- x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
+ x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
Times.Once);
}
@@ -256,7 +272,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var newToken = "new-renewed-token";
var newExpiry = DateTime.UtcNow.AddMinutes(3); // New token also
near expiry
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -272,7 +288,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
};
_mockTokenExchangeClient
- .Setup(x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
+ .Setup(x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
.ReturnsAsync(tokenExchangeResponse);
_mockInnerHandler.Protected()
@@ -288,14 +304,14 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get,
"https://example.com/1"));
// Wait for background renewal to complete
- await Task.Delay(100);
+ await Task.Delay(1000);
// Make second request
await httpClient.SendAsync(new HttpRequestMessage(HttpMethod.Get,
"https://example.com/2"));
// Token exchange should only be called once (renewed tokens
cannot be renewed again)
_mockTokenExchangeClient.Verify(
- x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
+ x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
Times.Once);
}
@@ -306,7 +322,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var newToken = "new-renewed-token";
var newExpiry = DateTime.UtcNow.AddHours(1);
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -323,7 +339,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
// Add a small delay to token exchange to simulate concurrent
access
_mockTokenExchangeClient
- .Setup(x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
+ .Setup(x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
.Returns(async () =>
{
await Task.Delay(200);
@@ -350,18 +366,18 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
await Task.WhenAll(tasks);
// Wait for any background token renewal to complete
- await Task.Delay(300);
+ await Task.Delay(1000);
// Token exchange should only be called once despite concurrent
requests
_mockTokenExchangeClient.Verify(
- x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
+ x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
Times.Once);
}
[Fact]
public async Task
SendAsync_WithCancellationToken_PropagatesCancellation()
{
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -392,7 +408,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
public async Task
SendAsync_WithTokenRenewalAndCancellation_HandlesCancellationGracefully()
{
var nearExpiryTime = DateTime.UtcNow.AddMinutes(5); // Within
renewal limit
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -403,7 +419,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var cts = new CancellationTokenSource();
_mockTokenExchangeClient
- .Setup(x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
+ .Setup(x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
.Returns<string, CancellationToken>((token, ct) =>
{
ct.ThrowIfCancellationRequested();
@@ -436,7 +452,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
[Fact]
public void Dispose_ReleasesResources()
{
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -454,7 +470,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
public async Task
SendAsync_WithDifferentRenewalLimits_RenewsTokenAppropriately(int
renewalLimitMinutes)
{
var tokenExpiryTime =
DateTime.UtcNow.AddMinutes(renewalLimitMinutes / 2); // Half the renewal limit
- var handler = new TokenExchangeDelegatingHandler(
+ var handler = new TokenRefreshDelegatingHandler(
_mockInnerHandler.Object,
_mockTokenExchangeClient.Object,
_initialToken,
@@ -464,7 +480,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
var request = new HttpRequestMessage(HttpMethod.Get,
"https://example.com");
_mockTokenExchangeClient
- .Setup(x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
+ .Setup(x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()))
.ReturnsAsync(new TokenExchangeResponse
{
AccessToken = "new-token",
@@ -484,10 +500,10 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Tests.Auth
await httpClient.SendAsync(request);
// Wait for background renewal to complete
- await Task.Delay(100);
+ await Task.Delay(1000);
_mockTokenExchangeClient.Verify(
- x => x.ExchangeTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
+ x => x.RefreshTokenAsync(_initialToken,
It.IsAny<CancellationToken>()),
Times.Once);
}