This is an automated email from the ASF dual-hosted git repository.

wusheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c63ba3b4 fix the memPart Leak (#1012)
2c63ba3b4 is described below

commit 2c63ba3b4e310a0ed769be9c2f3ec06a7a0b165d
Author: Gao Hongtao <[email protected]>
AuthorDate: Wed Mar 18 15:43:30 2026 +0800

    fix the memPart Leak (#1012)
---
 CHANGES.md                        |  2 ++
 banyand/internal/wqueue/wqueue.go | 14 ++++++++++++--
 banyand/measure/tstable.go        |  1 +
 banyand/measure/tstable_test.go   | 33 +++++++++++++++++++++++++++++++++
 banyand/measure/write_data.go     | 11 +++++++++--
 banyand/queue/sub/chunked_sync.go |  2 ++
 banyand/queue/sub/server.go       | 36 ++++++++++++++++++++++++++++++++++++
 banyand/stream/tstable.go         |  1 +
 banyand/stream/tstable_test.go    | 33 +++++++++++++++++++++++++++++++++
 banyand/stream/write_data.go      | 11 +++++++++--
 banyand/trace/tstable_test.go     | 34 ++++++++++++++++++++++++++++++++++
 banyand/trace/write_data.go       |  7 ++++++-
 pkg/pool/pool.go                  | 21 +++++++++++----------
 13 files changed, 189 insertions(+), 17 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 50071ac63..55bd0b4fe 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -44,6 +44,8 @@ Release Notes.
 - Fix sidx tag filter range check returning inverted skip decision and use 
correct int64 encoding for block min/max.
 - Ignore take snapshot when no data.
 - Fix measure standalone write handler resetting accumulated groups on error, 
which dropped all successfully processed events in the batch.
+- Fix memory part reference leak in mustAddMemPart when tsTable loop closes.
+- Fix memory part leak in syncPartContext Close and prevent double-release in 
FinishSync.
 
 ### Document
 
diff --git a/banyand/internal/wqueue/wqueue.go 
b/banyand/internal/wqueue/wqueue.go
index b24c56059..37d82d1cd 100644
--- a/banyand/internal/wqueue/wqueue.go
+++ b/banyand/internal/wqueue/wqueue.go
@@ -42,8 +42,11 @@ const (
        lockFilename    = "lock"
 )
 
-// ErrUnknownShard indicates that the shard is not found.
-var ErrUnknownShard = errors.New("unknown shard")
+var (
+       // ErrUnknownShard indicates that the shard is not found.
+       ErrUnknownShard = errors.New("unknown shard")
+       errQueueClosed  = errors.New("queue is closed")
+)
 
 // Metrics is the interface of metrics.
 type Metrics interface {
@@ -167,6 +170,9 @@ func Open[S SubQueue, O any](ctx context.Context, opts 
Opts[S, O], _ string) (*Q
 // If the shard already exists, it returns it without locking.
 // If the shard doesn't exist, it creates a new one with proper locking.
 func (q *Queue[S, O]) GetOrCreateShard(shardID common.ShardID) (*Shard[S], 
error) {
+       if q.closed.Load() {
+               return nil, errQueueClosed
+       }
        // First check if shard exists without locking
        if shard := q.getShard(shardID); shard != nil {
                return shard, nil
@@ -176,6 +182,10 @@ func (q *Queue[S, O]) GetOrCreateShard(shardID 
common.ShardID) (*Shard[S], error
        q.Lock()
        defer q.Unlock()
 
+       if q.closed.Load() {
+               return nil, errQueueClosed
+       }
+
        // Double-check after acquiring lock
        if shard := q.getShard(shardID); shard != nil {
                return shard, nil
diff --git a/banyand/measure/tstable.go b/banyand/measure/tstable.go
index d213cc989..5358b1634 100644
--- a/banyand/measure/tstable.go
+++ b/banyand/measure/tstable.go
@@ -322,6 +322,7 @@ func (tst *tsTable) mustAddMemPart(mp *memPart) {
        case tst.introductions <- ind:
        case <-tst.loopCloser.CloseNotify():
                tst.addPendingDataCount(-int64(totalCount))
+               ind.memPart.decRef()
                return
        }
        select {
diff --git a/banyand/measure/tstable_test.go b/banyand/measure/tstable_test.go
index afb44b918..68c39d223 100644
--- a/banyand/measure/tstable_test.go
+++ b/banyand/measure/tstable_test.go
@@ -29,12 +29,15 @@ import (
 
        "github.com/apache/skywalking-banyandb/api/common"
        "github.com/apache/skywalking-banyandb/banyand/internal/storage"
+       "github.com/apache/skywalking-banyandb/banyand/protector"
        "github.com/apache/skywalking-banyandb/pkg/convert"
        "github.com/apache/skywalking-banyandb/pkg/fs"
+       "github.com/apache/skywalking-banyandb/pkg/logger"
        pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
        "github.com/apache/skywalking-banyandb/pkg/query/model"
        "github.com/apache/skywalking-banyandb/pkg/run"
        "github.com/apache/skywalking-banyandb/pkg/test"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
        "github.com/apache/skywalking-banyandb/pkg/watcher"
 )
 
@@ -297,6 +300,36 @@ var fieldProjections = map[int][]string{
        3: {"intField"},
 }
 
+func Test_mustAddMemPart_closeNotifyReleasesMemPart(t *testing.T) {
+       tmpPath, defFn := test.Space(require.New(t))
+       defer defFn()
+       tst, err := newTSTable(
+               fs.NewLocalFileSystem(),
+               tmpPath,
+               common.Position{},
+               logger.GetLogger("test"),
+               timestamp.TimeRange{},
+               option{
+                       flushTimeout: 0,
+                       protector:    protector.Nop{},
+               },
+               nil,
+       )
+       require.NoError(t, err)
+
+       mp := generateMemPart()
+       mp.mustInitFromDataPoints(dpsTS1)
+       originalCount := mp.partMetadata.TotalCount
+       require.Greater(t, originalCount, uint64(0))
+
+       tst.Close()
+       tst.mustAddMemPart(mp)
+
+       // Verify the memPart was released by checking the count is reset to 0
+       // after mustAddMemPart returns (which handles the release via decRef)
+       require.Equal(t, uint64(0), mp.partMetadata.TotalCount)
+}
+
 var dpsTS1 = &dataPoints{
        seriesIDs:  []common.SeriesID{1, 2, 3},
        timestamps: []int64{1, 1, 1},
diff --git a/banyand/measure/write_data.go b/banyand/measure/write_data.go
index b83ff65f6..ef4fe160a 100644
--- a/banyand/measure/write_data.go
+++ b/banyand/measure/write_data.go
@@ -43,7 +43,9 @@ func (s *syncPartContext) NewPartType(_ 
*queue.ChunkedSyncPartContext) error {
 }
 
 func (s *syncPartContext) FinishSync() error {
-       s.tsTable.mustAddMemPart(s.memPart)
+       mp := s.memPart
+       s.memPart = nil
+       s.tsTable.mustAddMemPart(mp)
        return s.Close()
 }
 
@@ -51,7 +53,12 @@ func (s *syncPartContext) Close() error {
        s.writers.MustClose()
        releaseWriters(s.writers)
        s.writers = nil
-       s.memPart = nil
+       if s.memPart != nil {
+               // syncPartContext owns the memPart directly without 
partWrapper refcounting.
+               // It must release via releaseMemPart, not decRef, which is 
used in mustAddMemPart.
+               releaseMemPart(s.memPart)
+               s.memPart = nil
+       }
        s.tsTable = nil
        return nil
 }
diff --git a/banyand/queue/sub/chunked_sync.go 
b/banyand/queue/sub/chunked_sync.go
index 1153bc44b..63d6137fa 100644
--- a/banyand/queue/sub/chunked_sync.go
+++ b/banyand/queue/sub/chunked_sync.go
@@ -128,6 +128,7 @@ func (s *server) SyncPart(stream 
clusterv1.ChunkedSyncService_SyncPartServer) er
        var sessionID string
        defer func() {
                if currentSession != nil {
+                       s.unregisterSession(currentSession.sessionID)
                        if currentSession.partCtx != nil {
                                currentSession.partCtx.Close()
                        }
@@ -160,6 +161,7 @@ func (s *server) SyncPart(stream 
clusterv1.ChunkedSyncService_SyncPartServer) er
                                chunksReceived: 0,
                                partsProgress:  make(map[int]*partProgress),
                        }
+                       s.registerSession(sessionID, currentSession)
                        if dl := s.log.Debug(); dl.Enabled() {
                                dl.Str("session_id", sessionID).
                                        Str("topic", req.GetMetadata().Topic).
diff --git a/banyand/queue/sub/server.go b/banyand/queue/sub/server.go
index 2f36afa8d..69da73ffc 100644
--- a/banyand/queue/sub/server.go
+++ b/banyand/queue/sub/server.go
@@ -85,6 +85,7 @@ type server struct {
        listeners             map[bus.Topic][]bus.MessageListener
        topicMap              map[string]bus.Topic
        chunkedSyncHandlers   map[bus.Topic]queue.ChunkedSyncHandler
+       activeSessions        map[string]*syncSession
        log                   *logger.Logger
        httpSrv               *http.Server
        tlsReloader           *pkgtls.Reloader
@@ -100,6 +101,7 @@ type server struct {
        maxRecvMsgSize        run.Bytes
        listenersLock         sync.RWMutex
        routeTableProviderMu  sync.RWMutex
+       activeSessionsMu      sync.Mutex
        port                  uint32
        httpPort              uint32
        maxChunkBufferSize    uint32
@@ -119,6 +121,7 @@ func NewServerWithPorts(omr observability.MetricsRegistry, 
flagNamePrefix string
                listeners:           make(map[bus.Topic][]bus.MessageListener),
                topicMap:            make(map[string]bus.Topic),
                chunkedSyncHandlers: 
make(map[bus.Topic]queue.ChunkedSyncHandler),
+               activeSessions:      make(map[string]*syncSession),
                omr:                 omr,
                maxRecvMsgSize:      defaultRecvSize,
                flagNamePrefix:      flagNamePrefix,
@@ -378,6 +381,39 @@ func (s *server) GracefulStop() {
                t.Stop()
                s.log.Info().Msg("stopped gracefully")
        }
+
+       s.closeAllSessions()
+}
+
+// registerSession adds a session to the active sessions map.
+func (s *server) registerSession(id string, session *syncSession) {
+       s.activeSessionsMu.Lock()
+       s.activeSessions[id] = session
+       s.activeSessionsMu.Unlock()
+}
+
+// unregisterSession removes a session from the active sessions map.
+func (s *server) unregisterSession(id string) {
+       s.activeSessionsMu.Lock()
+       delete(s.activeSessions, id)
+       s.activeSessionsMu.Unlock()
+}
+
+// closeAllSessions closes the partCtx of every remaining active session.
+// It is called after the gRPC server has fully stopped as a safety net.
+func (s *server) closeAllSessions() {
+       s.activeSessionsMu.Lock()
+       sessions := s.activeSessions
+       s.activeSessions = make(map[string]*syncSession)
+       s.activeSessionsMu.Unlock()
+
+       for id, session := range sessions {
+               if session.partCtx != nil {
+                       if closeErr := session.partCtx.Close(); closeErr != nil 
{
+                               s.log.Error().Err(closeErr).Str("session_id", 
id).Msg("failed to close session partCtx during shutdown")
+                       }
+               }
+       }
 }
 
 // RegisterChunkedSyncHandler implements queue.Server.
diff --git a/banyand/stream/tstable.go b/banyand/stream/tstable.go
index cbd4ce81b..c980a2e9c 100644
--- a/banyand/stream/tstable.go
+++ b/banyand/stream/tstable.go
@@ -333,6 +333,7 @@ func (tst *tsTable) mustAddMemPart(mp *memPart) {
        case tst.introductions <- ind:
        case <-tst.loopCloser.CloseNotify():
                tst.addPendingDataCount(-int64(totalCount))
+               ind.memPart.decRef()
                return
        }
        select {
diff --git a/banyand/stream/tstable_test.go b/banyand/stream/tstable_test.go
index ecdb2758a..ffbd31e35 100644
--- a/banyand/stream/tstable_test.go
+++ b/banyand/stream/tstable_test.go
@@ -29,11 +29,14 @@ import (
        "github.com/stretchr/testify/require"
 
        "github.com/apache/skywalking-banyandb/api/common"
+       "github.com/apache/skywalking-banyandb/banyand/protector"
        "github.com/apache/skywalking-banyandb/pkg/convert"
        "github.com/apache/skywalking-banyandb/pkg/fs"
+       "github.com/apache/skywalking-banyandb/pkg/logger"
        pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
        "github.com/apache/skywalking-banyandb/pkg/run"
        "github.com/apache/skywalking-banyandb/pkg/test"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
        "github.com/apache/skywalking-banyandb/pkg/watcher"
 )
 
@@ -250,6 +253,36 @@ func Test_tstIter(t *testing.T) {
        })
 }
 
+func Test_mustAddMemPart_closeNotifyReleasesMemPart(t *testing.T) {
+       tmpPath, defFn := test.Space(require.New(t))
+       defer defFn()
+       tst, err := newTSTable(
+               fs.NewLocalFileSystem(),
+               tmpPath,
+               common.Position{},
+               logger.GetLogger("test"),
+               timestamp.TimeRange{},
+               option{
+                       flushTimeout: 0,
+                       protector:    protector.Nop{},
+               },
+               nil,
+       )
+       require.NoError(t, err)
+
+       mp := generateMemPart()
+       mp.mustInitFromElements(esTS1)
+       originalCount := mp.partMetadata.TotalCount
+       require.Greater(t, originalCount, uint64(0))
+
+       tst.Close()
+       tst.mustAddMemPart(mp)
+
+       // Verify the memPart was released by checking the count is reset to 0
+       // after mustAddMemPart returns (which handles the release via decRef)
+       require.Equal(t, uint64(0), mp.partMetadata.TotalCount)
+}
+
 var esTS1 = &elements{
        seriesIDs:  []common.SeriesID{1, 2, 3},
        timestamps: []int64{1, 1, 1},
diff --git a/banyand/stream/write_data.go b/banyand/stream/write_data.go
index e9bd52cc1..93c5dfe08 100644
--- a/banyand/stream/write_data.go
+++ b/banyand/stream/write_data.go
@@ -42,7 +42,9 @@ func (s *syncPartContext) NewPartType(_ 
*queue.ChunkedSyncPartContext) error {
 }
 
 func (s *syncPartContext) FinishSync() error {
-       s.tsTable.mustAddMemPart(s.memPart)
+       mp := s.memPart
+       s.memPart = nil
+       s.tsTable.mustAddMemPart(mp)
        return s.Close()
 }
 
@@ -50,7 +52,12 @@ func (s *syncPartContext) Close() error {
        s.writers.MustClose()
        releaseWriters(s.writers)
        s.writers = nil
-       s.memPart = nil
+       if s.memPart != nil {
+               // syncPartContext owns the memPart directly without 
partWrapper refcounting.
+               // It must release via releaseMemPart, not decRef, which is 
used in mustAddMemPart.
+               releaseMemPart(s.memPart)
+               s.memPart = nil
+       }
        s.tsTable = nil
        return nil
 }
diff --git a/banyand/trace/tstable_test.go b/banyand/trace/tstable_test.go
index d6a3cdf19..2b437aafe 100644
--- a/banyand/trace/tstable_test.go
+++ b/banyand/trace/tstable_test.go
@@ -28,12 +28,16 @@ import (
        "github.com/stretchr/testify/assert"
        "github.com/stretchr/testify/require"
 
+       "github.com/apache/skywalking-banyandb/api/common"
+       "github.com/apache/skywalking-banyandb/banyand/protector"
        "github.com/apache/skywalking-banyandb/pkg/convert"
        "github.com/apache/skywalking-banyandb/pkg/fs"
+       "github.com/apache/skywalking-banyandb/pkg/logger"
        pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
        "github.com/apache/skywalking-banyandb/pkg/query/model"
        "github.com/apache/skywalking-banyandb/pkg/run"
        "github.com/apache/skywalking-banyandb/pkg/test"
+       "github.com/apache/skywalking-banyandb/pkg/timestamp"
        "github.com/apache/skywalking-banyandb/pkg/watcher"
 )
 
@@ -235,6 +239,36 @@ var testSchemaTagTypes = map[string]pbv1.ValueType{
        "intTag":    pbv1.ValueTypeInt64,
 }
 
+func Test_mustAddMemPart_closeNotifyReleasesMemPart(t *testing.T) {
+       tmpPath, defFn := test.Space(require.New(t))
+       defer defFn()
+       tst, err := newTSTable(
+               fs.NewLocalFileSystem(),
+               tmpPath,
+               common.Position{},
+               logger.GetLogger("test"),
+               timestamp.TimeRange{},
+               option{
+                       flushTimeout: 0,
+                       protector:    protector.Nop{},
+               },
+               nil,
+       )
+       require.NoError(t, err)
+
+       mp := generateMemPart()
+       mp.mustInitFromTraces(tsTS1)
+       originalCount := mp.partMetadata.TotalCount
+       require.Greater(t, originalCount, uint64(0))
+
+       tst.Close()
+       tst.mustAddMemPart(mp, nil)
+
+       // Verify the memPart was released by checking the count is reset to 0
+       // after mustAddMemPart returns (which handles the release via decRef)
+       require.Equal(t, uint64(0), mp.partMetadata.TotalCount)
+}
+
 var tsTS1 = &traces{
        traceIDs:   []string{"trace1", "trace2", "trace3"},
        timestamps: []int64{1, 1, 1},
diff --git a/banyand/trace/write_data.go b/banyand/trace/write_data.go
index 6e44cbb7a..63c56c2df 100644
--- a/banyand/trace/write_data.go
+++ b/banyand/trace/write_data.go
@@ -87,11 +87,13 @@ func (s *syncPartContext) FinishSync() error {
        }
 
        if s.memPart != nil {
+               mp := s.memPart
+               s.memPart = nil
                sidxPartContexts := make(map[string]*sidx.MemPart, 
len(s.sidxPartContexts))
                for _, sidxPartContext := range s.sidxPartContexts {
                        sidxPartContexts[sidxPartContext.Name()] = 
sidxPartContext.GetMemPart()
                }
-               s.tsTable.mustAddMemPart(s.memPart, sidxPartContexts)
+               s.tsTable.mustAddMemPart(mp, sidxPartContexts)
        }
        return s.Close()
 }
@@ -103,6 +105,9 @@ func (s *syncPartContext) Close() error {
                s.writers = nil
        }
        if s.memPart != nil {
+               // syncPartContext owns the memPart directly without 
partWrapper refcounting.
+               // It must release via releaseMemPart, not decRef, which is 
used in mustAddMemPart.
+               releaseMemPart(s.memPart)
                s.memPart = nil
        }
        if s.sidxPartContexts != nil {
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
index 677cdfd6a..c2b275fd8 100644
--- a/pkg/pool/pool.go
+++ b/pkg/pool/pool.go
@@ -102,20 +102,19 @@ func (p *Synced[T]) Get() T {
                result = v.(T)
        }
 
-       // Capture stack trace if tracking is enabled
-       if stackTrackingEnabled.Load() {
-               // Lazy initialize maps on first use
+       // Capture stack trace if tracking is enabled.
+       // Skip tracking when the pool returns nil because the caller will 
create
+       // a new object whose pointer won't match the nil key in idMap,
+       // so Put() would never clean up the entry.
+       if v != nil && stackTrackingEnabled.Load() {
                p.stacksMutex.Lock()
                if p.stacks == nil {
                        p.stacks = make(map[uint64]string)
                        p.idMap = make(map[any]uint64)
                }
-
-               // Generate unique ID and capture stack trace
                id := p.idCounter.Add(1)
                buf := make([]byte, 4096)
                n := runtime.Stack(buf, false)
-
                p.idMap[any(result)] = id
                p.stacks[id] = "Pool.Get() called:\n" + string(buf[:n])
                p.stacksMutex.Unlock()
@@ -126,10 +125,10 @@ func (p *Synced[T]) Get() T {
 
 // Put puts an object back to the pool.
 func (p *Synced[T]) Put(v T) {
-       p.Pool.Put(v)
-       p.refs.Add(-1)
-
-       // Remove the stack trace for this object if tracking is enabled
+       // Remove the stack trace BEFORE returning the object to the pool.
+       // Otherwise another goroutine's Get() can reuse the pointer and
+       // overwrite its idMap entry, causing this Put to delete the wrong
+       // stack and orphan the original one.
        if stackTrackingEnabled.Load() {
                p.stacksMutex.Lock()
                if p.idMap != nil {
@@ -140,6 +139,8 @@ func (p *Synced[T]) Put(v T) {
                }
                p.stacksMutex.Unlock()
        }
+       p.Pool.Put(v)
+       p.refs.Add(-1)
 }
 
 // RefsCount returns the reference count of the pool.

Reply via email to