This is an automated email from the ASF dual-hosted git repository.

jinrongtong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-clients.git


The following commit(s) were added to refs/heads/master by this push:
     new 2306251e golang: when pushconsumer is shutdown, wait for messages that 
are already ready to be consumed (#1030)
2306251e is described below

commit 2306251e9c1875ef6ffed786f2e6bceca1d4c2d4
Author: guyinyou <[email protected]>
AuthorDate: Wed Jul 2 11:08:52 2025 +0800

    golang: when pushconsumer is shutdown, wait for messages that are already 
ready to be consumed (#1030)
    
    Co-authored-by: guyinyou <[email protected]>
---
 golang/push_consumer.go      | 99 ++++++++++++++++++++++++++++++++++++++++++--
 golang/simple_thread_pool.go | 41 +++++++++++++++---
 2 files changed, 131 insertions(+), 9 deletions(-)

diff --git a/golang/push_consumer.go b/golang/push_consumer.go
index f6a446a7..5833f17b 100644
--- a/golang/push_consumer.go
+++ b/golang/push_consumer.go
@@ -54,7 +54,6 @@ type defaultPushConsumer struct {
        cli *defaultClient
 
        groupName                    string
-       topicIndex                   atomic.Int32
        pcOpts                       pushConsumerOptions
        pcSettings                   *pushConsumerSettings
        awaitDuration                time.Duration
@@ -68,6 +67,9 @@ type defaultPushConsumer struct {
 
        consumptionOkQuantity    atomic.Int64
        consumptionErrorQuantity atomic.Int64
+
+       stopping                        atomic.Bool
+       inflightRequestCountInterceptor *defultInflightRequestCountInterceptor
 }
 
 func (pc *defaultPushConsumer) SetRequestTimeout(timeout time.Duration) {
@@ -341,9 +343,11 @@ var NewPushConsumer = func(config *Config, opts 
...PushConsumerOption) (PushCons
                awaitDuration:           pcOpts.awaitDuration,
                subscriptionExpressions: pcOpts.subscriptionExpressions,
 
-               subTopicRouteDataResultCache: &sync.Map{},
-               cacheAssignments:             &sync.Map{},
-               processQueueTable:            &sync.Map{},
+               subTopicRouteDataResultCache:    &sync.Map{},
+               cacheAssignments:                &sync.Map{},
+               processQueueTable:               &sync.Map{},
+               stopping:                        *atomic.NewBool(false),
+               inflightRequestCountInterceptor: 
NewDefultInflightRequestCountInterceptor(),
        }
        pc.cli.initTopics = make([]string, 0)
        pcOpts.subscriptionExpressions.Range(func(key, value interface{}) bool {
@@ -375,6 +379,7 @@ var NewPushConsumer = func(config *Config, opts 
...PushConsumerOption) (PushCons
        }
        pc.cli.settings = pc.pcSettings
        pc.cli.clientImpl = pc
+       pc.cli.registerMessageInterceptor(pc.inflightRequestCountInterceptor)
        return pc, nil
 }
 
@@ -406,6 +411,10 @@ func (pc *defaultPushConsumer) Start() error {
 }
 
 func (pc *defaultPushConsumer) scanAssignments() {
+       // When stopping in progress, return directly
+       if pc.stopping.Load() {
+               return
+       }
        pc.subscriptionExpressions.Range(func(key, value interface{}) bool {
                topic := key.(string)
                filterExpression := value.(*FilterExpression)
@@ -498,11 +507,59 @@ func (pc *defaultPushConsumer) dropProcessQueue(mqstr 
utils.MessageQueueStr) {
        }
 }
 
+/**
+ * PushConsumerImpl shutdown order
+ * 1. when begin shutdown, do not send any new receive request
+ * 2. cancel scanAssignmentsFuture, do not create new processQueue
+ * 3. waiting all inflight receive request finished or timeout
+ * 4. shutdown consumptionExecutor and waiting all message consumption finished
+ * 5. sleep 1s to ack message async
+ * 6. shutdown clientImpl
+ */
 func (pc *defaultPushConsumer) GracefulStop() error {
+       // step 1 and 2
+       pc.stopping.Store(true)
+
+       // step 3
+       pc.cli.log.Infof("Waiting for the inflight receive requests to be 
finished, clientId=%s", pc.cli.clientID)
+       pc.waitingReceiveRequestFinished()
+       pc.cli.log.Infof("Begin to Shutdown consumption executor, clientId=%s", 
pc.cli.clientID)
+
+       // step 4
        pc.consumerService.Shutdown()
+
+       // step 5
+       time.Sleep(time.Second)
+
+       // step 6
        return pc.cli.GracefulStop()
 }
 
+func (pc *defaultPushConsumer) waitingReceiveRequestFinished() error {
+       maxWaitingTime := pc.pcSettings.GetRequestTimeout() + 
pc.pcSettings.longPollingTimeout
+       endTime := time.Now().Add(maxWaitingTime)
+       defer func() {
+               if err := recover(); err != nil {
+                       pc.cli.log.Errorf("Unexpected exception while waiting 
for the inflight receive requests to be finished, "+
+                               "clientId=%s, err=%v", pc.cli.clientID, err)
+               }
+       }()
+
+       for {
+               inflightReceiveRequestCount := 
pc.inflightRequestCountInterceptor.getInflightReceiveRequestCount()
+               if inflightReceiveRequestCount <= 0 {
+                       pc.cli.log.Infof("All inflight receive requests have 
been finished, clientId=%s", pc.cli.clientID)
+                       break
+               } else if time.Now().After(endTime) {
+                       pc.cli.log.Warnf("Timeout waiting for all inflight 
receive requests to be finished, clientId=%s, "+
+                               "inflightReceiveRequestCount=%d", 
pc.cli.clientID, inflightReceiveRequestCount)
+                       break
+               }
+               time.Sleep(100 * time.Millisecond)
+       }
+       return nil
+}
+
 func (pc *defaultPushConsumer) getSubscriptionTopicRouteResult(ctx 
context.Context, topic string) (SubscriptionLoadBalancer, error) {
        item, ok := pc.subTopicRouteDataResultCache.Load(topic)
        if ok {
@@ -595,6 +652,10 @@ func (pc *defaultPushConsumer) 
forwardMessageToDeadLetterQueue0(ctx context.Cont
 }
 
 func (pc *defaultPushConsumer) isRunning() bool {
+       // graceful stop in pushConsumer
+       if pc.stopping.Load() {
+               return false
+       }
        return pc.cli.isRunning()
 }
 
@@ -619,3 +680,33 @@ func (pc *defaultPushConsumer) 
cacheMessageBytesThresholdPerQueue() int64 {
        }
        return int64(math.Max(1, 
float64(pc.pcOpts.maxCacheMessageSizeInBytes)/float64(size)))
 }
+
+type defultInflightRequestCountInterceptor struct {
+       inflightReceiveRequestCount atomic.Int64
+}
+
+var _ = MessageInterceptor(&defultInflightRequestCountInterceptor{})
+
+var NewDefultInflightRequestCountInterceptor = func() 
*defultInflightRequestCountInterceptor {
+       return &defultInflightRequestCountInterceptor{
+               inflightReceiveRequestCount: *atomic.NewInt64(0),
+       }
+}
+
+func (dirci *defultInflightRequestCountInterceptor) doBefore(messageHookPoints 
MessageHookPoints, messageCommons []*MessageCommon) error {
+       if messageHookPoints == MessageHookPoints_RECEIVE {
+               dirci.inflightReceiveRequestCount.Inc()
+       }
+       return nil
+}
+
+func (dirci *defultInflightRequestCountInterceptor) doAfter(messageHookPoints 
MessageHookPoints, messageCommons []*MessageCommon, duration time.Duration, 
status MessageHookPointsStatus) error {
+       if messageHookPoints == MessageHookPoints_RECEIVE {
+               dirci.inflightReceiveRequestCount.Dec()
+       }
+       return nil
+}
+
+func (dirci *defultInflightRequestCountInterceptor) 
getInflightReceiveRequestCount() int64 {
+       return dirci.inflightReceiveRequestCount.Load()
+}
diff --git a/golang/simple_thread_pool.go b/golang/simple_thread_pool.go
index d4509630..8f8d51d5 100644
--- a/golang/simple_thread_pool.go
+++ b/golang/simple_thread_pool.go
@@ -17,10 +17,19 @@
 
 package golang
 
+import (
+       "sync"
+
+       "go.uber.org/atomic"
+)
+
 type simpleThreadPool struct {
-       name     string
-       tasks    chan func()
-       shutdown chan any
+       name      string
+       tasks     chan func()
+       shutdown  chan any
+       waitGroup sync.WaitGroup
+       once      sync.Once
+       running   atomic.Bool
 }
 
 func NewSimpleThreadPool(poolName string, taskSize int, threadNum int) 
*simpleThreadPool {
@@ -28,14 +37,21 @@ func NewSimpleThreadPool(poolName string, taskSize int, 
threadNum int) *simpleTh
                name:     poolName,
                tasks:    make(chan func(), taskSize),
                shutdown: make(chan any),
+               running:  *atomic.NewBool(true),
        }
        for i := 0; i < threadNum; i++ {
+               r.waitGroup.Add(1)
                go func() {
+                       defer r.waitGroup.Done()
                        tp := r
                        for {
                                select {
                                case <-tp.shutdown:
                                        sugarBaseLogger.Infof("routine pool is 
shutdown, name=%s", tp.name)
+                                       // complete all remaining tasks
+                                       for t := range tp.tasks {
+                                               t()
+                                       }
                                        return
                                case t := <-tp.tasks:
                                        t()
@@ -47,9 +63,24 @@ func NewSimpleThreadPool(poolName string, taskSize int, 
threadNum int) *simpleTh
 }
 
 func (tp *simpleThreadPool) Submit(task func()) {
+       defer func() {
+               if r := recover(); r != nil {
+                       // the running flag may have concurrency security, here 
is a fallback
+                       sugarBaseLogger.Warnf("recover: simple thread pool 
[%s], task=%v, err=%v", tp.name, task, r)
+               }
+       }()
+       if !tp.running.Load() {
+               sugarBaseLogger.Warnf("simple thread pool [%s] is not running, 
task=%v", tp.name, task)
+               return
+       }
        tp.tasks <- task
 }
 func (tp *simpleThreadPool) Shutdown() {
-       tp.shutdown <- 0
-       close(tp.shutdown)
+       tp.running.Store(false)
+       tp.once.Do(func() {
+               close(tp.shutdown)
+               // do not accept other task
+               close(tp.tasks)
+               tp.waitGroup.Wait()
+       })
 }

Reply via email to