This is an automated email from the ASF dual-hosted git repository.

zeroshade pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-go.git


The following commit(s) were added to refs/heads/main by this push:
     new 2d0962ed fix(flight): make StreamChunksFromReader ctx aware and 
cancellation-safe (#615)
2d0962ed is described below

commit 2d0962ed55074f050193b82aaf80f8b5995a2ffc
Author: Arnold Wakim <[email protected]>
AuthorDate: Mon Dec 29 16:44:46 2025 +0100

    fix(flight): make StreamChunksFromReader ctx aware and cancellation-safe 
(#615)
    
    ### Rationale for this change
    
    `StreamChunksFromReader` previously did not observe context
    cancellation. As a result, if a client disconnected early, the reader
    could continue producing data indefinitely, potentially blocking on
    channel sends, leaking `RecordBatch` objects, leaking the reader, and
    consuming unbounded memory and CPU (this observation triggered this PR).
    
    This fix ensures that data streaming promptly stops when the client
    disconnects.
    
    ### What changes are included in this PR?
    
    - `StreamChunksFromReader` now accepts a `context.Context`.
    - Tiny change was made to `DoGet`, to ensure it continues to work with
    the context-aware `StreamChunksFromReader`.
    
    ### Are these changes tested?
    
    - To be removed from description: the tests are bit tricky to write,
    similar to that of #437. Maybe @zeroshade has suggestions?
    
    ### Are there any user-facing changes?
    
    - `StreamChunksFromReader` now accepts a `context.Context`.
    
    ---------
    
    Co-authored-by: awakim <[email protected]>
---
 arrow/flight/flight_test.go                     | 250 ++++++++++++++++++++++++
 arrow/flight/flightsql/example/sqlite_server.go |   4 +-
 arrow/flight/flightsql/server.go                |  31 +--
 arrow/flight/record_batch_reader.go             |  23 ++-
 4 files changed, 289 insertions(+), 19 deletions(-)

diff --git a/arrow/flight/flight_test.go b/arrow/flight/flight_test.go
index 98d1734c..8d75aac2 100644
--- a/arrow/flight/flight_test.go
+++ b/arrow/flight/flight_test.go
@@ -21,6 +21,8 @@ import (
        "errors"
        "fmt"
        "io"
+       "sync"
+       "sync/atomic"
        "testing"
 
        "github.com/apache/arrow-go/v18/arrow"
@@ -484,3 +486,251 @@ type flightStreamWriter struct{}
 func (f *flightStreamWriter) Send(data *flight.FlightData) error { return nil }
 
 var _ flight.DataStreamWriter = (*flightStreamWriter)(nil)
+
+// callbackRecordReader wraps a record reader and invokes a callback on each 
Next() call.
+// It tracks whether batches are properly released and the reader itself is 
released.
+type callbackRecordReader struct {
+       mem            memory.Allocator
+       schema         *arrow.Schema
+       numBatches     int
+       currentBatch   atomic.Int32
+       onNext         func(batchIndex int) // callback invoked before 
returning from Next()
+       released       atomic.Bool
+       batchesCreated atomic.Int32
+       totalRetains   atomic.Int32
+       totalReleases  atomic.Int32
+       createdBatches []arrow.RecordBatch // track all created batches for 
cleanup
+       mu             sync.Mutex
+}
+
+func newCallbackRecordReader(mem memory.Allocator, schema *arrow.Schema, 
numBatches int, onNext func(int)) *callbackRecordReader {
+       return &callbackRecordReader{
+               mem:        mem,
+               schema:     schema,
+               numBatches: numBatches,
+               onNext:     onNext,
+       }
+}
+
+func (r *callbackRecordReader) Schema() *arrow.Schema {
+       return r.schema
+}
+
+func (r *callbackRecordReader) Next() bool {
+       current := r.currentBatch.Load()
+       if int(current) >= r.numBatches {
+               return false
+       }
+       r.currentBatch.Add(1)
+
+       if r.onNext != nil {
+               r.onNext(int(current))
+       }
+
+       return true
+}
+
+func (r *callbackRecordReader) RecordBatch() arrow.RecordBatch {
+       bldr := array.NewInt64Builder(r.mem)
+       defer bldr.Release()
+
+       currentBatch := r.currentBatch.Load()
+       bldr.AppendValues([]int64{int64(currentBatch)}, nil)
+       arr := bldr.NewArray()
+
+       rec := array.NewRecordBatch(r.schema, []arrow.Array{arr}, 1)
+       arr.Release()
+
+       tracked := &trackedRecordBatch{
+               RecordBatch: rec,
+               onRetain: func() {
+                       r.totalRetains.Add(1)
+               },
+               onRelease: func() {
+                       r.totalReleases.Add(1)
+               },
+       }
+
+       r.mu.Lock()
+       r.createdBatches = append(r.createdBatches, tracked)
+       r.mu.Unlock()
+
+       r.batchesCreated.Add(1)
+       return tracked
+}
+
+func (r *callbackRecordReader) ReleaseAll() {
+       r.mu.Lock()
+       defer r.mu.Unlock()
+       for _, batch := range r.createdBatches {
+               batch.Release()
+       }
+       r.createdBatches = nil
+}
+
+func (r *callbackRecordReader) Retain() {}
+
+func (r *callbackRecordReader) Release() {
+       r.released.Store(true)
+}
+
+func (r *callbackRecordReader) Record() arrow.RecordBatch {
+       return r.RecordBatch()
+}
+
+func (r *callbackRecordReader) Err() error {
+       return nil
+}
+
+// trackedRecordBatch wraps a RecordBatch to track Retain/Release calls.
+type trackedRecordBatch struct {
+       arrow.RecordBatch
+       onRetain  func()
+       onRelease func()
+}
+
+func (t *trackedRecordBatch) Retain() {
+       if t.onRetain != nil {
+               t.onRetain()
+       }
+       t.RecordBatch.Retain()
+}
+
+func (t *trackedRecordBatch) Release() {
+       if t.onRelease != nil {
+               t.onRelease()
+       }
+       t.RecordBatch.Release()
+}
+
+func TestStreamChunksFromReader_OK(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+       defer mem.AssertSize(t, 0)
+
+       schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: 
arrow.PrimitiveTypes.Int64}}, nil)
+
+       rdr := newCallbackRecordReader(mem, schema, 5, nil)
+       defer rdr.ReleaseAll()
+
+       ch := make(chan flight.StreamChunk, 5)
+
+       ctx := context.Background()
+
+       go flight.StreamChunksFromReader(ctx, rdr, ch)
+
+       var chunksReceived int
+       for chunk := range ch {
+               if chunk.Err != nil {
+                       t.Errorf("unexpected error chunk: %v", chunk.Err)
+                       continue
+               }
+               if chunk.Data != nil {
+                       chunksReceived++
+                       chunk.Data.Release()
+               }
+       }
+
+       require.Equal(t, 5, chunksReceived, "should receive all 5 batches")
+       require.True(t, rdr.released.Load(), "reader should be released")
+
+}
+
+// TestStreamChunksFromReader_HandlesCancellation verifies that context 
cancellation
+// causes StreamChunksFromReader to exit cleanly and release the reader.
+func TestStreamChunksFromReader_HandlesCancellation(t *testing.T) {
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       mem := memory.DefaultAllocator
+       schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: 
arrow.PrimitiveTypes.Int64}}, nil)
+
+       rdr := newCallbackRecordReader(mem, schema, 10, nil)
+       defer rdr.ReleaseAll()
+       ch := make(chan flight.StreamChunk) // unbuffered channel
+
+       go flight.StreamChunksFromReader(ctx, rdr, ch)
+
+       chunksReceived := 0
+       for chunk := range ch {
+               if chunk.Data != nil {
+                       chunksReceived++
+                       chunk.Data.Release()
+               }
+
+               // Cancel context after 2 batches (simulating server detecting 
client disconnect)
+               if chunksReceived == 2 {
+                       cancel()
+               }
+       }
+
+       // After canceling context, StreamChunksFromReader exits and closes the 
channel.
+       // The for-range loop above exits when the channel closes.
+       // By the time we reach here, the channel is closed, which means 
StreamChunksFromReader's
+       // defer stack has already executed, so the reader must be released.
+
+       require.True(t, rdr.released.Load(), "reader must be released when 
context is canceled")
+
+}
+
+// TestStreamChunksFromReader_CancellationReleasesBatches verifies that 
batches are
+// properly tracked and demonstrates memory leaks without cleanup, then proves 
cleanup fixes it.
+func TestStreamChunksFromReader_CancellationReleasesBatches(t *testing.T) {
+       mem := memory.NewCheckedAllocator(memory.DefaultAllocator)
+
+       schema := arrow.NewSchema([]arrow.Field{{Name: "value", Type: 
arrow.PrimitiveTypes.Int64}}, nil)
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       // Create reader that will produce 10 batches, but we'll cancel after 3
+       reader := newCallbackRecordReader(mem, schema, 10, func(batchIndex int) 
{
+               if batchIndex == 2 {
+                       cancel()
+               }
+       })
+
+       ch := make(chan flight.StreamChunk, 5)
+
+       // Start streaming
+       go flight.StreamChunksFromReader(ctx, reader, ch)
+
+       // Consume chunks until channel closes
+       var chunksReceived int
+       for chunk := range ch {
+               if chunk.Err != nil {
+                       t.Errorf("unexpected error chunk: %v", chunk.Err)
+                       continue
+               }
+               if chunk.Data != nil {
+                       chunksReceived++
+                       chunk.Data.Release()
+               }
+       }
+
+       // Verify the reader was released
+       require.True(t, reader.released.Load(), "reader should be released")
+
+       // We should have received at most 3-4 chunks (depending on timing)
+       // The important part is we didn't receive all 10
+       require.LessOrEqual(t, chunksReceived, 4, "should not receive all 10 
chunks, got %d", chunksReceived)
+       require.Greater(t, chunksReceived, 0, "should receive at least 1 chunk")
+
+       // Check that Retain and Release don't balance - proving there's a leak 
without manual cleanup
+       retains := reader.totalRetains.Load()
+       releases := reader.totalReleases.Load()
+       batchesCreated := reader.batchesCreated.Load()
+
+       // Each batch starts with refcount=1, then StreamChunksFromReader calls 
Retain() (refcount=2)
+       // For sent batches: we call Release() (refcount=1), batch still has 
initial ref
+       // For unsent batches due to cancellation: they keep refcount=1 from 
creation
+       // So we expect: releases < retains + batchesCreated
+       require.Less(t, releases, retains+batchesCreated,
+               "without cleanup, releases should be less than retains+created: 
retains=%d, releases=%d, created=%d",
+               retains, releases, batchesCreated)
+
+       // Now manually release all created batches to show proper cleanup 
fixes the leak
+       reader.ReleaseAll()
+
+       // After cleanup, memory should be freed
+       mem.AssertSize(t, 0)
+}
diff --git a/arrow/flight/flightsql/example/sqlite_server.go 
b/arrow/flight/flightsql/example/sqlite_server.go
index fc7d76a2..dca7b2d6 100644
--- a/arrow/flight/flightsql/example/sqlite_server.go
+++ b/arrow/flight/flightsql/example/sqlite_server.go
@@ -354,7 +354,7 @@ func (s *SQLiteFlightSQLServer) DoGetTables(ctx 
context.Context, cmd flightsql.G
        }
 
        schema := rdr.Schema()
-       go flight.StreamChunksFromReader(rdr, ch)
+       go flight.StreamChunksFromReader(ctx, rdr, ch)
        return schema, ch, nil
 }
 
