This is an automated email from the ASF dual-hosted git repository.
wusheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-banyandb.git
The following commit(s) were added to refs/heads/main by this push:
new 2c63ba3b4 fix the memPart Leak (#1012)
2c63ba3b4 is described below
commit 2c63ba3b4e310a0ed769be9c2f3ec06a7a0b165d
Author: Gao Hongtao <[email protected]>
AuthorDate: Wed Mar 18 15:43:30 2026 +0800
fix the memPart Leak (#1012)
---
CHANGES.md | 2 ++
banyand/internal/wqueue/wqueue.go | 14 ++++++++++++--
banyand/measure/tstable.go | 1 +
banyand/measure/tstable_test.go | 33 +++++++++++++++++++++++++++++++++
banyand/measure/write_data.go | 11 +++++++++--
banyand/queue/sub/chunked_sync.go | 2 ++
banyand/queue/sub/server.go | 36 ++++++++++++++++++++++++++++++++++++
banyand/stream/tstable.go | 1 +
banyand/stream/tstable_test.go | 33 +++++++++++++++++++++++++++++++++
banyand/stream/write_data.go | 11 +++++++++--
banyand/trace/tstable_test.go | 34 ++++++++++++++++++++++++++++++++++
banyand/trace/write_data.go | 7 ++++++-
pkg/pool/pool.go | 21 +++++++++++----------
13 files changed, 189 insertions(+), 17 deletions(-)
diff --git a/CHANGES.md b/CHANGES.md
index 50071ac63..55bd0b4fe 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -44,6 +44,8 @@ Release Notes.
- Fix sidx tag filter range check returning inverted skip decision and use
correct int64 encoding for block min/max.
- Ignore take snapshot when no data.
- Fix measure standalone write handler resetting accumulated groups on error,
which dropped all successfully processed events in the batch.
+- Fix memory part reference leak in mustAddMemPart when tsTable loop closes.
+- Fix memory part leak in syncPartContext Close and prevent double-release in
FinishSync.
### Document
diff --git a/banyand/internal/wqueue/wqueue.go
b/banyand/internal/wqueue/wqueue.go
index b24c56059..37d82d1cd 100644
--- a/banyand/internal/wqueue/wqueue.go
+++ b/banyand/internal/wqueue/wqueue.go
@@ -42,8 +42,11 @@ const (
lockFilename = "lock"
)
-// ErrUnknownShard indicates that the shard is not found.
-var ErrUnknownShard = errors.New("unknown shard")
+var (
+ // ErrUnknownShard indicates that the shard is not found.
+ ErrUnknownShard = errors.New("unknown shard")
+ errQueueClosed = errors.New("queue is closed")
+)
// Metrics is the interface of metrics.
type Metrics interface {
@@ -167,6 +170,9 @@ func Open[S SubQueue, O any](ctx context.Context, opts
Opts[S, O], _ string) (*Q
// If the shard already exists, it returns it without locking.
// If the shard doesn't exist, it creates a new one with proper locking.
func (q *Queue[S, O]) GetOrCreateShard(shardID common.ShardID) (*Shard[S],
error) {
+ if q.closed.Load() {
+ return nil, errQueueClosed
+ }
// First check if shard exists without locking
if shard := q.getShard(shardID); shard != nil {
return shard, nil
@@ -176,6 +182,10 @@ func (q *Queue[S, O]) GetOrCreateShard(shardID
common.ShardID) (*Shard[S], error
q.Lock()
defer q.Unlock()
+ if q.closed.Load() {
+ return nil, errQueueClosed
+ }
+
// Double-check after acquiring lock
if shard := q.getShard(shardID); shard != nil {
return shard, nil
diff --git a/banyand/measure/tstable.go b/banyand/measure/tstable.go
index d213cc989..5358b1634 100644
--- a/banyand/measure/tstable.go
+++ b/banyand/measure/tstable.go
@@ -322,6 +322,7 @@ func (tst *tsTable) mustAddMemPart(mp *memPart) {
case tst.introductions <- ind:
case <-tst.loopCloser.CloseNotify():
tst.addPendingDataCount(-int64(totalCount))
+ ind.memPart.decRef()
return
}
select {
diff --git a/banyand/measure/tstable_test.go b/banyand/measure/tstable_test.go
index afb44b918..68c39d223 100644
--- a/banyand/measure/tstable_test.go
+++ b/banyand/measure/tstable_test.go
@@ -29,12 +29,15 @@ import (
"github.com/apache/skywalking-banyandb/api/common"
"github.com/apache/skywalking-banyandb/banyand/internal/storage"
+ "github.com/apache/skywalking-banyandb/banyand/protector"
"github.com/apache/skywalking-banyandb/pkg/convert"
"github.com/apache/skywalking-banyandb/pkg/fs"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
"github.com/apache/skywalking-banyandb/pkg/query/model"
"github.com/apache/skywalking-banyandb/pkg/run"
"github.com/apache/skywalking-banyandb/pkg/test"
+ "github.com/apache/skywalking-banyandb/pkg/timestamp"
"github.com/apache/skywalking-banyandb/pkg/watcher"
)
@@ -297,6 +300,36 @@ var fieldProjections = map[int][]string{
3: {"intField"},
}
+func Test_mustAddMemPart_closeNotifyReleasesMemPart(t *testing.T) {
+ tmpPath, defFn := test.Space(require.New(t))
+ defer defFn()
+ tst, err := newTSTable(
+ fs.NewLocalFileSystem(),
+ tmpPath,
+ common.Position{},
+ logger.GetLogger("test"),
+ timestamp.TimeRange{},
+ option{
+ flushTimeout: 0,
+ protector: protector.Nop{},
+ },
+ nil,
+ )
+ require.NoError(t, err)
+
+ mp := generateMemPart()
+ mp.mustInitFromDataPoints(dpsTS1)
+ originalCount := mp.partMetadata.TotalCount
+ require.Greater(t, originalCount, uint64(0))
+
+ tst.Close()
+ tst.mustAddMemPart(mp)
+
+ // Verify the memPart was released by checking the count is reset to 0
+ // after mustAddMemPart returns (which handles the release via decRef)
+ require.Equal(t, uint64(0), mp.partMetadata.TotalCount)
+}
+
var dpsTS1 = &dataPoints{
seriesIDs: []common.SeriesID{1, 2, 3},
timestamps: []int64{1, 1, 1},
diff --git a/banyand/measure/write_data.go b/banyand/measure/write_data.go
index b83ff65f6..ef4fe160a 100644
--- a/banyand/measure/write_data.go
+++ b/banyand/measure/write_data.go
@@ -43,7 +43,9 @@ func (s *syncPartContext) NewPartType(_
*queue.ChunkedSyncPartContext) error {
}
func (s *syncPartContext) FinishSync() error {
- s.tsTable.mustAddMemPart(s.memPart)
+ mp := s.memPart
+ s.memPart = nil
+ s.tsTable.mustAddMemPart(mp)
return s.Close()
}
@@ -51,7 +53,12 @@ func (s *syncPartContext) Close() error {
s.writers.MustClose()
releaseWriters(s.writers)
s.writers = nil
- s.memPart = nil
+ if s.memPart != nil {
+ // syncPartContext owns the memPart directly without
partWrapper refcounting.
+ // It must release via releaseMemPart, not decRef, which is
used in mustAddMemPart.
+ releaseMemPart(s.memPart)
+ s.memPart = nil
+ }
s.tsTable = nil
return nil
}
diff --git a/banyand/queue/sub/chunked_sync.go
b/banyand/queue/sub/chunked_sync.go
index 1153bc44b..63d6137fa 100644
--- a/banyand/queue/sub/chunked_sync.go
+++ b/banyand/queue/sub/chunked_sync.go
@@ -128,6 +128,7 @@ func (s *server) SyncPart(stream
clusterv1.ChunkedSyncService_SyncPartServer) er
var sessionID string
defer func() {
if currentSession != nil {
+ s.unregisterSession(currentSession.sessionID)
if currentSession.partCtx != nil {
currentSession.partCtx.Close()
}
@@ -160,6 +161,7 @@ func (s *server) SyncPart(stream
clusterv1.ChunkedSyncService_SyncPartServer) er
chunksReceived: 0,
partsProgress: make(map[int]*partProgress),
}
+ s.registerSession(sessionID, currentSession)
if dl := s.log.Debug(); dl.Enabled() {
dl.Str("session_id", sessionID).
Str("topic", req.GetMetadata().Topic).
diff --git a/banyand/queue/sub/server.go b/banyand/queue/sub/server.go
index 2f36afa8d..69da73ffc 100644
--- a/banyand/queue/sub/server.go
+++ b/banyand/queue/sub/server.go
@@ -85,6 +85,7 @@ type server struct {
listeners map[bus.Topic][]bus.MessageListener
topicMap map[string]bus.Topic
chunkedSyncHandlers map[bus.Topic]queue.ChunkedSyncHandler
+ activeSessions map[string]*syncSession
log *logger.Logger
httpSrv *http.Server
tlsReloader *pkgtls.Reloader
@@ -100,6 +101,7 @@ type server struct {
maxRecvMsgSize run.Bytes
listenersLock sync.RWMutex
routeTableProviderMu sync.RWMutex
+ activeSessionsMu sync.Mutex
port uint32
httpPort uint32
maxChunkBufferSize uint32
@@ -119,6 +121,7 @@ func NewServerWithPorts(omr observability.MetricsRegistry,
flagNamePrefix string
listeners: make(map[bus.Topic][]bus.MessageListener),
topicMap: make(map[string]bus.Topic),
chunkedSyncHandlers:
make(map[bus.Topic]queue.ChunkedSyncHandler),
+ activeSessions: make(map[string]*syncSession),
omr: omr,
maxRecvMsgSize: defaultRecvSize,
flagNamePrefix: flagNamePrefix,
@@ -378,6 +381,39 @@ func (s *server) GracefulStop() {
t.Stop()
s.log.Info().Msg("stopped gracefully")
}
+
+ s.closeAllSessions()
+}
+
+// registerSession adds a session to the active sessions map.
+func (s *server) registerSession(id string, session *syncSession) {
+ s.activeSessionsMu.Lock()
+ s.activeSessions[id] = session
+ s.activeSessionsMu.Unlock()
+}
+
+// unregisterSession removes a session from the active sessions map.
+func (s *server) unregisterSession(id string) {
+ s.activeSessionsMu.Lock()
+ delete(s.activeSessions, id)
+ s.activeSessionsMu.Unlock()
+}
+
+// closeAllSessions closes the partCtx of every remaining active session.
+// It is called after the gRPC server has fully stopped as a safety net.
+func (s *server) closeAllSessions() {
+ s.activeSessionsMu.Lock()
+ sessions := s.activeSessions
+ s.activeSessions = make(map[string]*syncSession)
+ s.activeSessionsMu.Unlock()
+
+ for id, session := range sessions {
+ if session.partCtx != nil {
+ if closeErr := session.partCtx.Close(); closeErr != nil
{
+ s.log.Error().Err(closeErr).Str("session_id",
id).Msg("failed to close session partCtx during shutdown")
+ }
+ }
+ }
}
// RegisterChunkedSyncHandler implements queue.Server.
diff --git a/banyand/stream/tstable.go b/banyand/stream/tstable.go
index cbd4ce81b..c980a2e9c 100644
--- a/banyand/stream/tstable.go
+++ b/banyand/stream/tstable.go
@@ -333,6 +333,7 @@ func (tst *tsTable) mustAddMemPart(mp *memPart) {
case tst.introductions <- ind:
case <-tst.loopCloser.CloseNotify():
tst.addPendingDataCount(-int64(totalCount))
+ ind.memPart.decRef()
return
}
select {
diff --git a/banyand/stream/tstable_test.go b/banyand/stream/tstable_test.go
index ecdb2758a..ffbd31e35 100644
--- a/banyand/stream/tstable_test.go
+++ b/banyand/stream/tstable_test.go
@@ -29,11 +29,14 @@ import (
"github.com/stretchr/testify/require"
"github.com/apache/skywalking-banyandb/api/common"
+ "github.com/apache/skywalking-banyandb/banyand/protector"
"github.com/apache/skywalking-banyandb/pkg/convert"
"github.com/apache/skywalking-banyandb/pkg/fs"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
"github.com/apache/skywalking-banyandb/pkg/run"
"github.com/apache/skywalking-banyandb/pkg/test"
+ "github.com/apache/skywalking-banyandb/pkg/timestamp"
"github.com/apache/skywalking-banyandb/pkg/watcher"
)
@@ -250,6 +253,36 @@ func Test_tstIter(t *testing.T) {
})
}
+func Test_mustAddMemPart_closeNotifyReleasesMemPart(t *testing.T) {
+ tmpPath, defFn := test.Space(require.New(t))
+ defer defFn()
+ tst, err := newTSTable(
+ fs.NewLocalFileSystem(),
+ tmpPath,
+ common.Position{},
+ logger.GetLogger("test"),
+ timestamp.TimeRange{},
+ option{
+ flushTimeout: 0,
+ protector: protector.Nop{},
+ },
+ nil,
+ )
+ require.NoError(t, err)
+
+ mp := generateMemPart()
+ mp.mustInitFromElements(esTS1)
+ originalCount := mp.partMetadata.TotalCount
+ require.Greater(t, originalCount, uint64(0))
+
+ tst.Close()
+ tst.mustAddMemPart(mp)
+
+ // Verify the memPart was released by checking the count is reset to 0
+ // after mustAddMemPart returns (which handles the release via decRef)
+ require.Equal(t, uint64(0), mp.partMetadata.TotalCount)
+}
+
var esTS1 = &elements{
seriesIDs: []common.SeriesID{1, 2, 3},
timestamps: []int64{1, 1, 1},
diff --git a/banyand/stream/write_data.go b/banyand/stream/write_data.go
index e9bd52cc1..93c5dfe08 100644
--- a/banyand/stream/write_data.go
+++ b/banyand/stream/write_data.go
@@ -42,7 +42,9 @@ func (s *syncPartContext) NewPartType(_
*queue.ChunkedSyncPartContext) error {
}
func (s *syncPartContext) FinishSync() error {
- s.tsTable.mustAddMemPart(s.memPart)
+ mp := s.memPart
+ s.memPart = nil
+ s.tsTable.mustAddMemPart(mp)
return s.Close()
}
@@ -50,7 +52,12 @@ func (s *syncPartContext) Close() error {
s.writers.MustClose()
releaseWriters(s.writers)
s.writers = nil
- s.memPart = nil
+ if s.memPart != nil {
+ // syncPartContext owns the memPart directly without
partWrapper refcounting.
+ // It must release via releaseMemPart, not decRef, which is
used in mustAddMemPart.
+ releaseMemPart(s.memPart)
+ s.memPart = nil
+ }
s.tsTable = nil
return nil
}
diff --git a/banyand/trace/tstable_test.go b/banyand/trace/tstable_test.go
index d6a3cdf19..2b437aafe 100644
--- a/banyand/trace/tstable_test.go
+++ b/banyand/trace/tstable_test.go
@@ -28,12 +28,16 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
+ "github.com/apache/skywalking-banyandb/api/common"
+ "github.com/apache/skywalking-banyandb/banyand/protector"
"github.com/apache/skywalking-banyandb/pkg/convert"
"github.com/apache/skywalking-banyandb/pkg/fs"
+ "github.com/apache/skywalking-banyandb/pkg/logger"
pbv1 "github.com/apache/skywalking-banyandb/pkg/pb/v1"
"github.com/apache/skywalking-banyandb/pkg/query/model"
"github.com/apache/skywalking-banyandb/pkg/run"
"github.com/apache/skywalking-banyandb/pkg/test"
+ "github.com/apache/skywalking-banyandb/pkg/timestamp"
"github.com/apache/skywalking-banyandb/pkg/watcher"
)
@@ -235,6 +239,36 @@ var testSchemaTagTypes = map[string]pbv1.ValueType{
"intTag": pbv1.ValueTypeInt64,
}
+func Test_mustAddMemPart_closeNotifyReleasesMemPart(t *testing.T) {
+ tmpPath, defFn := test.Space(require.New(t))
+ defer defFn()
+ tst, err := newTSTable(
+ fs.NewLocalFileSystem(),
+ tmpPath,
+ common.Position{},
+ logger.GetLogger("test"),
+ timestamp.TimeRange{},
+ option{
+ flushTimeout: 0,
+ protector: protector.Nop{},
+ },
+ nil,
+ )
+ require.NoError(t, err)
+
+ mp := generateMemPart()
+ mp.mustInitFromTraces(tsTS1)
+ originalCount := mp.partMetadata.TotalCount
+ require.Greater(t, originalCount, uint64(0))
+
+ tst.Close()
+ tst.mustAddMemPart(mp, nil)
+
+ // Verify the memPart was released by checking the count is reset to 0
+ // after mustAddMemPart returns (which handles the release via decRef)
+ require.Equal(t, uint64(0), mp.partMetadata.TotalCount)
+}
+
var tsTS1 = &traces{
traceIDs: []string{"trace1", "trace2", "trace3"},
timestamps: []int64{1, 1, 1},
diff --git a/banyand/trace/write_data.go b/banyand/trace/write_data.go
index 6e44cbb7a..63c56c2df 100644
--- a/banyand/trace/write_data.go
+++ b/banyand/trace/write_data.go
@@ -87,11 +87,13 @@ func (s *syncPartContext) FinishSync() error {
}
if s.memPart != nil {
+ mp := s.memPart
+ s.memPart = nil
sidxPartContexts := make(map[string]*sidx.MemPart,
len(s.sidxPartContexts))
for _, sidxPartContext := range s.sidxPartContexts {
sidxPartContexts[sidxPartContext.Name()] =
sidxPartContext.GetMemPart()
}
- s.tsTable.mustAddMemPart(s.memPart, sidxPartContexts)
+ s.tsTable.mustAddMemPart(mp, sidxPartContexts)
}
return s.Close()
}
@@ -103,6 +105,9 @@ func (s *syncPartContext) Close() error {
s.writers = nil
}
if s.memPart != nil {
+ // syncPartContext owns the memPart directly without
partWrapper refcounting.
+ // It must release via releaseMemPart, not decRef, which is
used in mustAddMemPart.
+ releaseMemPart(s.memPart)
s.memPart = nil
}
if s.sidxPartContexts != nil {
diff --git a/pkg/pool/pool.go b/pkg/pool/pool.go
index 677cdfd6a..c2b275fd8 100644
--- a/pkg/pool/pool.go
+++ b/pkg/pool/pool.go
@@ -102,20 +102,19 @@ func (p *Synced[T]) Get() T {
result = v.(T)
}
- // Capture stack trace if tracking is enabled
- if stackTrackingEnabled.Load() {
- // Lazy initialize maps on first use
+ // Capture stack trace if tracking is enabled.
+ // Skip tracking when the pool returns nil because the caller will
create
+ // a new object whose pointer won't match the nil key in idMap,
+ // so Put() would never clean up the entry.
+ if v != nil && stackTrackingEnabled.Load() {
p.stacksMutex.Lock()
if p.stacks == nil {
p.stacks = make(map[uint64]string)
p.idMap = make(map[any]uint64)
}
-
- // Generate unique ID and capture stack trace
id := p.idCounter.Add(1)
buf := make([]byte, 4096)
n := runtime.Stack(buf, false)
-
p.idMap[any(result)] = id
p.stacks[id] = "Pool.Get() called:\n" + string(buf[:n])
p.stacksMutex.Unlock()
@@ -126,10 +125,10 @@ func (p *Synced[T]) Get() T {
// Put puts an object back to the pool.
func (p *Synced[T]) Put(v T) {
- p.Pool.Put(v)
- p.refs.Add(-1)
-
- // Remove the stack trace for this object if tracking is enabled
+ // Remove the stack trace BEFORE returning the object to the pool.
+ // Otherwise another goroutine's Get() can reuse the pointer and
+ // overwrite its idMap entry, causing this Put to delete the wrong
+ // stack and orphan the original one.
if stackTrackingEnabled.Load() {
p.stacksMutex.Lock()
if p.idMap != nil {
@@ -140,6 +139,8 @@ func (p *Synced[T]) Put(v T) {
}
p.stacksMutex.Unlock()
}
+ p.Pool.Put(v)
+ p.refs.Add(-1)
}
// RefsCount returns the reference count of the pool.