This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 9eee98deb feat(csharp): Implement CloudFetch for Databricks Spark
driver (#2634)
9eee98deb is described below
commit 9eee98deb5d902dbe5322e5a2a9b5c922a80b220
Author: Jade Wang <[email protected]>
AuthorDate: Mon Mar 31 13:21:27 2025 -0700
feat(csharp): Implement CloudFetch for Databricks Spark driver (#2634)
Initial implementation of adding CloudFetch feature in Databricks Spark
Driver.
- create a new CloudFetchReader to handle CloudFetch file download and
decompress.
- Test case for small and large result.
Coming changes after this
- Adding prefetch to the downloader
- Adding renewal for expired presigned url
- Retries
---
.../Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj | 2 +
.../Drivers/Apache/Hive2/HiveServer2Connection.cs | 2 +-
.../Apache/Hive2/HiveServer2HttpConnection.cs | 2 +-
.../Drivers/Apache/Hive2/HiveServer2Statement.cs | 14 +-
.../Drivers/Apache/Impala/ImpalaHttpConnection.cs | 2 +-
.../Apache/Impala/ImpalaStandardConnection.cs | 2 +-
.../Spark/CloudFetch/SparkCloudFetchReader.cs | 318 +++++++++++++++++++++
csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 3 +-
.../Apache/Spark/SparkDatabricksConnection.cs | 33 ++-
.../Drivers/Apache/Spark/SparkDatabricksReader.cs | 2 +-
.../Drivers/Apache/Spark/SparkHttpConnection.cs | 2 +-
csharp/src/Drivers/Apache/Spark/SparkParameters.cs | 19 ++
csharp/src/Drivers/Apache/Spark/SparkStatement.cs | 103 ++++++-
.../Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj | 1 +
.../test/Drivers/Apache/Spark/CloudFetchE2ETest.cs | 94 ++++++
15 files changed, 580 insertions(+), 19 deletions(-)
diff --git a/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj
b/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj
index 7e4c7c096..2ad285410 100644
--- a/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj
+++ b/csharp/src/Drivers/Apache/Apache.Arrow.Adbc.Drivers.Apache.csproj
@@ -6,6 +6,8 @@
<ItemGroup>
<PackageReference Include="ApacheThrift" Version="0.21.0" />
+ <PackageReference Include="K4os.Compression.LZ4" Version="1.3.8" />
+ <PackageReference Include="K4os.Compression.LZ4.Streams" Version="1.3.8" />
<PackageReference Include="System.Net.Http" Version="4.3.4" />
<PackageReference Include="System.Text.Json" Version="8.0.5" />
</ItemGroup>
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index ab3efea31..990cb4774 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -354,7 +354,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
internal abstract SchemaParser SchemaParser { get; }
- internal abstract IArrowArrayStream NewReader<T>(T statement, Schema
schema) where T : HiveServer2Statement;
+ internal abstract IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) where T :
HiveServer2Statement;
public override IArrowArrayStream GetObjects(GetObjectsDepth depth,
string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern,
IReadOnlyList<string>? tableTypes, string? columnNamePattern)
{
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs
index 187e5712c..6ebcb9267 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2HttpConnection.cs
@@ -144,7 +144,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return new HiveServer2Statement(this);
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new HiveServer2Reader(
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(
statement,
schema,
dataTypeConversion: statement.Connection.DataTypeConversion,
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index c08f997ca..9042b4205 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -84,9 +84,11 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
// take QueryTimeoutSeconds (but this could be restricting)
await ExecuteStatementAsync(cancellationToken); // --> get
QueryTimeout +
await HiveServer2Connection.PollForResponseAsync(OperationHandle!,
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to
QueryTimeout
- Schema schema = await GetResultSetSchemaAsync(OperationHandle!,
Connection.Client, cancellationToken); // + get the result, up to QueryTimeout
+ TGetResultSetMetadataResp response = await
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!,
Connection.Client, cancellationToken);
+ Schema schema =
Connection.SchemaParser.GetArrowSchema(response.Schema,
Connection.DataTypeConversion);
- return new QueryResult(-1, Connection.NewReader(this, schema));
+ // Store metadata for use in readers
+ return new QueryResult(-1, Connection.NewReader(this, schema,
response));
}
public override async ValueTask<QueryResult> ExecuteQueryAsync()
@@ -108,12 +110,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
}
- private async Task<Schema> GetResultSetSchemaAsync(TOperationHandle
operationHandle, TCLIService.IAsync client, CancellationToken cancellationToken
= default)
- {
- TGetResultSetMetadataResp response = await
HiveServer2Connection.GetResultSetMetadataAsync(operationHandle, client,
cancellationToken);
- return Connection.SchemaParser.GetArrowSchema(response.Schema,
Connection.DataTypeConversion);
- }
-
public async Task<UpdateResult>
ExecuteUpdateAsyncInternal(CancellationToken cancellationToken = default)
{
const string NumberOfAffectedRowsColumnName = "num_affected_rows";
@@ -195,7 +191,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected async Task ExecuteStatementAsync(CancellationToken
cancellationToken = default)
{
- TExecuteStatementReq executeRequest = new
TExecuteStatementReq(Connection.SessionHandle, SqlQuery);
+ TExecuteStatementReq executeRequest = new
TExecuteStatementReq(Connection.SessionHandle!, SqlQuery!);
SetStatementProperties(executeRequest);
TExecuteStatementResp executeResponse = await
Connection.Client.ExecuteStatement(executeRequest, cancellationToken);
if (executeResponse.Status.StatusCode == TStatusCode.ERROR_STATUS)
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
index 914ba9269..ef5c34166 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
@@ -123,7 +123,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
TlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(Properties);
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
protected override TTransport CreateTransport()
{
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
index 99ac368be..2665070bb 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
@@ -149,7 +149,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
return request;
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
internal override ImpalaServerType ServerType =>
ImpalaServerType.Standard;
diff --git
a/csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs
b/csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs
new file mode 100644
index 000000000..343bb5a0d
--- /dev/null
+++ b/csharp/src/Drivers/Apache/Spark/CloudFetch/SparkCloudFetchReader.cs
@@ -0,0 +1,318 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Net.Http;
+using System.Threading;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
+using Apache.Arrow.Ipc;
+using Apache.Hive.Service.Rpc.Thrift;
+using K4os.Compression.LZ4.Streams;
+
+namespace Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch
+{
+ /// <summary>
+ /// Reader for CloudFetch results from Databricks Spark Thrift server.
+ /// Handles downloading and processing URL-based result sets.
+ /// </summary>
+ internal sealed class SparkCloudFetchReader : IArrowArrayStream
+ {
+ // Default values used if not specified in connection properties
+ private const int DefaultMaxRetries = 3;
+ private const int DefaultRetryDelayMs = 500;
+ private const int DefaultTimeoutMinutes = 5;
+
+ private readonly int maxRetries;
+ private readonly int retryDelayMs;
+ private readonly int timeoutMinutes;
+
+ private HiveServer2Statement? statement;
+ private readonly Schema schema;
+ private List<TSparkArrowResultLink>? resultLinks;
+ private int linkIndex;
+ private ArrowStreamReader? currentReader;
+ private readonly bool isLz4Compressed;
+ private long startOffset;
+
+ // Lazy initialization of HttpClient
+ private readonly Lazy<HttpClient> httpClient;
+
+ /// <summary>
+ /// Initializes a new instance of the <see
cref="SparkCloudFetchReader"/> class.
+ /// </summary>
+ /// <param name="statement">The HiveServer2 statement.</param>
+ /// <param name="schema">The Arrow schema.</param>
+ /// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
+ public SparkCloudFetchReader(HiveServer2Statement statement, Schema
schema, bool isLz4Compressed)
+ {
+ this.statement = statement;
+ this.schema = schema;
+ this.isLz4Compressed = isLz4Compressed;
+
+ // Get configuration values from connection properties or use
defaults
+ var connectionProps = statement.Connection.Properties;
+
+ // Parse max retries
+ int parsedMaxRetries = DefaultMaxRetries;
+ if
(connectionProps.TryGetValue(SparkParameters.CloudFetchMaxRetries, out string?
maxRetriesStr) &&
+ int.TryParse(maxRetriesStr, out parsedMaxRetries) &&
+ parsedMaxRetries > 0)
+ {
+ // Value was successfully parsed
+ }
+ else
+ {
+ parsedMaxRetries = DefaultMaxRetries;
+ }
+ this.maxRetries = parsedMaxRetries;
+
+ // Parse retry delay
+ int parsedRetryDelay = DefaultRetryDelayMs;
+ if
(connectionProps.TryGetValue(SparkParameters.CloudFetchRetryDelayMs, out
string? retryDelayStr) &&
+ int.TryParse(retryDelayStr, out parsedRetryDelay) &&
+ parsedRetryDelay > 0)
+ {
+ // Value was successfully parsed
+ }
+ else
+ {
+ parsedRetryDelay = DefaultRetryDelayMs;
+ }
+ this.retryDelayMs = parsedRetryDelay;
+
+ // Parse timeout minutes
+ int parsedTimeout = DefaultTimeoutMinutes;
+ if
(connectionProps.TryGetValue(SparkParameters.CloudFetchTimeoutMinutes, out
string? timeoutStr) &&
+ int.TryParse(timeoutStr, out parsedTimeout) &&
+ parsedTimeout > 0)
+ {
+ // Value was successfully parsed
+ }
+ else
+ {
+ parsedTimeout = DefaultTimeoutMinutes;
+ }
+ this.timeoutMinutes = parsedTimeout;
+
+ // Initialize HttpClient with the configured timeout
+ this.httpClient = new Lazy<HttpClient>(() =>
+ {
+ var client = new HttpClient();
+ client.Timeout = TimeSpan.FromMinutes(this.timeoutMinutes);
+ return client;
+ });
+ }
+
+ /// <summary>
+ /// Gets the Arrow schema.
+ /// </summary>
+ public Schema Schema { get { return schema; } }
+
+ private HttpClient HttpClient
+ {
+ get { return httpClient.Value; }
+ }
+
+ /// <summary>
+ /// Reads the next record batch from the result set.
+ /// </summary>
+ /// <param name="cancellationToken">The cancellation token.</param>
+ /// <returns>The next record batch, or null if there are no more
batches.</returns>
+ public async ValueTask<RecordBatch?>
ReadNextRecordBatchAsync(CancellationToken cancellationToken = default)
+ {
+ while (true)
+ {
+ // If we have a current reader, try to read the next batch
+ if (this.currentReader != null)
+ {
+ RecordBatch? next = await
this.currentReader.ReadNextRecordBatchAsync(cancellationToken);
+ if (next != null)
+ {
+ return next;
+ }
+ else
+ {
+ this.currentReader.Dispose();
+ this.currentReader = null;
+ }
+ }
+
+ // If we have more links to process, download and process the
next one
+ if (this.resultLinks != null && this.linkIndex <
this.resultLinks.Count)
+ {
+ var link = this.resultLinks[this.linkIndex++];
+ byte[]? fileData = null;
+
+ // Retry logic for downloading files
+ for (int retry = 0; retry < this.maxRetries; retry++)
+ {
+ try
+ {
+ fileData = await DownloadFileAsync(link.FileLink,
cancellationToken);
+ break; // Success, exit retry loop
+ }
+ catch (Exception ex) when (retry < this.maxRetries - 1)
+ {
+ // Log the error and retry
+ Debug.WriteLine($"Error downloading file (attempt
{retry + 1}/{this.maxRetries}): {ex.Message}");
+ await Task.Delay(this.retryDelayMs * (retry + 1),
cancellationToken);
+ }
+ }
+
+ // Process the downloaded file data
+ MemoryStream dataStream;
+
+ // If the data is LZ4 compressed, decompress it
+ if (this.isLz4Compressed)
+ {
+ try
+ {
+ dataStream = new MemoryStream();
+ using (var inputStream = new
MemoryStream(fileData!))
+ using (var decompressor =
LZ4Stream.Decode(inputStream))
+ {
+ await decompressor.CopyToAsync(dataStream);
+ }
+ dataStream.Position = 0;
+ }
+ catch (Exception ex)
+ {
+ Debug.WriteLine($"Error decompressing data:
{ex.Message}");
+ continue; // Skip this link and try the next one
+ }
+ }
+ else
+ {
+ dataStream = new MemoryStream(fileData!);
+ }
+
+ try
+ {
+ this.currentReader = new ArrowStreamReader(dataStream);
+ continue;
+ }
+ catch (Exception ex)
+ {
+ Debug.WriteLine($"Error creating Arrow reader:
{ex.Message}");
+ dataStream.Dispose();
+ continue; // Skip this link and try the next one
+ }
+ }
+
+ this.resultLinks = null;
+ this.linkIndex = 0;
+
+ // If there's no statement, we're done
+ if (this.statement == null)
+ {
+ return null;
+ }
+
+ // Fetch more results from the server
+ TFetchResultsReq request = new
TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT,
this.statement.BatchSize);
+
+ // Set the start row offset if we have processed some links
already
+ if (this.startOffset > 0)
+ {
+ request.StartRowOffset = this.startOffset;
+ }
+
+ TFetchResultsResp response;
+ try
+ {
+ response = await
this.statement.Connection.Client!.FetchResults(request, cancellationToken);
+ }
+ catch (Exception ex)
+ {
+ Debug.WriteLine($"Error fetching results from server:
{ex.Message}");
+ this.statement = null; // Mark as done due to error
+ return null;
+ }
+
+ // Check if we have URL-based results
+ if (response.Results.__isset.resultLinks &&
+ response.Results.ResultLinks != null &&
+ response.Results.ResultLinks.Count > 0)
+ {
+ this.resultLinks = response.Results.ResultLinks;
+
+ // Update the start offset for the next fetch by
calculating it from the links
+ if (this.resultLinks.Count > 0)
+ {
+ var lastLink = this.resultLinks[this.resultLinks.Count
- 1];
+ this.startOffset = lastLink.StartRowOffset +
lastLink.RowCount;
+ }
+
+ // If the server indicates there are no more rows, we can
close the statement
+ if (!response.HasMoreRows)
+ {
+ this.statement = null;
+ }
+ }
+ else
+ {
+ // If there are no more results, we're done
+ this.statement = null;
+ return null;
+ }
+ }
+ }
+
+ /// <summary>
+ /// Downloads a file from a URL.
+ /// </summary>
+ /// <param name="url">The URL to download from.</param>
+ /// <param name="cancellationToken">The cancellation token.</param>
+ /// <returns>The downloaded file data.</returns>
+ private async Task<byte[]> DownloadFileAsync(string url,
CancellationToken cancellationToken)
+ {
+ using HttpResponseMessage response = await
HttpClient.GetAsync(url, HttpCompletionOption.ResponseHeadersRead,
cancellationToken);
+ response.EnsureSuccessStatusCode();
+
+ // Get the content length if available
+ long? contentLength = response.Content.Headers.ContentLength;
+ if (contentLength.HasValue && contentLength.Value > 0)
+ {
+ Debug.WriteLine($"Downloading file of size:
{contentLength.Value / 1024.0 / 1024.0:F2} MB");
+ }
+
+ return await response.Content.ReadAsByteArrayAsync();
+ }
+
+ /// <summary>
+ /// Disposes the reader.
+ /// </summary>
+ public void Dispose()
+ {
+ if (this.currentReader != null)
+ {
+ this.currentReader.Dispose();
+ this.currentReader = null;
+ }
+
+ // Dispose the HttpClient if it was created
+ if (httpClient.IsValueCreated)
+ {
+ httpClient.Value.Dispose();
+ }
+ }
+ }
+}
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index b9b40dfd1..2eb11e941 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -63,7 +63,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
public override AdbcStatement CreateStatement()
{
- return new SparkStatement(this);
+ SparkStatement statement = new SparkStatement(this);
+ return statement;
}
protected internal override int PositionRequiredOffset => 1;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
index 14d94acf3..a2413635b 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksConnection.cs
@@ -18,6 +18,8 @@
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
+using Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
@@ -29,7 +31,35 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new SparkDatabricksReader(statement, schema);
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null)
+ {
+ // Get result format from metadata response if available
+ TSparkRowSetType resultFormat = TSparkRowSetType.ARROW_BASED_SET;
+ bool isLz4Compressed = false;
+
+ if (metadataResp != null)
+ {
+ if (metadataResp.__isset.resultFormat)
+ {
+ resultFormat = metadataResp.ResultFormat;
+ }
+
+ if (metadataResp.__isset.lz4Compressed)
+ {
+ isLz4Compressed = metadataResp.Lz4Compressed;
+ }
+ }
+
+ // Choose the appropriate reader based on the result format
+ if (resultFormat == TSparkRowSetType.URL_BASED_SET)
+ {
+ return new SparkCloudFetchReader(statement, schema,
isLz4Compressed);
+ }
+ else
+ {
+ return new SparkDatabricksReader(statement, schema);
+ }
+ }
internal override SchemaParser SchemaParser => new
SparkDatabricksSchemaParser();
@@ -40,6 +70,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
var req = new TOpenSessionReq
{
Client_protocol =
TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7,
+ Client_protocol_i64 =
(long)TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V7,
CanUseMultipleCatalogs = true,
};
return req;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
index 059ab1690..0e0166926 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkDatabricksReader.cs
@@ -68,7 +68,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return null;
}
- TFetchResultsReq request = new
TFetchResultsReq(this.statement.OperationHandle, TFetchOrientation.FETCH_NEXT,
this.statement.BatchSize);
+ TFetchResultsReq request = new
TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT,
this.statement.BatchSize);
TFetchResultsResp response = await
this.statement.Connection.Client!.FetchResults(request, cancellationToken);
this.batches = response.Results.ArrowBatches;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
index e28ab4632..fd7f18097 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
@@ -139,7 +139,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
TlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(Properties);
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema) => new HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
protected override TTransport CreateTransport()
{
diff --git a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
index 8e75ae3f5..b5587197d 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkParameters.cs
@@ -33,6 +33,25 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
public const string Type = "adbc.spark.type";
public const string DataTypeConv = "adbc.spark.data_type_conv";
public const string ConnectTimeoutMilliseconds =
"adbc.spark.connect_timeout_ms";
+
+ // CloudFetch configuration parameters
+ /// <summary>
+ /// Maximum number of retry attempts for CloudFetch downloads.
+ /// Default value is 3 if not specified.
+ /// </summary>
+ public const string CloudFetchMaxRetries =
"adbc.spark.cloudfetch.max_retries";
+
+ /// <summary>
+ /// Delay in milliseconds between CloudFetch retry attempts.
+ /// Default value is 500ms if not specified.
+ /// </summary>
+ public const string CloudFetchRetryDelayMs =
"adbc.spark.cloudfetch.retry_delay_ms";
+
+ /// <summary>
+ /// Timeout in minutes for CloudFetch HTTP operations.
+ /// Default value is 5 minutes if not specified.
+ /// </summary>
+ public const string CloudFetchTimeoutMinutes =
"adbc.spark.cloudfetch.timeout_minutes";
}
public static class SparkAuthTypeConstants
diff --git a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
index ffe491e72..4c4e61562 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkStatement.cs
@@ -15,6 +15,7 @@
* limitations under the License.
*/
+using System;
using System.Collections.Generic;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Hive.Service.Rpc.Thrift;
@@ -23,6 +24,14 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
{
internal class SparkStatement : HiveServer2Statement
{
+ // Default maximum bytes per file for CloudFetch
+ private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB
+
+ // CloudFetch configuration
+ private bool useCloudFetch = true;
+ private bool canDecompressLz4 = true;
+ private long maxBytesPerFile = DefaultMaxBytesPerFile;
+
internal SparkStatement(SparkConnection connection)
: base(connection)
{
@@ -37,7 +46,12 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
// Set in combination with a CancellationToken.
statement.QueryTimeout = QueryTimeoutSeconds;
statement.CanReadArrowResult = true;
- statement.CanDownloadResult = true;
+
+ // Set CloudFetch capabilities
+ statement.CanDownloadResult = useCloudFetch;
+ statement.CanDecompressLZ4Result = canDecompressLz4;
+ statement.MaxBytesPerFile = maxBytesPerFile;
+
#pragma warning disable CS0618 // Type or member is obsolete
statement.ConfOverlay = SparkConnection.timestampConfig;
#pragma warning restore CS0618 // Type or member is obsolete
@@ -54,12 +68,97 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
};
}
+ public override void SetOption(string key, string value)
+ {
+ switch (key)
+ {
+ case Options.UseCloudFetch:
+ if (bool.TryParse(value, out bool useCloudFetchValue))
+ {
+ this.useCloudFetch = useCloudFetchValue;
+ }
+ else
+ {
+ throw new ArgumentException($"Invalid value for {key}:
{value}. Expected a boolean value.");
+ }
+ break;
+ case Options.CanDecompressLz4:
+ if (bool.TryParse(value, out bool canDecompressLz4Value))
+ {
+ this.canDecompressLz4 = canDecompressLz4Value;
+ }
+ else
+ {
+ throw new ArgumentException($"Invalid value for {key}:
{value}. Expected a boolean value.");
+ }
+ break;
+ case Options.MaxBytesPerFile:
+ if (long.TryParse(value, out long maxBytesPerFileValue))
+ {
+ this.maxBytesPerFile = maxBytesPerFileValue;
+ }
+ else
+ {
+ throw new ArgumentException($"Invalid value for {key}:
{value}. Expected a long value.");
+ }
+ break;
+ default:
+ base.SetOption(key, value);
+ break;
+ }
+ }
+
+ /// <summary>
+ /// Sets whether to use CloudFetch for retrieving results.
+ /// </summary>
+ /// <param name="useCloudFetch">Whether to use CloudFetch.</param>
+ internal void SetUseCloudFetch(bool useCloudFetch)
+ {
+ this.useCloudFetch = useCloudFetch;
+ }
+
+ /// <summary>
+ /// Gets whether CloudFetch is enabled.
+ /// </summary>
+ public bool UseCloudFetch => useCloudFetch;
+
+ /// <summary>
+ /// Sets whether the client can decompress LZ4 compressed results.
+ /// </summary>
+ /// <param name="canDecompressLz4">Whether the client can decompress
LZ4.</param>
+ internal void SetCanDecompressLz4(bool canDecompressLz4)
+ {
+ this.canDecompressLz4 = canDecompressLz4;
+ }
+
+ /// <summary>
+ /// Gets whether LZ4 decompression is enabled.
+ /// </summary>
+ public bool CanDecompressLz4 => canDecompressLz4;
+
+ /// <summary>
+ /// Sets the maximum bytes per file for CloudFetch.
+ /// </summary>
+ /// <param name="maxBytesPerFile">The maximum bytes per file.</param>
+ internal void SetMaxBytesPerFile(long maxBytesPerFile)
+ {
+ this.maxBytesPerFile = maxBytesPerFile;
+ }
+
+ /// <summary>
+ /// Gets the maximum bytes per file for CloudFetch.
+ /// </summary>
+ public long MaxBytesPerFile => maxBytesPerFile;
+
/// <summary>
/// Provides the constant string key values to the <see
cref="AdbcStatement.SetOption(string, string)" /> method.
/// </summary>
public sealed class Options : ApacheParameters
{
- // options specific to Spark go here
+ // CloudFetch options
+ public const string UseCloudFetch =
"adbc.spark.cloudfetch.enabled";
+ public const string CanDecompressLz4 =
"adbc.spark.cloudfetch.lz4.enabled";
+ public const string MaxBytesPerFile =
"adbc.spark.cloudfetch.max_bytes_per_file";
}
}
}
diff --git
a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj
b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj
index 63365312e..af779a699 100644
--- a/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj
+++ b/csharp/test/Drivers/Apache/Apache.Arrow.Adbc.Tests.Drivers.Apache.csproj
@@ -14,6 +14,7 @@
</PackageReference>
<PackageReference Include="Xunit.SkippableFact" Version="1.5.23" />
<PackageReference Include="System.Net.Http" Version="4.3.4" />
+ <PackageReference Include="K4os.Compression.LZ4" Version="1.3.8" />
</ItemGroup>
<ItemGroup>
diff --git a/csharp/test/Drivers/Apache/Spark/CloudFetchE2ETest.cs
b/csharp/test/Drivers/Apache/Spark/CloudFetchE2ETest.cs
new file mode 100644
index 000000000..0325c3e98
--- /dev/null
+++ b/csharp/test/Drivers/Apache/Spark/CloudFetchE2ETest.cs
@@ -0,0 +1,94 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Reflection;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Spark;
+using Apache.Arrow.Adbc.Drivers.Apache.Spark.CloudFetch;
+using Apache.Arrow.Types;
+using Xunit;
+using Xunit.Abstractions;
+using Apache.Arrow.Adbc.Client;
+using Apache.Arrow.Adbc.Tests.Drivers.Apache.Common;
+
+namespace Apache.Arrow.Adbc.Tests.Drivers.Apache.Spark
+{
+ /// <summary>
+ /// End-to-end tests for the CloudFetch feature in the Spark ADBC driver.
+ /// </summary>
+ public class CloudFetchE2ETest : TestBase<SparkTestConfiguration,
SparkTestEnvironment>
+ {
+ public CloudFetchE2ETest(ITestOutputHelper? outputHelper)
+ : base(outputHelper, new SparkTestEnvironment.Factory())
+ {
+ // Skip the test if the SPARK_TEST_CONFIG_FILE environment
variable is not set
+ Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable));
+ }
+
+ /// <summary>
+ /// Integration test for running a large query against a real
Databricks cluster.
+ /// </summary>
+ [Fact]
+ public async Task TestRealDatabricksCloudFetchSmallResultSet()
+ {
+ await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM
range(1000)", 1000);
+ }
+
+ [Fact]
+ public async Task TestRealDatabricksCloudFetchLargeResultSet()
+ {
+ await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000);
+ }
+
+ private async Task TestRealDatabricksCloudFetchLargeQuery(string
query, int rowCount)
+ {
+ // Create a statement with CloudFetch enabled
+ var statement = Connection.CreateStatement();
+ statement.SetOption(SparkStatement.Options.UseCloudFetch, "true");
+ statement.SetOption(SparkStatement.Options.CanDecompressLz4,
"true");
+ statement.SetOption(SparkStatement.Options.MaxBytesPerFile,
"10485760"); // 10MB
+
+
+ // Execute a query that generates a large result set using range
function
+ statement.SqlQuery = query;
+
+ // Execute the query and get the result
+ var result = await statement.ExecuteQueryAsync();
+
+
+ if (result.Stream == null)
+ {
+ throw new InvalidOperationException("Result stream is null");
+ }
+
+ // Read all the data and count rows
+ long totalRows = 0;
+ RecordBatch? batch;
+ while ((batch = await result.Stream.ReadNextRecordBatchAsync()) !=
null)
+ {
+ totalRows += batch.Length;
+ }
+
+ Assert.True(totalRows >= rowCount);
+
+ // Also log to the test output helper if available
+ OutputHelper?.WriteLine($"Read {totalRows} rows from range
function");
+ }
+ }
+}