This is an automated email from the ASF dual-hosted git repository.
jinrongtong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-clients.git
The following commit(s) were added to refs/heads/master by this push:
new 2306251e golang: when pushconsumer is shutdown, wait for messages that
are already ready to be consumed (#1030)
2306251e is described below
commit 2306251e9c1875ef6ffed786f2e6bceca1d4c2d4
Author: guyinyou <[email protected]>
AuthorDate: Wed Jul 2 11:08:52 2025 +0800
golang: when pushconsumer is shutdown, wait for messages that are already
ready to be consumed (#1030)
Co-authored-by: guyinyou <[email protected]>
---
golang/push_consumer.go | 99 ++++++++++++++++++++++++++++++++++++++++++--
golang/simple_thread_pool.go | 41 +++++++++++++++---
2 files changed, 131 insertions(+), 9 deletions(-)
diff --git a/golang/push_consumer.go b/golang/push_consumer.go
index f6a446a7..5833f17b 100644
--- a/golang/push_consumer.go
+++ b/golang/push_consumer.go
@@ -54,7 +54,6 @@ type defaultPushConsumer struct {
cli *defaultClient
groupName string
- topicIndex atomic.Int32
pcOpts pushConsumerOptions
pcSettings *pushConsumerSettings
awaitDuration time.Duration
@@ -68,6 +67,9 @@ type defaultPushConsumer struct {
consumptionOkQuantity atomic.Int64
consumptionErrorQuantity atomic.Int64
+
+ stopping atomic.Bool
+ inflightRequestCountInterceptor *defultInflightRequestCountInterceptor
}
func (pc *defaultPushConsumer) SetRequestTimeout(timeout time.Duration) {
@@ -341,9 +343,11 @@ var NewPushConsumer = func(config *Config, opts
...PushConsumerOption) (PushCons
awaitDuration: pcOpts.awaitDuration,
subscriptionExpressions: pcOpts.subscriptionExpressions,
- subTopicRouteDataResultCache: &sync.Map{},
- cacheAssignments: &sync.Map{},
- processQueueTable: &sync.Map{},
+ subTopicRouteDataResultCache: &sync.Map{},
+ cacheAssignments: &sync.Map{},
+ processQueueTable: &sync.Map{},
+ stopping: *atomic.NewBool(false),
+ inflightRequestCountInterceptor:
NewDefultInflightRequestCountInterceptor(),
}
pc.cli.initTopics = make([]string, 0)
pcOpts.subscriptionExpressions.Range(func(key, value interface{}) bool {
@@ -375,6 +379,7 @@ var NewPushConsumer = func(config *Config, opts
...PushConsumerOption) (PushCons
}
pc.cli.settings = pc.pcSettings
pc.cli.clientImpl = pc
+ pc.cli.registerMessageInterceptor(pc.inflightRequestCountInterceptor)
return pc, nil
}
@@ -406,6 +411,10 @@ func (pc *defaultPushConsumer) Start() error {
}
func (pc *defaultPushConsumer) scanAssignments() {
+ // When stopping in progress, return directly
+ if pc.stopping.Load() {
+ return
+ }
pc.subscriptionExpressions.Range(func(key, value interface{}) bool {
topic := key.(string)
filterExpression := value.(*FilterExpression)
@@ -498,11 +507,59 @@ func (pc *defaultPushConsumer) dropProcessQueue(mqstr
utils.MessageQueueStr) {
}
}
+/**
+ * PushConsumerImpl shutdown order
+ * 1. when begin shutdown, do not send any new receive request
+ * 2. cancel scanAssignmentsFuture, do not create new processQueue
+ * 3. waiting all inflight receive request finished or timeout
+ * 4. shutdown consumptionExecutor and waiting all message consumption finished
+ * 5. sleep 1s to ack message async
+ * 6. shutdown clientImpl
+ */
func (pc *defaultPushConsumer) GracefulStop() error {
+ // step 1 and 2
+ pc.stopping.Store(true)
+
+ // step 3
+ pc.cli.log.Infof("Waiting for the inflight receive requests to be
finished, clientId=%s", pc.cli.clientID)
+ pc.waitingReceiveRequestFinished()
+ pc.cli.log.Infof("Begin to Shutdown consumption executor, clientId=%s",
pc.cli.clientID)
+
+ // step 4
pc.consumerService.Shutdown()
+
+ // step 5
+ time.Sleep(time.Second)
+
+ // step 6
return pc.cli.GracefulStop()
}
+func (pc *defaultPushConsumer) waitingReceiveRequestFinished() error {
+ maxWaitingTime := pc.pcSettings.GetRequestTimeout() +
pc.pcSettings.longPollingTimeout
+ endTime := time.Now().Add(maxWaitingTime)
+ defer func() {
+ if err := recover(); err != nil {
+ pc.cli.log.Errorf("Unexpected exception while waiting
for the inflight receive requests to be finished, "+
+ "clientId=%s, err=%v", pc.cli.clientID, err)
+ }
+ }()
+
+ for {
+ inflightReceiveRequestCount :=
pc.inflightRequestCountInterceptor.getInflightReceiveRequestCount()
+ if inflightReceiveRequestCount <= 0 {
+ pc.cli.log.Infof("All inflight receive requests have
been finished, clientId=%s", pc.cli.clientID)
+ break
+ } else if time.Now().After(endTime) {
+ pc.cli.log.Warnf("Timeout waiting for all inflight
receive requests to be finished, clientId=%s, "+
+ "inflightReceiveRequestCount=%d",
pc.cli.clientID, inflightReceiveRequestCount)
+ break
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ return nil
+}
+
func (pc *defaultPushConsumer) getSubscriptionTopicRouteResult(ctx
context.Context, topic string) (SubscriptionLoadBalancer, error) {
item, ok := pc.subTopicRouteDataResultCache.Load(topic)
if ok {
@@ -595,6 +652,10 @@ func (pc *defaultPushConsumer)
forwardMessageToDeadLetterQueue0(ctx context.Cont
}
func (pc *defaultPushConsumer) isRunning() bool {
+ // graceful stop in pushConsumer
+ if pc.stopping.Load() {
+ return false
+ }
return pc.cli.isRunning()
}
@@ -619,3 +680,33 @@ func (pc *defaultPushConsumer)
cacheMessageBytesThresholdPerQueue() int64 {
}
return int64(math.Max(1,
float64(pc.pcOpts.maxCacheMessageSizeInBytes)/float64(size)))
}
+
+type defultInflightRequestCountInterceptor struct {
+ inflightReceiveRequestCount atomic.Int64
+}
+
+var _ = MessageInterceptor(&defultInflightRequestCountInterceptor{})
+
+var NewDefultInflightRequestCountInterceptor = func()
*defultInflightRequestCountInterceptor {
+ return &defultInflightRequestCountInterceptor{
+ inflightReceiveRequestCount: *atomic.NewInt64(0),
+ }
+}
+
+func (dirci *defultInflightRequestCountInterceptor) doBefore(messageHookPoints
MessageHookPoints, messageCommons []*MessageCommon) error {
+ if messageHookPoints == MessageHookPoints_RECEIVE {
+ dirci.inflightReceiveRequestCount.Inc()
+ }
+ return nil
+}
+
+func (dirci *defultInflightRequestCountInterceptor) doAfter(messageHookPoints
MessageHookPoints, messageCommons []*MessageCommon, duration time.Duration,
status MessageHookPointsStatus) error {
+ if messageHookPoints == MessageHookPoints_RECEIVE {
+ dirci.inflightReceiveRequestCount.Dec()
+ }
+ return nil
+}
+
+func (dirci *defultInflightRequestCountInterceptor)
getInflightReceiveRequestCount() int64 {
+ return dirci.inflightReceiveRequestCount.Load()
+}
diff --git a/golang/simple_thread_pool.go b/golang/simple_thread_pool.go
index d4509630..8f8d51d5 100644
--- a/golang/simple_thread_pool.go
+++ b/golang/simple_thread_pool.go
@@ -17,10 +17,19 @@
package golang
+import (
+ "sync"
+
+ "go.uber.org/atomic"
+)
+
type simpleThreadPool struct {
- name string
- tasks chan func()
- shutdown chan any
+ name string
+ tasks chan func()
+ shutdown chan any
+ waitGroup sync.WaitGroup
+ once sync.Once
+ running atomic.Bool
}
func NewSimpleThreadPool(poolName string, taskSize int, threadNum int)
*simpleThreadPool {
@@ -28,14 +37,21 @@ func NewSimpleThreadPool(poolName string, taskSize int,
threadNum int) *simpleTh
name: poolName,
tasks: make(chan func(), taskSize),
shutdown: make(chan any),
+ running: *atomic.NewBool(true),
}
for i := 0; i < threadNum; i++ {
+ r.waitGroup.Add(1)
go func() {
+ defer r.waitGroup.Done()
tp := r
for {
select {
case <-tp.shutdown:
sugarBaseLogger.Infof("routine pool is
shutdown, name=%s", tp.name)
+ // complete all remaining tasks
+ for t := range tp.tasks {
+ t()
+ }
return
case t := <-tp.tasks:
t()
@@ -47,9 +63,24 @@ func NewSimpleThreadPool(poolName string, taskSize int,
threadNum int) *simpleTh
}
func (tp *simpleThreadPool) Submit(task func()) {
+ defer func() {
+ if r := recover(); r != nil {
+ // the running flag may have concurrency security, here
is a fallback
+ sugarBaseLogger.Warnf("recover: simple thread pool
[%s], task=%v, err=%v", tp.name, task, r)
+ }
+ }()
+ if !tp.running.Load() {
+ sugarBaseLogger.Warnf("simple thread pool [%s] is not running,
task=%v", tp.name, task)
+ return
+ }
tp.tasks <- task
}
func (tp *simpleThreadPool) Shutdown() {
- tp.shutdown <- 0
- close(tp.shutdown)
+ tp.running.Store(false)
+ tp.once.Do(func() {
+ close(tp.shutdown)
+ // do not accept other task
+ close(tp.tasks)
+ tp.waitGroup.Wait()
+ })
}