This is an automated email from the ASF dual-hosted git repository.
curth pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 9e1d1c22b chore(csharp/src/Drivers/Apache): Cleanup HiveServer2-based
code with shared Thrift request/response interfaces (#3256)
9e1d1c22b is described below
commit 9e1d1c22ba37285308333f9882e15a10a2be8a5c
Author: Bruce Irschick <[email protected]>
AuthorDate: Wed Aug 13 12:52:02 2025 -0700
chore(csharp/src/Drivers/Apache): Cleanup HiveServer2-based code with
shared Thrift request/response interfaces (#3256)
Refactor API to improve handling of request and responses to simplify
number of overloads.
Refactor API to send the IResponse to the Reader (`IArrowArrayStream`).
- The Stream/Reader is now responsible to close the operation.
- The Statement is no longer responsible for keeping a singleton
instance of the (most recent) response.
Replaces https://github.com/apache/arrow-adbc/pull/2797
---
csharp/src/Drivers/Apache/Attributes.cs | 28 ++++
.../Drivers/Apache/Hive2/HiveServer2Connection.cs | 85 ++++-------
.../Apache/Hive2/HiveServer2ExtendedConnection.cs | 38 ++---
.../src/Drivers/Apache/Hive2/HiveServer2Reader.cs | 50 +++++-
.../Drivers/Apache/Hive2/HiveServer2Statement.cs | 170 ++++++++-------------
.../Hive2}/IHiveServer2Statement.cs | 21 ++-
.../src/Drivers/Apache/Impala/ImpalaConnection.cs | 26 +---
.../Drivers/Apache/Impala/ImpalaHttpConnection.cs | 3 +-
.../Apache/Impala/ImpalaStandardConnection.cs | 3 +-
csharp/src/Drivers/Apache/Spark/SparkConnection.cs | 47 ++++--
.../Drivers/Apache/Spark/SparkHttpConnection.cs | 29 +---
.../Apache/Thrift/Service/Rpc/Thrift/IRequest.cs | 35 +++++
.../Apache/Thrift/Service/Rpc/Thrift/IResponse.cs | 37 +++++
.../src/Drivers/Databricks/DatabricksConnection.cs | 51 ++-----
.../src/Drivers/Databricks/DatabricksStatement.cs | 16 +-
.../Databricks/Reader/BaseDatabricksReader.cs | 48 +++++-
.../Reader/CloudFetch/CloudFetchDownloadManager.cs | 11 +-
.../Reader/CloudFetch/CloudFetchReader.cs | 16 +-
.../Reader/CloudFetch/CloudFetchResultFetcher.cs | 19 ++-
.../Databricks/Reader/DatabricksCompositeReader.cs | 59 +++++--
.../Reader/DatabricksOperationStatusPoller.cs | 6 +-
.../Drivers/Databricks/Reader/DatabricksReader.cs | 13 +-
.../E2E/CloudFetch/CloudFetchDownloaderTest.cs | 3 +-
.../E2E/CloudFetch/CloudFetchResultFetcherTest.cs | 38 +++--
.../Unit/DatabricksOperationStatusPollerTests.cs | 18 ++-
25 files changed, 482 insertions(+), 388 deletions(-)
diff --git a/csharp/src/Drivers/Apache/Attributes.cs
b/csharp/src/Drivers/Apache/Attributes.cs
new file mode 100644
index 000000000..510138e38
--- /dev/null
+++ b/csharp/src/Drivers/Apache/Attributes.cs
@@ -0,0 +1,28 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+#if !NET5_0_OR_GREATER
+
+namespace System.Diagnostics.CodeAnalysis
+{
+ sealed class MaybeNullWhenAttribute : Attribute
+ {
+ public MaybeNullWhenAttribute(bool returnValue) { }
+ }
+}
+
+#endif
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
index d07a92a3e..2456a7182 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Connection.cs
@@ -18,6 +18,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
+using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Reflection;
using System.Text;
@@ -374,7 +375,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
internal abstract IArrowArrayStream NewReader<T>(
T statement,
Schema schema,
- TGetResultSetMetadataResp? metadataResp = null) where T :
HiveServer2Statement;
+ IResponse response,
+ TGetResultSetMetadataResp? metadataResp = null) where T :
IHiveServer2Statement;
public override IArrowArrayStream GetObjects(GetObjectsDepth depth,
string? catalogPattern, string? dbSchemaPattern, string? tableNamePattern,
IReadOnlyList<string>? tableTypes, string? columnNamePattern)
{
@@ -577,11 +579,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
SessionHandle = SessionHandle ?? throw new
InvalidOperationException("session not created"),
};
-
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(req);
- }
+ TrySetGetDirectResults(req);
CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
try
@@ -770,33 +768,27 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected abstract int ColumnMapIndexOffset { get; }
- protected abstract Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
getCatalogsResp, CancellationToken cancellationToken = default);
- protected abstract Task<TRowSet> GetRowSetAsync(TGetSchemasResp
getSchemasResp, CancellationToken cancellationToken = default);
- protected internal abstract Task<TRowSet>
GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default);
- protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default);
- protected internal abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default);
-
- protected internal virtual bool AreResultsAvailableDirectly => false;
-
- protected virtual void SetDirectResults(TGetColumnsReq request) =>
throw new System.NotImplementedException();
-
- protected virtual void SetDirectResults(TGetCatalogsReq request) =>
throw new System.NotImplementedException();
-
- protected virtual void SetDirectResults(TGetSchemasReq request) =>
throw new System.NotImplementedException();
+ protected abstract Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(IResponse response, CancellationToken
cancellationToken = default);
- protected virtual void SetDirectResults(TGetTablesReq request) =>
throw new System.NotImplementedException();
+ protected abstract Task<TRowSet> GetRowSetAsync(IResponse response,
CancellationToken cancellationToken = default);
- protected virtual void SetDirectResults(TGetTableTypesReq request) =>
throw new System.NotImplementedException();
+ protected internal virtual bool TrySetGetDirectResults(IRequest
request) => false;
- protected virtual void SetDirectResults(TGetPrimaryKeysReq request) =>
throw new System.NotImplementedException();
+ protected internal virtual bool
TryGetDirectResults(TSparkDirectResults? directResults, [MaybeNullWhen(false)]
out QueryResult result)
+ {
+ result = null;
+ return false;
+ }
- protected virtual void SetDirectResults(TGetCrossReferenceReq request)
=> throw new System.NotImplementedException();
+ protected internal virtual bool TryGetDirectResults(
+ TSparkDirectResults? directResults,
+ [MaybeNullWhen(false)] out TGetResultSetMetadataResp metadata,
+ [MaybeNullWhen(false)] out TRowSet rowSet)
+ {
+ metadata = null;
+ rowSet = null;
+ return false;
+ }
protected internal abstract int PositionRequiredOffset { get; }
@@ -944,10 +936,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetCatalogsReq req = new TGetCatalogsReq(SessionHandle);
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(req);
- }
+ TrySetGetDirectResults(req);
TGetCatalogsResp resp = await Client.GetCatalogs(req,
cancellationToken);
HandleThriftResponse(resp.Status, activity);
@@ -969,10 +958,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetSchemasReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(req);
- }
+ TrySetGetDirectResults(req);
if (catalogName != null)
{
req.CatalogName = catalogName;
@@ -1004,10 +990,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetTablesReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(req);
- }
+ TrySetGetDirectResults(req);
if (catalogName != null)
{
req.CatalogName = catalogName;
@@ -1047,10 +1030,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetColumnsReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(req);
- }
+ TrySetGetDirectResults(req);
if (catalogName != null)
{
req.CatalogName = catalogName;
@@ -1089,10 +1069,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetPrimaryKeysReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(req);
- }
+ TrySetGetDirectResults(req);
if (catalogName != null)
{
req.CatalogName = catalogName!;
@@ -1130,10 +1107,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
TGetCrossReferenceReq req = new(SessionHandle);
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(req);
- }
+ TrySetGetDirectResults(req);
if (catalogName != null)
{
req.ParentCatalogName = catalogName!;
@@ -1264,10 +1238,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
getColumnsReq.CatalogName = catalog;
getColumnsReq.SchemaName = dbSchema;
getColumnsReq.TableName = tableName;
- if (AreResultsAvailableDirectly)
- {
- SetDirectResults(getColumnsReq);
- }
+ TrySetGetDirectResults(getColumnsReq);
CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
try
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2ExtendedConnection.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2ExtendedConnection.cs
index 20ff8a6f8..c45ace928 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2ExtendedConnection.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2ExtendedConnection.cs
@@ -61,11 +61,13 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return new HiveServer2Statement(this);
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(
- statement,
- schema,
- dataTypeConversion: statement.Connection.DataTypeConversion,
- enableBatchSizeStopCondition: false);
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, IResponse response, TGetResultSetMetadataResp? metadataResp = null) =>
+ new HiveServer2Reader(
+ statement,
+ schema,
+ response,
+ dataTypeConversion: statement.Connection.DataTypeConversion,
+ enableBatchSizeStopCondition: false);
internal override void SetPrecisionScaleAndTypeName(
short colType,
@@ -135,28 +137,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
};
}
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected internal override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected internal override Task<TRowSet>
GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(IResponse response, CancellationToken
cancellationToken = default) =>
+ GetResultSetMetadataAsync(response.OperationHandle!, Client,
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(IResponse response,
CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle!, cancellationToken:
cancellationToken);
protected internal override int PositionRequiredOffset => 0;
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
index a1e3e6df9..64d0a2f51 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Reader.cs
@@ -55,7 +55,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
private const int SecondSubsecondSepIndex = 19;
private const int SubsecondIndex = 20;
private const int MillisecondDecimalPlaces = 3;
- private readonly HiveServer2Statement _statement;
+ private readonly IHiveServer2Statement _statement;
+ private readonly IResponse _response;
+ private bool _disposed;
private bool _hasNoMoreData = false;
private readonly DataTypeConversion _dataTypeConversion;
// Flag to enable/disable stopping reading based on batch size
condition
@@ -74,13 +76,15 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
};
public HiveServer2Reader(
- HiveServer2Statement statement,
+ IHiveServer2Statement statement,
Schema schema,
+ IResponse response,
DataTypeConversion dataTypeConversion,
bool enableBatchSizeStopCondition = true) : base(statement)
{
_statement = statement;
Schema = schema;
+ _response = response;
_dataTypeConversion = dataTypeConversion;
_enableBatchSizeStopCondition = enableBatchSizeStopCondition;
}
@@ -103,7 +107,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
try
{
// Await the fetch response
- TFetchResultsResp response = await FetchNext(_statement,
cancellationToken);
+ TFetchResultsResp response = await FetchNext(_statement,
_response, cancellationToken);
HiveServer2Connection.HandleThriftResponse(response.Status, activity);
int columnCount = GetColumnCount(response.Results);
@@ -159,9 +163,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
internal static int GetRowCount(TRowSet response, int columnCount) =>
columnCount > 0 ? GetArray(response.Columns[0]).Length : 0;
- private static async Task<TFetchResultsResp>
FetchNext(HiveServer2Statement statement, CancellationToken cancellationToken =
default)
+ private static async Task<TFetchResultsResp>
FetchNext(IHiveServer2Statement statement, IResponse response,
CancellationToken cancellationToken = default)
{
- var request = new TFetchResultsReq(statement.OperationHandle!,
TFetchOrientation.FETCH_NEXT, statement.BatchSize);
+ var request = new TFetchResultsReq(response.OperationHandle!,
TFetchOrientation.FETCH_NEXT, statement.BatchSize);
return await statement.Connection.Client.FetchResults(request,
cancellationToken);
}
@@ -397,5 +401,41 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return true;
}
+
+ protected override void Dispose(bool disposing)
+ {
+ try
+ {
+ if (!_disposed)
+ {
+ if (disposing)
+ {
+ _ = CloseOperationAsync(_statement, _response)
+ .ConfigureAwait(false).GetAwaiter().GetResult();
+ }
+ }
+ }
+ finally
+ {
+ base.Dispose(disposing);
+ _disposed = true;
+ }
+ }
+
+ /// <summary>
+ /// Closes the operation contained in the response.
+ /// </summary>
+ /// <param name="statement">The associated statement used for timeout
properties.</param>
+ /// <param name="response">The response for the operation.</param>
+ /// <returns>The server response for the CloseOperation call.</returns>
+ /// <exception cref="HiveServer2Exception" />
+ internal static async Task<TCloseOperationResp>
CloseOperationAsync(IHiveServer2Statement statement, IResponse response)
+ {
+ CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(statement.QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
+ TCloseOperationReq request = new
TCloseOperationReq(response.OperationHandle!);
+ TCloseOperationResp resp = await
statement.Client.CloseOperation(request, cancellationToken);
+ HiveServer2Connection.HandleThriftResponse(resp.Status, activity:
null);
+ return resp;
+ }
}
}
diff --git a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
index 1151b2bbe..0ed673ea3 100644
--- a/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/HiveServer2Statement.cs
@@ -28,7 +28,7 @@ using Thrift.Transport;
namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
- internal class HiveServer2Statement : TracingStatement
+ internal class HiveServer2Statement : TracingStatement,
IHiveServer2Statement
{
private const string GetPrimaryKeysCommandName = "getprimarykeys";
private const string GetCrossReferenceCommandName =
"getcrossreference";
@@ -122,27 +122,25 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return await
ExecuteMetadataCommandQuery(cancellationToken);
}
- _directResults = null;
-
// this could either:
// take QueryTimeoutSeconds * 3
// OR
// take QueryTimeoutSeconds (but this could be restricting)
- await ExecuteStatementAsync(cancellationToken); // --> get
QueryTimeout +
+ IResponse response = await
ExecuteStatementAsync(cancellationToken); // --> get QueryTimeout +
TGetResultSetMetadataResp metadata;
- if (_directResults?.OperationStatus?.OperationState ==
TOperationState.FINISHED_STATE)
+ if (response.DirectResults?.OperationStatus?.OperationState ==
TOperationState.FINISHED_STATE)
{
// The initial response has result data so we don't need
to poll
- metadata = _directResults.ResultSetMetadata;
+ metadata = response.DirectResults.ResultSetMetadata;
}
else
{
- await
HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client,
PollTimeMilliseconds, cancellationToken); // + poll, up to QueryTimeout
- metadata = await
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!,
Connection.Client, cancellationToken);
+ await
HiveServer2Connection.PollForResponseAsync(response.OperationHandle!,
Connection.Client, PollTimeMilliseconds, cancellationToken); // + poll, up to
QueryTimeout
+ metadata = await
HiveServer2Connection.GetResultSetMetadataAsync(response.OperationHandle!,
Connection.Client, cancellationToken);
}
Schema schema = GetSchemaFromMetadata(metadata);
- return new QueryResult(-1, Connection.NewReader(this, schema,
metadata));
+ return new QueryResult(-1, Connection.NewReader(this, schema,
response, metadata));
});
}
@@ -298,9 +296,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
}
- protected async Task ExecuteStatementAsync(CancellationToken
cancellationToken = default)
+ protected async Task<IResponse>
ExecuteStatementAsync(CancellationToken cancellationToken = default)
{
- await this.TraceActivityAsync(async activity =>
+ return await this.TraceActivityAsync(async activity =>
{
if (Connection.SessionHandle == null)
{
@@ -310,32 +308,29 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
activity?.AddTag(SemanticConventions.Db.Client.Connection.SessionId,
Connection.SessionHandle.SessionId.Guid, "N");
TExecuteStatementReq executeRequest = new
TExecuteStatementReq(Connection.SessionHandle, SqlQuery!);
SetStatementProperties(executeRequest);
- TExecuteStatementResp executeResponse = await
Connection.Client.ExecuteStatement(executeRequest, cancellationToken);
-
HiveServer2Connection.HandleThriftResponse(executeResponse.Status, activity);
- activity?.AddTag(SemanticConventions.Db.Response.OperationId,
executeResponse.OperationHandle.OperationId.Guid, "N");
-
- OperationHandle = executeResponse.OperationHandle;
+ IResponse response = await
Connection.Client.ExecuteStatement(executeRequest, cancellationToken);
+ HiveServer2Connection.HandleThriftResponse(response.Status!,
activity);
+ activity?.AddTag(SemanticConventions.Db.Response.OperationId,
response.OperationHandle!.OperationId.Guid, "N");
// Capture direct results if they're available
- if (executeResponse.DirectResults != null)
+ if (response.DirectResults != null)
{
- _directResults = executeResponse.DirectResults;
-
- if
(!string.IsNullOrEmpty(_directResults.OperationStatus?.DisplayMessage))
+ if
(!string.IsNullOrEmpty(response.DirectResults.OperationStatus.DisplayMessage))
{
- throw new
HiveServer2Exception(_directResults.OperationStatus!.DisplayMessage)
-
.SetSqlState(_directResults.OperationStatus.SqlState)
-
.SetNativeError(_directResults.OperationStatus.ErrorCode);
+ throw new
HiveServer2Exception(response.DirectResults.OperationStatus.DisplayMessage)
+
.SetSqlState(response.DirectResults.OperationStatus.SqlState)
+
.SetNativeError(response.DirectResults.OperationStatus.ErrorCode);
}
}
+ return response;
});
}
protected internal int PollTimeMilliseconds { get; private set; } =
HiveServer2Connection.PollTimeMillisecondsDefault;
- protected internal long BatchSize { get; private set; } =
HiveServer2Connection.BatchSizeDefault;
+ public long BatchSize { get; private set; } =
HiveServer2Connection.BatchSizeDefault;
- protected internal int QueryTimeoutSeconds
+ public int QueryTimeoutSeconds
{
// Coordinate updates with the connection
get => Connection.QueryTimeoutSeconds;
@@ -352,12 +347,9 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
protected internal string? ForeignSchemaName { get; set; }
protected internal string? ForeignTableName { get; set; }
protected internal bool EscapePatternWildcards { get; set; } = false;
- protected internal TSparkDirectResults? _directResults { get; set; }
public HiveServer2Connection Connection { get; private set; }
- public TOperationHandle? OperationHandle { get; private set; }
-
// Keep the original Client property for internal use
public TCLIService.IAsync Client => Connection.Client;
@@ -381,24 +373,6 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return name.Replace("_", "\\_").Replace("%", "\\%");
}
- public override void Dispose()
- {
- this.TraceActivity(activity =>
- {
- if (OperationHandle != null &&
_directResults?.CloseOperation?.Status?.StatusCode !=
TStatusCode.SUCCESS_STATUS)
- {
- CancellationToken cancellationToken =
ApacheUtility.GetCancellationToken(QueryTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
-
activity?.AddTag(SemanticConventions.Db.Operation.OperationId,
OperationHandle.OperationId.Guid, "N");
- TCloseOperationReq request = new
TCloseOperationReq(OperationHandle);
- TCloseOperationResp resp =
Connection.Client.CloseOperation(request, cancellationToken).Result;
- HiveServer2Connection.HandleThriftResponse(resp.Status,
activity);
- OperationHandle = null;
- }
-
- base.Dispose();
- });
- }
-
protected void ValidateOptions(IReadOnlyDictionary<string, string>
properties)
{
foreach (KeyValuePair<string, string> kvp in properties)
@@ -438,7 +412,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
/// since the backend treats these as exact match queries rather than
pattern matches.
protected virtual async Task<QueryResult>
GetCrossReferenceAsForeignTableAsync(CancellationToken cancellationToken =
default)
{
- TGetCrossReferenceResp resp = await
Connection.GetCrossReferenceAsync(
+ IResponse response = await Connection.GetCrossReferenceAsync(
null,
null,
null,
@@ -446,10 +420,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
SchemaName,
TableName,
cancellationToken);
- OperationHandle = resp.OperationHandle;
- _directResults = resp.DirectResults;
- return await GetQueryResult(resp.DirectResults, cancellationToken);
+ return await GetQueryResult(response, cancellationToken);
}
/// <summary>
@@ -459,7 +431,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
/// </summary>
protected virtual async Task<QueryResult>
GetCrossReferenceAsync(CancellationToken cancellationToken = default)
{
- TGetCrossReferenceResp resp = await
Connection.GetCrossReferenceAsync(
+ IResponse response = await Connection.GetCrossReferenceAsync(
CatalogName,
SchemaName,
TableName,
@@ -467,10 +439,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
ForeignSchemaName,
ForeignTableName,
cancellationToken);
- OperationHandle = resp.OperationHandle;
- _directResults = resp.DirectResults;
- return await GetQueryResult(resp.DirectResults, cancellationToken);
+ return await GetQueryResult(response, cancellationToken);
}
/// <summary>
@@ -480,93 +450,69 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
/// </summary>
protected virtual async Task<QueryResult>
GetPrimaryKeysAsync(CancellationToken cancellationToken = default)
{
- TGetPrimaryKeysResp resp = await Connection.GetPrimaryKeysAsync(
+ IResponse response = await Connection.GetPrimaryKeysAsync(
CatalogName,
SchemaName,
TableName,
cancellationToken);
- OperationHandle = resp.OperationHandle;
- _directResults = resp.DirectResults;
- return await GetQueryResult(resp.DirectResults, cancellationToken);
+ return await GetQueryResult(response, cancellationToken);
}
protected virtual async Task<QueryResult>
GetCatalogsAsync(CancellationToken cancellationToken = default)
{
- TGetCatalogsResp resp = await
Connection.GetCatalogsAsync(cancellationToken);
- OperationHandle = resp.OperationHandle;
- _directResults = resp.DirectResults;
+ IResponse response = await
Connection.GetCatalogsAsync(cancellationToken);
- return await GetQueryResult(resp.DirectResults, cancellationToken);
+ return await GetQueryResult(response, cancellationToken);
}
protected virtual async Task<QueryResult>
GetSchemasAsync(CancellationToken cancellationToken = default)
{
- TGetSchemasResp resp = await Connection.GetSchemasAsync(
+ IResponse response = await Connection.GetSchemasAsync(
EscapePatternWildcardsInName(CatalogName),
EscapePatternWildcardsInName(SchemaName),
cancellationToken);
- OperationHandle = resp.OperationHandle;
- _directResults = resp.DirectResults;
- return await GetQueryResult(resp.DirectResults, cancellationToken);
+ return await GetQueryResult(response, cancellationToken);
}
protected virtual async Task<QueryResult>
GetTablesAsync(CancellationToken cancellationToken = default)
{
List<string>? tableTypesList =
this.TableTypes?.Split(',').ToList();
- TGetTablesResp resp = await Connection.GetTablesAsync(
+ IResponse response = await Connection.GetTablesAsync(
EscapePatternWildcardsInName(CatalogName),
EscapePatternWildcardsInName(SchemaName),
EscapePatternWildcardsInName(TableName),
tableTypesList,
cancellationToken);
- OperationHandle = resp.OperationHandle;
- _directResults = resp.DirectResults;
- return await GetQueryResult(resp.DirectResults, cancellationToken);
+ return await GetQueryResult(response, cancellationToken);
}
protected virtual async Task<QueryResult>
GetColumnsAsync(CancellationToken cancellationToken = default)
{
- TGetColumnsResp resp = await Connection.GetColumnsAsync(
+ IResponse response = await Connection.GetColumnsAsync(
EscapePatternWildcardsInName(CatalogName),
EscapePatternWildcardsInName(SchemaName),
EscapePatternWildcardsInName(TableName),
EscapePatternWildcardsInName(ColumnName),
cancellationToken);
- OperationHandle = resp.OperationHandle;
-
- // Set _directResults so that dispose logic can check if operation
was already closed
- _directResults = resp.DirectResults;
-
- // Common variables declared upfront
- TGetResultSetMetadataResp metadata;
- Schema schema;
- TRowSet rowSet;
// For GetColumns, we need to enhance the result with
BASE_TYPE_NAME
- if (Connection.AreResultsAvailableDirectly &&
resp.DirectResults?.ResultSet?.Results != null)
- {
- // Get data from direct results
- metadata = resp.DirectResults.ResultSetMetadata;
- schema = GetSchemaFromMetadata(metadata);
- rowSet = resp.DirectResults.ResultSet.Results;
- }
- else
+ if (!Connection.TryGetDirectResults(response.DirectResults, out
TGetResultSetMetadataResp? metadata, out TRowSet? rowSet))
{
// Poll and fetch results
- await
HiveServer2Connection.PollForResponseAsync(OperationHandle!, Connection.Client,
PollTimeMilliseconds, cancellationToken);
+ await
HiveServer2Connection.PollForResponseAsync(response.OperationHandle!,
Connection.Client, PollTimeMilliseconds, cancellationToken);
// Get metadata
- metadata = await
HiveServer2Connection.GetResultSetMetadataAsync(OperationHandle!,
Connection.Client, cancellationToken);
- schema = GetSchemaFromMetadata(metadata);
+ metadata = await
HiveServer2Connection.GetResultSetMetadataAsync(response.OperationHandle!,
Connection.Client, cancellationToken);
// Fetch the results
- rowSet = await Connection.FetchResultsAsync(OperationHandle!,
BatchSize, cancellationToken);
+ rowSet = await
Connection.FetchResultsAsync(response.OperationHandle!, BatchSize,
cancellationToken);
}
// Common processing for both paths
+ Schema schema =
Connection.SchemaParser.GetArrowSchema(metadata!.Schema,
Connection.DataTypeConversion);
int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);
IReadOnlyList<IArrowArray> data =
HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema,
Connection.DataTypeConversion);
@@ -581,28 +527,17 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
return GetSchemaFromMetadata(response);
}
- private async Task<QueryResult> GetQueryResult(TSparkDirectResults?
directResults, CancellationToken cancellationToken)
+ private async Task<QueryResult> GetQueryResult(IResponse response,
CancellationToken cancellationToken)
{
- // Set _directResults so that dispose logic can check if operation
was already closed
- _directResults = directResults;
-
- Schema schema;
- if (Connection.AreResultsAvailableDirectly &&
directResults?.ResultSet?.Results != null)
+ if (Connection.TryGetDirectResults(response.DirectResults, out
QueryResult? result))
{
- TGetResultSetMetadataResp resultSetMetadata =
directResults.ResultSetMetadata;
- schema = GetSchemaFromMetadata(resultSetMetadata);
- TRowSet rowSet = directResults.ResultSet.Results;
- int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
- int rowCount = HiveServer2Reader.GetRowCount(rowSet,
columnCount);
- IReadOnlyList<IArrowArray> data =
HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema,
Connection.DataTypeConversion);
-
- return new QueryResult(rowCount, new
HiveServer2Connection.HiveInfoArrowStream(schema, data));
+ return result!;
}
- await HiveServer2Connection.PollForResponseAsync(OperationHandle!,
Connection.Client, PollTimeMilliseconds, cancellationToken);
- schema = await GetResultSetSchemaAsync(OperationHandle!,
Connection.Client, cancellationToken);
+ await
HiveServer2Connection.PollForResponseAsync(response.OperationHandle!,
Connection.Client, PollTimeMilliseconds, cancellationToken);
+ Schema schema = await
GetResultSetSchemaAsync(response.OperationHandle!, Connection.Client,
cancellationToken);
- return new QueryResult(-1, Connection.NewReader(this, schema));
+ return new QueryResult(-1, Connection.NewReader(this, schema,
response));
}
protected internal QueryResult EnhanceGetColumnsResult(Schema
originalSchema, IReadOnlyList<IArrowArray> originalData,
@@ -863,7 +798,7 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
combinedData.Add(new FloatArray.Builder().Build());
break;
case ArrowTypeId.Double:
- combinedData.Add(new DoubleArray.Builder().Build());
+ combinedData.Add(new DoubleArray.Builder().Build());
break;
case ArrowTypeId.Date32:
combinedData.Add(new Date32Array.Builder().Build());
@@ -1056,5 +991,20 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
}
}
}
+
+ /// <inheritdoc/>
+ public virtual bool HasDirectResults(IResponse response) =>
response?.DirectResults?.ResultSet != null &&
response.DirectResults.ResultSetMetadata != null;
+
+ /// <inheritdoc/>
+ public bool TryGetDirectResults(IResponse response, out
TSparkDirectResults? directResults)
+ {
+ if (HasDirectResults(response))
+ {
+ directResults = response!.DirectResults;
+ return true;
+ }
+ directResults = null;
+ return false;
+ }
}
}
diff --git a/csharp/src/Drivers/Databricks/IHiveServer2Statement.cs
b/csharp/src/Drivers/Apache/Hive2/IHiveServer2Statement.cs
similarity index 76%
rename from csharp/src/Drivers/Databricks/IHiveServer2Statement.cs
rename to csharp/src/Drivers/Apache/Hive2/IHiveServer2Statement.cs
index 5c1a44425..c405f40b6 100644
--- a/csharp/src/Drivers/Databricks/IHiveServer2Statement.cs
+++ b/csharp/src/Drivers/Apache/Hive2/IHiveServer2Statement.cs
@@ -16,36 +16,33 @@
*/
using Apache.Arrow.Adbc.Tracing;
-using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Hive.Service.Rpc.Thrift;
-namespace Apache.Arrow.Adbc.Drivers.Databricks
+namespace Apache.Arrow.Adbc.Drivers.Apache.Hive2
{
/// <summary>
/// Interface for accessing HiveServer2Statement properties needed by
CloudFetchResultFetcher.
/// </summary>
internal interface IHiveServer2Statement : ITracingStatement
{
- /// <summary>
- /// Gets the operation handle.
- /// </summary>
- TOperationHandle? OperationHandle { get; }
-
/// <summary>
/// Gets the client.
/// </summary>
TCLIService.IAsync Client { get; }
/// <summary>
- /// Gets the direct results.
+ /// Checks if direct results are available.
/// </summary>
- TSparkDirectResults? DirectResults { get; }
+ /// <returns>True if direct results are available and contain result
data, false otherwise.</returns>
+ bool HasDirectResults(IResponse response);
/// <summary>
- /// Checks if direct results are available.
+ /// Tries to get the direct results <see cref="TSparkDirectResults"/>
if available.
/// </summary>
- /// <returns>True if direct results are available and contain result
data, false otherwise.</returns>
- bool HasDirectResults { get; }
+ /// <param name="response">The <see cref="IResponse"/> object to
check.</param>
+ /// <param name="directResults">The <see cref="TSparkDirectResults"/>
object if the respnose has direct results.</param>
+ /// <returns>True if direct results are available, false
otherwise.</returns>
+ bool TryGetDirectResults(IResponse response, out TSparkDirectResults?
directResults);
/// <summary>
/// Gets the query timeout in seconds.
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
index c5d8e45c8..a44132933 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaConnection.cs
@@ -63,28 +63,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
return new ImpalaStatement(this);
}
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected internal override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected internal override Task<TRowSet>
GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(IResponse response, CancellationToken
cancellationToken = default) =>
+ GetResultSetMetadataAsync(response.OperationHandle!, Client,
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(IResponse response,
CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle!, cancellationToken:
cancellationToken);
internal override void SetPrecisionScaleAndTypeName(
short colType,
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
index 53945572d..91dd3aee9 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaHttpConnection.cs
@@ -125,7 +125,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
TlsOptions = HiveServer2TlsImpl.GetHttpTlsOptions(Properties);
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, IResponse response, TGetResultSetMetadataResp? metadataResp = null) =>
+ new HiveServer2Reader(statement, schema, response,
dataTypeConversion: statement.Connection.DataTypeConversion);
protected override TTransport CreateTransport()
{
diff --git a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
index a87a7973b..0e1bef111 100644
--- a/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
+++ b/csharp/src/Drivers/Apache/Impala/ImpalaStandardConnection.cs
@@ -193,7 +193,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Impala
return request;
}
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, IResponse response, TGetResultSetMetadataResp? metadataResp = null) =>
+ new HiveServer2Reader(statement, schema, response,
dataTypeConversion: statement.Connection.DataTypeConversion);
internal override ImpalaServerType ServerType =>
ImpalaServerType.Standard;
diff --git a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
index c7e25861e..151e93f11 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkConnection.cs
@@ -17,6 +17,7 @@
using System;
using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
using System.Threading;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Hive.Service.Rpc.Thrift;
@@ -117,21 +118,47 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
protected override bool IsColumnSizeValidForDecimal => false;
- protected internal override bool AreResultsAvailableDirectly => true;
-
- protected override void SetDirectResults(TGetColumnsReq request) =>
request.GetDirectResults = sparkGetDirectResults;
-
- protected override void SetDirectResults(TGetCatalogsReq request) =>
request.GetDirectResults = sparkGetDirectResults;
+ protected internal override bool TrySetGetDirectResults(IRequest
request)
+ {
+ request.GetDirectResults = sparkGetDirectResults;
+ return true;
+ }
- protected override void SetDirectResults(TGetSchemasReq request) =>
request.GetDirectResults = sparkGetDirectResults;
+ protected internal override bool
TryGetDirectResults(TSparkDirectResults? directResults, [MaybeNullWhen(false)]
out QueryResult result)
+ {
+ if (directResults?.ResultSet?.Results == null)
+ {
+ result = null;
+ return false;
+ }
- protected override void SetDirectResults(TGetTablesReq request) =>
request.GetDirectResults = sparkGetDirectResults;
+ TGetResultSetMetadataResp resultSetMetadata =
directResults.ResultSetMetadata;
+ Schema schema =
SchemaParser.GetArrowSchema(resultSetMetadata.Schema, DataTypeConversion);
+ TRowSet rowSet = directResults.ResultSet.Results;
+ int columnCount = HiveServer2Reader.GetColumnCount(rowSet);
+ int rowCount = HiveServer2Reader.GetRowCount(rowSet, columnCount);
+ IReadOnlyList<IArrowArray> data =
HiveServer2Reader.GetArrowArrayData(rowSet, columnCount, schema,
DataTypeConversion);
- protected override void SetDirectResults(TGetTableTypesReq request) =>
request.GetDirectResults = sparkGetDirectResults;
+ result = new QueryResult(rowCount, new
HiveServer2Connection.HiveInfoArrowStream(schema, data));
+ return true;
+ }
- protected override void SetDirectResults(TGetPrimaryKeysReq request)
=> request.GetDirectResults = sparkGetDirectResults;
+ protected internal override bool TryGetDirectResults(
+ TSparkDirectResults? directResults,
+ [MaybeNullWhen(false)] out TGetResultSetMetadataResp metadata,
+ [MaybeNullWhen(false)] out TRowSet rowSet)
+ {
+ if (directResults?.ResultSet?.Results == null)
+ {
+ metadata = null;
+ rowSet = null;
+ return false;
+ }
- protected override void SetDirectResults(TGetCrossReferenceReq
request) => request.GetDirectResults = sparkGetDirectResults;
+ metadata = directResults.ResultSetMetadata;
+ rowSet = directResults.ResultSet.Results;
+ return true;
+ }
protected abstract void ValidateConnection();
protected abstract void ValidateAuthentication();
diff --git a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
index aead49343..d887e4ae3 100644
--- a/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
+++ b/csharp/src/Drivers/Apache/Spark/SparkHttpConnection.cs
@@ -150,7 +150,8 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
internal override IArrowArrayStream NewReader<T>(
T statement,
Schema schema,
- TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(statement, schema, dataTypeConversion:
statement.Connection.DataTypeConversion);
+ IResponse response,
+ TGetResultSetMetadataResp? metadataResp = null) => new
HiveServer2Reader(statement, schema, response, dataTypeConversion:
statement.Connection.DataTypeConversion);
protected virtual HttpMessageHandler CreateHttpHandler()
{
@@ -240,28 +241,10 @@ namespace Apache.Arrow.Adbc.Drivers.Apache.Spark
return req;
}
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected internal override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- GetResultSetMetadataAsync(response.OperationHandle, Client,
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response, CancellationToken cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
- protected internal override Task<TRowSet>
GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- FetchResultsAsync(response.OperationHandle, cancellationToken:
cancellationToken);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(IResponse response, CancellationToken
cancellationToken = default) =>
+ GetResultSetMetadataAsync(response.OperationHandle!, Client,
cancellationToken);
+ protected override Task<TRowSet> GetRowSetAsync(IResponse response,
CancellationToken cancellationToken = default) =>
+ FetchResultsAsync(response.OperationHandle!, cancellationToken:
cancellationToken);
internal override SchemaParser SchemaParser => new
HiveServer2SchemaParser();
diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/IRequest.cs
b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/IRequest.cs
new file mode 100644
index 000000000..84a32e0e1
--- /dev/null
+++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/IRequest.cs
@@ -0,0 +1,35 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using Thrift.Protocol;
+
+namespace Apache.Hive.Service.Rpc.Thrift
+{
+ internal interface IRequest
+ {
+ TSparkGetDirectResults? GetDirectResults { get; set; }
+ }
+
+ internal partial class TExecuteStatementReq : TBase, IRequest { }
+ internal partial class TGetCatalogsReq : TBase, IRequest { }
+ internal partial class TGetColumnsReq : TBase, IRequest { }
+ internal partial class TGetCrossReferenceReq : TBase, IRequest { }
+ internal partial class TGetPrimaryKeysReq : TBase, IRequest { }
+ internal partial class TGetSchemasReq : TBase, IRequest { }
+ internal partial class TGetTablesReq : TBase, IRequest { }
+ internal partial class TGetTableTypesReq : TBase, IRequest { }
+}
diff --git a/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/IResponse.cs
b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/IResponse.cs
new file mode 100644
index 000000000..15818ee4c
--- /dev/null
+++ b/csharp/src/Drivers/Apache/Thrift/Service/Rpc/Thrift/IResponse.cs
@@ -0,0 +1,37 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using Thrift.Protocol;
+
+namespace Apache.Hive.Service.Rpc.Thrift
+{
+ internal interface IResponse
+ {
+ TStatus? Status { get; set; }
+ TOperationHandle? OperationHandle { get; set; }
+ TSparkDirectResults? DirectResults { get; set; }
+ }
+
+ internal partial class TExecuteStatementResp : TBase, IResponse { }
+ internal partial class TGetCatalogsResp : TBase, IResponse { }
+ internal partial class TGetColumnsResp : TBase, IResponse { }
+ internal partial class TGetCrossReferenceResp : TBase, IResponse { }
+ internal partial class TGetPrimaryKeysResp : TBase, IResponse { }
+ internal partial class TGetSchemasResp : TBase, IResponse { }
+ internal partial class TGetTablesResp : TBase, IResponse { }
+ internal partial class TGetTableTypesResp : TBase, IResponse { }
+}
diff --git a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
index 207539142..feaa42dfe 100644
--- a/csharp/src/Drivers/Databricks/DatabricksConnection.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksConnection.cs
@@ -427,25 +427,9 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
return baseHandler;
}
- protected internal override bool AreResultsAvailableDirectly =>
_enableDirectResults;
-
protected override bool GetObjectsPatternsRequireLowerCase => true;
- protected override void SetDirectResults(TGetColumnsReq request) =>
request.GetDirectResults = defaultGetDirectResults;
-
- protected override void SetDirectResults(TGetCatalogsReq request) =>
request.GetDirectResults = defaultGetDirectResults;
-
- protected override void SetDirectResults(TGetSchemasReq request) =>
request.GetDirectResults = defaultGetDirectResults;
-
- protected override void SetDirectResults(TGetTablesReq request) =>
request.GetDirectResults = defaultGetDirectResults;
-
- protected override void SetDirectResults(TGetTableTypesReq request) =>
request.GetDirectResults = defaultGetDirectResults;
-
- protected override void SetDirectResults(TGetPrimaryKeysReq request)
=> request.GetDirectResults = defaultGetDirectResults;
-
- protected override void SetDirectResults(TGetCrossReferenceReq
request) => request.GetDirectResults = defaultGetDirectResults;
-
- internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, TGetResultSetMetadataResp? metadataResp = null)
+ internal override IArrowArrayStream NewReader<T>(T statement, Schema
schema, IResponse response, TGetResultSetMetadataResp? metadataResp = null)
{
bool isLz4Compressed = false;
@@ -461,7 +445,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
isLz4Compressed = metadataResp.Lz4Compressed;
}
- return new DatabricksCompositeReader(databricksStatement, schema,
isLz4Compressed, TlsOptions, _proxyConfigurator);
+ return new DatabricksCompositeReader(databricksStatement, schema,
response, isLz4Compressed, TlsOptions, _proxyConfigurator);
}
internal override SchemaParser SchemaParser => new
DatabricksSchemaParser();
@@ -506,7 +490,8 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
if (session != null)
{
var version = session.ServerProtocolVersion;
- if
(!FeatureVersionNegotiator.IsDatabricksProtocolVersion(version)) {
+ if
(!FeatureVersionNegotiator.IsDatabricksProtocolVersion(version))
+ {
throw new DatabricksException("Attempted to use databricks
driver with a non-databricks server");
}
_enablePKFK = _enablePKFK &&
FeatureVersionNegotiator.SupportsPKFK(version);
@@ -638,29 +623,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
}
}
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetSchemasResp response, CancellationToken
cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetCatalogsResp response, CancellationToken
cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetColumnsResp response, CancellationToken
cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetTablesResp response, CancellationToken
cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSetMetadata);
- protected internal override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSetMetadata);
-
- protected override Task<TRowSet> GetRowSetAsync(TGetTableTypesResp
response, CancellationToken cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetColumnsResp
response, CancellationToken cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetTablesResp
response, CancellationToken cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetCatalogsResp
response, CancellationToken cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSet.Results);
- protected override Task<TRowSet> GetRowSetAsync(TGetSchemasResp
response, CancellationToken cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSet.Results);
- protected internal override Task<TRowSet>
GetRowSetAsync(TGetPrimaryKeysResp response, CancellationToken
cancellationToken = default) =>
- Task.FromResult(response.DirectResults.ResultSet.Results);
+ protected override Task<TGetResultSetMetadataResp>
GetResultSetMetadataAsync(IResponse response, CancellationToken
cancellationToken = default) =>
+ Task.FromResult(response.DirectResults!.ResultSetMetadata);
+
+ protected override Task<TRowSet> GetRowSetAsync(IResponse response,
CancellationToken cancellationToken = default) =>
+ Task.FromResult(response.DirectResults!.ResultSet.Results);
protected override AuthenticationHeaderValue?
GetAuthenticationHeaderValue(SparkAuthType authType)
{
diff --git a/csharp/src/Drivers/Databricks/DatabricksStatement.cs
b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
index 63d6103f3..84ae63406 100644
--- a/csharp/src/Drivers/Databricks/DatabricksStatement.cs
+++ b/csharp/src/Drivers/Databricks/DatabricksStatement.cs
@@ -120,21 +120,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks
statement.MaxBytesPerFile = maxBytesPerFile;
statement.RunAsync = runAsyncInThrift;
- if (Connection.AreResultsAvailableDirectly)
- {
- statement.GetDirectResults =
DatabricksConnection.defaultGetDirectResults;
- }
- }
-
- /// <summary>
- /// Checks if direct results are available.
- /// </summary>
- /// <returns>True if direct results are available and contain result
data, false otherwise.</returns>
- public bool HasDirectResults => DirectResults?.ResultSet != null &&
DirectResults?.ResultSetMetadata != null;
-
- public TSparkDirectResults? DirectResults
- {
- get { return _directResults; }
+ Connection.TrySetGetDirectResults(statement);
}
// Cast the Client to IAsync for CloudFetch compatibility
diff --git a/csharp/src/Drivers/Databricks/Reader/BaseDatabricksReader.cs
b/csharp/src/Drivers/Databricks/Reader/BaseDatabricksReader.cs
index 657e93523..d6246878d 100644
--- a/csharp/src/Drivers/Databricks/Reader/BaseDatabricksReader.cs
+++ b/csharp/src/Drivers/Databricks/Reader/BaseDatabricksReader.cs
@@ -16,8 +16,10 @@
*/
using System;
-using Apache.Arrow.Adbc.Drivers.Databricks;
+using System.Threading.Tasks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Tracing;
+using Apache.Hive.Service.Rpc.Thrift;
namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
{
@@ -28,14 +30,17 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
{
protected IHiveServer2Statement statement;
protected readonly Schema schema;
+ protected readonly IResponse response;
protected readonly bool isLz4Compressed;
protected bool hasNoMoreRows = false;
private bool isDisposed;
+ private bool isClosed;
- protected BaseDatabricksReader(IHiveServer2Statement statement, Schema
schema, bool isLz4Compressed)
+ protected BaseDatabricksReader(IHiveServer2Statement statement, Schema
schema, IResponse response, bool isLz4Compressed)
: base(statement)
{
this.schema = schema;
+ this.response = response;
this.isLz4Compressed = isLz4Compressed;
this.statement = statement;
}
@@ -44,8 +49,43 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
protected override void Dispose(bool disposing)
{
- base.Dispose(disposing);
- isDisposed = true;
+ try
+ {
+ if (!isDisposed)
+ {
+ if (disposing)
+ {
+ _ = CloseOperationAsync().Result;
+ }
+ }
+ }
+ finally
+ {
+ base.Dispose(disposing);
+ isDisposed = true;
+ }
+ }
+
+ /// <summary>
+ /// Closes the current operation.
+ /// </summary>
+ /// <returns>Returns true if the close operation completes
successfully, false otherwise.</returns>
+ /// <exception cref="HiveServer2Exception" />
+ public async Task<bool> CloseOperationAsync()
+ {
+ try
+ {
+ if (!isClosed)
+ {
+ _ = await
HiveServer2Reader.CloseOperationAsync(this.statement, this.response);
+ return true;
+ }
+ return false;
+ }
+ finally
+ {
+ isClosed = true;
+ }
}
protected void ThrowIfDisposed()
diff --git
a/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchDownloadManager.cs
b/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchDownloadManager.cs
index 75f10a4bb..a7a98648f 100644
---
a/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchDownloadManager.cs
+++
b/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchDownloadManager.cs
@@ -17,12 +17,10 @@
using System;
using System.Collections.Concurrent;
-using System.Collections.Generic;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
-using Apache.Arrow.Adbc.Drivers.Databricks;
using Apache.Hive.Service.Rpc.Thrift;
namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
@@ -61,7 +59,13 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
/// <param name="statement">The HiveServer2 statement.</param>
/// <param name="schema">The Arrow schema.</param>
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
- public CloudFetchDownloadManager(IHiveServer2Statement statement,
Schema schema, TFetchResultsResp? initialResults, bool isLz4Compressed,
HttpClient httpClient)
+ public CloudFetchDownloadManager(
+ IHiveServer2Statement statement,
+ Schema schema,
+ IResponse response,
+ TFetchResultsResp? initialResults,
+ bool isLz4Compressed,
+ HttpClient httpClient)
{
_statement = statement ?? throw new
ArgumentNullException(nameof(statement));
_schema = schema ?? throw new
ArgumentNullException(nameof(schema));
@@ -195,6 +199,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
// Initialize the result fetcher with URL management capabilities
_resultFetcher = new CloudFetchResultFetcher(
_statement,
+ response,
initialResults,
_memoryManager,
_downloadQueue,
diff --git
a/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchReader.cs
b/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchReader.cs
index d101cfcf8..8c1730833 100644
--- a/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchReader.cs
+++ b/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchReader.cs
@@ -20,7 +20,7 @@ using System.Diagnostics;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Databricks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Tracing;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
@@ -44,8 +44,14 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
/// <param name="statement">The Databricks statement.</param>
/// <param name="schema">The Arrow schema.</param>
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
- public CloudFetchReader(IHiveServer2Statement statement, Schema
schema, TFetchResultsResp? initialResults, bool isLz4Compressed, HttpClient
httpClient)
- : base(statement, schema, isLz4Compressed)
+ public CloudFetchReader(
+ IHiveServer2Statement statement,
+ Schema schema,
+ IResponse response,
+ TFetchResultsResp? initialResults,
+ bool isLz4Compressed,
+ HttpClient httpClient)
+ : base(statement, schema, response, isLz4Compressed)
{
// Check if prefetch is enabled
var connectionProps = statement.Connection.Properties;
@@ -65,14 +71,14 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
// Initialize the download manager
if (isPrefetchEnabled)
{
- downloadManager = new CloudFetchDownloadManager(statement,
schema, initialResults, isLz4Compressed, httpClient);
+ downloadManager = new CloudFetchDownloadManager(statement,
schema, response, initialResults, isLz4Compressed, httpClient);
downloadManager.StartAsync().Wait();
}
else
{
// For now, we only support the prefetch implementation
// This flag is reserved for future use if we need to support
a non-prefetch mode
- downloadManager = new CloudFetchDownloadManager(statement,
schema, initialResults, isLz4Compressed, httpClient);
+ downloadManager = new CloudFetchDownloadManager(statement,
schema, response, initialResults, isLz4Compressed, httpClient);
downloadManager.StartAsync().Wait();
}
}
diff --git
a/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchResultFetcher.cs
b/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchResultFetcher.cs
index 4ef40a0d9..5b18af2ea 100644
--- a/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchResultFetcher.cs
+++ b/csharp/src/Drivers/Databricks/Reader/CloudFetch/CloudFetchResultFetcher.cs
@@ -23,7 +23,7 @@ using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache;
-using Apache.Arrow.Adbc.Drivers.Databricks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Hive.Service.Rpc.Thrift;
namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
@@ -34,6 +34,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
internal class CloudFetchResultFetcher : ICloudFetchResultFetcher
{
private readonly IHiveServer2Statement _statement;
+ private readonly IResponse _response;
private readonly TFetchResultsResp? _initialResults;
private readonly ICloudFetchMemoryBufferManager _memoryManager;
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
@@ -60,6 +61,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
/// <param name="clock">Clock implementation for time operations. If
null, uses system clock.</param>
public CloudFetchResultFetcher(
IHiveServer2Statement statement,
+ IResponse response,
TFetchResultsResp? initialResults,
ICloudFetchMemoryBufferManager memoryManager,
BlockingCollection<IDownloadResult> downloadQueue,
@@ -68,6 +70,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
IClock? clock = null)
{
_statement = statement ?? throw new
ArgumentNullException(nameof(statement));
+ _response = response;
_initialResults = initialResults;
_memoryManager = memoryManager ?? throw new
ArgumentNullException(nameof(memoryManager));
_downloadQueue = downloadQueue ?? throw new
ArgumentNullException(nameof(downloadQueue));
@@ -156,7 +159,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
{
// Create fetch request for the specific offset
TFetchResultsReq request = new TFetchResultsReq(
- _statement.OperationHandle!,
+ _response.OperationHandle!,
TFetchOrientation.FETCH_NEXT,
1);
@@ -218,8 +221,9 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
try
{
// Process direct results first, if available
- if ((_statement.HasDirectResults &&
_statement.DirectResults?.ResultSet?.Results?.ResultLinks?.Count > 0) ||
- _initialResults?.Results?.ResultLinks?.Count > 0)
+ if ((_statement.TryGetDirectResults(_response, out
TSparkDirectResults? directResults)
+ && directResults!.ResultSet?.Results?.ResultLinks?.Count >
0)
+ || _initialResults?.Results?.ResultLinks?.Count > 0)
{
// Yield execution so the download queue doesn't get
blocked before downloader is started
await Task.Yield();
@@ -274,7 +278,7 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
private async Task FetchNextResultBatchAsync(long? offset,
CancellationToken cancellationToken)
{
// Create fetch request
- TFetchResultsReq request = new
TFetchResultsReq(_statement.OperationHandle!, TFetchOrientation.FETCH_NEXT,
_batchSize);
+ TFetchResultsReq request = new
TFetchResultsReq(_response.OperationHandle!, TFetchOrientation.FETCH_NEXT,
_batchSize);
// Set the start row offset
long startOffset = offset ?? _startOffset;
@@ -340,9 +344,10 @@ namespace
Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch
private void ProcessDirectResultsAsync(CancellationToken
cancellationToken)
{
TFetchResultsResp fetchResults;
- if (_statement.HasDirectResults &&
_statement.DirectResults?.ResultSet?.Results?.ResultLinks?.Count > 0)
+ if (_statement.TryGetDirectResults(_response, out
TSparkDirectResults? directResults)
+ && directResults!.ResultSet?.Results?.ResultLinks?.Count > 0)
{
- fetchResults = _statement.DirectResults!.ResultSet;
+ fetchResults = directResults.ResultSet;
}
else
{
diff --git a/csharp/src/Drivers/Databricks/Reader/DatabricksCompositeReader.cs
b/csharp/src/Drivers/Databricks/Reader/DatabricksCompositeReader.cs
index 0bac863e0..a2218d163 100644
--- a/csharp/src/Drivers/Databricks/Reader/DatabricksCompositeReader.cs
+++ b/csharp/src/Drivers/Databricks/Reader/DatabricksCompositeReader.cs
@@ -20,7 +20,6 @@ using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
-using Apache.Arrow.Adbc.Drivers.Databricks;
using Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch;
using Apache.Arrow.Adbc.Tracing;
using Apache.Hive.Service.Rpc.Thrift;
@@ -42,11 +41,13 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
private BaseDatabricksReader? _activeReader;
private readonly IHiveServer2Statement _statement;
private readonly Schema _schema;
+ private readonly IResponse _response;
private readonly bool _isLz4Compressed;
private readonly TlsProperties _tlsOptions;
private readonly HiveServer2ProxyConfigurator _proxyConfigurator;
private IOperationStatusPoller? operationStatusPoller;
+ private bool _disposed;
/// <summary>
/// Initializes a new instance of the <see
cref="DatabricksCompositeReader"/> class.
@@ -55,22 +56,32 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
/// <param name="schema">The Arrow schema.</param>
/// <param name="isLz4Compressed">Whether the results are LZ4
compressed.</param>
/// <param name="httpClient">The HTTP client for CloudFetch
operations.</param>
- internal DatabricksCompositeReader(IHiveServer2Statement statement,
Schema schema, bool isLz4Compressed, TlsProperties tlsOptions,
HiveServer2ProxyConfigurator proxyConfigurator): base(statement)
+ internal DatabricksCompositeReader(
+ IHiveServer2Statement statement,
+ Schema schema,
+ IResponse response,
+ bool isLz4Compressed,
+ TlsProperties tlsOptions,
+ HiveServer2ProxyConfigurator proxyConfigurator)
+ : base(statement)
{
_statement = statement ?? throw new
ArgumentNullException(nameof(statement));
_schema = schema ?? throw new
ArgumentNullException(nameof(schema));
+ _response = response;
_isLz4Compressed = isLz4Compressed;
_tlsOptions = tlsOptions;
_proxyConfigurator = proxyConfigurator;
// use direct results if available
- if (_statement.HasDirectResults && _statement.DirectResults !=
null && _statement.DirectResults.__isset.resultSet &&
statement.DirectResults?.ResultSet != null)
+ if (_statement.TryGetDirectResults(_response, out
TSparkDirectResults? directResults)
+ && directResults!.__isset.resultSet
+ && directResults.ResultSet != null)
{
- _activeReader =
DetermineReader(_statement.DirectResults.ResultSet);
+ _activeReader = DetermineReader(directResults.ResultSet);
}
- if (_statement.DirectResults?.ResultSet?.HasMoreRows ?? true)
+ if (_response.DirectResults?.ResultSet?.HasMoreRows ?? true)
{
- operationStatusPoller = new
DatabricksOperationStatusPoller(statement);
+ operationStatusPoller = new
DatabricksOperationStatusPoller(statement, _response);
operationStatusPoller.Start();
}
}
@@ -83,11 +94,11 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
initialResults.Results.ResultLinks?.Count > 0)
{
HttpClient cloudFetchHttpClient = new
HttpClient(HiveServer2TlsImpl.NewHttpClientHandler(_tlsOptions,
_proxyConfigurator));
- return new CloudFetchReader(_statement, _schema,
initialResults, _isLz4Compressed, cloudFetchHttpClient);
+ return new CloudFetchReader(_statement, _schema, _response,
initialResults, _isLz4Compressed, cloudFetchHttpClient);
}
else
{
- return new DatabricksReader(_statement, _schema,
initialResults, _isLz4Compressed);
+ return new DatabricksReader(_statement, _schema, _response,
initialResults, _isLz4Compressed);
}
}
@@ -104,7 +115,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
// if no reader, we did not have direct results
// Make a FetchResults call to get the initial result set
// and determine the reader based on the result set
- TFetchResultsReq request = new
TFetchResultsReq(this._statement.OperationHandle!,
TFetchOrientation.FETCH_NEXT, this._statement.BatchSize);
+ TFetchResultsReq request = new
TFetchResultsReq(_response.OperationHandle!, TFetchOrientation.FETCH_NEXT,
this._statement.BatchSize);
TFetchResultsResp response = await
this._statement.Connection.Client!.FetchResults(request, cancellationToken);
_activeReader = DetermineReader(response);
}
@@ -125,13 +136,33 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
protected override void Dispose(bool disposing)
{
- if (disposing)
+ try
{
- _activeReader?.Dispose();
- StopOperationStatusPoller();
+ if (!_disposed)
+ {
+ if (disposing)
+ {
+ StopOperationStatusPoller();
+ if (_activeReader == null)
+ {
+ _ =
HiveServer2Reader.CloseOperationAsync(_statement, _response)
+
.ConfigureAwait(false).GetAwaiter().GetResult();
+ }
+ else
+ {
+ // Note: Have the contained reader close the
operation to avoid duplicate calls.
+ _ = _activeReader.CloseOperationAsync()
+
.ConfigureAwait(false).GetAwaiter().GetResult();
+ _activeReader = null;
+ }
+ }
+ }
+ }
+ finally
+ {
+ base.Dispose(disposing);
+ _disposed = true;
}
- _activeReader = null;
- base.Dispose(disposing);
}
private void StopOperationStatusPoller()
diff --git
a/csharp/src/Drivers/Databricks/Reader/DatabricksOperationStatusPoller.cs
b/csharp/src/Drivers/Databricks/Reader/DatabricksOperationStatusPoller.cs
index b5e10dc9b..27a85fb79 100644
--- a/csharp/src/Drivers/Databricks/Reader/DatabricksOperationStatusPoller.cs
+++ b/csharp/src/Drivers/Databricks/Reader/DatabricksOperationStatusPoller.cs
@@ -19,6 +19,7 @@ using System;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Hive.Service.Rpc.Thrift;
namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
@@ -32,16 +33,19 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
private readonly IHiveServer2Statement _statement;
private readonly int _heartbeatIntervalSeconds;
private readonly int _requestTimeoutSeconds;
+ private readonly IResponse _response;
// internal cancellation token source - won't affect the external token
private CancellationTokenSource? _internalCts;
private Task? _operationStatusPollingTask;
public DatabricksOperationStatusPoller(
IHiveServer2Statement statement,
+ IResponse response,
int heartbeatIntervalSeconds =
DatabricksConstants.DefaultOperationStatusPollingIntervalSeconds,
int requestTimeoutSeconds =
DatabricksConstants.DefaultOperationStatusRequestTimeoutSeconds)
{
_statement = statement ?? throw new
ArgumentNullException(nameof(statement));
+ _response = response;
_heartbeatIntervalSeconds = heartbeatIntervalSeconds;
_requestTimeoutSeconds = requestTimeoutSeconds;
}
@@ -69,7 +73,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
{
while (!cancellationToken.IsCancellationRequested)
{
- var operationHandle = _statement.OperationHandle;
+ TOperationHandle? operationHandle = _response.OperationHandle;
if (operationHandle == null) break;
CancellationToken GetOperationStatusTimeoutToken =
ApacheUtility.GetCancellationToken(_requestTimeoutSeconds,
ApacheUtility.TimeUnit.Seconds);
diff --git a/csharp/src/Drivers/Databricks/Reader/DatabricksReader.cs
b/csharp/src/Drivers/Databricks/Reader/DatabricksReader.cs
index 80bd1f98d..b1aea2bda 100644
--- a/csharp/src/Drivers/Databricks/Reader/DatabricksReader.cs
+++ b/csharp/src/Drivers/Databricks/Reader/DatabricksReader.cs
@@ -20,7 +20,7 @@ using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Adbc.Drivers.Apache;
-using Apache.Arrow.Adbc.Drivers.Databricks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Tracing;
using Apache.Arrow.Ipc;
using Apache.Hive.Service.Rpc.Thrift;
@@ -33,13 +33,14 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
int index;
IArrowReader? reader;
- public DatabricksReader(IHiveServer2Statement statement, Schema
schema, TFetchResultsResp? initialResults, bool isLz4Compressed) :
base(statement, schema, isLz4Compressed)
+ public DatabricksReader(IHiveServer2Statement statement, Schema
schema, IResponse response, TFetchResultsResp? initialResults, bool
isLz4Compressed)
+ : base(statement, schema, response, isLz4Compressed)
{
// If we have direct results, initialize the batches from them
- if (statement.HasDirectResults)
+ if (statement.TryGetDirectResults(this.response, out
TSparkDirectResults? directResults))
{
- this.batches =
statement.DirectResults!.ResultSet.Results.ArrowBatches;
- this.hasNoMoreRows =
!statement.DirectResults.ResultSet.HasMoreRows;
+ this.batches = directResults!.ResultSet.Results.ArrowBatches;
+ this.hasNoMoreRows = !directResults.ResultSet.HasMoreRows;
}
else if (initialResults != null)
{
@@ -81,7 +82,7 @@ namespace Apache.Arrow.Adbc.Drivers.Databricks.Reader
return null;
}
// TODO: use an expiring cancellationtoken
- TFetchResultsReq request = new
TFetchResultsReq(this.statement.OperationHandle!, TFetchOrientation.FETCH_NEXT,
this.statement.BatchSize);
+ TFetchResultsReq request = new
TFetchResultsReq(this.response.OperationHandle!, TFetchOrientation.FETCH_NEXT,
this.statement.BatchSize);
TFetchResultsResp response = await
this.statement.Connection.Client!.FetchResults(request, cancellationToken);
// Make sure we get the arrowBatches
diff --git
a/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchDownloaderTest.cs
b/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchDownloaderTest.cs
index 80341ce37..42c20c2c2 100644
--- a/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchDownloaderTest.cs
+++ b/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchDownloaderTest.cs
@@ -17,14 +17,13 @@
using System;
using System.Collections.Concurrent;
-using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Databricks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch;
using Apache.Hive.Service.Rpc.Thrift;
using Moq;
diff --git
a/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchResultFetcherTest.cs
b/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchResultFetcherTest.cs
index d4abc811f..765e46ae4 100644
---
a/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchResultFetcherTest.cs
+++
b/csharp/test/Drivers/Databricks/E2E/CloudFetch/CloudFetchResultFetcherTest.cs
@@ -18,11 +18,9 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
-using System.Diagnostics;
-using System.Linq;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Databricks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Drivers.Databricks.Reader.CloudFetch;
using Apache.Hive.Service.Rpc.Thrift;
using Moq;
@@ -36,8 +34,8 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
public class CloudFetchResultFetcherTest
{
private readonly Mock<IHiveServer2Statement> _mockStatement;
+ private readonly Mock<IResponse> _mockResponse;
private readonly Mock<TCLIService.IAsync> _mockClient;
- private readonly TOperationHandle _operationHandle;
private readonly MockClock _mockClock;
private readonly CloudFetchResultFetcherWithMockClock _resultFetcher;
private readonly BlockingCollection<IDownloadResult> _downloadQueue;
@@ -47,15 +45,9 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
{
_mockClient = new Mock<TCLIService.IAsync>();
_mockStatement = new Mock<IHiveServer2Statement>();
- _operationHandle = new TOperationHandle
- {
- OperationId = new THandleIdentifier { Guid = new byte[] { 1,
2, 3, 4 } },
- OperationType = TOperationType.EXECUTE_STATEMENT,
- HasResultSet = true
- };
+ _mockResponse = CreateResponse();
_mockStatement.Setup(s => s.Client).Returns(_mockClient.Object);
- _mockStatement.Setup(s =>
s.OperationHandle).Returns(_operationHandle);
_mockClock = new MockClock();
_downloadQueue = new BlockingCollection<IDownloadResult>(new
ConcurrentQueue<IDownloadResult>(), 10);
@@ -63,6 +55,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
_resultFetcher = new CloudFetchResultFetcherWithMockClock(
_mockStatement.Object,
+ _mockResponse.Object,
_mockMemoryManager.Object,
_downloadQueue,
100, // batchSize
@@ -543,6 +536,7 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
{
return new CloudFetchResultFetcherWithMockClock(
_mockStatement.Object,
+ _mockResponse.Object,
initialResults,
_mockMemoryManager.Object,
_downloadQueue,
@@ -601,6 +595,22 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
};
}
+ private Mock<IResponse> CreateResponse()
+ {
+ var mockResponse = new Mock<IResponse>();
+ mockResponse.Setup(r => r.OperationHandle).Returns(new
TOperationHandle
+ {
+ OperationId = new THandleIdentifier
+ {
+ Guid = new byte[16],
+ Secret = new byte[16]
+ },
+ OperationType = TOperationType.EXECUTE_STATEMENT,
+ HasResultSet = true
+ });
+ return mockResponse;
+ }
+
#endregion
}
@@ -636,24 +646,26 @@ namespace
Apache.Arrow.Adbc.Tests.Drivers.Databricks.CloudFetch
{
public CloudFetchResultFetcherWithMockClock(
IHiveServer2Statement statement,
+ IResponse response,
ICloudFetchMemoryBufferManager memoryManager,
BlockingCollection<IDownloadResult> downloadQueue,
long batchSize,
IClock clock,
int expirationBufferSeconds = 60)
- : base(statement, null, memoryManager, downloadQueue, batchSize,
expirationBufferSeconds, clock)
+ : base(statement, response, null, memoryManager, downloadQueue,
batchSize, expirationBufferSeconds, clock)
{
}
public CloudFetchResultFetcherWithMockClock(
IHiveServer2Statement statement,
+ IResponse response,
TFetchResultsResp? initialResults,
ICloudFetchMemoryBufferManager memoryManager,
BlockingCollection<IDownloadResult> downloadQueue,
long batchSize,
IClock clock,
int expirationBufferSeconds = 60)
- : base(statement, initialResults, memoryManager, downloadQueue,
batchSize, expirationBufferSeconds, clock)
+ : base(statement, response, initialResults, memoryManager,
downloadQueue, batchSize, expirationBufferSeconds, clock)
{
}
}
diff --git
a/csharp/test/Drivers/Databricks/Unit/DatabricksOperationStatusPollerTests.cs
b/csharp/test/Drivers/Databricks/Unit/DatabricksOperationStatusPollerTests.cs
index 9c05dba64..02eb2ad35 100644
---
a/csharp/test/Drivers/Databricks/Unit/DatabricksOperationStatusPollerTests.cs
+++
b/csharp/test/Drivers/Databricks/Unit/DatabricksOperationStatusPollerTests.cs
@@ -18,7 +18,7 @@
using System;
using System.Threading;
using System.Threading.Tasks;
-using Apache.Arrow.Adbc.Drivers.Databricks;
+using Apache.Arrow.Adbc.Drivers.Apache.Hive2;
using Apache.Arrow.Adbc.Drivers.Databricks.Reader;
using Apache.Hive.Service.Rpc.Thrift;
using Moq;
@@ -33,6 +33,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
private readonly Mock<IHiveServer2Statement> _mockStatement;
private readonly Mock<TCLIService.IAsync> _mockClient;
private readonly TOperationHandle _operationHandle;
+ private readonly Mock<IResponse> _mockResponse;
private readonly int _heartbeatIntervalSeconds = 1;
@@ -41,6 +42,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
_outputHelper = outputHelper;
_mockClient = new Mock<TCLIService.IAsync>();
_mockStatement = new Mock<IHiveServer2Statement>();
+ _mockResponse = new Mock<IResponse>();
_operationHandle = new TOperationHandle
{
OperationId = new THandleIdentifier { Guid = new byte[] { 1,
2, 3, 4 } },
@@ -48,14 +50,14 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
};
_mockStatement.Setup(s => s.Client).Returns(_mockClient.Object);
- _mockStatement.Setup(s =>
s.OperationHandle).Returns(_operationHandle);
+ _mockResponse.Setup(r =>
r.OperationHandle!).Returns(_operationHandle);
}
[Fact]
public async Task StartPollsOperationStatusAtInterval()
{
// Arrange
- using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object,
_heartbeatIntervalSeconds);
+ using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object, _mockResponse.Object,
_heartbeatIntervalSeconds);
var pollCount = 0;
_mockClient.Setup(c =>
c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new TGetOperationStatusResp())
@@ -74,7 +76,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
public async Task DisposeStopsPolling()
{
// Arrange
- using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object,
_heartbeatIntervalSeconds);
+ using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object, _mockResponse.Object,
_heartbeatIntervalSeconds);
var pollCount = 0;
_mockClient.Setup(c =>
c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new TGetOperationStatusResp())
@@ -96,7 +98,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
public async Task StopStopsPolling()
{
// Arrange
- using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object,
_heartbeatIntervalSeconds);
+ using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object, _mockResponse.Object,
_heartbeatIntervalSeconds);
var pollCount = 0;
_mockClient.Setup(c =>
c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new TGetOperationStatusResp())
@@ -129,7 +131,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
foreach (var terminalState in terminalStates)
{
// Arrange
- using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object,
_heartbeatIntervalSeconds);
+ using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object, _mockResponse.Object,
_heartbeatIntervalSeconds);
var pollCount = 0;
_mockClient.Setup(c =>
c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new TGetOperationStatusResp { OperationState
= terminalState })
@@ -148,7 +150,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
public async Task ContinuesPollingOnFinishedState()
{
// Arrange
- using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object,
_heartbeatIntervalSeconds);
+ using var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object, _mockResponse.Object,
_heartbeatIntervalSeconds);
var pollCount = 0;
_mockClient.Setup(c =>
c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(),
It.IsAny<CancellationToken>()))
.ReturnsAsync(new TGetOperationStatusResp { OperationState =
TOperationState.FINISHED_STATE })
@@ -167,7 +169,7 @@ namespace Apache.Arrow.Adbc.Tests.Drivers.Databricks.Unit
public async Task StopsPollingOnException()
{
// Arrange
- var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object,
_heartbeatIntervalSeconds);
+ var poller = new
DatabricksOperationStatusPoller(_mockStatement.Object, _mockResponse.Object,
_heartbeatIntervalSeconds);
var pollCount = 0;
_mockClient.Setup(c =>
c.GetOperationStatus(It.IsAny<TGetOperationStatusReq>(),
It.IsAny<CancellationToken>()))
.ThrowsAsync(new Exception("Test exception"))