@@ -485,7 +485,7 @@ func doGetQuery(ctx context.Context, mem memory.Allocator, 
db dbQueryCtx, query
        }
 
        ch := make(chan flight.StreamChunk)
-       go flight.StreamChunksFromReader(rdr, ch)
+       go flight.StreamChunksFromReader(ctx, rdr, ch)
        return schema, ch, nil
 }
 
diff --git a/arrow/flight/flightsql/server.go b/arrow/flight/flightsql/server.go
index d5102a27..25c89bf5 100644
--- a/arrow/flight/flightsql/server.go
+++ b/arrow/flight/flightsql/server.go
@@ -381,7 +381,7 @@ func (b *BaseServer) GetFlightInfoSqlInfo(_ 
context.Context, _ GetSqlInfo, desc
 }
 
 // DoGetSqlInfo returns a flight stream containing the list of sqlinfo results
-func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd GetSqlInfo) 
(*arrow.Schema, <-chan flight.StreamChunk, error) {
+func (b *BaseServer) DoGetSqlInfo(ctx context.Context, cmd GetSqlInfo) 
(*arrow.Schema, <-chan flight.StreamChunk, error) {
        if b.Alloc == nil {
                b.Alloc = memory.DefaultAllocator
        }
@@ -430,7 +430,7 @@ func (b *BaseServer) DoGetSqlInfo(_ context.Context, cmd 
GetSqlInfo) (*arrow.Sch
        }
 
        // StreamChunksFromReader will call release on the reader when done
-       go flight.StreamChunksFromReader(rdr, ch)
+       go flight.StreamChunksFromReader(ctx, rdr, ch)
        return schema_ref.SqlInfo, ch, nil
 }
 
@@ -927,19 +927,24 @@ func (f *flightSqlServer) DoGet(request *flight.Ticket, 
stream flight.FlightServ
        wr := flight.NewRecordWriter(stream, ipc.WithSchema(sc))
        defer wr.Close()
 
-       for chunk := range cc {
-               if chunk.Err != nil {
-                       return chunk.Err
-               }
-
-               wr.SetFlightDescriptor(chunk.Desc)
-               if err = wr.WriteWithAppMetadata(chunk.Data, 
chunk.AppMetadata); err != nil {
-                       return err
+       for {
+               select {
+               case <-stream.Context().Done():
+                       return stream.Context().Err()
+               case chunk, ok := <-cc:
+                       if !ok {
+                               return nil
+                       }
+                       if chunk.Err != nil {
+                               return chunk.Err
+                       }
+                       wr.SetFlightDescriptor(chunk.Desc)
+                       if err := wr.WriteWithAppMetadata(chunk.Data, 
chunk.AppMetadata); err != nil {
+                               return err
+                       }
+                       chunk.Data.Release()
                }
-               chunk.Data.Release()
        }
-
-       return err
 }
 
 type putMetadataWriter struct {
diff --git a/arrow/flight/record_batch_reader.go 
b/arrow/flight/record_batch_reader.go
index 7b744075..e6990a57 100644
--- a/arrow/flight/record_batch_reader.go
+++ b/arrow/flight/record_batch_reader.go
@@ -18,6 +18,7 @@ package flight
 
 import (
        "bytes"
+       "context"
        "errors"
        "fmt"
        "io"
@@ -212,24 +213,38 @@ type haserr interface {
 
 // StreamChunksFromReader is a convenience function to populate a channel
 // from a record reader. It is intended to be run using a separate goroutine
-// by calling `go flight.StreamChunksFromReader(rdr, ch)`.
+// by calling `go flight.StreamChunksFromReader(ctx, rdr, ch)`.
 //
 // If the record reader panics, an error chunk will get sent on the channel.
 //
 // This will close the channel and release the reader when it completes.
-func StreamChunksFromReader(rdr array.RecordReader, ch chan<- StreamChunk) {
+func StreamChunksFromReader(ctx context.Context, rdr array.RecordReader, ch 
chan<- StreamChunk) {
        defer close(ch)
        defer func() {
                if err := recover(); err != nil {
-                       ch <- StreamChunk{Err: 
utils.FormatRecoveredError("panic while reading", err)}
+                       select {
+                       case ch <- StreamChunk{Err: 
utils.FormatRecoveredError("panic while reading", err)}:
+                       case <-ctx.Done():
+                       }
                }
        }()
 
        defer rdr.Release()
        for rdr.Next() {
+               select {
+               case <-ctx.Done():
+                       return
+               default:
+               }
+
                rec := rdr.RecordBatch()
                rec.Retain()
-               ch <- StreamChunk{Data: rec}
+               select {
+               case ch <- StreamChunk{Data: rec}:
+               case <-ctx.Done():
+                       rec.Release()
+                       return
+               }
        }
 
        if e, ok := rdr.(haserr); ok {

Reply via email to