This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git


The following commit(s) were added to refs/heads/main by this push:
     new a8a345497 feat(go/adbc): add context.Context support for uniform 
OpenTelemetry instrumentation (#4009)
a8a345497 is described below

commit a8a345497fbefe0f81e07c96911b2c407470b6e1
Author: Matt Topol <[email protected]>
AuthorDate: Wed Feb 25 22:00:35 2026 -0500

    feat(go/adbc): add context.Context support for uniform OpenTelemetry 
instrumentation (#4009)
    
    ## Summary
    
    Implements context.Context support for all Go ADBC methods to enable
    uniform OpenTelemetry instrumentation, cancellation, and deadline
    propagation (addresses #2772).
    
    - Add `DatabaseContext`, `ConnectionContext`, and `StatementContext`
    interfaces that mirror existing interfaces but require `context.Context`
    for all methods
    - Add extension Context interfaces: `PostInitOptionsContext`,
    `GetSetOptionsContext`, `ConnectionGetStatisticsContext`,
    `StatementExecuteSchemaContext`
    - Implement adapter functions (`AsDatabaseContext`,
    `AsConnectionContext`, `AsStatementContext`) to wrap non-context
    implementations for backward compatibility
    - Add `IngestStreamContext` helper for context-aware bulk ingestion
    - Comprehensive test suite (9 tests, all passing)
    
    ## Changes
    
    **New Interfaces** (`go/adbc/adbc.go`):
    - `DatabaseContext` - Database with context support for all operations
    - `ConnectionContext` - Connection with context support for all
    operations
    - `StatementContext` - Statement with context support for all operations
    - Extension interfaces with context support
    
    **Adapter Implementation** (`go/adbc/context_adapters.go`):
    - Three adapter types that wrap non-context implementations
    - Proper nil handling and double-wrap prevention
    - Context pass-through for methods that already accept context
    - Clear documentation about context limitations for legacy methods
    
    **Tests** (`go/adbc/context_adapters_test.go`):
    - Comprehensive test coverage for all adapters
    - Tests for error handling, nil handling, and context pass-through
    - All tests passing with race detector
    
    **Documentation & Examples**:
    - Updated package documentation explaining Context interfaces
    - Added `IngestStreamContext` helper function
    - Example code demonstrating adapter usage
    
    ## Migration Path
    
    **For Applications:**
    ```go
    // Wrap existing Database to add context support
    db := driver.NewDatabase(opts)
    dbCtx := adbc.AsDatabaseContext(db)
    
    ctx := context.Background()
    conn, err := dbCtx.Open(ctx)
    ```
    
    **For Driver Implementers:**
    - Existing drivers work unchanged via adapters
    - New drivers should implement Context interfaces directly
    - Gradual migration supported - no breaking changes
    
    ## Notes
    
    This is the first step toward uniform OpenTelemetry support. Future work
    includes:
    - Updating individual drivers to implement Context interfaces natively
    - Adding context-aware tests to validation suite
    - Enhanced tracing instrumentation using context propagation
    
    Closes #2772.
---
 .gitignore                       |   3 +
 go/adbc/adbc.go                  | 115 ++++++++++++++
 go/adbc/context_adapters.go      | 193 +++++++++++++++++++++++
 go/adbc/context_adapters_test.go | 320 +++++++++++++++++++++++++++++++++++++++
 go/adbc/example_context_test.go  |  74 +++++++++
 go/adbc/ext.go                   |  82 ++++++++--
 6 files changed, 772 insertions(+), 15 deletions(-)

diff --git a/.gitignore b/.gitignore
index 68ffd9040..dbc278a0f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -83,6 +83,9 @@ local/
 
 site/
 
+# Design plans
+docs/plans/
+
 # Python
 dist/
 .hypothesis/
diff --git a/go/adbc/adbc.go b/go/adbc/adbc.go
index e64decf26..8f0cac1b8 100644
--- a/go/adbc/adbc.go
+++ b/go/adbc/adbc.go
@@ -33,6 +33,18 @@
 // safely from multiple goroutines, but not necessarily concurrent
 // access. Specific implementations may allow concurrent access.
 //
+// # Context Support
+//
+// Context-aware interfaces are available:
+// DatabaseWithContext, ConnectionWithContext, and StatementWithContext. These 
interfaces
+// require context.Context for all methods to enable uniform OpenTelemetry
+// instrumentation, cancellation, and deadline propagation.
+//
+// Applications can use adapter functions (AsDatabaseContext, 
AsConnectionContext,
+// AsStatementContext) to wrap non-context implementations, allowing gradual
+// migration without breaking changes. New drivers should implement the Context
+// interfaces directly.
+//
 // EXPERIMENTAL. Interface subject to change.
 package adbc
 
@@ -341,6 +353,19 @@ type Driver interface {
        NewDatabase(opts map[string]string) (Database, error)
 }
 
+// DriverWithContext is an extension interface to allow the creation of a 
database
+// by providing an existing [context.Context] to initialize OpenTelemetry 
tracing.
+// It is similar to [database/sql.Driver] taking a map of keys and values as 
options
+// to initialize a [Connection] to the database. Any common connection
+// state can live in the Driver itself, for example an in-memory database
+// can place ownership of the actual database in this driver.
+//
+// Any connection specific options should be set using SetOptions before
+// calling Open.
+type DriverWithContext interface {
+       NewDatabaseWithContext(ctx context.Context, opts map[string]string) 
(Database, error)
+}
+
 type Database interface {
        SetOptions(map[string]string) error
        Open(ctx context.Context) (Connection, error)
@@ -349,6 +374,22 @@ type Database interface {
        Close() error
 }
 
+// DatabaseWithContext is a Database that supports context.Context for all 
operations.
+//
+// This interface mirrors Database but requires context.Context for all methods
+// that may perform I/O or long-running operations. This enables uniform
+// OpenTelemetry instrumentation, cancellation, and deadline propagation.
+type DatabaseWithContext interface {
+       // SetOptions sets options for the database.
+       SetOptions(ctx context.Context, opts map[string]string) error
+
+       // Open opens a connection to the database.
+       Open(ctx context.Context) (ConnectionWithContext, error)
+
+       // Close closes the database and releases associated resources.
+       Close(ctx context.Context) error
+}
+
 type InfoCode uint32
 
 const (
@@ -581,6 +622,30 @@ type Connection interface {
        ReadPartition(ctx context.Context, serializedPartition []byte) 
(array.RecordReader, error)
 }
 
+// ConnectionWithContext is a Connection that supports context.Context for all 
operations.
+//
+// This interface mirrors Connection but requires context.Context for all 
methods
+// that may perform I/O or long-running operations. Methods that already 
accepted
+// context in Connection maintain their signatures here.
+type ConnectionWithContext interface {
+       // Metadata methods (these already accept context in Connection)
+       GetInfo(ctx context.Context, infoCodes []InfoCode) (array.RecordReader, 
error)
+       GetObjects(ctx context.Context, depth ObjectDepth, catalog, dbSchema, 
tableName, columnName *string, tableType []string) (array.RecordReader, error)
+       GetTableSchema(ctx context.Context, catalog, dbSchema *string, 
tableName string) (*arrow.Schema, error)
+       GetTableTypes(ctx context.Context) (array.RecordReader, error)
+
+       // Transaction methods (these already accept context in Connection)
+       Commit(ctx context.Context) error
+       Rollback(ctx context.Context) error
+
+       // Methods that now require context
+       NewStatement(ctx context.Context) (StatementWithContext, error)
+       Close(ctx context.Context) error
+
+       // Partition method (already accepts context in Connection)
+       ReadPartition(ctx context.Context, serializedPartition []byte) 
(array.RecordReader, error)
+}
+
 // PostInitOptions is an optional interface which can be implemented by
 // drivers which allow modifying and setting options after initializing
 // a connection or statement.
@@ -721,6 +786,30 @@ type Statement interface {
        ExecutePartitions(context.Context) (*arrow.Schema, Partitions, int64, 
error)
 }
 
+// StatementWithContext is a Statement that supports context.Context for all 
operations.
+//
+// This interface mirrors Statement but requires context.Context for all 
methods
+// that may perform I/O or long-running operations. Methods that already 
accepted
+// context in Statement maintain their signatures here.
+type StatementWithContext interface {
+       // Methods that now require context
+       Close(ctx context.Context) error
+       SetOption(ctx context.Context, key, val string) error
+       SetSqlQuery(ctx context.Context, query string) error
+       SetSubstraitPlan(ctx context.Context, plan []byte) error
+       GetParameterSchema(ctx context.Context) (*arrow.Schema, error)
+
+       // Execute methods (these already accept context in Statement)
+       ExecuteQuery(ctx context.Context) (array.RecordReader, int64, error)
+       ExecuteUpdate(ctx context.Context) (int64, error)
+       Prepare(ctx context.Context) error
+       ExecutePartitions(ctx context.Context) (*arrow.Schema, Partitions, 
int64, error)
+
+       // Bind methods (these already accept context in Statement)
+       Bind(ctx context.Context, values arrow.RecordBatch) error
+       BindStream(ctx context.Context, stream array.RecordReader) error
+}
+
 // ConnectionGetStatistics is a Connection that supports getting
 // statistics on data in the database.
 //
@@ -818,3 +907,29 @@ type GetSetOptions interface {
        GetOptionInt(key string) (int64, error)
        GetOptionDouble(key string) (float64, error)
 }
+
+// GetSetOptionsWithContext is a GetSetOptions that supports context.Context 
for all operations.
+//
+// GetOption functions should return an error with StatusNotFound for 
unsupported options.
+// SetOption functions should return an error with StatusNotImplemented for 
unsupported options.
+type GetSetOptionsWithContext interface {
+       SetOption(ctx context.Context, key, value string) error
+       SetOptionBytes(ctx context.Context, key string, value []byte) error
+       SetOptionInt(ctx context.Context, key string, value int64) error
+       SetOptionDouble(ctx context.Context, key string, value float64) error
+       GetOption(ctx context.Context, key string) (string, error)
+       GetOptionBytes(ctx context.Context, key string) ([]byte, error)
+       GetOptionInt(ctx context.Context, key string) (int64, error)
+       GetOptionDouble(ctx context.Context, key string) (float64, error)
+}
+
+// ConnectionGetStatisticsWithContext is a ConnectionGetStatistics that 
supports context.Context.
+type ConnectionGetStatisticsWithContext interface {
+       GetStatistics(ctx context.Context, catalog, dbSchema, tableName 
*string, approximate bool) (array.RecordReader, error)
+       GetStatisticNames(ctx context.Context) (array.RecordReader, error)
+}
+
+// StatementExecuteSchemaWithContext is a StatementExecuteSchema that supports 
context.Context.
+type StatementExecuteSchemaWithContext interface {
+       ExecuteSchema(ctx context.Context) (*arrow.Schema, error)
+}
diff --git a/go/adbc/context_adapters.go b/go/adbc/context_adapters.go
new file mode 100644
index 000000000..b0b55ddec
--- /dev/null
+++ b/go/adbc/context_adapters.go
@@ -0,0 +1,193 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package adbc
+
+import (
+       "context"
+
+       "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/array"
+)
+
+// databaseContextAdapter wraps a Database to implement DatabaseContext.
+type databaseContextAdapter struct {
+       db Database
+}
+
+// AsDatabaseContext wraps a Database to implement DatabaseContext.
+// This adapter allows using a non-context Database implementation with
+// context-aware code. The context parameter is used for methods that
+// already accept context (like Open), but is effectively ignored for
+// methods that don't (like Close, SetOptions) since the underlying
+// implementation cannot respond to cancellation or deadlines.
+func AsDatabaseContext(db Database) DatabaseWithContext {
+       if db == nil {
+               return nil
+       }
+       return &databaseContextAdapter{db: db}
+}
+
+func (d *databaseContextAdapter) SetOptions(ctx context.Context, opts 
map[string]string) error {
+       // Context cannot be propagated to SetOptions since it doesn't accept 
context
+       return d.db.SetOptions(opts)
+}
+
+func (d *databaseContextAdapter) Open(ctx context.Context) 
(ConnectionWithContext, error) {
+       // Pass context through since Open already accepts it
+       conn, err := d.db.Open(ctx)
+       if err != nil {
+               return nil, err
+       }
+       // Wrap the returned Connection as ConnectionContext
+       return AsConnectionContext(conn), nil
+}
+
+func (d *databaseContextAdapter) Close(ctx context.Context) error {
+       // Context cannot be propagated to Close since it doesn't accept context
+       return d.db.Close()
+}
+
+// connectionContextAdapter wraps a Connection to implement ConnectionContext.
+type connectionContextAdapter struct {
+       conn Connection
+}
+
+// AsConnectionContext wraps a Connection to implement ConnectionContext.
+// This adapter allows using a non-context Connection implementation with
+// context-aware code. The context parameter is passed through for methods
+// that already accept context, but is ignored for methods that don't.
+func AsConnectionContext(conn Connection) ConnectionWithContext {
+       if conn == nil {
+               return nil
+       }
+       // Note: We cannot check if conn already implements ConnectionContext
+       // because Connection and ConnectionContext have conflicting Close 
method signatures.
+       // Connection.Close() vs ConnectionContext.Close(ctx)
+       return &connectionContextAdapter{conn: conn}
+}
+
+func (c *connectionContextAdapter) GetInfo(ctx context.Context, infoCodes 
[]InfoCode) (array.RecordReader, error) {
+       return c.conn.GetInfo(ctx, infoCodes)
+}
+
+func (c *connectionContextAdapter) GetObjects(ctx context.Context, depth 
ObjectDepth, catalog, dbSchema, tableName, columnName *string, tableType 
[]string) (array.RecordReader, error) {
+       return c.conn.GetObjects(ctx, depth, catalog, dbSchema, tableName, 
columnName, tableType)
+}
+
+func (c *connectionContextAdapter) GetTableSchema(ctx context.Context, 
catalog, dbSchema *string, tableName string) (*arrow.Schema, error) {
+       return c.conn.GetTableSchema(ctx, catalog, dbSchema, tableName)
+}
+
+func (c *connectionContextAdapter) GetTableTypes(ctx context.Context) 
(array.RecordReader, error) {
+       return c.conn.GetTableTypes(ctx)
+}
+
+func (c *connectionContextAdapter) Commit(ctx context.Context) error {
+       return c.conn.Commit(ctx)
+}
+
+func (c *connectionContextAdapter) Rollback(ctx context.Context) error {
+       return c.conn.Rollback(ctx)
+}
+
+func (c *connectionContextAdapter) NewStatement(ctx context.Context) 
(StatementWithContext, error) {
+       // Context cannot be propagated to NewStatement since it doesn't accept 
context
+       stmt, err := c.conn.NewStatement()
+       if err != nil {
+               return nil, err
+       }
+       // Wrap the returned Statement as StatementContext
+       return AsStatementContext(stmt), nil
+}
+
+func (c *connectionContextAdapter) Close(ctx context.Context) error {
+       // Context cannot be propagated to Close since it doesn't accept context
+       return c.conn.Close()
+}
+
+func (c *connectionContextAdapter) ReadPartition(ctx context.Context, 
serializedPartition []byte) (array.RecordReader, error) {
+       return c.conn.ReadPartition(ctx, serializedPartition)
+}
+
+// statementContextAdapter wraps a Statement to implement StatementContext.
+type statementContextAdapter struct {
+       stmt Statement
+}
+
+// AsStatementContext wraps a Statement to implement StatementContext.
+// This adapter allows using a non-context Statement implementation with
+// context-aware code. The context parameter is passed through for methods
+// that already accept context, but is ignored for methods that don't.
+func AsStatementContext(stmt Statement) StatementWithContext {
+       if stmt == nil {
+               return nil
+       }
+       // Note: We cannot check if stmt already implements StatementContext
+       // because Statement and StatementContext have conflicting Close method 
signatures.
+       // Statement.Close() vs StatementContext.Close(ctx)
+       return &statementContextAdapter{stmt: stmt}
+}
+
+func (s *statementContextAdapter) Close(ctx context.Context) error {
+       // Context cannot be propagated to Close since it doesn't accept context
+       return s.stmt.Close()
+}
+
+func (s *statementContextAdapter) SetOption(ctx context.Context, key, val 
string) error {
+       // Context cannot be propagated to SetOption since it doesn't accept 
context
+       return s.stmt.SetOption(key, val)
+}
+
+func (s *statementContextAdapter) SetSqlQuery(ctx context.Context, query 
string) error {
+       // Context cannot be propagated to SetSqlQuery since it doesn't accept 
context
+       return s.stmt.SetSqlQuery(query)
+}
+
+func (s *statementContextAdapter) SetSubstraitPlan(ctx context.Context, plan 
[]byte) error {
+       // Context cannot be propagated to SetSubstraitPlan since it doesn't 
accept context
+       return s.stmt.SetSubstraitPlan(plan)
+}
+
+func (s *statementContextAdapter) GetParameterSchema(ctx context.Context) 
(*arrow.Schema, error) {
+       // Context cannot be propagated to GetParameterSchema since it doesn't 
accept context
+       return s.stmt.GetParameterSchema()
+}
+
+func (s *statementContextAdapter) ExecuteQuery(ctx context.Context) 
(array.RecordReader, int64, error) {
+       return s.stmt.ExecuteQuery(ctx)
+}
+
+func (s *statementContextAdapter) ExecuteUpdate(ctx context.Context) (int64, 
error) {
+       return s.stmt.ExecuteUpdate(ctx)
+}
+
+func (s *statementContextAdapter) Prepare(ctx context.Context) error {
+       return s.stmt.Prepare(ctx)
+}
+
+func (s *statementContextAdapter) ExecutePartitions(ctx context.Context) 
(*arrow.Schema, Partitions, int64, error) {
+       return s.stmt.ExecutePartitions(ctx)
+}
+
+func (s *statementContextAdapter) Bind(ctx context.Context, values 
arrow.RecordBatch) error {
+       return s.stmt.Bind(ctx, values)
+}
+
+func (s *statementContextAdapter) BindStream(ctx context.Context, stream 
array.RecordReader) error {
+       return s.stmt.BindStream(ctx, stream)
+}
diff --git a/go/adbc/context_adapters_test.go b/go/adbc/context_adapters_test.go
new file mode 100644
index 000000000..61362c5e8
--- /dev/null
+++ b/go/adbc/context_adapters_test.go
@@ -0,0 +1,320 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package adbc_test
+
+import (
+       "context"
+       "errors"
+       "testing"
+
+       "github.com/apache/arrow-adbc/go/adbc"
+       "github.com/apache/arrow-go/v18/arrow"
+       "github.com/apache/arrow-go/v18/arrow/array"
+       "github.com/stretchr/testify/assert"
+       "github.com/stretchr/testify/mock"
+)
+
+// mockConnection is a mock implementation of adbc.Connection for testing
+type mockConnection struct {
+       mock.Mock
+}
+
+func (m *mockConnection) GetInfo(ctx context.Context, infoCodes 
[]adbc.InfoCode) (array.RecordReader, error) {
+       args := m.Called(ctx, infoCodes)
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(array.RecordReader), args.Error(1)
+}
+
+func (m *mockConnection) GetObjects(ctx context.Context, depth 
adbc.ObjectDepth, catalog, dbSchema, tableName, columnName *string, tableType 
[]string) (array.RecordReader, error) {
+       args := m.Called(ctx, depth, catalog, dbSchema, tableName, columnName, 
tableType)
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(array.RecordReader), args.Error(1)
+}
+
+func (m *mockConnection) GetTableSchema(ctx context.Context, catalog, dbSchema 
*string, tableName string) (*arrow.Schema, error) {
+       args := m.Called(ctx, catalog, dbSchema, tableName)
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(*arrow.Schema), args.Error(1)
+}
+
+func (m *mockConnection) GetTableTypes(ctx context.Context) 
(array.RecordReader, error) {
+       args := m.Called(ctx)
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(array.RecordReader), args.Error(1)
+}
+
+func (m *mockConnection) Commit(ctx context.Context) error {
+       args := m.Called(ctx)
+       return args.Error(0)
+}
+
+func (m *mockConnection) Rollback(ctx context.Context) error {
+       args := m.Called(ctx)
+       return args.Error(0)
+}
+
+func (m *mockConnection) NewStatement() (adbc.Statement, error) {
+       args := m.Called()
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(adbc.Statement), args.Error(1)
+}
+
+func (m *mockConnection) Close() error {
+       args := m.Called()
+       return args.Error(0)
+}
+
+func (m *mockConnection) ReadPartition(ctx context.Context, 
serializedPartition []byte) (array.RecordReader, error) {
+       args := m.Called(ctx, serializedPartition)
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(array.RecordReader), args.Error(1)
+}
+
+// mockStatement is a mock implementation of adbc.Statement for testing
+type mockStatement struct {
+       mock.Mock
+}
+
+func (m *mockStatement) Close() error {
+       args := m.Called()
+       return args.Error(0)
+}
+
+func (m *mockStatement) SetOption(key, val string) error {
+       args := m.Called(key, val)
+       return args.Error(0)
+}
+
+func (m *mockStatement) SetSqlQuery(query string) error {
+       args := m.Called(query)
+       return args.Error(0)
+}
+
+func (m *mockStatement) ExecuteQuery(ctx context.Context) (array.RecordReader, 
int64, error) {
+       args := m.Called(ctx)
+       if args.Get(0) == nil {
+               return nil, args.Get(1).(int64), args.Error(2)
+       }
+       return args.Get(0).(array.RecordReader), args.Get(1).(int64), 
args.Error(2)
+}
+
+func (m *mockStatement) ExecuteUpdate(ctx context.Context) (int64, error) {
+       args := m.Called(ctx)
+       return args.Get(0).(int64), args.Error(1)
+}
+
+func (m *mockStatement) Prepare(ctx context.Context) error {
+       args := m.Called(ctx)
+       return args.Error(0)
+}
+
+func (m *mockStatement) SetSubstraitPlan(plan []byte) error {
+       args := m.Called(plan)
+       return args.Error(0)
+}
+
+func (m *mockStatement) Bind(ctx context.Context, values arrow.RecordBatch) 
error {
+       args := m.Called(ctx, values)
+       return args.Error(0)
+}
+
+func (m *mockStatement) BindStream(ctx context.Context, stream 
array.RecordReader) error {
+       args := m.Called(ctx, stream)
+       return args.Error(0)
+}
+
+func (m *mockStatement) GetParameterSchema() (*arrow.Schema, error) {
+       args := m.Called()
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(*arrow.Schema), args.Error(1)
+}
+
+func (m *mockStatement) ExecutePartitions(ctx context.Context) (*arrow.Schema, 
adbc.Partitions, int64, error) {
+       args := m.Called(ctx)
+       if args.Get(0) == nil {
+               return nil, args.Get(1).(adbc.Partitions), args.Get(2).(int64), 
args.Error(3)
+       }
+       return args.Get(0).(*arrow.Schema), args.Get(1).(adbc.Partitions), 
args.Get(2).(int64), args.Error(3)
+}
+
+// mockDatabase is a mock implementation of adbc.Database for testing
+type mockDatabase struct {
+       mock.Mock
+}
+
+func (m *mockDatabase) SetOptions(opts map[string]string) error {
+       args := m.Called(opts)
+       return args.Error(0)
+}
+
+func (m *mockDatabase) Open(ctx context.Context) (adbc.Connection, error) {
+       args := m.Called(ctx)
+       if args.Get(0) == nil {
+               return nil, args.Error(1)
+       }
+       return args.Get(0).(adbc.Connection), args.Error(1)
+}
+
+func (m *mockDatabase) Close() error {
+       args := m.Called()
+       return args.Error(0)
+}
+
+func TestAsDatabaseContext(t *testing.T) {
+       db := &mockDatabase{}
+       dbCtx := adbc.AsDatabaseContext(db)
+
+       assert.NotNil(t, dbCtx)
+
+       // Test SetOptions with context
+       opts := map[string]string{"key": "value"}
+       db.On("SetOptions", opts).Return(nil)
+
+       err := dbCtx.SetOptions(context.Background(), opts)
+       assert.NoError(t, err)
+       db.AssertExpectations(t)
+
+       // Test Close with context
+       db.On("Close").Return(nil)
+       err = dbCtx.Close(context.Background())
+       assert.NoError(t, err)
+       db.AssertExpectations(t)
+
+       // Test Open with context
+       mockConn := &mockConnection{}
+       db.On("Open", mock.Anything).Return(mockConn, nil)
+
+       conn, err := dbCtx.Open(context.Background())
+       assert.NoError(t, err)
+       assert.NotNil(t, conn)
+       db.AssertExpectations(t)
+}
+
+func TestAsDatabaseContext_ErrorHandling(t *testing.T) {
+       db := &mockDatabase{}
+       dbCtx := adbc.AsDatabaseContext(db)
+
+       expectedErr := errors.New("test error")
+       db.On("Close").Return(expectedErr)
+
+       err := dbCtx.Close(context.Background())
+       assert.Equal(t, expectedErr, err)
+       db.AssertExpectations(t)
+}
+
+func TestAsConnectionContext(t *testing.T) {
+       conn := &mockConnection{}
+       connCtx := adbc.AsConnectionContext(conn)
+
+       assert.NotNil(t, connCtx)
+
+       // Test NewStatement with context
+       mockStmt := &mockStatement{}
+       conn.On("NewStatement").Return(mockStmt, nil)
+
+       stmt, err := connCtx.NewStatement(context.Background())
+       assert.NoError(t, err)
+       assert.NotNil(t, stmt)
+       conn.AssertExpectations(t)
+
+       // Test Close with context
+       conn.On("Close").Return(nil)
+       err = connCtx.Close(context.Background())
+       assert.NoError(t, err)
+       conn.AssertExpectations(t)
+}
+
+func TestAsConnectionContext_PassThrough(t *testing.T) {
+       conn := &mockConnection{}
+       connCtx := adbc.AsConnectionContext(conn)
+
+       // Test that context is passed through for methods that already accept 
it
+       ctx := context.Background()
+       conn.On("Commit", ctx).Return(nil)
+
+       err := connCtx.Commit(ctx)
+       assert.NoError(t, err)
+       conn.AssertExpectations(t)
+}
+
+func TestAsStatementContext(t *testing.T) {
+       stmt := &mockStatement{}
+       stmtCtx := adbc.AsStatementContext(stmt)
+
+       assert.NotNil(t, stmtCtx)
+
+       // Test SetOption with context
+       stmt.On("SetOption", "key", "value").Return(nil)
+       err := stmtCtx.SetOption(context.Background(), "key", "value")
+       assert.NoError(t, err)
+       stmt.AssertExpectations(t)
+
+       // Test SetSqlQuery with context
+       stmt.On("SetSqlQuery", "SELECT 1").Return(nil)
+       err = stmtCtx.SetSqlQuery(context.Background(), "SELECT 1")
+       assert.NoError(t, err)
+       stmt.AssertExpectations(t)
+
+       // Test Close with context
+       stmt.On("Close").Return(nil)
+       err = stmtCtx.Close(context.Background())
+       assert.NoError(t, err)
+       stmt.AssertExpectations(t)
+}
+
+func TestAsStatementContext_ContextPassThrough(t *testing.T) {
+       stmt := &mockStatement{}
+       stmtCtx := adbc.AsStatementContext(stmt)
+
+       // Test that context is passed through for methods that already accept 
it
+       ctx := context.Background()
+       stmt.On("Prepare", ctx).Return(nil)
+
+       err := stmtCtx.Prepare(ctx)
+       assert.NoError(t, err)
+       stmt.AssertExpectations(t)
+}
+
+func TestAsDatabaseContext_Nil(t *testing.T) {
+       dbCtx := adbc.AsDatabaseContext(nil)
+       assert.Nil(t, dbCtx)
+}
+
+func TestAsConnectionContext_Nil(t *testing.T) {
+       connCtx := adbc.AsConnectionContext(nil)
+       assert.Nil(t, connCtx)
+}
+
+func TestAsStatementContext_Nil(t *testing.T) {
+       stmtCtx := adbc.AsStatementContext(nil)
+       assert.Nil(t, stmtCtx)
+}
diff --git a/go/adbc/example_context_test.go b/go/adbc/example_context_test.go
new file mode 100644
index 000000000..831a79a32
--- /dev/null
+++ b/go/adbc/example_context_test.go
@@ -0,0 +1,74 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package adbc_test
+
+import (
+       "context"
+       "fmt"
+       "time"
+
+       "github.com/apache/arrow-adbc/go/adbc"
+)
+
+// ExampleAsDatabaseContext demonstrates wrapping a non-context Database
+// to use with context-aware code.
+func ExampleAsDatabaseContext() {
+       var db adbc.Database // obtained from driver
+
+       // Wrap the database to add context support
+       dbCtx := adbc.AsDatabaseContext(db)
+
+       // Now you can use context with all operations
+       ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+       defer cancel()
+
+       // Open connection with context
+       conn, err := dbCtx.Open(ctx)
+       if err != nil {
+               fmt.Printf("Error: %v\n", err)
+               return
+       }
+       defer func() { _ = conn.Close(ctx) }()
+
+       fmt.Println("Connected successfully")
+}
+
+// ExampleConnectionContext demonstrates using context for cancellation.
+func ExampleConnectionWithContext() {
+       var connCtx adbc.ConnectionWithContext // obtained from 
DatabaseContext.Open
+
+       // Create a context with cancellation
+       ctx, cancel := context.WithCancel(context.Background())
+
+       // Cancel after some condition
+       go func() {
+               // Simulate cancellation after some event
+               time.Sleep(100 * time.Millisecond)
+               cancel()
+       }()
+
+       // This operation will be cancelled
+       stmt, err := connCtx.NewStatement(ctx)
+       if err != nil {
+               fmt.Printf("Error: %v\n", err)
+               return
+       }
+       defer func() { _ = stmt.Close(ctx) }()
+
+       fmt.Println("Statement created")
+}
diff --git a/go/adbc/ext.go b/go/adbc/ext.go
index 7fcbadcdb..dca009709 100644
--- a/go/adbc/ext.go
+++ b/go/adbc/ext.go
@@ -45,21 +45,6 @@ type OTelTracingInit interface {
        InitTracing(ctx context.Context, driverName string, driverVersion 
string) error
 }
 
-// DriverWithContext is an extension interface to allow the creation of a 
database
-// by providing an existing [context.Context] to initialize OpenTelemetry 
tracing.
-// It is similar to [database/sql.Driver] taking a map of keys and values as 
options
-// to initialize a [Connection] to the database. Any common connection
-// state can live in the Driver itself, for example an in-memory database
-// can place ownership of the actual database in this driver.
-//
-// Any connection specific options should be set using SetOptions before
-// calling Open.
-//
-// EXPERIMENTAL. Not formally part of the ADBC APIs.
-type DriverWithContext interface {
-       NewDatabaseWithContext(ctx context.Context, opts map[string]string) 
(Database, error)
-}
-
 // OTelTracing is an interface that supports instrumentation of 
[OpenTelementry tracing].
 //
 // EXPERIMENTAL. Not formally part of the ADBC APIs.
@@ -162,6 +147,73 @@ func IngestStream(ctx context.Context, cnxn Connection, 
reader array.RecordReade
        return count, nil
 }
 
+// IngestStreamContext is a helper for executing a bulk ingestion with context 
support.
+// This is a wrapper around the five-step boilerplate of NewStatement, 
SetOption, Bind,
+// Execute, and Close, with context propagation throughout.
+//
+// This version uses ConnectionContext and StatementContext for uniform 
context propagation.
+// For backward compatibility with non-context connections, use IngestStream.
+//
+// This is not part of the ADBC API specification.
+//
+// Since ADBC API revision 1.2.0 (Experimental).
+func IngestStreamContext(ctx context.Context, cnxn ConnectionWithContext, 
reader array.RecordReader, targetTable, ingestMode string, opt 
IngestStreamOptions) (int64, error) {
+       // Create a new statement
+       stmt, err := cnxn.NewStatement(ctx)
+       if err != nil {
+               return -1, fmt.Errorf("error during ingestion: NewStatement: 
%w", err)
+       }
+       defer func() {
+               err = errors.Join(err, stmt.Close(ctx))
+       }()
+
+       // Bind the record batch stream
+       if err = stmt.BindStream(ctx, reader); err != nil {
+               return -1, fmt.Errorf("error during ingestion: BindStream: %w", 
err)
+       }
+
+       // Set required options
+       if err = stmt.SetOption(ctx, OptionKeyIngestTargetTable, targetTable); 
err != nil {
+               return -1, fmt.Errorf("error during ingestion: 
SetOption(target_table=%s): %w", targetTable, err)
+       }
+       if err = stmt.SetOption(ctx, OptionKeyIngestMode, ingestMode); err != 
nil {
+               return -1, fmt.Errorf("error during ingestion: 
SetOption(mode=%s): %w", ingestMode, err)
+       }
+
+       // Set other options if provided
+       if opt.Catalog != "" {
+               if err = stmt.SetOption(ctx, OptionValueIngestTargetCatalog, 
opt.Catalog); err != nil {
+                       return -1, fmt.Errorf("error during ingestion: 
target_catalog=%s: %w", opt.Catalog, err)
+               }
+       }
+       if opt.DBSchema != "" {
+               if err = stmt.SetOption(ctx, OptionValueIngestTargetDBSchema, 
opt.DBSchema); err != nil {
+                       return -1, fmt.Errorf("error during ingestion: 
target_db_schema=%s: %w", opt.DBSchema, err)
+               }
+       }
+       if opt.Temporary {
+               if err = stmt.SetOption(ctx, OptionValueIngestTemporary, 
OptionValueEnabled); err != nil {
+                       return -1, fmt.Errorf("error during ingestion: 
temporary=true: %w", err)
+               }
+       }
+
+       // Set driver specific options
+       for k, v := range opt.Extra {
+               if err = stmt.SetOption(ctx, k, v); err != nil {
+                       return -1, fmt.Errorf("error during ingestion: 
SetOption(%s=%s): %w", k, v, err)
+               }
+       }
+
+       // Execute the update
+       var count int64
+       count, err = stmt.ExecuteUpdate(ctx)
+       if err != nil {
+               return -1, fmt.Errorf("error during ingestion: ExecuteUpdate: 
%w", err)
+       }
+
+       return count, nil
+}
+
 // DriverInfo library info map keys for auxiliary information
 //
 // NOTE: If in the future any of these InfoCodes are promoted to top-level 
fields


Reply via email to