This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 317c9c90c feat(csharp/src/Drivers/Databricks): Implement
CloudFetchUrlManager to handle presigned URL expiration in CloudFetch (#2855)
317c9c90c is described below
commit 317c9c90c690a8dedb2f4557337c12ad9cd673df
Author: Jade Wang <[email protected]>
AuthorDate: Wed May 28 16:47:11 2025 -0700
feat(csharp/src/Drivers/Databricks): Implement CloudFetchUrlManager to
handle presigned URL expiration in CloudFetch (#2855)
### Problem
The Databricks driver's CloudFetch functionality was not properly
handling expired cloud file URLs, which could lead to failed downloads
and errors during query execution. The system needed a way to track,
cache, and refresh presigned URLs before they expire.
### Solution
- Improve `CloudFetchResultFetcher` class that:
- Manages a cache of cloud file URLs with their expiration times
- Proactively refreshes URLs that are about to expire
- Provides thread-safe access to URL information
- Added an `IClock` interface and implementations to facilitate testing
with controlled time
- Extended the `IDownloadResult` interface to support URL refreshing and
expiration checking
- Updated namespace from
`Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch` to
`Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch` for better
organization
---
.../{IHiveServer2Statement.cs => Clock.cs} | 34 +-
.../CloudFetch/CloudFetchDownloadManager.cs | 43 ++-
.../Databricks/CloudFetch/CloudFetchDownloader.cs | 61 ++-
.../CloudFetch/CloudFetchMemoryBufferManager.cs | 2 +-
.../Databricks/CloudFetch/CloudFetchReader.cs | 2 -
.../CloudFetch/CloudFetchResultFetcher.cs | 138 +++++--
.../Databricks/CloudFetch/DownloadResult.cs | 34 +-
.../Databricks/CloudFetch/EndOfResultsGuard.cs | 11 +-
.../Databricks/CloudFetch/ICloudFetchInterfaces.cs | 28 +-
.../Databricks/CloudFetch/IHiveServer2Statement.cs | 2 +-
.../Databricks/DatabricksOperationStatusPoller.cs | 2 +-
.../src/Drivers/Databricks/DatabricksParameters.cs | 12 +
.../src/Drivers/Databricks/DatabricksStatement.cs | 2 +-
.../CloudFetch/CloudFetchDownloaderTest.cs | 151 +++++++-
.../CloudFetch/CloudFetchResultFetcherTest.cs | 410 ++++++++++++++-------
.../test/Drivers/Databricks/CloudFetchE2ETest.cs | 3 +-
.../DatabricksOperationStatusPollerTests.cs | 2 +-
17 files changed, 731 insertions(+), 206 deletions(-)
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
b/csharp/src/Drivers/Databricks/CloudFetch/Clock.cs
similarity index 51%
copy from csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
copy to csharp/src/Drivers/Databricks/CloudFetch/Clock.cs
index ee77dce9d..3d836115c 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/Clock.cs
@@ -15,34 +15,26 @@
* limitations under the License.
*/
-using Apache.Hive.Service.Rpc.Thrift;
+using System;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
- /// Interface for accessing HiveServer2Statement properties needed by
CloudFetchResultFetcher.
+ /// Abstraction for time operations to enable testing with controlled time.
/// </summary>
- internal interface IHiveServer2Statement
+ internal interface IClock
{
/// <summary>
- /// Gets the operation handle.
+ /// Gets the current UTC time.
/// </summary>
- TOperationHandle? OperationHandle { get; }
-
- /// <summary>
- /// Gets the client.
- /// </summary>
- TCLIService.IAsync Client { get; }
-
- /// <summary>
- /// Gets the direct results.
- /// </summary>
- TSparkDirectResults? DirectResults { get; }
+ DateTime UtcNow { get; }
+ }
- /// <summary>
- /// Checks if direct results are available.
- /// </summary>
- /// <returns>True if direct results are available and contain result
data, false otherwise.</returns>
- bool HasDirectResults { get; }
+ /// <summary>
+ /// Default implementation that uses system time.
+ /// </summary>
+ internal class SystemClock : IClock
+ {
+ public DateTime UtcNow => DateTime.UtcNow;
}
}
diff --git
a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs
index a64e6ffdc..97190dfec 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloadManager.cs
@@ -22,9 +22,8 @@ using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
-using Apache.Arrow.Adbc.Drivers.Databricks;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Manages the CloudFetch download pipeline.
@@ -38,6 +37,8 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
private const bool DefaultPrefetchEnabled = true;
private const int DefaultFetchBatchSize = 2000000;
private const int DefaultTimeoutMinutes = 5;
+ private const int DefaultMaxUrlRefreshAttempts = 3;
+ private const int DefaultUrlExpirationBufferSeconds = 60;
private readonly DatabricksStatement _statement;
private readonly Schema _schema;
@@ -151,6 +152,34 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
}
}
+ // Parse URL expiration buffer seconds
+ int urlExpirationBufferSeconds = DefaultUrlExpirationBufferSeconds;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchUrlExpirationBufferSeconds,
out string? urlExpirationBufferStr))
+ {
+ if (int.TryParse(urlExpirationBufferStr, out int
parsedUrlExpirationBuffer) && parsedUrlExpirationBuffer > 0)
+ {
+ urlExpirationBufferSeconds = parsedUrlExpirationBuffer;
+ }
+ else
+ {
+ throw new ArgumentException($"Invalid value for
{DatabricksParameters.CloudFetchUrlExpirationBufferSeconds}:
{urlExpirationBufferStr}. Expected a positive integer.");
+ }
+ }
+
+ // Parse max URL refresh attempts
+ int maxUrlRefreshAttempts = DefaultMaxUrlRefreshAttempts;
+ if
(connectionProps.TryGetValue(DatabricksParameters.CloudFetchMaxUrlRefreshAttempts,
out string? maxUrlRefreshAttemptsStr))
+ {
+ if (int.TryParse(maxUrlRefreshAttemptsStr, out int
parsedMaxUrlRefreshAttempts) && parsedMaxUrlRefreshAttempts > 0)
+ {
+ maxUrlRefreshAttempts = parsedMaxUrlRefreshAttempts;
+ }
+ else
+ {
+ throw new ArgumentException($"Invalid value for
{DatabricksParameters.CloudFetchMaxUrlRefreshAttempts}:
{maxUrlRefreshAttemptsStr}. Expected a positive integer.");
+ }
+ }
+
// Initialize the memory manager
_memoryManager = new
CloudFetchMemoryBufferManager(memoryBufferSizeMB);
@@ -161,12 +190,13 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
_httpClient = httpClient;
_httpClient.Timeout = TimeSpan.FromMinutes(timeoutMinutes);
- // Initialize the result fetcher
+ // Initialize the result fetcher with URL management capabilities
_resultFetcher = new CloudFetchResultFetcher(
_statement,
_memoryManager,
_downloadQueue,
- DefaultFetchBatchSize);
+ DefaultFetchBatchSize,
+ urlExpirationBufferSeconds);
// Initialize the downloader
_downloader = new CloudFetchDownloader(
@@ -174,10 +204,13 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
_resultQueue,
_memoryManager,
_httpClient,
+ _resultFetcher,
parallelDownloads,
_isLz4Compressed,
maxRetries,
- retryDelayMs);
+ retryDelayMs,
+ maxUrlRefreshAttempts,
+ urlExpirationBufferSeconds);
}
/// <summary>
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloader.cs
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloader.cs
index 8aadf58f2..d7f110601 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloader.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchDownloader.cs
@@ -24,7 +24,7 @@ using System.Threading;
using System.Threading.Tasks;
using K4os.Compression.LZ4.Streams;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Downloads files from URLs.
@@ -35,10 +35,13 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
private readonly BlockingCollection<IDownloadResult> _resultQueue;
private readonly ICloudFetchMemoryBufferManager _memoryManager;
private readonly HttpClient _httpClient;
+ private readonly ICloudFetchResultFetcher _resultFetcher;
private readonly int _maxParallelDownloads;
private readonly bool _isLz4Compressed;
private readonly int _maxRetries;
private readonly int _retryDelayMs;
+ private readonly int _maxUrlRefreshAttempts;
+ private readonly int _urlExpirationBufferSeconds;
private readonly SemaphoreSlim _downloadSemaphore;
private Task? _downloadTask;
private CancellationTokenSource? _cancellationTokenSource;
@@ -53,29 +56,37 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// <param name="resultQueue">The queue to add completed downloads
to.</param>
/// <param name="memoryManager">The memory buffer manager.</param>
/// <param name="httpClient">The HTTP client to use for
downloads.</param>
+ /// <param name="resultFetcher">The result fetcher that manages
URLs.</param>
/// <param name="maxParallelDownloads">The maximum number of parallel
downloads.</param>
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
- /// <param name="logger">The logger instance.</param>
/// <param name="maxRetries">The maximum number of retry
attempts.</param>
/// <param name="retryDelayMs">The delay between retry attempts in
milliseconds.</param>
+ /// <param name="maxUrlRefreshAttempts">The maximum number of URL
refresh attempts.</param>
+ /// <param name="urlExpirationBufferSeconds">Buffer time in seconds
before URL expiration to trigger refresh.</param>
public CloudFetchDownloader(
BlockingCollection<IDownloadResult> downloadQueue,
BlockingCollection<IDownloadResult> resultQueue,
ICloudFetchMemoryBufferManager memoryManager,
HttpClient httpClient,
+ ICloudFetchResultFetcher resultFetcher,
int maxParallelDownloads,
bool isLz4Compressed,
int maxRetries = 3,
- int retryDelayMs = 500)
+ int retryDelayMs = 500,
+ int maxUrlRefreshAttempts = 3,
+ int urlExpirationBufferSeconds = 60)
{
_downloadQueue = downloadQueue ?? throw new
ArgumentNullException(nameof(downloadQueue));
_resultQueue = resultQueue ?? throw new
ArgumentNullException(nameof(resultQueue));
_memoryManager = memoryManager ?? throw new
ArgumentNullException(nameof(memoryManager));
_httpClient = httpClient ?? throw new
ArgumentNullException(nameof(httpClient));
+ _resultFetcher = resultFetcher ?? throw new
ArgumentNullException(nameof(resultFetcher));
_maxParallelDownloads = maxParallelDownloads > 0 ?
maxParallelDownloads : throw new
ArgumentOutOfRangeException(nameof(maxParallelDownloads));
_isLz4Compressed = isLz4Compressed;
_maxRetries = maxRetries > 0 ? maxRetries : throw new
ArgumentOutOfRangeException(nameof(maxRetries));
_retryDelayMs = retryDelayMs > 0 ? retryDelayMs : throw new
ArgumentOutOfRangeException(nameof(retryDelayMs));
+ _maxUrlRefreshAttempts = maxUrlRefreshAttempts > 0 ?
maxUrlRefreshAttempts : throw new
ArgumentOutOfRangeException(nameof(maxUrlRefreshAttempts));
+ _urlExpirationBufferSeconds = urlExpirationBufferSeconds > 0 ?
urlExpirationBufferSeconds : throw new
ArgumentOutOfRangeException(nameof(urlExpirationBufferSeconds));
_downloadSemaphore = new SemaphoreSlim(_maxParallelDownloads,
_maxParallelDownloads);
_isCompleted = false;
}
@@ -237,6 +248,19 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
break;
}
+ // Check if the URL is expired or about to expire
+ if
(downloadResult.IsExpiredOrExpiringSoon(_urlExpirationBufferSeconds))
+ {
+ // Get a refreshed URL before starting the download
+ var refreshedLink = await
_resultFetcher.GetUrlAsync(downloadResult.Link.StartRowOffset,
cancellationToken);
+ if (refreshedLink != null)
+ {
+ // Update the download result with the refreshed
link
+
downloadResult.UpdateWithRefreshedLink(refreshedLink);
+ Trace.TraceInformation($"Updated URL for file at
offset {refreshedLink.StartRowOffset} before download");
+ }
+ }
+
// Acquire a download slot
await
_downloadSemaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
@@ -341,6 +365,37 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
HttpCompletionOption.ResponseHeadersRead,
cancellationToken).ConfigureAwait(false);
+ // Check if the response indicates an expired URL
(typically 403 or 401)
+ if (response.StatusCode ==
System.Net.HttpStatusCode.Forbidden ||
+ response.StatusCode ==
System.Net.HttpStatusCode.Unauthorized)
+ {
+ // If we've already tried refreshing too many times,
fail
+ if (downloadResult.RefreshAttempts >=
_maxUrlRefreshAttempts)
+ {
+ throw new InvalidOperationException($"Failed to
download file after {downloadResult.RefreshAttempts} URL refresh attempts.");
+ }
+
+ // Try to refresh the URL
+ var refreshedLink = await
_resultFetcher.GetUrlAsync(downloadResult.Link.StartRowOffset,
cancellationToken);
+ if (refreshedLink != null)
+ {
+ // Update the download result with the refreshed
link
+
downloadResult.UpdateWithRefreshedLink(refreshedLink);
+ url = refreshedLink.FileLink;
+ sanitizedUrl = SanitizeUrl(url);
+
+ Trace.TraceInformation($"URL for file at offset
{refreshedLink.StartRowOffset} was refreshed after expired URL response");
+
+ // Continue to the next retry attempt with the
refreshed URL
+ continue;
+ }
+ else
+ {
+ // If refresh failed, throw an exception
+ throw new InvalidOperationException("Failed to
refresh expired URL.");
+ }
+ }
+
response.EnsureSuccessStatusCode();
// Log the download size if available from response headers
diff --git
a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchMemoryBufferManager.cs
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchMemoryBufferManager.cs
index 7f5a13e10..584201abd 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchMemoryBufferManager.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchMemoryBufferManager.cs
@@ -19,7 +19,7 @@ using System;
using System.Threading;
using System.Threading.Tasks;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Manages memory allocation for prefetched files.
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
index d59e68bbc..8389faad9 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
@@ -21,8 +21,6 @@ using System.Diagnostics;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
-using Apache.Arrow.Adbc.Drivers.Databricks;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
diff --git
a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
index 3da5608ed..3167fd93a 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
@@ -19,20 +19,25 @@ using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
+using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Apache.Hive.Service.Rpc.Thrift;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
- /// Fetches result chunks from the Thrift server.
+ /// Fetches result chunks from the Thrift server and manages URL caching
and refreshing.
/// </summary>
- internal sealed class CloudFetchResultFetcher : ICloudFetchResultFetcher
+ internal class CloudFetchResultFetcher : ICloudFetchResultFetcher
{
private readonly IHiveServer2Statement _statement;
private readonly ICloudFetchMemoryBufferManager _memoryManager;
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
+ private readonly SemaphoreSlim _fetchLock = new SemaphoreSlim(1, 1);
+ private readonly ConcurrentDictionary<long, IDownloadResult>
_urlsByOffset = new ConcurrentDictionary<long, IDownloadResult>();
+ private readonly int _expirationBufferSeconds;
+ private readonly IClock _clock;
private long _startOffset;
private bool _hasMoreResults;
private bool _isCompleted;
@@ -47,19 +52,25 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// <param name="statement">The HiveServer2 statement
interface.</param>
/// <param name="memoryManager">The memory buffer manager.</param>
/// <param name="downloadQueue">The queue to add download tasks
to.</param>
- /// <param name="prefetchCount">The number of result chunks to
prefetch.</param>
+ /// <param name="batchSize">The number of rows to fetch in each
batch.</param>
+ /// <param name="expirationBufferSeconds">Buffer time in seconds
before URL expiration to trigger refresh.</param>
+ /// <param name="clock">Clock implementation for time operations. If
null, uses system clock.</param>
public CloudFetchResultFetcher(
IHiveServer2Statement statement,
ICloudFetchMemoryBufferManager memoryManager,
BlockingCollection<IDownloadResult> downloadQueue,
- long batchSize)
+ long batchSize,
+ int expirationBufferSeconds = 60,
+ IClock? clock = null)
{
_statement = statement ?? throw new
ArgumentNullException(nameof(statement));
_memoryManager = memoryManager ?? throw new
ArgumentNullException(nameof(memoryManager));
_downloadQueue = downloadQueue ?? throw new
ArgumentNullException(nameof(downloadQueue));
+ _batchSize = batchSize;
+ _expirationBufferSeconds = expirationBufferSeconds;
+ _clock = clock ?? new SystemClock();
_hasMoreResults = true;
_isCompleted = false;
- _batchSize = batchSize;
}
/// <inheritdoc />
@@ -87,6 +98,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
_hasMoreResults = true;
_isCompleted = false;
_error = null;
+ _urlsByOffset.Clear();
_cancellationTokenSource =
CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
_fetchTask = FetchResultsAsync(_cancellationTokenSource.Token);
@@ -124,6 +136,74 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
_fetchTask = null;
}
}
+ /// <inheritdoc />
+ public async Task<TSparkArrowResultLink?> GetUrlAsync(long offset,
CancellationToken cancellationToken)
+ {
+ // Check if we have a non-expired URL in the cache
+ if (_urlsByOffset.TryGetValue(offset, out var cachedResult) &&
!cachedResult.IsExpiredOrExpiringSoon(_expirationBufferSeconds))
+ {
+ return cachedResult.Link;
+ }
+
+ // Need to fetch or refresh the URL
+ await _fetchLock.WaitAsync(cancellationToken);
+ try
+ {
+ // Create fetch request for the specific offset
+ TFetchResultsReq request = new TFetchResultsReq(
+ _statement.OperationHandle!,
+ TFetchOrientation.FETCH_NEXT,
+ 1);
+
+ request.StartRowOffset = offset;
+
+ // Fetch results
+ TFetchResultsResp response = await
_statement.Client.FetchResults(request, cancellationToken);
+
+ // Process the results
+ if (response.Status.StatusCode == TStatusCode.SUCCESS_STATUS &&
+ response.Results.__isset.resultLinks &&
+ response.Results.ResultLinks != null &&
+ response.Results.ResultLinks.Count > 0)
+ {
+ var refreshedLink =
response.Results.ResultLinks.FirstOrDefault(l => l.StartRowOffset == offset);
+ if (refreshedLink != null)
+ {
+ Trace.TraceInformation($"Successfully fetched URL for
offset {offset}");
+
+ // Create a download result for the refreshed link
+ var downloadResult = new DownloadResult(refreshedLink,
_memoryManager);
+ _urlsByOffset[offset] = downloadResult;
+
+ return refreshedLink;
+ }
+ }
+
+ Trace.TraceWarning($"Failed to fetch URL for offset {offset}");
+ return null;
+ }
+ finally
+ {
+ _fetchLock.Release();
+ }
+ }
+
+ /// <summary>
+ /// Gets all currently cached URLs.
+ /// </summary>
+ /// <returns>A dictionary mapping offsets to their URL links.</returns>
+ public Dictionary<long, TSparkArrowResultLink> GetAllCachedUrls()
+ {
+ return _urlsByOffset.ToDictionary(kvp => kvp.Key, kvp =>
kvp.Value.Link);
+ }
+
+ /// <summary>
+ /// Clears all cached URLs.
+ /// </summary>
+ public void ClearCache()
+ {
+ _urlsByOffset.Clear();
+ }
private async Task FetchResultsAsync(CancellationToken
cancellationToken)
{
@@ -143,7 +223,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
try
{
// Fetch more results from the server
- await
FetchNextResultBatchAsync(cancellationToken).ConfigureAwait(false);
+ await FetchNextResultBatchAsync(null,
cancellationToken).ConfigureAwait(false);
}
catch (OperationCanceledException) when
(cancellationToken.IsCancellationRequested)
{
@@ -155,7 +235,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
Debug.WriteLine($"Error fetching results:
{ex.Message}");
_error = ex;
_hasMoreResults = false;
- break;
+ throw;
}
}
@@ -182,15 +262,16 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
}
}
- private async Task FetchNextResultBatchAsync(CancellationToken
cancellationToken)
+ private async Task FetchNextResultBatchAsync(long? offset,
CancellationToken cancellationToken)
{
// Create fetch request
TFetchResultsReq request = new
TFetchResultsReq(_statement.OperationHandle!, TFetchOrientation.FETCH_NEXT,
_batchSize);
- // Set the start row offset if we have processed some links already
- if (_startOffset > 0)
+ // Set the start row offset
+ long startOffset = offset ?? _startOffset;
+ if (startOffset > 0)
{
- request.StartRowOffset = _startOffset;
+ request.StartRowOffset = startOffset;
}
// Fetch results
@@ -212,19 +293,27 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
response.Results.ResultLinks.Count > 0)
{
List<TSparkArrowResultLink> resultLinks =
response.Results.ResultLinks;
+ long maxOffset = 0;
- // Add each link to the download queue
+ // Process each link
foreach (var link in resultLinks)
{
+ // Create download result
var downloadResult = new DownloadResult(link,
_memoryManager);
+
+ // Add to download queue and cache
_downloadQueue.Add(downloadResult, cancellationToken);
+ _urlsByOffset[link.StartRowOffset] = downloadResult;
+
+ // Track the maximum offset for future fetches
+ long endOffset = link.StartRowOffset + link.RowCount;
+ maxOffset = Math.Max(maxOffset, endOffset);
}
// Update the start offset for the next fetch
- if (resultLinks.Count > 0)
+ if (!offset.HasValue) // Only update if this was a sequential
fetch
{
- var lastLink = resultLinks[resultLinks.Count - 1];
- _startOffset = lastLink.StartRowOffset + lastLink.RowCount;
+ _startOffset = maxOffset;
}
// Update whether there are more results
@@ -240,20 +329,27 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
private void ProcessDirectResultsAsync(CancellationToken
cancellationToken)
{
List<TSparkArrowResultLink> resultLinks =
_statement.DirectResults!.ResultSet.Results.ResultLinks;
+ long maxOffset = 0;
+ // Process each link
foreach (var link in resultLinks)
{
+ // Create download result
var downloadResult = new DownloadResult(link, _memoryManager);
+
+ // Add to download queue and cache
_downloadQueue.Add(downloadResult, cancellationToken);
+ _urlsByOffset[link.StartRowOffset] = downloadResult;
+
+ // Track the maximum offset for future fetches
+ long endOffset = link.StartRowOffset + link.RowCount;
+ maxOffset = Math.Max(maxOffset, endOffset);
}
// Update the start offset for the next fetch
- if (resultLinks.Count > 0)
- {
- var lastLink = resultLinks[resultLinks.Count - 1];
- _startOffset = lastLink.StartRowOffset + lastLink.RowCount;
- }
+ _startOffset = maxOffset;
+ // Update whether there are more results
_hasMoreResults = _statement.DirectResults!.ResultSet.HasMoreRows;
}
}
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/DownloadResult.cs
b/csharp/src/Drivers/Databricks/CloudFetch/DownloadResult.cs
index eb2736c33..55d2fdc4a 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/DownloadResult.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/DownloadResult.cs
@@ -20,7 +20,7 @@ using System.IO;
using System.Threading.Tasks;
using Apache.Hive.Service.Rpc.Thrift;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Represents a downloaded result file with its associated metadata.
@@ -47,7 +47,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
}
/// <inheritdoc />
- public TSparkArrowResultLink Link { get; }
+ public TSparkArrowResultLink Link { get; private set; }
/// <inheritdoc />
public Stream DataStream
@@ -72,6 +72,36 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// <inheritdoc />
public bool IsCompleted => _downloadCompletionSource.Task.IsCompleted
&& !_downloadCompletionSource.Task.IsFaulted;
+ /// <summary>
+ /// Gets the number of URL refresh attempts for this download.
+ /// </summary>
+ public int RefreshAttempts { get; private set; } = 0;
+
+ /// <summary>
+ /// Checks if the URL is expired or about to expire.
+ /// </summary>
+ /// <param name="expirationBufferSeconds">Buffer time in seconds
before expiration to consider a URL as expiring soon.</param>
+ /// <returns>True if the URL is expired or about to expire, false
otherwise.</returns>
+ public bool IsExpiredOrExpiringSoon(int expirationBufferSeconds = 60)
+ {
+ // Convert expiry time to DateTime
+ var expiryTime =
DateTimeOffset.FromUnixTimeMilliseconds(Link.ExpiryTime).UtcDateTime;
+
+ // Check if the URL is already expired or will expire soon
+ return DateTime.UtcNow.AddSeconds(expirationBufferSeconds) >=
expiryTime;
+ }
+
+ /// <summary>
+ /// Updates this download result with a refreshed link.
+ /// </summary>
+ /// <param name="refreshedLink">The refreshed link information.</param>
+ public void UpdateWithRefreshedLink(TSparkArrowResultLink
refreshedLink)
+ {
+ ThrowIfDisposed();
+ Link = refreshedLink ?? throw new
ArgumentNullException(nameof(refreshedLink));
+ RefreshAttempts++;
+ }
+
/// <inheritdoc />
public void SetCompleted(Stream dataStream, long size)
{
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs
b/csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs
index d305082cf..d4b94426e 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/EndOfResultsGuard.cs
@@ -20,7 +20,7 @@ using System.IO;
using System.Threading.Tasks;
using Apache.Hive.Service.Rpc.Thrift;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Special marker class that indicates the end of results in the download
queue.
@@ -54,12 +54,21 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// <inheritdoc />
public bool IsCompleted => true;
+ /// <inheritdoc />
+ public int RefreshAttempts => 0;
+
/// <inheritdoc />
public void SetCompleted(Stream dataStream, long size) => throw new
NotSupportedException("EndOfResultsGuard cannot be completed.");
/// <inheritdoc />
public void SetFailed(Exception exception) => throw new
NotSupportedException("EndOfResultsGuard cannot fail.");
+ /// <inheritdoc />
+ public void UpdateWithRefreshedLink(TSparkArrowResultLink
refreshedLink) => throw new NotSupportedException("EndOfResultsGuard cannot be
updated with a refreshed link.");
+
+ /// <inheritdoc />
+ public bool IsExpiredOrExpiringSoon(int expirationBufferSeconds = 60)
=> false;
+
/// <inheritdoc />
public void Dispose()
{
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs
b/csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs
index 444213087..1103f391c 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/ICloudFetchInterfaces.cs
@@ -22,7 +22,7 @@ using System.Threading;
using System.Threading.Tasks;
using Apache.Hive.Service.Rpc.Thrift;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Represents a downloaded result file with its associated metadata.
@@ -54,6 +54,11 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// </summary>
bool IsCompleted { get; }
+ /// <summary>
+ /// Gets the number of URL refresh attempts for this download.
+ /// </summary>
+ int RefreshAttempts { get; }
+
/// <summary>
/// Sets the download as completed with the provided data stream.
/// </summary>
@@ -66,6 +71,19 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// </summary>
/// <param name="exception">The exception that caused the
failure.</param>
void SetFailed(Exception exception);
+
+ /// <summary>
+ /// Updates this download result with a refreshed link.
+ /// </summary>
+ /// <param name="refreshedLink">The refreshed link information.</param>
+ void UpdateWithRefreshedLink(TSparkArrowResultLink refreshedLink);
+
+ /// <summary>
+ /// Checks if the URL is expired or about to expire.
+ /// </summary>
+ /// <param name="expirationBufferSeconds">Buffer time in seconds
before expiration to consider a URL as expiring soon.</param>
+ /// <returns>True if the URL is expired or about to expire, false
otherwise.</returns>
+ bool IsExpiredOrExpiringSoon(int expirationBufferSeconds = 60);
}
/// <summary>
@@ -142,6 +160,14 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// Gets the error encountered by the fetcher, if any.
/// </summary>
Exception? Error { get; }
+
+ /// <summary>
+ /// Gets a URL for the specified offset, fetching or refreshing as
needed.
+ /// </summary>
+ /// <param name="offset">The row offset for which to get a URL.</param>
+ /// <param name="cancellationToken">The cancellation token.</param>
+ /// <returns>The URL link for the specified offset, or null if not
available.</returns>
+ Task<TSparkArrowResultLink?> GetUrlAsync(long offset,
CancellationToken cancellationToken);
}
/// <summary>
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
index ee77dce9d..89359bb1d 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
@@ -17,7 +17,7 @@
using Apache.Hive.Service.Rpc.Thrift;
-namespace Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Interface for accessing HiveServer2Statement properties needed by
CloudFetchResultFetcher.
diff --git a/csharp/src/Drivers/Databricks/DatabricksOperationStatusPoller.cs
b/csharp/src/Drivers/Databricks/DatabricksOperationStatusPoller.cs
index 1963e10c9..279c7b5a6 100644
--- a/csharp/src/Drivers/Databricks/DatabricksOperationStatusPoller.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksOperationStatusPoller.cs
@@ -18,7 +18,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
+using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch;
using Apache.Hive.Service.Rpc.Thrift;
namespace Apache.Arrow.Adbc.Drivers.Databricks
diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs
b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
index db62c04b2..3c5b65716 100644
--- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
@@ -62,6 +62,18 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
/// </summary>
public const string CloudFetchTimeoutMinutes =
"adbc.databricks.cloudfetch.timeout_minutes";
+ /// <summary>
+ /// Buffer time in seconds before URL expiration to trigger refresh.
+ /// Default value is 60 seconds if not specified.
+ /// </summary>
+ public const string CloudFetchUrlExpirationBufferSeconds =
"adbc.databricks.cloudfetch.url_expiration_buffer_seconds";
+
+ /// <summary>
+ /// Maximum number of URL refresh attempts for CloudFetch downloads.
+ /// Default value is 3 if not specified.
+ /// </summary>
+ public const string CloudFetchMaxUrlRefreshAttempts =
"adbc.databricks.cloudfetch.max_url_refresh_attempts";
+
/// <summary>
/// Whether to enable the use of direct results when executing queries.
/// Default value is true if not specified.
diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs
b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
index cb92cdd5e..3d7d0a94c 100644
--- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
@@ -20,7 +20,7 @@ using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache;
using Apache.Arrow.Adbc.Drivers.Apache.Spark;
-using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
+using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Types;
using Apache.Hive.Service.Rpc.Thrift;
diff --git
a/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchDownloaderTest.cs
b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchDownloaderTest.cs
index 350ddf219..dddccba2d 100644
--- a/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchDownloaderTest.cs
+++ b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchDownloaderTest.cs
@@ -24,30 +24,47 @@ using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
+using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch;
using Apache.Hive.Service.Rpc.Thrift;
using Moq;
using Moq.Protected;
using Xunit;
-namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
{
public class CloudFetchDownloaderTest
{
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
private readonly BlockingCollection<IDownloadResult> _resultQueue;
private readonly Mock<ICloudFetchMemoryBufferManager>
_mockMemoryManager;
+ private readonly Mock<IHiveServer2Statement> _mockStatement;
+ private readonly Mock<ICloudFetchResultFetcher> _mockResultFetcher;
public CloudFetchDownloaderTest()
{
_downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
_resultQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
_mockMemoryManager = new Mock<ICloudFetchMemoryBufferManager>();
+ _mockStatement = new Mock<IHiveServer2Statement>();
+ _mockResultFetcher = new Mock<ICloudFetchResultFetcher>();
// Set up memory manager defaults
_mockMemoryManager.Setup(m =>
m.TryAcquireMemory(It.IsAny<long>())).Returns(true);
_mockMemoryManager.Setup(m =>
m.AcquireMemoryAsync(It.IsAny<long>(), It.IsAny<CancellationToken>()))
.Returns(Task.CompletedTask);
+
+ // Set up result fetcher defaults
+ _mockResultFetcher.Setup(f => f.GetUrlAsync(It.IsAny<long>(),
It.IsAny<CancellationToken>()))
+ .ReturnsAsync((long offset, CancellationToken token) =>
+ {
+ // Return a URL with the same offset
+ return new TSparkArrowResultLink
+ {
+ StartRowOffset = offset,
+ FileLink = $"http://test.com/file{offset}",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(30).ToUnixTimeMilliseconds()
+ };
+ });
}
[Fact]
@@ -77,6 +94,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
_resultQueue,
_mockMemoryManager.Object,
httpClient,
+ _mockResultFetcher.Object,
3, // maxParallelDownloads
false); // isLz4Compressed
@@ -108,9 +126,14 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
// Create a test download result
var mockDownloadResult = new Mock<IDownloadResult>();
- var resultLink = new TSparkArrowResultLink { FileLink =
"http://test.com/file1" };
+ var resultLink = new TSparkArrowResultLink {
+ FileLink = "http://test.com/file1",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(30).ToUnixTimeMilliseconds() // Set expiry 30
minutes in the future
+ };
mockDownloadResult.Setup(r => r.Link).Returns(resultLink);
mockDownloadResult.Setup(r =>
r.Size).Returns(testContentBytes.Length);
+ mockDownloadResult.Setup(r => r.RefreshAttempts).Returns(0);
+ mockDownloadResult.Setup(r =>
r.IsExpiredOrExpiringSoon(It.IsAny<int>())).Returns(false);
// Capture the stream and size passed to SetCompleted
Stream? capturedStream = null;
@@ -128,6 +151,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
_resultQueue,
_mockMemoryManager.Object,
httpClient,
+ _mockResultFetcher.Object,
1, // maxParallelDownloads
false, // isLz4Compressed
1, // maxRetries
@@ -189,9 +213,14 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
// Create a test download result
var mockDownloadResult = new Mock<IDownloadResult>();
- var resultLink = new TSparkArrowResultLink { FileLink =
"http://test.com/file1" };
+ var resultLink = new TSparkArrowResultLink {
+ FileLink = "http://test.com/file1",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(30).ToUnixTimeMilliseconds() // Set expiry 30
minutes in the future
+ };
mockDownloadResult.Setup(r => r.Link).Returns(resultLink);
mockDownloadResult.Setup(r => r.Size).Returns(1000); // Some
arbitrary size
+ mockDownloadResult.Setup(r => r.RefreshAttempts).Returns(0);
+ mockDownloadResult.Setup(r =>
r.IsExpiredOrExpiringSoon(It.IsAny<int>())).Returns(false);
// Capture when SetFailed is called
Exception? capturedException = null;
@@ -204,6 +233,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
_resultQueue,
_mockMemoryManager.Object,
httpClient,
+ _mockResultFetcher.Object,
1, // maxParallelDownloads
false, // isLz4Compressed
1, // maxRetries
@@ -256,9 +286,14 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
// Create test download results
var mockDownloadResult = new Mock<IDownloadResult>();
- var resultLink = new TSparkArrowResultLink { FileLink =
"http://test.com/file1" };
+ var resultLink = new TSparkArrowResultLink {
+ FileLink = "http://test.com/file1",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(30).ToUnixTimeMilliseconds() // Set expiry 30
minutes in the future
+ };
mockDownloadResult.Setup(r => r.Link).Returns(resultLink);
mockDownloadResult.Setup(r => r.Size).Returns(100);
+ mockDownloadResult.Setup(r => r.RefreshAttempts).Returns(0);
+ mockDownloadResult.Setup(r =>
r.IsExpiredOrExpiringSoon(It.IsAny<int>())).Returns(false);
// Capture when SetFailed is called
Exception? capturedException = null;
@@ -271,6 +306,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
_resultQueue,
_mockMemoryManager.Object,
httpClient,
+ _mockResultFetcher.Object,
1, // maxParallelDownloads
false, // isLz4Compressed
1, // maxRetries
@@ -345,9 +381,14 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
// Create a test download result
var mockDownloadResult = new Mock<IDownloadResult>();
- var resultLink = new TSparkArrowResultLink { FileLink =
"http://test.com/file1" };
+ var resultLink = new TSparkArrowResultLink {
+ FileLink = "http://test.com/file1",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(30).ToUnixTimeMilliseconds() // Set expiry 30
minutes in the future
+ };
mockDownloadResult.Setup(r => r.Link).Returns(resultLink);
mockDownloadResult.Setup(r => r.Size).Returns(100);
+ mockDownloadResult.Setup(r => r.RefreshAttempts).Returns(0);
+ mockDownloadResult.Setup(r =>
r.IsExpiredOrExpiringSoon(It.IsAny<int>())).Returns(false);
// Create the downloader and add the download to the queue
var downloader = new CloudFetchDownloader(
@@ -355,6 +396,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
_resultQueue,
_mockMemoryManager.Object,
httpClient,
+ _mockResultFetcher.Object,
1, // maxParallelDownloads
false); // isLz4Compressed
@@ -429,9 +471,14 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
for (int i = 0; i < totalDownloads; i++)
{
var mockDownloadResult = new Mock<IDownloadResult>();
- var resultLink = new TSparkArrowResultLink { FileLink =
$"http://test.com/file{i}" };
+ var resultLink = new TSparkArrowResultLink {
+ FileLink = $"http://test.com/file{i}",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(30).ToUnixTimeMilliseconds() // Set expiry 30
minutes in the future
+ };
mockDownloadResult.Setup(r => r.Link).Returns(resultLink);
mockDownloadResult.Setup(r => r.Size).Returns(100);
+ mockDownloadResult.Setup(r => r.RefreshAttempts).Returns(0);
+ mockDownloadResult.Setup(r =>
r.IsExpiredOrExpiringSoon(It.IsAny<int>())).Returns(false);
mockDownloadResult.Setup(r =>
r.SetCompleted(It.IsAny<Stream>(), It.IsAny<long>()))
.Callback<Stream, long>((_, _) => { });
downloadResults[i] = mockDownloadResult.Object;
@@ -443,6 +490,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
_resultQueue,
_mockMemoryManager.Object,
httpClient,
+ _mockResultFetcher.Object,
maxParallelDownloads,
false); // isLz4Compressed
@@ -484,6 +532,95 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
await downloader.StopAsync();
}
+ [Fact]
+ public async Task
DownloadFileAsync_RefreshesExpiredUrl_WhenHttpErrorOccurs()
+ {
+ // Arrange
+ // Create a mock HTTP handler that returns a 403 error for the
first request and success for the second
+ var mockHttpMessageHandler = new Mock<HttpMessageHandler>();
+ var requestCount = 0;
+
+ mockHttpMessageHandler
+ .Protected()
+ .Setup<Task<HttpResponseMessage>>(
+ "SendAsync",
+ ItExpr.IsAny<HttpRequestMessage>(),
+ ItExpr.IsAny<CancellationToken>())
+ .Returns<HttpRequestMessage, CancellationToken>(async
(request, token) =>
+ {
+ await Task.Delay(1, token); // Small delay to simulate
network
+
+ // First request fails with 403 Forbidden (expired URL)
+ if (requestCount == 0)
+ {
+ requestCount++;
+ return new
HttpResponseMessage(HttpStatusCode.Forbidden);
+ }
+
+ // Second request succeeds with the refreshed URL
+ return new HttpResponseMessage(HttpStatusCode.OK)
+ {
+ Content = new StringContent("Test content")
+ };
+ });
+
+ var httpClient = new HttpClient(mockHttpMessageHandler.Object);
+
+ // Create a test download result
+ var mockDownloadResult = new Mock<IDownloadResult>();
+ var resultLink = new TSparkArrowResultLink {
+ StartRowOffset = 0,
+ FileLink = "http://test.com/file1",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(-5).ToUnixTimeMilliseconds() // Set expiry in
the past
+ };
+ mockDownloadResult.Setup(r => r.Link).Returns(resultLink);
+ mockDownloadResult.Setup(r => r.Size).Returns(100);
+ mockDownloadResult.Setup(r => r.RefreshAttempts).Returns(0);
+ // Important: Set this to false so the initial URL refresh doesn't
happen
+ mockDownloadResult.Setup(r =>
r.IsExpiredOrExpiringSoon(It.IsAny<int>())).Returns(false);
+
+ // Setup URL refreshing - expect it to be called once during the
HTTP 403 error handling
+ var refreshedLink = new TSparkArrowResultLink {
+ StartRowOffset = 0,
+ FileLink = "http://test.com/file1-refreshed",
+ ExpiryTime =
DateTimeOffset.UtcNow.AddMinutes(30).ToUnixTimeMilliseconds() // Set new expiry
in the future
+ };
+ _mockResultFetcher.Setup(f => f.GetUrlAsync(0,
It.IsAny<CancellationToken>()))
+ .ReturnsAsync(refreshedLink);
+
+ // Create the downloader and add the download to the queue
+ var downloader = new CloudFetchDownloader(
+ _downloadQueue,
+ _resultQueue,
+ _mockMemoryManager.Object,
+ httpClient,
+ _mockResultFetcher.Object,
+ 1, // maxParallelDownloads
+ false, // isLz4Compressed
+ 2, // maxRetries
+ 10); // retryDelayMs
+
+ // Act
+ await downloader.StartAsync(CancellationToken.None);
+ _downloadQueue.Add(mockDownloadResult.Object);
+
+ // Wait for the download to be processed
+ await Task.Delay(200);
+
+ // Add the end of results guard to complete the downloader
+ _downloadQueue.Add(EndOfResultsGuard.Instance);
+
+ // Assert
+ // Verify that GetUrlAsync was called exactly once to refresh the
URL
+ _mockResultFetcher.Verify(f => f.GetUrlAsync(0,
It.IsAny<CancellationToken>()), Times.Once);
+
+ // Verify that UpdateWithRefreshedLink was called with the
refreshed link
+ mockDownloadResult.Verify(r =>
r.UpdateWithRefreshedLink(refreshedLink), Times.Once);
+
+ // Cleanup
+ await downloader.StopAsync();
+ }
+
private static Mock<HttpMessageHandler> CreateMockHttpMessageHandler(
byte[]? content,
HttpStatusCode statusCode = HttpStatusCode.OK,
diff --git
a/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs
b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs
index 32f9a1a81..d70a378e0 100644
--- a/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs
+++ b/csharp/test/Drivers/Databricks/CloudFetch/CloudFetchResultFetcherTest.cs
@@ -18,53 +18,195 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
+using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch;
using Apache.Hive.Service.Rpc.Thrift;
using Moq;
using Xunit;
-namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
+namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
{
/// <summary>
/// Tests for CloudFetchResultFetcher
/// </summary>
public class CloudFetchResultFetcherTest
{
- private readonly Mock<ICloudFetchMemoryBufferManager>
_mockMemoryManager;
+ private readonly Mock<IHiveServer2Statement> _mockStatement;
+ private readonly Mock<TCLIService.IAsync> _mockClient;
+ private readonly TOperationHandle _operationHandle;
+ private readonly MockClock _mockClock;
+ private readonly CloudFetchResultFetcherWithMockClock _resultFetcher;
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
+ private readonly Mock<ICloudFetchMemoryBufferManager>
_mockMemoryManager;
public CloudFetchResultFetcherTest()
{
- _mockMemoryManager = new Mock<ICloudFetchMemoryBufferManager>();
+ _mockClient = new Mock<TCLIService.IAsync>();
+ _mockStatement = new Mock<IHiveServer2Statement>();
+ _operationHandle = new TOperationHandle
+ {
+ OperationId = new THandleIdentifier { Guid = new byte[] { 1,
2, 3, 4 } },
+ OperationType = TOperationType.EXECUTE_STATEMENT,
+ HasResultSet = true
+ };
+
+ _mockStatement.Setup(s => s.Client).Returns(_mockClient.Object);
+ _mockStatement.Setup(s =>
s.OperationHandle).Returns(_operationHandle);
+
+ _mockClock = new MockClock();
_downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
+ _mockMemoryManager = new Mock<ICloudFetchMemoryBufferManager>();
+
+ _resultFetcher = new CloudFetchResultFetcherWithMockClock(
+ _mockStatement.Object,
+ _mockMemoryManager.Object,
+ _downloadQueue,
+ 100, // batchSize
+ _mockClock,
+ 60); // expirationBufferSeconds
}
+ #region URL Management Tests
+
[Fact]
- public async Task StartAsync_CalledTwice_ThrowsException()
+ public async Task GetUrlAsync_FetchesNewUrl_WhenNotCached()
{
// Arrange
- var mockClient = new Mock<TCLIService.IAsync>();
- mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(),
It.IsAny<CancellationToken>()))
- .ReturnsAsync(CreateFetchResultsResponse(new
List<TSparkArrowResultLink>(), false));
+ long offset = 0;
+ var resultLink = CreateTestResultLink(offset, 100,
"http://test.com/file1", 3600);
+ SetupMockClientFetchResults(new List<TSparkArrowResultLink> {
resultLink }, true);
- var mockStatement = new Mock<IHiveServer2Statement>();
- mockStatement.Setup(s =>
s.OperationHandle).Returns(CreateOperationHandle());
- mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
+ // Act
+ var result = await _resultFetcher.GetUrlAsync(offset,
CancellationToken.None);
- var fetcher = new CloudFetchResultFetcher(
- mockStatement.Object,
- _mockMemoryManager.Object,
- _downloadQueue,
- 5); // batchSize
+ // Assert
+ Assert.NotNull(result);
+ Assert.Equal(offset, result.StartRowOffset);
+ Assert.Equal("http://test.com/file1", result.FileLink);
+ _mockClient.Verify(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()),
Times.Once);
+ }
+
+ [Fact]
+ public async Task GetUrlRangeAsync_FetchesMultipleUrls()
+ {
+ // Arrange
+ var resultLinks = new List<TSparkArrowResultLink>
+ {
+ CreateTestResultLink(0, 100, "http://test.com/file1", 3600),
+ CreateTestResultLink(100, 100, "http://test.com/file2", 3600),
+ CreateTestResultLink(200, 100, "http://test.com/file3", 3600)
+ };
+
+ // Set hasMoreRows to false so the fetcher doesn't keep trying to
fetch more results
+ SetupMockClientFetchResults(resultLinks, false);
+
+ // Act
+ await _resultFetcher.StartAsync(CancellationToken.None);
+
+ // Wait for the fetcher to process the links and complete
+ await Task.Delay(200);
+
+ // Get all cached URLs
+ var cachedUrls = _resultFetcher.GetAllCachedUrls();
+
+ // Assert
+ Assert.Equal(3, cachedUrls.Count);
+ Assert.Equal("http://test.com/file1", cachedUrls[0].FileLink);
+ Assert.Equal("http://test.com/file2", cachedUrls[100].FileLink);
+ Assert.Equal("http://test.com/file3", cachedUrls[200].FileLink);
+ _mockClient.Verify(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()),
Times.Once);
+
+ // Verify the fetcher completed
+ Assert.True(_resultFetcher.IsCompleted);
+ Assert.False(_resultFetcher.HasMoreResults);
+
+ // No need to stop explicitly as it should have completed
naturally,
+ // but it's good practice to clean up
+ await _resultFetcher.StopAsync();
+ }
+
+ [Fact]
+ public async Task ClearCache_RemovesAllCachedUrls()
+ {
+ // Arrange
+ var resultLinks = new List<TSparkArrowResultLink>
+ {
+ CreateTestResultLink(0, 100, "http://test.com/file1", 3600),
+ CreateTestResultLink(100, 100, "http://test.com/file2", 3600)
+ };
+
+ // Set hasMoreRows to false so the fetcher doesn't keep trying to
fetch more results
+ SetupMockClientFetchResults(resultLinks, false);
+
+ // Cache the URLs
+ await _resultFetcher.StartAsync(CancellationToken.None);
+
+ // Wait for the fetcher to process the links and complete
+ await Task.Delay(200);
+
+ // Act
+ _resultFetcher.ClearCache();
+ var cachedUrls = _resultFetcher.GetAllCachedUrls();
+
+ // Assert
+ Assert.Empty(cachedUrls);
+
+ // Verify the fetcher completed
+ Assert.True(_resultFetcher.IsCompleted);
+ Assert.False(_resultFetcher.HasMoreResults);
+
+ // Cleanup
+ await _resultFetcher.StopAsync();
+ }
+
+ [Fact]
+ public async Task GetUrlAsync_RefreshesExpiredUrl()
+ {
+ // Arrange
+ long offset = 0;
+ // Create a URL that will expire soon
+ var expiredLink = CreateTestResultLink(offset, 100,
"http://test.com/expired", 30);
+ var refreshedLink = CreateTestResultLink(offset, 100,
"http://test.com/refreshed", 3600);
+
+ // First return the expired link, then the refreshed one
+ _mockClient.SetupSequence(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
+ .ReturnsAsync(CreateFetchResultsResponse(new
List<TSparkArrowResultLink> { expiredLink }, true))
+ .ReturnsAsync(CreateFetchResultsResponse(new
List<TSparkArrowResultLink> { refreshedLink }, true));
+
+ // First fetch to cache the soon-to-expire URL
+ await _resultFetcher.GetUrlAsync(offset, CancellationToken.None);
+
+ // Advance time so the URL is now expired
+ _mockClock.AdvanceTime(TimeSpan.FromSeconds(40));
+
+ // Act - This should refresh the URL
+ var result = await _resultFetcher.GetUrlAsync(offset,
CancellationToken.None);
+
+ // Assert
+ Assert.NotNull(result);
+ Assert.Equal("http://test.com/refreshed", result.FileLink);
+ _mockClient.Verify(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()),
Times.Exactly(2));
+ }
+
+ #endregion
+
+ #region Core Functionality Tests (Restored)
+
+ [Fact]
+ public async Task StartAsync_CalledTwice_ThrowsException()
+ {
+ // Arrange
+ SetupMockClientFetchResults(new List<TSparkArrowResultLink>(),
false);
// Act & Assert
- await fetcher.StartAsync(CancellationToken.None);
- await Assert.ThrowsAsync<InvalidOperationException>(() =>
fetcher.StartAsync(CancellationToken.None));
+ await _resultFetcher.StartAsync(CancellationToken.None);
+ await Assert.ThrowsAsync<InvalidOperationException>(() =>
_resultFetcher.StartAsync(CancellationToken.None));
// Cleanup
- await fetcher.StopAsync();
+ await _resultFetcher.StopAsync();
}
[Fact]
@@ -73,34 +215,21 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
// Arrange
var resultLinks = new List<TSparkArrowResultLink>
{
- CreateTestResultLink(0, 100, "http://test.com/file1"),
- CreateTestResultLink(100, 100, "http://test.com/file2"),
- CreateTestResultLink(200, 100, "http://test.com/file3")
+ CreateTestResultLink(0, 100, "http://test.com/file1", 3600),
+ CreateTestResultLink(100, 100, "http://test.com/file2", 3600),
+ CreateTestResultLink(200, 100, "http://test.com/file3", 3600)
};
- var mockClient = new Mock<TCLIService.IAsync>();
- mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(),
It.IsAny<CancellationToken>()))
- .ReturnsAsync(CreateFetchResultsResponse(resultLinks, false));
-
- var mockStatement = new Mock<IHiveServer2Statement>();
- mockStatement.Setup(s =>
s.OperationHandle).Returns(CreateOperationHandle());
- mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
-
- var fetcher = new CloudFetchResultFetcher(
- mockStatement.Object,
- _mockMemoryManager.Object,
- _downloadQueue,
- 5); // batchSize
+ SetupMockClientFetchResults(resultLinks, false);
// Act
- await fetcher.StartAsync(CancellationToken.None);
+ await _resultFetcher.StartAsync(CancellationToken.None);
// Wait for the fetcher to process the results
await Task.Delay(100);
// Assert
// The download queue should contain our result links
- // Note: With prefetch, there might be more items in the queue
than just our result links
Assert.True(_downloadQueue.Count >= resultLinks.Count,
$"Expected at least {resultLinks.Count} items in queue, but
found {_downloadQueue.Count}");
@@ -127,13 +256,13 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
}
// Verify the fetcher state
- Assert.False(fetcher.HasMoreResults);
- Assert.True(fetcher.IsCompleted);
- Assert.False(fetcher.HasError);
- Assert.Null(fetcher.Error);
+ Assert.False(_resultFetcher.HasMoreResults);
+ Assert.True(_resultFetcher.IsCompleted);
+ Assert.False(_resultFetcher.HasError);
+ Assert.Null(_resultFetcher.Error);
// Cleanup
- await fetcher.StopAsync();
+ await _resultFetcher.StopAsync();
}
[Fact]
@@ -142,40 +271,28 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
// Arrange
var firstBatchLinks = new List<TSparkArrowResultLink>
{
- CreateTestResultLink(0, 100, "http://test.com/file1"),
- CreateTestResultLink(100, 100, "http://test.com/file2")
+ CreateTestResultLink(0, 100, "http://test.com/file1", 3600),
+ CreateTestResultLink(100, 100, "http://test.com/file2", 3600)
};
var secondBatchLinks = new List<TSparkArrowResultLink>
{
- CreateTestResultLink(200, 100, "http://test.com/file3"),
- CreateTestResultLink(300, 100, "http://test.com/file4")
+ CreateTestResultLink(200, 100, "http://test.com/file3", 3600),
+ CreateTestResultLink(300, 100, "http://test.com/file4", 3600)
};
- var mockClient = new Mock<TCLIService.IAsync>();
- mockClient.SetupSequence(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
+ _mockClient.SetupSequence(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
.ReturnsAsync(CreateFetchResultsResponse(firstBatchLinks,
true))
.ReturnsAsync(CreateFetchResultsResponse(secondBatchLinks,
false));
- var mockStatement = new Mock<IHiveServer2Statement>();
- mockStatement.Setup(s =>
s.OperationHandle).Returns(CreateOperationHandle());
- mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
-
- var fetcher = new CloudFetchResultFetcher(
- mockStatement.Object,
- _mockMemoryManager.Object,
- _downloadQueue,
- 5); // batchSize
-
// Act
- await fetcher.StartAsync(CancellationToken.None);
+ await _resultFetcher.StartAsync(CancellationToken.None);
// Wait for the fetcher to process all results
await Task.Delay(200);
// Assert
// The download queue should contain all result links (both
batches)
- // Note: With prefetch, there might be more items in the queue
than just our result links
Assert.True(_downloadQueue.Count >= firstBatchLinks.Count +
secondBatchLinks.Count,
$"Expected at least {firstBatchLinks.Count +
secondBatchLinks.Count} items in queue, but found {_downloadQueue.Count}");
@@ -194,34 +311,22 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
Assert.Equal(firstBatchLinks.Count + secondBatchLinks.Count,
downloadResults.Count);
// Verify the fetcher state
- Assert.False(fetcher.HasMoreResults);
- Assert.True(fetcher.IsCompleted);
- Assert.False(fetcher.HasError);
+ Assert.False(_resultFetcher.HasMoreResults);
+ Assert.True(_resultFetcher.IsCompleted);
+ Assert.False(_resultFetcher.HasError);
// Cleanup
- await fetcher.StopAsync();
+ await _resultFetcher.StopAsync();
}
[Fact]
public async Task
FetchResultsAsync_WithEmptyResults_CompletesGracefully()
{
// Arrange
- var mockClient = new Mock<TCLIService.IAsync>();
- mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(),
It.IsAny<CancellationToken>()))
- .ReturnsAsync(CreateFetchResultsResponse(new
List<TSparkArrowResultLink>(), false));
-
- var mockStatement = new Mock<IHiveServer2Statement>();
- mockStatement.Setup(s =>
s.OperationHandle).Returns(CreateOperationHandle());
- mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
-
- var fetcher = new CloudFetchResultFetcher(
- mockStatement.Object,
- _mockMemoryManager.Object,
- _downloadQueue,
- 5); // batchSize
+ SetupMockClientFetchResults(new List<TSparkArrowResultLink>(),
false);
// Act
- await fetcher.StartAsync(CancellationToken.None);
+ await _resultFetcher.StartAsync(CancellationToken.None);
// Wait for the fetcher to process the results
await Task.Delay(100);
@@ -239,53 +344,40 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
Assert.Empty(nonGuardItems);
// Verify the fetcher state
- Assert.False(fetcher.HasMoreResults);
- Assert.True(fetcher.IsCompleted);
- Assert.False(fetcher.HasError);
+ Assert.False(_resultFetcher.HasMoreResults);
+ Assert.True(_resultFetcher.IsCompleted);
+ Assert.False(_resultFetcher.HasError);
// Cleanup
- await fetcher.StopAsync();
+ await _resultFetcher.StopAsync();
}
[Fact]
public async Task FetchResultsAsync_WithServerError_SetsErrorState()
{
// Arrange
- var mockClient = new Mock<TCLIService.IAsync>();
- mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(),
It.IsAny<CancellationToken>()))
+ _mockClient.Setup(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
.ThrowsAsync(new InvalidOperationException("Test server
error"));
- var mockStatement = new Mock<IHiveServer2Statement>();
- mockStatement.Setup(s =>
s.OperationHandle).Returns(CreateOperationHandle());
- mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
-
- var fetcher = new CloudFetchResultFetcher(
- mockStatement.Object,
- _mockMemoryManager.Object,
- _downloadQueue,
- 5); // batchSize
-
// Act
- await fetcher.StartAsync(CancellationToken.None);
+ await _resultFetcher.StartAsync(CancellationToken.None);
// Wait for the fetcher to process the error
await Task.Delay(100);
// Assert
// Verify the fetcher state
- Assert.False(fetcher.HasMoreResults);
- Assert.True(fetcher.IsCompleted);
- Assert.True(fetcher.HasError);
- Assert.NotNull(fetcher.Error);
- Assert.IsType<InvalidOperationException>(fetcher.Error);
+ Assert.False(_resultFetcher.HasMoreResults);
+ Assert.True(_resultFetcher.IsCompleted);
+ Assert.True(_resultFetcher.HasError);
+ Assert.NotNull(_resultFetcher.Error);
+ Assert.IsType<InvalidOperationException>(_resultFetcher.Error);
// The download queue should have the end guard
- Assert.Single(_downloadQueue);
- var result = _downloadQueue.Take();
- Assert.Same(EndOfResultsGuard.Instance, result);
+ Assert.True(_downloadQueue.Count <= 1, "Expected at most 1 item
(end guard) in queue");
// Cleanup
- await fetcher.StopAsync();
+ await _resultFetcher.StopAsync();
}
[Fact]
@@ -295,8 +387,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
var fetchStarted = new TaskCompletionSource<bool>();
var fetchCancelled = new TaskCompletionSource<bool>();
- var mockClient = new Mock<TCLIService.IAsync>();
- mockClient.Setup(c => c.FetchResults(It.IsAny<TFetchResultsReq>(),
It.IsAny<CancellationToken>()))
+ _mockClient.Setup(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
.Returns(async (TFetchResultsReq req, CancellationToken token)
=>
{
fetchStarted.TrySetResult(true);
@@ -316,24 +407,14 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
return CreateFetchResultsResponse(new
List<TSparkArrowResultLink>(), false);
});
- var mockStatement = new Mock<IHiveServer2Statement>();
- mockStatement.Setup(s =>
s.OperationHandle).Returns(CreateOperationHandle());
- mockStatement.Setup(s => s.Client).Returns(mockClient.Object);
-
- var fetcher = new CloudFetchResultFetcher(
- mockStatement.Object,
- _mockMemoryManager.Object,
- _downloadQueue,
- 5); // batchSize
-
// Act
- await fetcher.StartAsync(CancellationToken.None);
+ await _resultFetcher.StartAsync(CancellationToken.None);
// Wait for the fetch to start
await fetchStarted.Task;
// Stop the fetcher
- await fetcher.StopAsync();
+ await _resultFetcher.StopAsync();
// Assert
// Wait a short time for cancellation to propagate
@@ -341,46 +422,101 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Apache.Databricks.CloudFetch
Assert.True(cancelled, "Fetch operation should have been
cancelled");
// Verify the fetcher state
- Assert.True(fetcher.IsCompleted);
+ Assert.True(_resultFetcher.IsCompleted);
}
- private TOperationHandle CreateOperationHandle()
+ #endregion
+
+ #region Helper Methods
+
+ private TSparkArrowResultLink CreateTestResultLink(long
startRowOffset, int rowCount, string fileLink, int expirySeconds)
{
- return new TOperationHandle
+ return new TSparkArrowResultLink
{
- OperationId = new THandleIdentifier
- {
- Guid = new byte[16],
- Secret = new byte[16]
- },
- OperationType = TOperationType.EXECUTE_STATEMENT,
- HasResultSet = true
+ StartRowOffset = startRowOffset,
+ RowCount = rowCount,
+ FileLink = fileLink,
+ ExpiryTime = new
DateTimeOffset(_mockClock.UtcNow.AddSeconds(expirySeconds)).ToUnixTimeMilliseconds()
};
}
- private TFetchResultsResp
CreateFetchResultsResponse(List<TSparkArrowResultLink> resultLinks, bool
hasMoreRows)
+ private void SetupMockClientFetchResults(List<TSparkArrowResultLink>
resultLinks, bool hasMoreRows)
{
- var results = new TRowSet();
- results.__isset.resultLinks = true;
+ var results = new TRowSet { __isset = { resultLinks = true } };
results.ResultLinks = resultLinks;
- return new TFetchResultsResp
+ var response = new TFetchResultsResp
{
Status = new TStatus { StatusCode = TStatusCode.SUCCESS_STATUS
},
HasMoreRows = hasMoreRows,
Results = results,
__isset = { results = true, hasMoreRows = true }
};
+
+ // Clear any previous setups
+ _mockClient.Reset();
+
+ // Setup for any fetch request
+ _mockClient.Setup(c =>
c.FetchResults(It.IsAny<TFetchResultsReq>(), It.IsAny<CancellationToken>()))
+ .ReturnsAsync(response);
}
- private TSparkArrowResultLink CreateTestResultLink(long
startRowOffset, int rowCount, string fileLink)
+ private TFetchResultsResp
CreateFetchResultsResponse(List<TSparkArrowResultLink> resultLinks, bool
hasMoreRows)
{
- return new TSparkArrowResultLink
+ var results = new TRowSet { __isset = { resultLinks = true } };
+ results.ResultLinks = resultLinks;
+
+ return new TFetchResultsResp
{
- StartRowOffset = startRowOffset,
- RowCount = rowCount,
- FileLink = fileLink
+ Status = new TStatus { StatusCode = TStatusCode.SUCCESS_STATUS
},
+ HasMoreRows = hasMoreRows,
+ Results = results,
+ __isset = { results = true, hasMoreRows = true }
};
}
+
+ #endregion
+ }
+
+ /// <summary>
+ /// Mock clock implementation for testing time-dependent behavior.
+ /// </summary>
+ public class MockClock : IClock
+ {
+ private DateTimeOffset _now;
+
+ public MockClock()
+ {
+ _now = DateTimeOffset.UtcNow;
+ }
+
+ public DateTime UtcNow => _now.UtcDateTime;
+
+ public void AdvanceTime(TimeSpan timeSpan)
+ {
+ _now = _now.Add(timeSpan);
+ }
+
+ public void SetTime(DateTimeOffset time)
+ {
+ _now = time;
+ }
+ }
+
+ /// <summary>
+ /// Extension of CloudFetchResultFetcher that uses a mock clock for
testing.
+ /// </summary>
+ internal class CloudFetchResultFetcherWithMockClock :
CloudFetchResultFetcher
+ {
+ public CloudFetchResultFetcherWithMockClock(
+ IHiveServer2Statement statement,
+ ICloudFetchMemoryBufferManager memoryManager,
+ BlockingCollection<IDownloadResult> downloadQueue,
+ long batchSize,
+ IClock clock,
+ int expirationBufferSeconds = 60)
+ : base(statement, memoryManager, downloadQueue, batchSize,
expirationBufferSeconds, clock)
+ {
+ }
}
}
diff --git a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
index a96b88cbf..31bd2e87d 100644
--- a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
+++ b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
@@ -65,7 +65,8 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
[DatabricksParameters.UseCloudFetch] =
useCloudFetch.ToString(),
[DatabricksParameters.EnableDirectResults] =
enableDirectResults.ToString(),
[DatabricksParameters.CanDecompressLz4] = "true",
- [DatabricksParameters.MaxBytesPerFile] = "10485760" // 10MB
+ [DatabricksParameters.MaxBytesPerFile] = "10485760", // 10MB
+ [DatabricksParameters.CloudFetchUrlExpirationBufferSeconds] =
(15 * 60 - 2).ToString(),
});
// Execute a query that generates a large result set using range
function
diff --git
a/csharp/test/Drivers/Databricks/DatabricksOperationStatusPollerTests.cs
b/csharp/test/Drivers/Databricks/DatabricksOperationStatusPollerTests.cs
index 9d91ed590..c72326a62 100644
--- a/csharp/test/Drivers/Databricks/DatabricksOperationStatusPollerTests.cs
+++ b/csharp/test/Drivers/Databricks/DatabricksOperationStatusPollerTests.cs
@@ -18,7 +18,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch;
+using Apache.Arrow.Adbc.Drivers.Databricks.CloudFetch;
using Apache.Arrow.Adbc.Drivers.Databricks;
using Apache.Hive.Service.Rpc.Thrift;
using Moq;