This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 103865846 feat(csharp/src/Drivers/Databricks): Add option to enable
using direct results for statements (#2737)
103865846 is described below
commit 103865846d78821d90500b9c402efddf484d3b59
Author: Alex Guo <[email protected]>
AuthorDate: Mon Apr 28 14:18:04 2025 -0700
feat(csharp/src/Drivers/Databricks): Add option to enable using direct
results for statements (#2737)
- Add option to set EnableDirectResults, which sends getDirectResults in
the Thrift execute statement request
- If getDirectResults is set in the request, then directResults is set
on the response containing initial results (equivalent to the server
calling GetOperationStatus, GetResultSetMetadata, FetchResults, and
CloseOperation)
- If directResults is set on the response, don't poll for the operation
status
- We already set getDirectResults on requests for metadata commands,
just not the execute statement request
Tested E2E using `dotnet test --filter CloudFetchE2ETest`
```
[xUnit.net 00:00:00.11] Starting:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
[xUnit.net 00:01:27.27] Finished:
Apache.Arrow.Adbc.Tests.Drivers.Databricks
Apache.Arrow.Adbc.Tests.Drivers.Databricks test net8.0 succeeded (87.7s)
Test summary: total: 8, failed: 0, succeeded: 8, skipped: 0, duration: 87.7s
Build succeeded in 89.1s
```
---
.../Drivers/Apache/Hive2/HiveServer2Connection.cs | 26 +++++-----
.../Drivers/Apache/Hive2/HiveServer2Statement.cs | 37 ++++++++++++---
csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 2 +-
.../Databricks/CloudFetch/CloudFetchReader.cs | 2 +-
.../CloudFetch/CloudFetchResultFetcher.cs | 28 +++++++++++
.../Databricks/CloudFetch/IHiveServer2Statement.cs | 11 +++++
.../src/Drivers/Databricks/DatabricksConnection.cs | 40 ++++++++++++++++
.../src/Drivers/Databricks/DatabricksParameters.cs | 6 +++
csharp/src/Drivers/Databricks/DatabricksReader.cs | 11 +++++
.../src/Drivers/Databricks/DatabricksStatement.cs | 16 +++++++
.../test/Drivers/Databricks/CloudFetchE2ETest.cs | 55 +++++++++++-----------
.../Drivers/Databricks/DatabricksConnectionTest.cs | 1 +
12 files changed, 186 insertions(+), 49 deletions(-)
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index 9a17eb6c1..f4f0f7978 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -377,7 +377,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
if (depth == GetObjectsDepth.All || depth >=
GetObjectsDepth.Catalogs)
{
TGetCatalogsReq getCatalogsReq = new
TGetCatalogsReq(SessionHandle);
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(getCatalogsReq);
}
@@ -416,7 +416,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
TGetSchemasReq getSchemasReq = new
TGetSchemasReq(SessionHandle);
getSchemasReq.CatalogName = catalogPattern;
getSchemasReq.SchemaName = dbSchemaPattern;
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(getSchemasReq);
}
@@ -449,7 +449,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
getTablesReq.CatalogName = catalogPattern;
getTablesReq.SchemaName = dbSchemaPattern;
getTablesReq.TableName = tableNamePattern;
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(getTablesReq);
}
@@ -486,7 +486,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
columnsReq.CatalogName = catalogPattern;
columnsReq.SchemaName = dbSchemaPattern;
columnsReq.TableName = tableNamePattern;
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(columnsReq);
}
@@ -594,7 +594,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
SessionHandle = SessionHandle ?? throw new
InvalidOperationException("session not created"),
};
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(req);
}
@@ -786,7 +786,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default);
protected internal abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default);
- protected internal virtual bool AreResultsAvailableDirectly() => false;
+ protected internal virtual bool AreResultsAvailableDirectly => false;
protected virtual void SetDirectResults(TGetColumnsReq request) =>
throw new System.NotImplementedException();
@@ -923,7 +923,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetCatalogsReq req = new TGetCatalogsReq(SessionHandle);
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(req);
}
@@ -950,7 +950,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetSchemasReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(req);
}
@@ -987,7 +987,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetTablesReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(req);
}
@@ -1032,7 +1032,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetColumnsReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(req);
}
@@ -1076,7 +1076,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetPrimaryKeysReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(req);
}
@@ -1119,7 +1119,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetCrossReferenceReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(req);
}
@@ -1255,7 +1255,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
getColumnsReq.CatalogName = catalog;
getColumnsReq.SchemaName = dbSchema;
getColumnsReq.TableName = tableName;
- if (AreResultsAvailableDirectly())
+ if (AreResultsAvailableDirectly)
{
SetDirectResults(getColumnsReq);
}
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index 564a28e9f..bff8ca225 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -99,17 +99,28 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return await ExecuteMetadataCommandQuery(cancellationToken);
}
+ _directResults = null;
+
// this could either:
// take QueryTimeoutSeconds * 3
// OR
// take QueryTimeoutSeconds (but this could be restricting)
await ExecuteStatementAsync(cancellationToken); // --> get
QueryTimeout +
- await HiveServer2Connection.PollForResponseAsync(OperationHandle!,
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to
QueryTimeout
- TGetResultSetMetadataResp response = await
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!,
Connection.Client, cancellationToken);
- Schema schema =
Connection.SchemaParser.GetArrowSchema(response.Schema,
Connection.DataTypeConversion);
+ TGetResultSetMetadataResp metadata;
+ if (_directResults?.OperationStatus?.OperationState ==
TOperationState.FINISHED_STATE)
+ {
+ // The initial response has result data so we don't need to
poll
+ metadata = _directResults.ResultSetMetadata;
+ }
+ else
+ {
+ await
HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client,
PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout
+ metadata = await
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!,
Connection.Client, cancellationToken);
+ }
// Store metadata for use in readers
- return new QueryResult(-1, Connection.NewReader(this, schema,
response));
+ Schema schema =
Connection.SchemaParser.GetArrowSchema(metadata.Schema,
Connection.DataTypeConversion);
+ return new QueryResult(-1, Connection.NewReader(this, schema,
metadata));
}
public override async ValueTask<QueryResult> ExecuteQueryAsync()
@@ -257,6 +268,19 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
.SetNativeError(executeResponse.Status.ErrorCode);
}
OperationHandle = executeResponse.OperationHandle;
+
+ // Capture direct results if they're available
+ if (executeResponse.DirectResults != null)
+ {
+ _directResults = executeResponse.DirectResults;
+
+ if
(!string.IsNullOrEmpty(_directResults.OperationStatus?.DisplayMessage))
+ {
+ throw new
HiveServer2Exception(_directResults.OperationStatus!.DisplayMessage)
+ .SetSqlState(_directResults.OperationStatus.SqlState)
+
.SetNativeError(_directResults.OperationStatus.ErrorCode);
+ }
+ }
}
protected internal int PollTimeMilliseconds { get; private set; } =
HiveServer2Connection.PollTimeMillisecondsDefault;
@@ -279,6 +303,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected internal string? ForeignCatalogName { get; set; }
protected internal string? ForeignSchemaName { get; set; }
protected internal string? ForeignTableName { get; set; }
+ protected internal TSparkDirectResults? _directResults { get; set; }
public HiveServer2Connection Connection { get; private set; }
@@ -416,7 +441,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
TRowSet rowSet;
// For GetColumns, we need to enhance the result with
BASE_TYPE_NAME
- if (Connection.AreResultsAvailableDirectly() &&
resp.DirectResults?.ResultSet?.Results != null)
+ if (Connection.AreResultsAvailableDirectly &&
resp.DirectResults?.ResultSet?.Results != null)
{
// Get data from direct results
metadata = resp.DirectResults.ResultSetMetadata;
@@ -454,7 +479,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
private async Task<QueryResult> GetQueryResult(TSparkDirectResults?
directResults, CancellationToken cancellationToken)
{
Schema schema;
- if (Connection.AreResultsAvailableDirectly() &&
directResults?.ResultSet?.Results != null)
+ if (Connection.AreResultsAvailableDirectly &&
directResults?.ResultSet?.Results != null)
{
TGetResultSetMetadataResp resultSetMetadata =
directResults.ResultSetMetadata;
schema =
Connection.SchemaParser.GetArrowSchema(resultSetMetadata.Schema,
Connection.DataTypeConversion);
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index 925073e35..c7e25861e 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -117,7 +117,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
protected override bool IsColumnSizeValidForDecimal => false;
- protected internal override bool AreResultsAvailableDirectly() => true;
+ protected internal override bool AreResultsAvailableDirectly => true;
protected override void SetDirectResults(TGetColumnsReq request) =>
request.GetDirectResults = sparkGetDirectResults;
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
index abca66d37..1e3861833 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchReader.cs
@@ -1,4 +1,4 @@
-/*
+/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
diff --git
a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
index b0ad05a6a..3da5608ed 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/CloudFetchResultFetcher.cs
@@ -129,6 +129,14 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
{
try
{
+ // Process direct results first, if available
+ if (_statement.HasDirectResults &&
_statement.DirectResults?.ResultSet?.Results?.ResultLinks?.Count > 0)
+ {
+ // Yield execution so the download queue doesn't get
blocked before downloader is started
+ await Task.Yield();
+ ProcessDirectResultsAsync(cancellationToken);
+ }
+
// Continue fetching as needed
while (_hasMoreResults &&
!cancellationToken.IsCancellationRequested)
{
@@ -228,5 +236,25 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
_hasMoreResults = false;
}
}
+
+ private void ProcessDirectResultsAsync(CancellationToken
cancellationToken)
+ {
+ List<TSparkArrowResultLink> resultLinks =
_statement.DirectResults!.ResultSet.Results.ResultLinks;
+
+ foreach (var link in resultLinks)
+ {
+ var downloadResult = new DownloadResult(link, _memoryManager);
+ _downloadQueue.Add(downloadResult, cancellationToken);
+ }
+
+ // Update the start offset for the next fetch
+ if (resultLinks.Count > 0)
+ {
+ var lastLink = resultLinks[resultLinks.Count - 1];
+ _startOffset = lastLink.StartRowOffset + lastLink.RowCount;
+ }
+
+ _hasMoreResults = _statement.DirectResults!.ResultSet.HasMoreRows;
+ }
}
}
diff --git a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
index cfc92b98b..ee77dce9d 100644
--- a/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
+++ b/csharp/src/Drivers/Databricks/CloudFetch/IHiveServer2Statement.cs
@@ -33,5 +33,16 @@ namespace
Apache.Arrow.Adbc.Drivers.Apache.Databricks.CloudFetch
/// Gets the client.
/// </summary>
TCLIService.IAsync Client { get; }
+
+ /// <summary>
+ /// Gets the direct results.
+ /// </summary>
+ TSparkDirectResults? DirectResults { get; }
+
+ /// <summary>
+ /// Checks if direct results are available.
+ /// </summary>
+ /// <returns>True if direct results are available and contain result
data, false otherwise.</returns>
+ bool HasDirectResults { get; }
}
}
diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
index aefe2df89..cd7fc02a0 100644
--- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
@@ -33,6 +33,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
internal class DatabricksConnection : SparkHttpConnection
{
private bool _applySSPWithQueries = false;
+ private bool _enableDirectResults = true;
+
+ internal static TSparkGetDirectResults defaultGetDirectResults = new()
+ {
+ MaxRows = 2000000,
+ MaxBytes = 404857600
+ };
// CloudFetch configuration
private const long DefaultMaxBytesPerFile = 20 * 1024 * 1024; // 20MB
@@ -62,6 +69,18 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
}
}
+ if
(Properties.TryGetValue(DatabricksParameters.EnableDirectResults, out string?
enableDirectResultsStr))
+ {
+ if (bool.TryParse(enableDirectResultsStr, out bool
enableDirectResultsValue))
+ {
+ _enableDirectResults = enableDirectResultsValue;
+ }
+ else
+ {
+ throw new ArgumentException($"Parameter
'{DatabricksParameters.EnableDirectResults}' value '{enableDirectResultsStr}'
could not be parsed. Valid values are 'true' and 'false'.");
+ }
+ }
+
// Parse CloudFetch options from connection properties
if (Properties.TryGetValue(DatabricksParameters.UseCloudFetch, out
string? useCloudFetchStr))
{
@@ -110,6 +129,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
/// </summary>
internal bool ApplySSPWithQueries => _applySSPWithQueries;
+ /// <summary>
+ /// Gets whether direct results are enabled.
+ /// </summary>
+ internal bool EnableDirectResults => _enableDirectResults;
+
/// <summary>
/// Gets whether CloudFetch is enabled.
/// </summary>
@@ -145,6 +169,22 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
return baseHandler;
}
+ protected internal override bool AreResultsAvailableDirectly =>
_enableDirectResults;
+
+ protected override void SetDirectResults(TGetColumnsReq request) =>
request.GetDirectResults = defaultGetDirectResults;
+
+ protected override void SetDirectResults(TGetCatalogsReq request) =>
request.GetDirectResults = defaultGetDirectResults;
+
+ protected override void SetDirectResults(TGetSchemasReq request) =>
request.GetDirectResults = defaultGetDirectResults;
+
+ protected override void SetDirectResults(TGetTablesReq request) =>
request.GetDirectResults = defaultGetDirectResults;
+
+ protected override void SetDirectResults(TGetTableTypesReq request) =>
request.GetDirectResults = defaultGetDirectResults;
+
+ protected override void SetDirectResults(TGetPrimaryKeysReq request)
=> request.GetDirectResults = defaultGetDirectResults;
+
+ protected override void SetDirectResults(TGetCrossReferenceReq
request) => request.GetDirectResults = defaultGetDirectResults;
+
internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null)
{
// Get result format from metadata response if available
diff --git a/csharp/src/Drivers/Databricks/DatabricksParameters.cs
b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
index f45350b72..7c6e9a69f 100644
--- a/csharp/src/Drivers/Databricks/DatabricksParameters.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksParameters.cs
@@ -61,6 +61,12 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
/// </summary>
public const string CloudFetchTimeoutMinutes =
"adbc.databricks.cloudfetch.timeout_minutes";
+ /// <summary>
+ /// Whether to enable the use of direct results when executing queries.
+ /// Default value is true if not specified.
+ /// </summary>
+ public const string EnableDirectResults =
"adbc.databricks.enable_direct_results";
+
/// <summary>
/// Whether to apply service side properties (SSP) with queries. If
false, SSP will be applied
/// by setting the Thrift configuration when the session is opened.
diff --git a/csharp/src/Drivers/Databricks/DatabricksReader.cs
b/csharp/src/Drivers/Databricks/DatabricksReader.cs
index 56abfbb20..cdd131111 100644
--- a/csharp/src/Drivers/Databricks/DatabricksReader.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksReader.cs
@@ -39,6 +39,17 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
this.statement = statement;
this.schema = schema;
this.isLz4Compressed = isLz4Compressed;
+
+ // If we have direct results, initialize the batches from them
+ if (statement.HasDirectResults)
+ {
+ this.batches =
statement.DirectResults!.ResultSet.Results.ArrowBatches;
+
+ if (!statement.DirectResults.ResultSet.HasMoreRows)
+ {
+ this.statement = null;
+ }
+ }
}
public Schema Schema { get { return schema; } }
diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs
b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
index 447689e51..72cdb8ac7 100644
--- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
@@ -49,6 +49,22 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
statement.CanDownloadResult = useCloudFetch;
statement.CanDecompressLZ4Result = canDecompressLz4;
statement.MaxBytesPerFile = maxBytesPerFile;
+
+ if (Connection.AreResultsAvailableDirectly)
+ {
+ statement.GetDirectResults =
DatabricksConnection.defaultGetDirectResults;
+ }
+ }
+
+ /// <summary>
+ /// Checks if direct results are available.
+ /// </summary>
+ /// <returns>True if direct results are available and contain result
data, false otherwise.</returns>
+ public bool HasDirectResults => DirectResults?.ResultSet != null &&
DirectResults?.ResultSetMetadata != null;
+
+ public TSparkDirectResults? DirectResults
+ {
+ get { return _directResults; }
}
// Cast the Client to IAsync for CloudFetch compatibility
diff --git a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
index 0d9bbfa90..96b040274 100644
--- a/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
+++ b/csharp/test/Drivers/Databricks/CloudFetchE2ETest.cs
@@ -16,6 +16,7 @@
*/
using System;
+using System.Collections.Generic;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Databricks;
using Xunit;
@@ -35,42 +36,40 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
Skip.IfNot(Utils.CanExecuteTestConfig(TestConfigVariable));
}
- /// <summary>
- /// Integration test for running a large query against a real
Databricks cluster.
- /// </summary>
- [Fact]
- public async Task TestRealDatabricksCloudFetchSmallResultSet()
- {
- await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM
range(1000)", 1000);
- }
-
- [Fact]
- public async Task TestRealDatabricksCloudFetchLargeResultSet()
+ public static IEnumerable<object[]> TestCases()
{
- await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000);
- }
+ // Test cases format: (query, expected row count, use cloud fetch,
enable direct results)
- [Fact]
- public async Task TestRealDatabricksNoCloudFetchSmallResultSet()
- {
- await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM
range(1000)", 1000, false);
- }
+ string smallQuery = $"SELECT * FROM range(1000)";
+ yield return new object[] { smallQuery, 1000, true, true };
+ yield return new object[] { smallQuery, 1000, false, true };
+ yield return new object[] { smallQuery, 1000, true, false };
+ yield return new object[] { smallQuery, 1000, false, false };
- [Fact]
- public async Task TestRealDatabricksNoCloudFetchLargeResultSet()
- {
- await TestRealDatabricksCloudFetchLargeQuery("SELECT * FROM
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000", 1000000, false);
+ string largeQuery = $"SELECT * FROM
main.tpcds_sf10_delta.catalog_sales LIMIT 1000000";
+ yield return new object[] { largeQuery, 1000000, true, true };
+ yield return new object[] { largeQuery, 1000000, false, true };
+ yield return new object[] { largeQuery, 1000000, true, false };
+ yield return new object[] { largeQuery, 1000000, false, false };
}
- private async Task TestRealDatabricksCloudFetchLargeQuery(string
query, int rowCount, bool useCloudFetch = true)
+ /// <summary>
+ /// Integration test for running queries against a real Databricks
cluster with different CloudFetch settings.
+ /// </summary>
+ [Theory]
+ [MemberData(nameof(TestCases))]
+ private async Task TestRealDatabricksCloudFetch(string query, int
rowCount, bool useCloudFetch, bool enableDirectResults)
{
- // Create a statement with CloudFetch enabled
- var statement = Connection.CreateStatement();
- statement.SetOption(DatabricksParameters.UseCloudFetch,
useCloudFetch.ToString());
- statement.SetOption(DatabricksParameters.CanDecompressLz4, "true");
- statement.SetOption(DatabricksParameters.MaxBytesPerFile,
"10485760"); // 10MB
+ var connection = NewConnection(TestConfiguration, new
Dictionary<string, string>
+ {
+ [DatabricksParameters.UseCloudFetch] =
useCloudFetch.ToString(),
+ [DatabricksParameters.EnableDirectResults] =
enableDirectResults.ToString(),
+ [DatabricksParameters.CanDecompressLz4] = "true",
+ [DatabricksParameters.MaxBytesPerFile] = "10485760" // 10MB
+ });
// Execute a query that generates a large result set using range
function
+ var statement = connection.CreateStatement();
statement.SqlQuery = query;
// Execute the query and get the result
diff --git a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs
b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs
index 859ee7e84..5c334957f 100644
--- a/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs
+++ b/csharp/test/Drivers/Databricks/DatabricksConnectionTest.cs
@@ -300,6 +300,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks
Add(new(new() { [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef",
[DatabricksParameters.CanDecompressLz4] = "notabool"},
typeof(ArgumentException)));
Add(new(new() { [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef",
[DatabricksParameters.MaxBytesPerFile] = "notanumber" },
typeof(ArgumentException)));
Add(new(new() { [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef",
[DatabricksParameters.MaxBytesPerFile] = "-100" },
typeof(ArgumentOutOfRangeException)));
+ Add(new(new() { [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef",
[DatabricksParameters.EnableDirectResults] = "notabool" },
typeof(ArgumentException)));
Add(new(new() { /*[SparkParameters.Type] =
SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port]
= "-1" }, typeof(ArgumentOutOfRangeException)));
Add(new(new() { /*[SparkParameters.Type] =
SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port]
= IPEndPoint.MinPort.ToString(CultureInfo.InvariantCulture) },
typeof(ArgumentOutOfRangeException)));
Add(new(new() { /*[SparkParameters.Type] =
SparkServerTypeConstants.Databricks,*/ [SparkParameters.HostName] =
"valid.server.com", [SparkParameters.Token] = "abcdef", [SparkParameters.Port]
= (IPEndPoint.MaxPort + 1).ToString(CultureInfo.InvariantCulture) },
typeof(ArgumentOutOfRangeException)));