This is an automated email from the ASF dual-hosted git repository.
alexstocks pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
The following commit(s) were added to refs/heads/develop by this push:
new c356735e9 feat: graceful shutdown #1977 (#3235)
c356735e9 is described below
commit c356735e93e3982a82e02a3f91829ddbb859b82b
Author: Oxidaner <[email protected]>
AuthorDate: Sun Mar 29 21:13:09 2026 +0800
feat: graceful shutdown #1977 (#3235)
* 信号监听相关注释
* graceful_shutdown
* adjust fmt
* adjust fmt
* adjust fmt
* feat: graceful_shutdown
* adjust fmt
* Restore annotations and adjust fmt
* 1. tripleProtocol = tp.(*TripleProtocol) assert problem
2. Add sync.RWMutex protection; GetAllGracefulShutdownCallbacks returns
rather than a copy of the original map.
* 1.Eliminate bugs that might cause subtraction without addition
* I added an optional interface, ClosingInstanceRemover, and implemented
RemoveClosingInstance(instanceKey string) bool on RegistryDirectory
* feat: add closing tombstones for registry directory
* feat: register closing removers by service key
* feat: dispatch passive closing events to directory handlers
* feat: route closing connection errors through event handler
* feat: add health-watch active graceful notice
* refactor: isolate graceful shutdown notify sequencing
* refactor: harden graceful shutdown protocol teardown
* refactor: narrow graceful shutdown unregister phase
* feat: add lightweight active notice ack stats
* test: cover active health-watch closing flows
* Refine graceful shutdown timeout and concurrency handling
* Set default offline request window timeout
* chore: run fmt
* chore: add ASF header to graceful shutdown test
* test: fix lint issues in graceful shutdown coverage
* refactor: tighten graceful shutdown compatibility
* refactor: remove duplicate grpc start flag assignment
---
cluster/directory/directory.go | 6 +
common/constant/key.go | 3 +
common/extension/graceful_shutdown.go | 80 ++++-
common/extension/graceful_shutdown_test.go | 62 ++++
common/extension/otel_trace.go | 2 +-
compat.go | 2 +
config/graceful_shutdown_config.go | 25 +-
config/root_config.go | 4 +-
filter/graceful_shutdown/compat.go | 1 +
filter/graceful_shutdown/consumer_filter.go | 165 ++++++++++-
filter/graceful_shutdown/consumer_filter_test.go | 155 ++++++++++
filter/graceful_shutdown/provider_filter.go | 53 +++-
filter/graceful_shutdown/provider_filter_test.go | 33 +++
global/shutdown_config.go | 18 +-
graceful_shutdown/closing_ack.go | 98 ++++++
.../closing_handler.go | 37 ++-
graceful_shutdown/closing_registry.go | 153 ++++++++++
graceful_shutdown/closing_registry_test.go | 128 ++++++++
graceful_shutdown/options.go | 6 +
graceful_shutdown/options_test.go | 10 +-
graceful_shutdown/shutdown.go | 190 +++++++++++-
graceful_shutdown/shutdown_test.go | 329 +++++++++++++++++++++
internal/internal.go | 3 +
protocol/base/base_invoker.go | 10 +
protocol/base/base_protocol.go | 7 +
protocol/grpc/active_notify_test.go | 139 +++++++++
protocol/grpc/grpc_invoker.go | 63 +++-
protocol/grpc/grpc_protocol.go | 54 +++-
protocol/grpc/grpc_protocol_test.go | 58 ++++
protocol/grpc/server.go | 38 ++-
protocol/triple/active_notify_test.go | 154 ++++++++++
protocol/triple/client.go | 29 +-
protocol/triple/health/healthServer.go | 1 +
protocol/triple/triple.go | 45 ++-
protocol/triple/triple_invoker.go | 64 +++-
protocol/triple/triple_test.go | 78 +++++
registry/directory/directory.go | 129 +++++++-
registry/directory/directory_test.go | 104 +++++++
registry/protocol/protocol.go | 13 +
39 files changed, 2461 insertions(+), 88 deletions(-)
diff --git a/cluster/directory/directory.go b/cluster/directory/directory.go
index 1f84c6a5e..ce2f48861 100644
--- a/cluster/directory/directory.go
+++ b/cluster/directory/directory.go
@@ -35,3 +35,9 @@ type Directory interface {
// Subscribe listen to registry instances
Subscribe(url *common.URL) error
}
+
+// ClosingInstanceRemover is an optional directory capability used by graceful
shutdown.
+// Implementations remove a single cached service instance identified by
instanceKey.
+type ClosingInstanceRemover interface {
+ RemoveClosingInstance(instanceKey string) bool
+}
diff --git a/common/constant/key.go b/common/constant/key.go
index b6d32a88c..98aab2c65 100644
--- a/common/constant/key.go
+++ b/common/constant/key.go
@@ -109,6 +109,9 @@ const (
GracefulShutdownProviderFilterKey = "pshutdown"
GracefulShutdownConsumerFilterKey = "cshutdown"
GracefulShutdownFilterShutdownConfig =
"GracefulShutdownFilterShutdownConfig"
+ GracefulShutdownClosingKey = "closing"
+ HystrixConsumerFilterKey = "hystrix_consumer"
+ HystrixProviderFilterKey = "hystrix_provider"
MetricsFilterKey = "metrics"
SeataFilterKey = "seata"
SentinelProviderFilterKey = "sentinel-provider"
diff --git a/common/extension/graceful_shutdown.go
b/common/extension/graceful_shutdown.go
index 1de3e583f..253b43c6f 100644
--- a/common/extension/graceful_shutdown.go
+++ b/common/extension/graceful_shutdown.go
@@ -19,11 +19,26 @@ package extension
import (
"container/list"
+ "context"
"sync"
)
-var customShutdownCallbacks = list.New()
-var customShutdownCallbacksLock sync.Mutex
+import (
+ "github.com/dubbogo/gost/log/logger"
+)
+
+// GracefulShutdownCallback is the callback for graceful shutdown
+// name: protocol name such as "grpc", "tri", "dubbo"
+// returns error if notify failed
+type GracefulShutdownCallback func(ctx context.Context) error
+
+var (
+ customShutdownCallbacks = list.New()
+ customShutdownCallbacksLock sync.RWMutex
+ customShutdownCallbacksMu = &customShutdownCallbacksLock
+ gracefulShutdownCallbacksMu sync.RWMutex
+ gracefulShutdownCallbacks = make(map[string]GracefulShutdownCallback)
+)
/**
* AddCustomShutdownCallback
@@ -45,21 +60,62 @@ var customShutdownCallbacksLock sync.Mutex
* the benefit of that mechanism is low.
* And it may introduce much complication for another users.
*/
+// AddCustomShutdownCallback adds custom shutdown callback
func AddCustomShutdownCallback(callback func()) {
- customShutdownCallbacksLock.Lock()
- defer customShutdownCallbacksLock.Unlock()
-
+ customShutdownCallbacksMu.Lock()
+ defer customShutdownCallbacksMu.Unlock()
customShutdownCallbacks.PushBack(callback)
}
-// GetAllCustomShutdownCallbacks gets all custom shutdown callbacks
+// GetAllCustomShutdownCallbacks returns all custom shutdown callbacks
func GetAllCustomShutdownCallbacks() *list.List {
- customShutdownCallbacksLock.Lock()
- defer customShutdownCallbacksLock.Unlock()
+ customShutdownCallbacksMu.RLock()
+ defer customShutdownCallbacksMu.RUnlock()
+
+ callbacks := list.New()
+ for callback := customShutdownCallbacks.Front(); callback != nil;
callback = callback.Next() {
+ callbacks.PushBack(callback.Value)
+ }
+ return callbacks
+}
+
+// RegisterGracefulShutdownCallback registers a protocol-level graceful
shutdown callback.
+func RegisterGracefulShutdownCallback(name string, f GracefulShutdownCallback)
{
+ gracefulShutdownCallbacksMu.Lock()
+ defer gracefulShutdownCallbacksMu.Unlock()
- ret := list.New()
- for e := customShutdownCallbacks.Front(); e != nil; e = e.Next() {
- ret.PushBack(e.Value)
+ if _, exists := gracefulShutdownCallbacks[name]; exists {
+ logger.Warnf("graceful shutdown callback %q already registered,
duplicate registration ignored", name)
+ return
}
- return ret
+
+ gracefulShutdownCallbacks[name] = f
+}
+
+// LookupGracefulShutdownCallback returns a protocol graceful shutdown
callback by name.
+func LookupGracefulShutdownCallback(name string) (GracefulShutdownCallback,
bool) {
+ gracefulShutdownCallbacksMu.RLock()
+ defer gracefulShutdownCallbacksMu.RUnlock()
+ f, ok := gracefulShutdownCallbacks[name]
+ return f, ok
+}
+
+// UnregisterGracefulShutdownCallback removes a protocol graceful shutdown
callback by name.
+func UnregisterGracefulShutdownCallback(name string) {
+ gracefulShutdownCallbacksMu.Lock()
+ defer gracefulShutdownCallbacksMu.Unlock()
+ delete(gracefulShutdownCallbacks, name)
+}
+
+// GracefulShutdownCallbacks returns a snapshot of all protocol graceful
shutdown callbacks.
+func GracefulShutdownCallbacks() map[string]GracefulShutdownCallback {
+ gracefulShutdownCallbacksMu.RLock()
+ defer gracefulShutdownCallbacksMu.RUnlock()
+
+ callbacks := make(map[string]GracefulShutdownCallback,
len(gracefulShutdownCallbacks))
+ for name, callback := range gracefulShutdownCallbacks {
+ callbacks[name] = callback
+ }
+
+ return callbacks
}
diff --git a/common/extension/graceful_shutdown_test.go
b/common/extension/graceful_shutdown_test.go
new file mode 100644
index 000000000..311530936
--- /dev/null
+++ b/common/extension/graceful_shutdown_test.go
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package extension
+
+import (
+ "container/list"
+ "context"
+ "testing"
+)
+
+func TestGracefulShutdownCallbacksReturnsSnapshot(t *testing.T) {
+ t.Cleanup(func() {
+ UnregisterGracefulShutdownCallback("snapshot-test")
+ })
+
+ RegisterGracefulShutdownCallback("snapshot-test", func(context.Context)
error {
+ return nil
+ })
+
+ callbacks := GracefulShutdownCallbacks()
+ delete(callbacks, "snapshot-test")
+
+ if _, ok := LookupGracefulShutdownCallback("snapshot-test"); !ok {
+ t.Fatal("expected stored callback to remain after mutating
returned snapshot")
+ }
+}
+
+func TestGetAllCustomShutdownCallbacksReturnsSnapshot(t *testing.T) {
+ customShutdownCallbacksMu.Lock()
+ original := customShutdownCallbacks
+ customShutdownCallbacks = list.New()
+ customShutdownCallbacksMu.Unlock()
+
+ t.Cleanup(func() {
+ customShutdownCallbacksMu.Lock()
+ customShutdownCallbacks = original
+ customShutdownCallbacksMu.Unlock()
+ })
+
+ AddCustomShutdownCallback(func() {})
+ callbacks := GetAllCustomShutdownCallbacks()
+ callbacks.PushBack(func() {})
+
+ if got := GetAllCustomShutdownCallbacks().Len(); got != 1 {
+ t.Fatalf("expected custom callback snapshot mutation not to
affect stored callbacks, got %d", got)
+ }
+}
diff --git a/common/extension/otel_trace.go b/common/extension/otel_trace.go
index e84ed3981..447ff8a8f 100644
--- a/common/extension/otel_trace.go
+++ b/common/extension/otel_trace.go
@@ -48,7 +48,7 @@ func GetTraceShutdownCallback() func() {
for name, createFunc := range traceExporterMap.Snapshot() {
if exporter, err := createFunc(nil); err == nil {
if err :=
exporter.GetTracerProvider().Shutdown(context.Background()); err != nil {
- logger.Errorf("Graceful shutdown ---
Failed to shutdown trace provider %s, error: %s", name, err.Error())
+ logger.Errorf("Graceful shutdown ---
Failed to shutdown trace provider %s, error --- %s", name, err.Error())
} else {
logger.Infof("Graceful shutdown ---
Tracer provider of %s", name)
}
diff --git a/compat.go b/compat.go
index 5af9d3913..013e54f78 100644
--- a/compat.go
+++ b/compat.go
@@ -447,6 +447,7 @@ func compatShutdownConfig(c *global.ShutdownConfig)
*config.ShutdownConfig {
cfg := &config.ShutdownConfig{
Timeout: c.Timeout,
StepTimeout: c.StepTimeout,
+ NotifyTimeout: c.NotifyTimeout,
ConsumerUpdateWaitTime: c.ConsumerUpdateWaitTime,
RejectRequestHandler: c.RejectRequestHandler,
InternalSignal: c.InternalSignal,
@@ -1043,6 +1044,7 @@ func compatGlobalShutdownConfig(c *config.ShutdownConfig)
*global.ShutdownConfig
cfg := &global.ShutdownConfig{
Timeout: c.Timeout,
StepTimeout: c.StepTimeout,
+ NotifyTimeout: c.NotifyTimeout,
ConsumerUpdateWaitTime: c.ConsumerUpdateWaitTime,
RejectRequestHandler: c.RejectRequestHandler,
InternalSignal: c.InternalSignal,
diff --git a/config/graceful_shutdown_config.go
b/config/graceful_shutdown_config.go
index 044087be7..a905ab6a9 100644
--- a/config/graceful_shutdown_config.go
+++ b/config/graceful_shutdown_config.go
@@ -36,6 +36,7 @@ import (
const (
defaultTimeout = 60 * time.Second
defaultStepTimeout = 3 * time.Second
+ defaultNotifyTimeout = 5 * time.Second
defaultConsumerUpdateWaitTime = 3 * time.Second
defaultOfflineRequestWindowTimeout = 3 * time.Second
)
@@ -58,6 +59,13 @@ type ShutdownConfig struct {
*/
StepTimeout string `default:"3s" yaml:"step-timeout"
json:"step.timeout,omitempty" property:"step.timeout"`
+ /*
+ * NotifyTimeout means the timeout budget for actively notifying
long-connection consumers
+ * during graceful shutdown. It only controls the notify step and
should not be coupled to
+ * request draining timeouts.
+ */
+ NotifyTimeout string `default:"5s" yaml:"notify-timeout"
json:"notify.timeout,omitempty" property:"notify.timeout"`
+
/*
* ConsumerUpdateWaitTime means when provider is shutting down, after
the unregister, time to wait for client to
* update invokers. During this time, incoming invocation can be
treated normally.
@@ -68,7 +76,7 @@ type ShutdownConfig struct {
// internal listen kill signal,the default is true.
InternalSignal *bool `default:"true" yaml:"internal-signal"
json:"internal.signal,omitempty" property:"internal.signal"`
// offline request window length
- OfflineRequestWindowTimeout string
`yaml:"offline-request-window-timeout"
json:"offlineRequestWindowTimeout,omitempty"
property:"offlineRequestWindowTimeout"`
+ OfflineRequestWindowTimeout string `default:"3s"
yaml:"offline-request-window-timeout"
json:"offlineRequestWindowTimeout,omitempty"
property:"offlineRequestWindowTimeout"`
// true -> new request will be rejected.
RejectRequest atomic.Bool
// active invocation
@@ -104,6 +112,16 @@ func (config *ShutdownConfig) GetStepTimeout()
time.Duration {
return result
}
+func (config *ShutdownConfig) GetNotifyTimeout() time.Duration {
+ result, err := time.ParseDuration(config.NotifyTimeout)
+ if err != nil {
+ logger.Errorf("The NotifyTimeout configuration is invalid: %s,
and we will use the default value: %s, err: %v",
+ config.NotifyTimeout, defaultNotifyTimeout.String(),
err)
+ return defaultNotifyTimeout
+ }
+ return result
+}
+
func (config *ShutdownConfig) GetOfflineRequestWindowTimeout() time.Duration {
result, err := time.ParseDuration(config.OfflineRequestWindowTimeout)
if err != nil {
@@ -153,6 +171,11 @@ func (scb *ShutdownConfigBuilder)
SetStepTimeout(stepTimeout string) *ShutdownCo
return scb
}
+func (scb *ShutdownConfigBuilder) SetNotifyTimeout(notifyTimeout string)
*ShutdownConfigBuilder {
+ scb.shutdownConfig.NotifyTimeout = notifyTimeout
+ return scb
+}
+
func (scb *ShutdownConfigBuilder) SetRejectRequestHandler(rejectRequestHandler
string) *ShutdownConfigBuilder {
scb.shutdownConfig.RejectRequestHandler = rejectRequestHandler
return scb
diff --git a/config/root_config.go b/config/root_config.go
index 18934826e..edbe7317b 100644
--- a/config/root_config.go
+++ b/config/root_config.go
@@ -135,7 +135,6 @@ func (rc *RootConfig) Init() error {
return err
}
}
-
if err := rc.Application.Init(); err != nil {
return err
}
@@ -175,6 +174,7 @@ func (rc *RootConfig) Init() error {
if err := rc.MetadataReport.Init(rc); err != nil {
return err
}
+
if err := rc.Otel.Init(rc.Application); err != nil {
return err
}
@@ -189,6 +189,7 @@ func (rc *RootConfig) Init() error {
if err := initRouterConfig(rc); err != nil {
return err
}
+
// provider、consumer must last init
if err := rc.Provider.Init(rc); err != nil {
return err
@@ -208,6 +209,7 @@ func (rc *RootConfig) Init() error {
func (rc *RootConfig) Start() {
startOnce.Do(func() {
gracefulShutdownInit()
+
rc.Consumer.Load()
rc.Provider.Load()
if err := initMetadata(rc); err != nil {
diff --git a/filter/graceful_shutdown/compat.go
b/filter/graceful_shutdown/compat.go
index ef9dc16e0..0bdffbbe2 100644
--- a/filter/graceful_shutdown/compat.go
+++ b/filter/graceful_shutdown/compat.go
@@ -33,6 +33,7 @@ func compatGlobalShutdownConfig(c *config.ShutdownConfig)
*global.ShutdownConfig
cfg := &global.ShutdownConfig{
Timeout: c.Timeout,
StepTimeout: c.StepTimeout,
+ NotifyTimeout: c.NotifyTimeout,
ConsumerUpdateWaitTime: c.ConsumerUpdateWaitTime,
RejectRequestHandler: c.RejectRequestHandler,
InternalSignal: c.InternalSignal,
diff --git a/filter/graceful_shutdown/consumer_filter.go
b/filter/graceful_shutdown/consumer_filter.go
index de5ca3148..e0dfce8a6 100644
--- a/filter/graceful_shutdown/consumer_filter.go
+++ b/filter/graceful_shutdown/consumer_filter.go
@@ -19,11 +19,18 @@ package graceful_shutdown
import (
"context"
+ "errors"
+ "strings"
"sync"
+ "time"
)
import (
"github.com/dubbogo/gost/log/logger"
+
+ "google.golang.org/grpc/codes"
+
+ "google.golang.org/grpc/status"
)
import (
@@ -32,6 +39,7 @@ import (
"dubbo.apache.org/dubbo-go/v3/config"
"dubbo.apache.org/dubbo-go/v3/filter"
"dubbo.apache.org/dubbo-go/v3/global"
+ gracefulshutdown "dubbo.apache.org/dubbo-go/v3/graceful_shutdown"
"dubbo.apache.org/dubbo-go/v3/protocol/base"
"dubbo.apache.org/dubbo-go/v3/protocol/result"
)
@@ -50,27 +58,61 @@ func init() {
}
type consumerGracefulShutdownFilter struct {
- shutdownConfig *global.ShutdownConfig
+ shutdownConfig *global.ShutdownConfig
+ closingEventHandler gracefulshutdown.ClosingEventHandler
+ closingInvokers sync.Map // map[string]time.Time (url key -> expire
time)
}
+const consumerCountMarkedKey = "dubbo-go-graceful-shutdown-consumer-counted"
+
func newConsumerGracefulShutdownFilter() filter.Filter {
if csf == nil {
csfOnce.Do(func() {
- csf = &consumerGracefulShutdownFilter{}
+ csf = &consumerGracefulShutdownFilter{
+ closingEventHandler:
gracefulshutdown.DefaultClosingEventHandler(),
+ }
})
}
return csf
}
-// Invoke adds the requests count and block the new requests if application is
closing
+// Invoke adds the requests count and checks if invoker is closing
func (f *consumerGracefulShutdownFilter) Invoke(ctx context.Context, invoker
base.Invoker, invocation base.Invocation) result.Result {
- f.shutdownConfig.ConsumerActiveCount.Inc()
- return invoker.Invoke(ctx, invocation)
+ // check if invoker is closing
+ if f.isClosingInvoker(invoker) {
+ logger.Warnf("Graceful shutdown --- Skipping closing invoker
--- %s", invoker.GetURL().String())
+ return &result.RPCResult{Err: errors.New("provider is closing")}
+ }
+
+ if f.shutdownConfig != nil {
+ f.shutdownConfig.ConsumerActiveCount.Inc()
+ }
+
+ res := invoker.Invoke(ctx, invocation)
+ if f.shutdownConfig == nil {
+ return res
+ }
+
+ return markCountedResult(res)
}
// OnResponse reduces the number of active processes then return the process
result
func (f *consumerGracefulShutdownFilter) OnResponse(ctx context.Context,
result result.Result, invoker base.Invoker, invocation base.Invocation)
result.Result {
- f.shutdownConfig.ConsumerActiveCount.Dec()
+ if f.shutdownConfig != nil && shouldDecrementConsumerActive(result) {
+ f.shutdownConfig.ConsumerActiveCount.Dec()
+ }
+
+ // check closing flag in response
+ if f.isClosingResponse(result) {
+ f.markClosingInvoker(invoker)
+ f.handleClosingEvent(invoker, "passive-attachment")
+ }
+
+ // handle request error
+ if result.Error() != nil {
+ f.handleRequestError(invoker, result.Error())
+ }
+
return result
}
@@ -91,3 +133,114 @@ func (f *consumerGracefulShutdownFilter) Set(name string,
conf any) {
// do nothing
}
}
+
+// isClosingInvoker checks if invoker is in closing list
+func (f *consumerGracefulShutdownFilter) isClosingInvoker(invoker
base.Invoker) bool {
+ key := invoker.GetURL().String()
+ if expireTime, ok := f.closingInvokers.Load(key); ok {
+ if time.Now().Before(expireTime.(time.Time)) {
+ return true
+ }
+ f.closingInvokers.Delete(key)
+ if setter, ok := invoker.(base.AvailabilitySetter); ok {
+ setter.SetAvailable(true)
+ logger.Infof("Graceful shutdown --- Recovered invoker
availability after closing TTL --- %s", key)
+ }
+ }
+ return false
+}
+
+// isClosingResponse checks if response contains closing flag
+func (f *consumerGracefulShutdownFilter) isClosingResponse(result
result.Result) bool {
+ if result != nil && result.Attachments() != nil {
+ if v, ok :=
result.Attachments()[constant.GracefulShutdownClosingKey]; ok {
+ if v == "true" {
+ return true
+ }
+ }
+ }
+ return false
+}
+
+// markClosingInvoker marks invoker as closing and sets available=false
+func (f *consumerGracefulShutdownFilter) markClosingInvoker(invoker
base.Invoker) {
+ key := invoker.GetURL().String()
+ expireTime := time.Now().Add(f.getClosingInvokerExpireTime())
+ f.closingInvokers.Store(key, expireTime)
+
+ logger.Infof("Graceful shutdown --- Marked invoker as closing --- %s,
will expire at %v, IsAvailable=%v",
+ key, expireTime, invoker.IsAvailable())
+
+ if setter, ok := invoker.(base.AvailabilitySetter); ok {
+ setter.SetAvailable(false)
+ logger.Infof("Graceful shutdown --- Set invoker unavailable ---
%s, IsAvailable now=%v",
+ key, invoker.IsAvailable())
+ }
+}
+
+func (f *consumerGracefulShutdownFilter) handleClosingEvent(invoker
base.Invoker, source string) {
+ if f.closingEventHandler == nil || invoker == nil || invoker.GetURL()
== nil {
+ return
+ }
+
+ f.closingEventHandler.HandleClosingEvent(gracefulshutdown.ClosingEvent{
+ Source: source,
+ InstanceKey: invoker.GetURL().GetCacheInvokerMapKey(),
+ ServiceKey: invoker.GetURL().ServiceKey(),
+ Address: invoker.GetURL().Location,
+ })
+}
+
+func (f *consumerGracefulShutdownFilter) getClosingInvokerExpireTime()
time.Duration {
+ if f.shutdownConfig != nil && f.shutdownConfig.ClosingInvokerExpireTime
!= "" {
+ if duration, err :=
time.ParseDuration(f.shutdownConfig.ClosingInvokerExpireTime); err == nil &&
duration > 0 {
+ return duration
+ }
+ }
+ return 30 * time.Second
+}
+
+// handleRequestError handles request errors and marks invoker as unavailable
for connection errors
+func (f *consumerGracefulShutdownFilter) handleRequestError(invoker
base.Invoker, err error) {
+ if err == nil {
+ return
+ }
+
+ if isClosingError(err) {
+ f.markClosingInvoker(invoker)
+ f.handleClosingEvent(invoker, "connection-closing-error")
+ }
+}
+
+func isClosingError(err error) bool {
+ if errors.Is(err, base.ErrClientClosed) || errors.Is(err,
base.ErrDestroyedInvoker) {
+ return true
+ }
+
+ if grpcStatus, ok := status.FromError(err); ok {
+ switch grpcStatus.Code() {
+ case codes.Unavailable, codes.Canceled:
+ return true
+ }
+ }
+
+ errMsg := strings.ToLower(err.Error())
+ return strings.Contains(errMsg, "transport is closing") ||
+ strings.Contains(errMsg, "client connection is closing")
+}
+
+func markCountedResult(res result.Result) result.Result {
+ if res == nil {
+ res = &result.RPCResult{}
+ }
+ res.AddAttachment(consumerCountMarkedKey, true)
+ return res
+}
+
+func shouldDecrementConsumerActive(res result.Result) bool {
+ if res == nil {
+ return false
+ }
+ marked, ok := res.Attachment(consumerCountMarkedKey, false).(bool)
+ return ok && marked
+}
diff --git a/filter/graceful_shutdown/consumer_filter_test.go
b/filter/graceful_shutdown/consumer_filter_test.go
index a1e2b9457..910d02956 100644
--- a/filter/graceful_shutdown/consumer_filter_test.go
+++ b/filter/graceful_shutdown/consumer_filter_test.go
@@ -19,12 +19,19 @@ package graceful_shutdown
import (
"context"
+ "errors"
"net/url"
"testing"
+ "time"
)
import (
"github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ "google.golang.org/grpc/codes"
+
+ "google.golang.org/grpc/status"
)
import (
@@ -34,8 +41,32 @@ import (
"dubbo.apache.org/dubbo-go/v3/graceful_shutdown"
"dubbo.apache.org/dubbo-go/v3/protocol/base"
"dubbo.apache.org/dubbo-go/v3/protocol/invocation"
+ "dubbo.apache.org/dubbo-go/v3/protocol/result"
)
+type testEmbeddedInvoker struct {
+ base.BaseInvoker
+}
+
+type testClosingEventHandler struct {
+ events []graceful_shutdown.ClosingEvent
+}
+
+func (h *testClosingEventHandler) HandleClosingEvent(event
graceful_shutdown.ClosingEvent) bool {
+ h.events = append(h.events, event)
+ return true
+}
+
+func newTestEmbeddedInvoker(rawURL *common.URL) *testEmbeddedInvoker {
+ return &testEmbeddedInvoker{
+ BaseInvoker: *base.NewBaseInvoker(rawURL),
+ }
+}
+
+func (i *testEmbeddedInvoker) Invoke(ctx context.Context, invocation
base.Invocation) result.Result {
+ return &result.RPCResult{}
+}
+
func TestConsumerFilterInvokeWithGlobalPackage(t *testing.T) {
var (
baseUrl =
common.NewURLWithOptions(common.WithParams(url.Values{}))
@@ -54,3 +85,127 @@ func TestConsumerFilterInvokeWithGlobalPackage(t
*testing.T) {
assert.NotNil(t, result)
assert.NoError(t, result.Error())
}
+
+func TestIsClosingError(t *testing.T) {
+ assert.True(t, isClosingError(base.ErrClientClosed))
+ assert.True(t, isClosingError(status.Error(codes.Unavailable, "server
shutting down")))
+ assert.True(t, isClosingError(status.Error(codes.Canceled, "request
canceled during shutdown")))
+ assert.True(t, isClosingError(errors.New("rpc error: code = Unavailable
desc = transport is closing")))
+ assert.False(t, isClosingError(errors.New("EOF")))
+ assert.False(t, isClosingError(errors.New("read tcp: connection reset
by peer")))
+}
+
+func TestMarkClosingInvokerSetsEmbeddedInvokerUnavailable(t *testing.T) {
+ filter := &consumerGracefulShutdownFilter{
+ shutdownConfig: graceful_shutdown.NewOptions().Shutdown,
+ }
+ invoker :=
newTestEmbeddedInvoker(common.NewURLWithOptions(common.WithParams(url.Values{})))
+
+ assert.True(t, invoker.IsAvailable())
+
+ filter.markClosingInvoker(invoker)
+
+ assert.False(t, invoker.IsAvailable())
+ expireTime, ok := filter.closingInvokers.Load(invoker.GetURL().String())
+ assert.True(t, ok)
+ assert.True(t, expireTime.(time.Time).After(time.Now()))
+}
+
+func TestConsumerFilterDoesNotDecrementWithoutIncrement(t *testing.T) {
+ filter := &consumerGracefulShutdownFilter{
+ shutdownConfig: graceful_shutdown.NewOptions().Shutdown,
+ }
+ invoker :=
newTestEmbeddedInvoker(common.NewURLWithOptions(common.WithParams(url.Values{})))
+ rpcInvocation := invocation.NewRPCInvocation("GetUser", []any{"OK"},
make(map[string]any))
+
+ filter.markClosingInvoker(invoker)
+
+ res := filter.Invoke(context.Background(), invoker, rpcInvocation)
+ require.Error(t, res.Error())
+ assert.Equal(t, "provider is closing", res.Error().Error())
+
+ filter.OnResponse(context.Background(), res, invoker, rpcInvocation)
+ assert.Equal(t, int32(0),
filter.shutdownConfig.ConsumerActiveCount.Load())
+}
+
+func TestClosingInvokerExpiryRestoresAvailability(t *testing.T) {
+ opt := graceful_shutdown.NewOptions()
+ opt.Shutdown.ClosingInvokerExpireTime = "20ms"
+
+ filter := &consumerGracefulShutdownFilter{
+ shutdownConfig: opt.Shutdown,
+ }
+ invoker :=
newTestEmbeddedInvoker(common.NewURLWithOptions(common.WithParams(url.Values{})))
+
+ filter.markClosingInvoker(invoker)
+ assert.False(t, invoker.IsAvailable())
+ assert.True(t, filter.isClosingInvoker(invoker))
+
+ time.Sleep(40 * time.Millisecond)
+ assert.False(t, filter.isClosingInvoker(invoker))
+ assert.True(t, invoker.IsAvailable())
+}
+
+func TestHandleRequestErrorDoesNotMarkNonClosingErrors(t *testing.T) {
+ filter := &consumerGracefulShutdownFilter{
+ shutdownConfig: graceful_shutdown.NewOptions().Shutdown,
+ }
+ invoker :=
newTestEmbeddedInvoker(common.NewURLWithOptions(common.WithParams(url.Values{})))
+
+ filter.handleRequestError(invoker, errors.New("EOF"))
+ _, ok := filter.closingInvokers.Load(invoker.GetURL().String())
+ assert.False(t, ok)
+ assert.True(t, invoker.IsAvailable())
+
+ filter.handleRequestError(invoker, errors.New("read tcp: connection
reset by peer"))
+ _, ok = filter.closingInvokers.Load(invoker.GetURL().String())
+ assert.False(t, ok)
+ assert.True(t, invoker.IsAvailable())
+}
+
+func TestClosingResponseDispatchesClosingEvent(t *testing.T) {
+ handler := &testClosingEventHandler{}
+ filter := &consumerGracefulShutdownFilter{
+ shutdownConfig: graceful_shutdown.NewOptions().Shutdown,
+ closingEventHandler: handler,
+ }
+ invokerURL, _ := common.NewURL(
+ "dubbo://127.0.0.1:20000/org.apache.dubbo-go.mockService",
+ common.WithParamsValue(constant.GroupKey, "group"),
+ common.WithParamsValue(constant.VersionKey, "1.0.0"),
+ )
+ invoker := newTestEmbeddedInvoker(invokerURL)
+ res := &result.RPCResult{}
+ res.AddAttachment(constant.GracefulShutdownClosingKey, "true")
+
+ filter.OnResponse(context.Background(), res, invoker,
invocation.NewRPCInvocation("GetUser", []any{"OK"}, map[string]any{}))
+
+ if assert.Len(t, handler.events, 1) {
+ assert.Equal(t, invokerURL.GetCacheInvokerMapKey(),
handler.events[0].InstanceKey)
+ assert.Equal(t, invokerURL.ServiceKey(),
handler.events[0].ServiceKey)
+ assert.Equal(t, invokerURL.Location, handler.events[0].Address)
+ assert.Equal(t, "passive-attachment", handler.events[0].Source)
+ }
+}
+
+func TestClosingErrorDispatchesClosingEvent(t *testing.T) {
+ handler := &testClosingEventHandler{}
+ filter := &consumerGracefulShutdownFilter{
+ shutdownConfig: graceful_shutdown.NewOptions().Shutdown,
+ closingEventHandler: handler,
+ }
+ invokerURL, _ := common.NewURL(
+ "dubbo://127.0.0.1:20000/org.apache.dubbo-go.mockService",
+ common.WithParamsValue(constant.GroupKey, "group"),
+ common.WithParamsValue(constant.VersionKey, "1.0.0"),
+ )
+ invoker := newTestEmbeddedInvoker(invokerURL)
+
+ filter.handleRequestError(invoker, errors.New("rpc error: code =
Unavailable desc = transport is closing"))
+
+ if assert.Len(t, handler.events, 1) {
+ assert.Equal(t, invokerURL.GetCacheInvokerMapKey(),
handler.events[0].InstanceKey)
+ assert.Equal(t, invokerURL.ServiceKey(),
handler.events[0].ServiceKey)
+ assert.Equal(t, "connection-closing-error",
handler.events[0].Source)
+ }
+}
diff --git a/filter/graceful_shutdown/provider_filter.go
b/filter/graceful_shutdown/provider_filter.go
index 9a9b40bba..76bff152c 100644
--- a/filter/graceful_shutdown/provider_filter.go
+++ b/filter/graceful_shutdown/provider_filter.go
@@ -43,7 +43,6 @@ var (
)
func init() {
- // `init()` is performed before config.Load(), so shutdownConfig will
be retrieved after config was loaded.
extension.SetFilter(constant.GracefulShutdownProviderFilterKey, func()
filter.Filter {
return newProviderGracefulShutdownFilter()
})
@@ -53,6 +52,8 @@ type providerGracefulShutdownFilter struct {
shutdownConfig *global.ShutdownConfig
}
+const providerCountMarkedKey = "dubbo-go-graceful-shutdown-provider-counted"
+
func newProviderGracefulShutdownFilter() filter.Filter {
if psf == nil {
psfOnce.Do(func() {
@@ -64,6 +65,10 @@ func newProviderGracefulShutdownFilter() filter.Filter {
// Invoke adds the requests count and blocks the new requests if application
is closing
func (f *providerGracefulShutdownFilter) Invoke(ctx context.Context, invoker
base.Invoker, invocation base.Invocation) result.Result {
+ if f.shutdownConfig == nil {
+ return invoker.Invoke(ctx, invocation)
+ }
+
if f.rejectNewRequest() {
logger.Info("The application is closing, new request will be
rejected.")
handler := constant.DefaultKey
@@ -79,13 +84,28 @@ func (f *providerGracefulShutdownFilter) Invoke(ctx
context.Context, invoker bas
}
f.shutdownConfig.ProviderActiveCount.Inc()
f.shutdownConfig.ProviderLastReceivedRequestTime.Store(time.Now())
- return invoker.Invoke(ctx, invocation)
+ return markProviderCountedResult(invoker.Invoke(ctx, invocation))
}
// OnResponse reduces the number of active processes then return the process
result
-func (f *providerGracefulShutdownFilter) OnResponse(ctx context.Context,
result result.Result, invoker base.Invoker, invocation base.Invocation)
result.Result {
- f.shutdownConfig.ProviderActiveCount.Dec()
- return result
+func (f *providerGracefulShutdownFilter) OnResponse(ctx context.Context, res
result.Result, invoker base.Invoker, invocation base.Invocation) result.Result {
+ if f.shutdownConfig == nil {
+ return res
+ }
+
+ if shouldDecrementProviderActive(res) {
+ f.shutdownConfig.ProviderActiveCount.Dec()
+ }
+
+ // add closing flag to response
+ if f.isClosing() {
+ if res == nil {
+ res = &result.RPCResult{}
+ }
+ res.AddAttachment(constant.GracefulShutdownClosingKey, "true")
+ }
+
+ return res
}
func (f *providerGracefulShutdownFilter) Set(name string, conf any) {
@@ -112,3 +132,26 @@ func (f *providerGracefulShutdownFilter)
rejectNewRequest() bool {
}
return f.shutdownConfig.RejectRequest.Load()
}
+
+func (f *providerGracefulShutdownFilter) isClosing() bool {
+ if f.shutdownConfig == nil {
+ return false
+ }
+ return f.shutdownConfig.Closing.Load()
+}
+
+func markProviderCountedResult(res result.Result) result.Result {
+ if res == nil {
+ res = &result.RPCResult{}
+ }
+ res.AddAttachment(providerCountMarkedKey, true)
+ return res
+}
+
+func shouldDecrementProviderActive(res result.Result) bool {
+ if res == nil {
+ return false
+ }
+ marked, ok := res.Attachment(providerCountMarkedKey, false).(bool)
+ return ok && marked
+}
diff --git a/filter/graceful_shutdown/provider_filter_test.go
b/filter/graceful_shutdown/provider_filter_test.go
index 4863b1f5e..393b61261 100644
--- a/filter/graceful_shutdown/provider_filter_test.go
+++ b/filter/graceful_shutdown/provider_filter_test.go
@@ -69,6 +69,25 @@ func TestProviderFilterInvokeWithGlobalPackage(t *testing.T)
{
assert.NotNil(t, invokeResult.Error().Error(), "Rejected")
}
+func TestProviderFilterOnResponseDoesNotDecrementRejectedRequest(t *testing.T)
{
+ baseURL := common.NewURLWithOptions(common.WithParams(url.Values{}))
+ rpcInvocation := invocation.NewRPCInvocation("GetUser", []any{"OK"},
make(map[string]any))
+ opt := graceful_shutdown.NewOptions()
+ extension.SetRejectedExecutionHandler("provider-count-test", func()
filter.RejectedExecutionHandler {
+ return &TestRejectedExecutionHandler{}
+ })
+ opt.Shutdown.RejectRequestHandler = "provider-count-test"
+
+ providerFilter :=
newProviderGracefulShutdownFilter().(*providerGracefulShutdownFilter)
+ providerFilter.Set(constant.GracefulShutdownFilterShutdownConfig,
opt.Shutdown)
+ opt.Shutdown.RejectRequest.Store(true)
+
+ res := providerFilter.Invoke(context.Background(),
base.NewBaseInvoker(baseURL), rpcInvocation)
+ providerFilter.OnResponse(context.Background(), res,
base.NewBaseInvoker(baseURL), rpcInvocation)
+
+ assert.Equal(t, int32(0), opt.Shutdown.ProviderActiveCount.Load())
+}
+
type TestRejectedExecutionHandler struct{}
// RejectedExecution will do nothing, it only log the invocation.
@@ -79,3 +98,17 @@ func (handler *TestRejectedExecutionHandler)
RejectedExecution(url *common.URL,
Err: perrors.New("Rejected"),
}
}
+
+func TestProviderFilterWithoutShutdownConfigPassThrough(t *testing.T) {
+ baseURL := common.NewURLWithOptions(common.WithParams(url.Values{}))
+ rpcInvocation := invocation.NewRPCInvocation("GetUser", []any{"OK"},
make(map[string]any))
+ providerFilter := &providerGracefulShutdownFilter{}
+
+ res := providerFilter.Invoke(context.Background(),
base.NewBaseInvoker(baseURL), rpcInvocation)
+ assert.NotNil(t, res)
+ require.NoError(t, res.Error())
+
+ assert.NotPanics(t, func() {
+ providerFilter.OnResponse(context.Background(), res,
base.NewBaseInvoker(baseURL), rpcInvocation)
+ })
+}
diff --git a/global/shutdown_config.go b/global/shutdown_config.go
index 18f3ecc4d..8374e2578 100644
--- a/global/shutdown_config.go
+++ b/global/shutdown_config.go
@@ -41,6 +41,13 @@ type ShutdownConfig struct {
*/
StepTimeout string `default:"3s" yaml:"step-timeout"
json:"step.timeout,omitempty" property:"step.timeout"`
+ /*
+ * NotifyTimeout means the timeout budget for actively notifying
long-connection consumers
+ * during graceful shutdown. It only controls the notify step and
should not be coupled to
+ * request draining timeouts.
+ */
+ NotifyTimeout string `default:"5s" yaml:"notify-timeout"
json:"notify.timeout,omitempty" property:"notify.timeout"`
+
/*
* ConsumerUpdateWaitTime means when provider is shutting down, after
the unregister, time to wait for client to
* update invokers. During this time, incoming invocation can be
treated normally.
@@ -51,7 +58,7 @@ type ShutdownConfig struct {
// internal listen kill signal,the default is true.
InternalSignal *bool `default:"true" yaml:"internal-signal"
json:"internal.signal,omitempty" property:"internal.signal"`
// offline request window length
- OfflineRequestWindowTimeout string
`yaml:"offline-request-window-timeout"
json:"offlineRequestWindowTimeout,omitempty"
property:"offlineRequestWindowTimeout"`
+ OfflineRequestWindowTimeout string `default:"3s"
yaml:"offline-request-window-timeout"
json:"offlineRequestWindowTimeout,omitempty"
property:"offlineRequestWindowTimeout"`
// true -> new request will be rejected.
RejectRequest atomic.Bool
// active invocation
@@ -60,6 +67,12 @@ type ShutdownConfig struct {
// provider last received request timestamp
ProviderLastReceivedRequestTime atomic.Time
+
+ Closing atomic.Bool
+
+ // ClosingInvokerExpireTime controls how long the consumer keeps an
invoker
+ // marked as closing after receiving an active/passive closing signal.
+ ClosingInvokerExpireTime string `default:"30s"
yaml:"closing-invoker-expire-time" json:"closingInvokerExpireTime,omitempty"
property:"closingInvokerExpireTime"`
}
func DefaultShutdownConfig() *ShutdownConfig {
@@ -84,16 +97,19 @@ func (c *ShutdownConfig) Clone() *ShutdownConfig {
newShutdownConfig := &ShutdownConfig{
Timeout: c.Timeout,
StepTimeout: c.StepTimeout,
+ NotifyTimeout: c.NotifyTimeout,
ConsumerUpdateWaitTime: c.ConsumerUpdateWaitTime,
RejectRequestHandler: c.RejectRequestHandler,
InternalSignal: newInternalSignal,
OfflineRequestWindowTimeout: c.OfflineRequestWindowTimeout,
+ ClosingInvokerExpireTime: c.ClosingInvokerExpireTime,
}
newShutdownConfig.RejectRequest.Store(c.RejectRequest.Load())
newShutdownConfig.ConsumerActiveCount.Store(c.ConsumerActiveCount.Load())
newShutdownConfig.ProviderActiveCount.Store(c.ProviderActiveCount.Load())
newShutdownConfig.ProviderLastReceivedRequestTime.Store(c.ProviderLastReceivedRequestTime.Load())
+ newShutdownConfig.Closing.Store(c.Closing.Load())
return newShutdownConfig
}
diff --git a/graceful_shutdown/closing_ack.go b/graceful_shutdown/closing_ack.go
new file mode 100644
index 000000000..5dadc50c0
--- /dev/null
+++ b/graceful_shutdown/closing_ack.go
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package graceful_shutdown
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+)
+
+// ClosingAckStats is a lightweight, consumer-side statistic for active
closing notices.
+// Received means the consumer observed an active notice; Removed means
directory removal
+// succeeded; Missed means the notice was observed but did not match a local
directory entry.
+type ClosingAckStats struct {
+ Received uint64
+ Removed uint64
+ Missed uint64
+}
+
+type closingAckCounter struct {
+ received atomic.Uint64
+ removed atomic.Uint64
+ missed atomic.Uint64
+}
+
+type closingAckTracker struct {
+ counters sync.Map // map[string]*closingAckCounter
+}
+
+var defaultClosingAckTracker = &closingAckTracker{}
+
+func isActiveClosingSource(source string) bool {
+ return strings.HasSuffix(source, "-health-watch")
+}
+
+func (t *closingAckTracker) record(event ClosingEvent, removed bool) {
+ if !isActiveClosingSource(event.Source) {
+ return
+ }
+
+ counter := t.getOrCreateCounter(event.Source)
+ counter.received.Add(1)
+ if removed {
+ counter.removed.Add(1)
+ return
+ }
+ counter.missed.Add(1)
+}
+
+func (t *closingAckTracker) getOrCreateCounter(source string)
*closingAckCounter {
+ if value, ok := t.counters.Load(source); ok {
+ return value.(*closingAckCounter)
+ }
+ counter := &closingAckCounter{}
+ actual, _ := t.counters.LoadOrStore(source, counter)
+ return actual.(*closingAckCounter)
+}
+
+func (t *closingAckTracker) snapshot() map[string]ClosingAckStats {
+ stats := make(map[string]ClosingAckStats)
+ t.counters.Range(func(key, value any) bool {
+ counter := value.(*closingAckCounter)
+ stats[key.(string)] = ClosingAckStats{
+ Received: counter.received.Load(),
+ Removed: counter.removed.Load(),
+ Missed: counter.missed.Load(),
+ }
+ return true
+ })
+ return stats
+}
+
+func (t *closingAckTracker) reset() {
+ t.counters.Range(func(key, _ any) bool {
+ t.counters.Delete(key)
+ return true
+ })
+}
+
+// DefaultClosingAckStats returns the process-wide active-notice ack
statistics.
+func DefaultClosingAckStats() map[string]ClosingAckStats {
+ return defaultClosingAckTracker.snapshot()
+}
diff --git a/filter/graceful_shutdown/compat.go
b/graceful_shutdown/closing_handler.go
similarity index 52%
copy from filter/graceful_shutdown/compat.go
copy to graceful_shutdown/closing_handler.go
index ef9dc16e0..b217983ab 100644
--- a/filter/graceful_shutdown/compat.go
+++ b/graceful_shutdown/closing_handler.go
@@ -18,28 +18,25 @@
package graceful_shutdown
import (
- "go.uber.org/atomic"
+ clusterdirectory "dubbo.apache.org/dubbo-go/v3/cluster/directory"
)
-import (
- "dubbo.apache.org/dubbo-go/v3/config"
- "dubbo.apache.org/dubbo-go/v3/global"
-)
+// ClosingEvent represents a single closing signal on the consumer side.
+type ClosingEvent struct {
+ Source string
+ InstanceKey string
+ ServiceKey string
+ Address string
+}
-func compatGlobalShutdownConfig(c *config.ShutdownConfig)
*global.ShutdownConfig {
- if c == nil {
- return nil
- }
- cfg := &global.ShutdownConfig{
- Timeout: c.Timeout,
- StepTimeout: c.StepTimeout,
- ConsumerUpdateWaitTime: c.ConsumerUpdateWaitTime,
- RejectRequestHandler: c.RejectRequestHandler,
- InternalSignal: c.InternalSignal,
- OfflineRequestWindowTimeout: c.OfflineRequestWindowTimeout,
- RejectRequest: atomic.Bool{},
- }
- cfg.RejectRequest.Store(c.RejectRequest.Load())
+// ClosingEventHandler handles closing signals from active and passive
notification paths.
+type ClosingEventHandler interface {
+ HandleClosingEvent(event ClosingEvent) bool
+}
- return cfg
+// ClosingDirectoryRegistry resolves directory-level removers by service key.
+type ClosingDirectoryRegistry interface {
+ Register(serviceKey string, remover
clusterdirectory.ClosingInstanceRemover)
+ Unregister(serviceKey string, remover
clusterdirectory.ClosingInstanceRemover)
+ Find(serviceKey string) []clusterdirectory.ClosingInstanceRemover
}
diff --git a/graceful_shutdown/closing_registry.go
b/graceful_shutdown/closing_registry.go
new file mode 100644
index 000000000..66d91e9ef
--- /dev/null
+++ b/graceful_shutdown/closing_registry.go
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package graceful_shutdown
+
+import (
+ "reflect"
+ "sync"
+)
+
+import (
+ "github.com/dubbogo/gost/log/logger"
+)
+
+import (
+ clusterdirectory "dubbo.apache.org/dubbo-go/v3/cluster/directory"
+)
+
+// implement ClosingDirectoryRegistry
+type closingDirectoryRegistry struct {
+ mu sync.RWMutex
// protects removers
+ removers map[string]map[uintptr]clusterdirectory.ClosingInstanceRemover
// serviceKey -> removerKey -> remover
+}
+
+type closingEventHandler struct {
+ registry ClosingDirectoryRegistry
+}
+
+var (
+ defaultClosingDirectoryRegistry ClosingDirectoryRegistry =
newClosingDirectoryRegistry()
+ defaultClosingEventHandler ClosingEventHandler =
&closingEventHandler{registry: defaultClosingDirectoryRegistry}
+)
+
+func newClosingDirectoryRegistry() ClosingDirectoryRegistry {
+ return &closingDirectoryRegistry{
+ removers:
make(map[string]map[uintptr]clusterdirectory.ClosingInstanceRemover),
+ }
+}
+
+func removerKey(remover clusterdirectory.ClosingInstanceRemover) uintptr {
+ if remover == nil {
+ return 0
+ }
+ value := reflect.ValueOf(remover)
+ if value.Kind() != reflect.Pointer {
+ return 0
+ }
+ return value.Pointer()
+}
+
+// DefaultClosingDirectoryRegistry returns the process-wide closing directory
registry.
+func DefaultClosingDirectoryRegistry() ClosingDirectoryRegistry {
+ return defaultClosingDirectoryRegistry
+}
+
+// DefaultClosingEventHandler returns the default closing event handler.
+func DefaultClosingEventHandler() ClosingEventHandler {
+ return defaultClosingEventHandler
+}
+
+func (r *closingDirectoryRegistry) Register(serviceKey string, remover
clusterdirectory.ClosingInstanceRemover) {
+ key := removerKey(remover)
+ if serviceKey == "" || key == 0 {
+ return
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if _, ok := r.removers[serviceKey]; !ok {
+ r.removers[serviceKey] =
make(map[uintptr]clusterdirectory.ClosingInstanceRemover)
+ }
+ r.removers[serviceKey][key] = remover
+}
+
+func (r *closingDirectoryRegistry) Unregister(serviceKey string, remover
clusterdirectory.ClosingInstanceRemover) {
+ key := removerKey(remover)
+ if serviceKey == "" || key == 0 {
+ return
+ }
+
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ entries, ok := r.removers[serviceKey]
+ if !ok {
+ return
+ }
+ delete(entries, key)
+ if len(entries) == 0 {
+ delete(r.removers, serviceKey)
+ }
+}
+
+func (r *closingDirectoryRegistry) Find(serviceKey string)
[]clusterdirectory.ClosingInstanceRemover {
+ if serviceKey == "" {
+ return nil
+ }
+
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+
+ entries, ok := r.removers[serviceKey]
+ if !ok {
+ return nil
+ }
+
+ result := make([]clusterdirectory.ClosingInstanceRemover, 0,
len(entries))
+ for _, remover := range entries {
+ result = append(result, remover)
+ }
+ return result
+}
+
+func (h *closingEventHandler) HandleClosingEvent(event ClosingEvent) bool {
+ if h.registry == nil || event.InstanceKey == "" || event.ServiceKey ==
"" {
+ defaultClosingAckTracker.record(event, false)
+ return false
+ }
+
+ removed := false
+ for _, remover := range h.registry.Find(event.ServiceKey) {
+ if remover.RemoveClosingInstance(event.InstanceKey) {
+ removed = true
+ }
+ }
+
+ defaultClosingAckTracker.record(event, removed)
+ if isActiveClosingSource(event.Source) {
+ if removed {
+ logger.Infof("Graceful shutdown --- Active closing ack
handled, source=%s service=%s address=%s instance=%s",
+ event.Source, event.ServiceKey, event.Address,
event.InstanceKey)
+ } else {
+ logger.Warnf("Graceful shutdown --- Active closing ack
missed local directory, source=%s service=%s address=%s instance=%s",
+ event.Source, event.ServiceKey, event.Address,
event.InstanceKey)
+ }
+ }
+ return removed
+}
diff --git a/graceful_shutdown/closing_registry_test.go
b/graceful_shutdown/closing_registry_test.go
new file mode 100644
index 000000000..7f4b07243
--- /dev/null
+++ b/graceful_shutdown/closing_registry_test.go
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package graceful_shutdown
+
+import (
+ "testing"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+)
+
+type testClosingInstanceRemover struct {
+ removedInstanceKeys []string
+ result bool
+}
+
+func (r *testClosingInstanceRemover) RemoveClosingInstance(instanceKey string)
bool {
+ r.removedInstanceKeys = append(r.removedInstanceKeys, instanceKey)
+ return r.result
+}
+
+func TestClosingDirectoryRegistryRegisterFindAndUnregister(t *testing.T) {
+ registry := newClosingDirectoryRegistry()
+ remover1 := &testClosingInstanceRemover{result: true}
+ remover2 := &testClosingInstanceRemover{result: true}
+
+ registry.Register("org.apache.dubbo-go.TestService:1.0.0", remover1)
+ registry.Register("org.apache.dubbo-go.TestService:1.0.0", remover2)
+
+ removers := registry.Find("org.apache.dubbo-go.TestService:1.0.0")
+ assert.Len(t, removers, 2)
+
+ registry.Unregister("org.apache.dubbo-go.TestService:1.0.0", remover1)
+ removers = registry.Find("org.apache.dubbo-go.TestService:1.0.0")
+ assert.Len(t, removers, 1)
+
+ registry.Unregister("org.apache.dubbo-go.TestService:1.0.0", remover2)
+ assert.Empty(t, registry.Find("org.apache.dubbo-go.TestService:1.0.0"))
+}
+
+func TestClosingEventHandlerDispatchesByServiceKey(t *testing.T) {
+ registry := newClosingDirectoryRegistry()
+ handler := &closingEventHandler{registry: registry}
+
+ targetRemover := &testClosingInstanceRemover{result: true}
+ otherRemover := &testClosingInstanceRemover{result: true}
+
+ registry.Register("org.apache.dubbo-go.TargetService:1.0.0",
targetRemover)
+ registry.Register("org.apache.dubbo-go.OtherService:1.0.0",
otherRemover)
+
+ removed := handler.HandleClosingEvent(ClosingEvent{
+ ServiceKey: "org.apache.dubbo-go.TargetService:1.0.0",
+ InstanceKey: "target-instance",
+ })
+
+ assert.True(t, removed)
+ assert.Equal(t, []string{"target-instance"},
targetRemover.removedInstanceKeys)
+ assert.Empty(t, otherRemover.removedInstanceKeys)
+}
+
+func TestClosingEventHandlerRejectsIncompleteEvent(t *testing.T) {
+ registry := newClosingDirectoryRegistry()
+ handler := &closingEventHandler{registry: registry}
+ defaultClosingAckTracker.reset()
+
+ assert.False(t, handler.HandleClosingEvent(ClosingEvent{}))
+ assert.False(t, handler.HandleClosingEvent(ClosingEvent{ServiceKey:
"svc"}))
+ assert.False(t, handler.HandleClosingEvent(ClosingEvent{InstanceKey:
"instance"}))
+}
+
+func TestClosingEventHandlerRecordsActiveAckStats(t *testing.T) {
+ registry := newClosingDirectoryRegistry()
+ handler := &closingEventHandler{registry: registry}
+ defaultClosingAckTracker.reset()
+
+ remover := &testClosingInstanceRemover{result: true}
+ registry.Register("org.apache.dubbo-go.TargetService:1.0.0", remover)
+
+ assert.True(t, handler.HandleClosingEvent(ClosingEvent{
+ Source: "grpc-health-watch",
+ ServiceKey: "org.apache.dubbo-go.TargetService:1.0.0",
+ InstanceKey: "target-instance",
+ Address: "127.0.0.1:20000",
+ }))
+
+ stats := DefaultClosingAckStats()
+ assert.Equal(t, ClosingAckStats{
+ Received: 1,
+ Removed: 1,
+ Missed: 0,
+ }, stats["grpc-health-watch"])
+}
+
+func TestClosingEventHandlerRecordsActiveAckMisses(t *testing.T) {
+ registry := newClosingDirectoryRegistry()
+ handler := &closingEventHandler{registry: registry}
+ defaultClosingAckTracker.reset()
+
+ assert.False(t, handler.HandleClosingEvent(ClosingEvent{
+ Source: "triple-health-watch",
+ ServiceKey: "org.apache.dubbo-go.TargetService:1.0.0",
+ InstanceKey: "missing-instance",
+ Address: "127.0.0.1:20000",
+ }))
+
+ stats := DefaultClosingAckStats()
+ assert.Equal(t, ClosingAckStats{
+ Received: 1,
+ Removed: 0,
+ Missed: 1,
+ }, stats["triple-health-watch"])
+}
diff --git a/graceful_shutdown/options.go b/graceful_shutdown/options.go
index 522cc9ef6..456d789af 100644
--- a/graceful_shutdown/options.go
+++ b/graceful_shutdown/options.go
@@ -58,6 +58,12 @@ func WithStepTimeout(timeout time.Duration) Option {
}
}
+func WithNotifyTimeout(timeout time.Duration) Option {
+ return func(opts *Options) {
+ opts.Shutdown.NotifyTimeout = timeout.String()
+ }
+}
+
func WithConsumerUpdateWaitTime(duration time.Duration) Option {
return func(opts *Options) {
opts.Shutdown.ConsumerUpdateWaitTime = duration.String()
diff --git a/graceful_shutdown/options_test.go
b/graceful_shutdown/options_test.go
index c4a18731f..d66b676d1 100644
--- a/graceful_shutdown/options_test.go
+++ b/graceful_shutdown/options_test.go
@@ -36,8 +36,9 @@ func TestDefaultOptions(t *testing.T) {
assert.NotNil(t, opts.Shutdown)
assert.Equal(t, "60s", opts.Shutdown.Timeout)
assert.Equal(t, "3s", opts.Shutdown.StepTimeout)
+ assert.Equal(t, "5s", opts.Shutdown.NotifyTimeout)
assert.Equal(t, "3s", opts.Shutdown.ConsumerUpdateWaitTime)
- assert.Empty(t, opts.Shutdown.OfflineRequestWindowTimeout) // No
default value
+ assert.Equal(t, "3s", opts.Shutdown.OfflineRequestWindowTimeout)
assert.True(t, *opts.Shutdown.InternalSignal)
}
@@ -50,12 +51,14 @@ func TestNewOptions(t *testing.T) {
// Test with custom options
customTimeout := 120 * time.Second
customStepTimeout := 10 * time.Second
+ customNotifyTimeout := 4 * time.Second
customConsumerUpdateWaitTime := 5 * time.Second
customOfflineRequestWindowTimeout := 2 * time.Second
opts = NewOptions(
WithTimeout(customTimeout),
WithStepTimeout(customStepTimeout),
+ WithNotifyTimeout(customNotifyTimeout),
WithConsumerUpdateWaitTime(customConsumerUpdateWaitTime),
WithOfflineRequestWindowTimeout(customOfflineRequestWindowTimeout),
WithoutInternalSignal(),
@@ -63,6 +66,7 @@ func TestNewOptions(t *testing.T) {
assert.Equal(t, customTimeout.String(), opts.Shutdown.Timeout)
assert.Equal(t, customStepTimeout.String(), opts.Shutdown.StepTimeout)
+ assert.Equal(t, customNotifyTimeout.String(),
opts.Shutdown.NotifyTimeout)
assert.Equal(t, customConsumerUpdateWaitTime.String(),
opts.Shutdown.ConsumerUpdateWaitTime)
assert.Equal(t, customOfflineRequestWindowTimeout.String(),
opts.Shutdown.OfflineRequestWindowTimeout)
assert.False(t, *opts.Shutdown.InternalSignal)
@@ -78,6 +82,10 @@ func TestOptionFunctions(t *testing.T) {
WithStepTimeout(5 * time.Second)(opts)
assert.Equal(t, "5s", opts.Shutdown.StepTimeout)
+ // Test WithNotifyTimeout
+ WithNotifyTimeout(7 * time.Second)(opts)
+ assert.Equal(t, "7s", opts.Shutdown.NotifyTimeout)
+
// Test WithConsumerUpdateWaitTime
WithConsumerUpdateWaitTime(10 * time.Second)(opts)
assert.Equal(t, "10s", opts.Shutdown.ConsumerUpdateWaitTime)
diff --git a/graceful_shutdown/shutdown.go b/graceful_shutdown/shutdown.go
index 0f4ae3389..310377ecc 100644
--- a/graceful_shutdown/shutdown.go
+++ b/graceful_shutdown/shutdown.go
@@ -18,6 +18,8 @@
package graceful_shutdown
import (
+ "context"
+ "fmt"
"os"
"os/signal"
"runtime/debug"
@@ -26,6 +28,8 @@ import (
)
import (
+ "github.com/cenkalti/backoff/v4"
+
"github.com/dubbogo/gost/log/logger"
)
@@ -34,17 +38,25 @@ import (
"dubbo.apache.org/dubbo-go/v3/common/extension"
"dubbo.apache.org/dubbo-go/v3/config"
"dubbo.apache.org/dubbo-go/v3/global"
+ protocolbase "dubbo.apache.org/dubbo-go/v3/protocol/base"
)
const (
// todo(DMwangnima): these descriptions and defaults could be wrapped
by functions of Options
defaultTimeout = 60 * time.Second
defaultStepTimeout = 3 * time.Second
+ defaultNotifyTimeout = 5 * time.Second
defaultConsumerUpdateWaitTime = 3 * time.Second
defaultOfflineRequestWindowTimeout = 3 * time.Second
+ // retry config
+ defaultMaxRetries = 3
+ defaultRetryBaseDelay = 500 * time.Millisecond
+ defaultRetryMaxDelay = 2 * time.Second
+
timeoutDesc = "Timeout"
stepTimeoutDesc = "StepTimeout"
+ notifyTimeoutDesc = "NotifyTimeout"
consumerUpdateWaitTimeDesc = "ConsumerUpdateWaitTime"
offlineRequestWindowTimeoutDesc = "OfflineRequestWindowTimeout"
)
@@ -88,7 +100,7 @@ func Init(opts ...Option) {
go func() {
sig := <-signals
logger.Infof("get signal %s, applicationConfig
will shutdown.", sig)
- // gracefulShutdownOnce.Do(func() {
+ // fallback timeout
time.AfterFunc(totalTimeout(newOpts.Shutdown),
func() {
logger.Warn("Shutdown gracefully
timeout, applicationConfig will shutdown immediately. ")
os.Exit(0)
@@ -111,6 +123,9 @@ func Init(opts ...Option) {
// function would not make any sense.
func RegisterProtocol(name string) {
proMu.Lock()
+ if protocols == nil {
+ protocols = make(map[string]struct{})
+ }
protocols[name] = struct{}{}
proMu.Unlock()
}
@@ -125,31 +140,125 @@ func totalTimeout(shutdown *global.ShutdownConfig)
time.Duration {
}
func beforeShutdown(shutdown *global.ShutdownConfig) {
- destroyRegistries()
+ // 1. mark closing state
+ logger.Info("Graceful shutdown --- Mark closing state.")
+ shutdown.Closing.Store(true)
+
+ // 2. unregister services from registries
+ unregisterRegistries()
+
+ // 3. notify long connection consumers
+ notifyLongConnectionConsumers(shutdown)
+
+ // 4. wait and accept new requests
// waiting for a short time so that the clients have enough time to get
the notification that server shutdowns
// The value of configuration depends on how long the clients will get
notification.
waitAndAcceptNewRequests(shutdown)
+ // 5. reject new requests and wait for in-flight requests
// reject sending/receiving the new request but keeping waiting for
accepting requests
waitForSendingAndReceivingRequests(shutdown)
- // destroy all protocols
+ // 6. destroy protocols
destroyProtocols()
- logger.Info("Graceful shutdown --- Execute the custom callbacks.")
- customCallbacks := extension.GetAllCustomShutdownCallbacks()
- for callback := customCallbacks.Front(); callback != nil; callback =
callback.Next() {
- callback.Value.(func())()
- }
+ // 7. execute custom callbacks
+ executeCustomShutdownCallbacks(shutdown)
}
-// destroyRegistries destroys RegistryProtocol directly.
-func destroyRegistries() {
- logger.Info("Graceful shutdown --- Destroy all registriesConfig. ")
- registryProtocol := extension.GetProtocol(constant.RegistryProtocol)
+// unregisterRegistries unregisters exported services from registries during
graceful shutdown.
+// If the registry protocol does not expose a narrower unregister capability,
it falls back to Destroy.
+func unregisterRegistries() {
+ logger.Info("Graceful shutdown --- Unregister exported services from
registries.")
+ registryProtocol, ok := getProtocolSafely(constant.RegistryProtocol)
+ if !ok {
+ logger.Warnf("Graceful shutdown --- Registry protocol %s is not
registered, skip unregistering registries.", constant.RegistryProtocol)
+ return
+ }
+
+ if unregisterer, ok :=
registryProtocol.(protocolbase.RegistryUnregisterer); ok {
+ unregisterer.UnregisterRegistries()
+ return
+ }
+
+ logger.Warnf("Graceful shutdown --- Registry protocol %s does not
support unregister-only shutdown, falling back to Destroy().",
constant.RegistryProtocol)
registryProtocol.Destroy()
}
+// notifyLongConnectionConsumers notifies all connected consumers via long
connections
+func notifyLongConnectionConsumers(shutdown *global.ShutdownConfig) {
+ logger.Info("Graceful shutdown --- Notify long connection consumers.")
+
+ notifyTimeout := parseDuration(shutdown.NotifyTimeout,
notifyTimeoutDesc, defaultNotifyTimeout)
+ callbacks := extension.GracefulShutdownCallbacks()
+ var wg sync.WaitGroup
+ for name, callback := range callbacks {
+ wg.Add(1)
+ go func(name string, callback
extension.GracefulShutdownCallback) {
+ defer wg.Done()
+ ctx, cancel :=
context.WithTimeout(context.Background(), notifyTimeout)
+ defer cancel()
+ notifyWithRetry(ctx, name, callback)
+ }(name, callback)
+ }
+ wg.Wait()
+}
+
+// notifyWithRetry notifies with exponential backoff retry
+func notifyWithRetry(ctx context.Context, name string, callback
extension.GracefulShutdownCallback) {
+ backOff := backoff.NewExponentialBackOff()
+ backOff.InitialInterval = defaultRetryBaseDelay
+ backOff.MaxInterval = defaultRetryMaxDelay
+ backOff.MaxElapsedTime = 0
+
+ var attempts int
+ operation := func() error {
+ attempts++
+ err := invokeGracefulShutdownCallback(ctx, name, callback)
+ if err == nil {
+ logger.Infof("Graceful shutdown --- Notify %s
completed", name)
+ return nil
+ }
+
+ logger.Warnf("Graceful shutdown --- Notify %s attempt %d failed
--- %v", name, attempts, err)
+ return err
+ }
+
+ notify := func(err error, delay time.Duration) {
+ logger.Infof("Graceful shutdown --- Notify %s retrying in %v
(attempt %d/%d)", name, delay, attempts, defaultMaxRetries)
+ }
+
+ retryPolicy := backoff.WithContext(backoff.WithMaxRetries(backOff,
uint64(defaultMaxRetries)), ctx)
+ if err := backoff.RetryNotify(operation, retryPolicy, notify); err !=
nil {
+ if ctx.Err() != nil {
+ logger.Warnf("Graceful shutdown --- Notify %s timeout
after %d attempts, continuing...", name, attempts)
+ return
+ }
+
+ logger.Warnf("Graceful shutdown --- Notify %s failed after %d
attempts --- %v", name, attempts, err)
+ }
+}
+
+func invokeGracefulShutdownCallback(ctx context.Context, name string, callback
extension.GracefulShutdownCallback) error {
+ done := make(chan error, 1)
+ go func() {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.Warnf("Graceful shutdown --- Notify %s
panicked --- %v", name, recovered)
+ done <- fmt.Errorf("graceful shutdown callback
panic: %v", recovered)
+ }
+ }()
+ done <- callback(ctx)
+ }()
+
+ select {
+ case err := <-done:
+ return err
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
func waitAndAcceptNewRequests(shutdown *global.ShutdownConfig) {
logger.Info("Graceful shutdown --- Keep waiting and accept new requests
for a short time. ")
@@ -205,10 +314,63 @@ func waitingConsumerProcessedTimeout(shutdown
*global.ShutdownConfig) {
func destroyProtocols() {
logger.Info("Graceful shutdown --- Destroy protocols. ")
+ for _, name := range registeredProtocolsSnapshot() {
+ protocol, ok := getProtocolSafely(name)
+ if !ok {
+ logger.Warnf("Graceful shutdown --- Protocol %s is not
registered, skip destroying it.", name)
+ continue
+ }
+ protocol.Destroy()
+ }
+}
+
+func registeredProtocolsSnapshot() []string {
proMu.Lock()
- // extension.GetProtocol might panic
defer proMu.Unlock()
+
+ names := make([]string, 0, len(protocols))
for name := range protocols {
- extension.GetProtocol(name).Destroy()
+ names = append(names, name)
+ }
+ return names
+}
+
+func executeCustomShutdownCallbacks(shutdown *global.ShutdownConfig) {
+ logger.Info("Graceful shutdown --- Execute the custom callbacks.")
+ callbackTimeout := totalTimeout(shutdown)
+ customCallbacks := extension.GetAllCustomShutdownCallbacks()
+ for callback := customCallbacks.Front(); callback != nil; callback =
callback.Next() {
+ invokeCustomShutdownCallback(callbackTimeout,
callback.Value.(func()))
}
}
+
+func invokeCustomShutdownCallback(timeout time.Duration, callback func()) {
+ done := make(chan struct{}, 1)
+ go func() {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.Warnf("Graceful shutdown --- Custom
shutdown callback panicked --- %v", recovered)
+ }
+ done <- struct{}{}
+ }()
+ callback()
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ logger.Warnf("Graceful shutdown --- Custom shutdown callback
timed out after %v", timeout)
+ }
+}
+
+func getProtocolSafely(name string) (protocol protocolbase.Protocol, ok bool) {
+ defer func() {
+ if recover() != nil {
+ protocol = nil
+ ok = false
+ }
+ }()
+ protocol = extension.GetProtocol(name)
+ ok = protocol != nil
+ return protocol, ok
+}
diff --git a/graceful_shutdown/shutdown_test.go
b/graceful_shutdown/shutdown_test.go
index 8aa2e665e..8288f431c 100644
--- a/graceful_shutdown/shutdown_test.go
+++ b/graceful_shutdown/shutdown_test.go
@@ -19,7 +19,9 @@ package graceful_shutdown
import (
"context"
+ "errors"
"sync"
+ "sync/atomic"
"testing"
"time"
)
@@ -30,6 +32,7 @@ import (
)
import (
+ "dubbo.apache.org/dubbo-go/v3/common"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"dubbo.apache.org/dubbo-go/v3/common/extension"
"dubbo.apache.org/dubbo-go/v3/filter"
@@ -43,6 +46,47 @@ type MockFilter struct {
mock.Mock
}
+type testProtocol struct {
+ destroy func()
+}
+
+func (p *testProtocol) Export(invoker base.Invoker) base.Exporter {
+ return nil
+}
+
+func (p *testProtocol) Refer(url *common.URL) base.Invoker {
+ return nil
+}
+
+func (p *testProtocol) Destroy() {
+ if p.destroy != nil {
+ p.destroy()
+ }
+}
+
+type testRegistryProtocol struct {
+ testProtocol
+ unregister func()
+}
+
+func (p *testRegistryProtocol) UnregisterRegistries() {
+ if p.unregister != nil {
+ p.unregister()
+ }
+}
+
+func getProtocolIfPresent(name string) (protocol base.Protocol, ok bool) {
+ defer func() {
+ if recover() != nil {
+ protocol = nil
+ ok = false
+ }
+ }()
+ protocol = extension.GetProtocol(name)
+ ok = protocol != nil
+ return protocol, ok
+}
+
func (m *MockFilter) Set(key string, value any) {
m.Called(key, value)
}
@@ -109,6 +153,19 @@ func TestRegisterProtocol(t *testing.T) {
assert.Len(t, protocols, 3)
}
+func TestRegisterProtocolInitializesMapWhenNeeded(t *testing.T) {
+ protocols = nil
+ proMu = sync.Mutex{}
+
+ assert.NotPanics(t, func() {
+ RegisterProtocol("grpc")
+ })
+
+ proMu.Lock()
+ defer proMu.Unlock()
+ assert.Contains(t, protocols, "grpc")
+}
+
func TestTotalTimeout(t *testing.T) {
// Test with default timeout
config := global.DefaultShutdownConfig()
@@ -190,3 +247,275 @@ func TestWaitForSendingAndReceivingRequests(t *testing.T)
{
// Should return immediately
assert.Less(t, elapsed, 50*time.Millisecond)
}
+
+func TestNotifyLongConnectionConsumersUsesIndependentTimeouts(t *testing.T) {
+ firstName := "shutdown-timeout-first"
+ secondName := "shutdown-timeout-second"
+
+ originalFirst, firstExists :=
extension.LookupGracefulShutdownCallback(firstName)
+ originalSecond, secondExists :=
extension.LookupGracefulShutdownCallback(secondName)
+ t.Cleanup(func() {
+ extension.UnregisterGracefulShutdownCallback(firstName)
+ extension.UnregisterGracefulShutdownCallback(secondName)
+ if firstExists {
+ extension.RegisterGracefulShutdownCallback(firstName,
originalFirst)
+ }
+ if secondExists {
+ extension.RegisterGracefulShutdownCallback(secondName,
originalSecond)
+ }
+ })
+
+ var secondCalled atomic.Bool
+ extension.RegisterGracefulShutdownCallback(firstName, func(ctx
context.Context) error {
+ <-ctx.Done()
+ return ctx.Err()
+ })
+ extension.RegisterGracefulShutdownCallback(secondName, func(ctx
context.Context) error {
+ secondCalled.Store(true)
+ return nil
+ })
+
+ config := global.DefaultShutdownConfig()
+ config.NotifyTimeout = "100ms"
+
+ notifyLongConnectionConsumers(config)
+
+ assert.True(t, secondCalled.Load())
+}
+
+func TestNotifyLongConnectionConsumersUsesShutdownNotifyTimeout(t *testing.T) {
+ name := "shutdown-step-timeout"
+
+ original, exists := extension.LookupGracefulShutdownCallback(name)
+ t.Cleanup(func() {
+ extension.UnregisterGracefulShutdownCallback(name)
+ if exists {
+ extension.RegisterGracefulShutdownCallback(name,
original)
+ }
+ })
+
+ extension.RegisterGracefulShutdownCallback(name, func(ctx
context.Context) error {
+ <-ctx.Done()
+ return ctx.Err()
+ })
+
+ config := global.DefaultShutdownConfig()
+ config.NotifyTimeout = "100ms"
+
+ start := time.Now()
+ notifyLongConnectionConsumers(config)
+ elapsed := time.Since(start)
+
+ assert.GreaterOrEqual(t, elapsed, 100*time.Millisecond)
+ assert.Less(t, elapsed, time.Second)
+}
+
+func TestNotifyLongConnectionConsumersRunsCallbacksInParallel(t *testing.T) {
+ firstName := "shutdown-parallel-first"
+ secondName := "shutdown-parallel-second"
+
+ originalFirst, firstExists :=
extension.LookupGracefulShutdownCallback(firstName)
+ originalSecond, secondExists :=
extension.LookupGracefulShutdownCallback(secondName)
+ t.Cleanup(func() {
+ extension.UnregisterGracefulShutdownCallback(firstName)
+ extension.UnregisterGracefulShutdownCallback(secondName)
+ if firstExists {
+ extension.RegisterGracefulShutdownCallback(firstName,
originalFirst)
+ }
+ if secondExists {
+ extension.RegisterGracefulShutdownCallback(secondName,
originalSecond)
+ }
+ })
+
+ started := make(chan struct{}, 2)
+ release := make(chan struct{})
+ extension.RegisterGracefulShutdownCallback(firstName, func(ctx
context.Context) error {
+ started <- struct{}{}
+ select {
+ case <-release:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ })
+ extension.RegisterGracefulShutdownCallback(secondName, func(ctx
context.Context) error {
+ started <- struct{}{}
+ select {
+ case <-release:
+ return nil
+ case <-ctx.Done():
+ return errors.New("unexpected timeout")
+ }
+ })
+
+ done := make(chan struct{})
+ go func() {
+ config := global.DefaultShutdownConfig()
+ config.NotifyTimeout = "1s"
+ notifyLongConnectionConsumers(config)
+ close(done)
+ }()
+
+ for i := 0; i < 2; i++ {
+ select {
+ case <-started:
+ case <-time.After(time.Second):
+ t.Fatal("callbacks did not start in parallel")
+ }
+ }
+ close(release)
+
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("notifyLongConnectionConsumers did not finish")
+ }
+}
+
+func TestBeforeShutdownNotifiesProtocolsBeforeDestroy(t *testing.T) {
+ initOnce = sync.Once{}
+ protocols = make(map[string]struct{})
+ proMu = sync.Mutex{}
+
+ events := make([]string, 0, 3)
+ var eventsMu sync.Mutex
+ record := func(event string) {
+ eventsMu.Lock()
+ defer eventsMu.Unlock()
+ events = append(events, event)
+ }
+
+ originalRegistryProtocol, registryProtocolExists :=
getProtocolIfPresent(constant.RegistryProtocol)
+ extension.SetProtocol(constant.RegistryProtocol, func() base.Protocol {
+ return &testRegistryProtocol{unregister: func() {
record("unregister-registry") }}
+ })
+ t.Cleanup(func() {
+ if registryProtocolExists {
+ extension.SetProtocol(constant.RegistryProtocol, func()
base.Protocol { return originalRegistryProtocol })
+ return
+ }
+ extension.UnregisterProtocol(constant.RegistryProtocol)
+ })
+
+ testProtocolName := "shutdown-order-test-protocol"
+ extension.SetProtocol(testProtocolName, func() base.Protocol {
+ return &testProtocol{destroy: func() {
record("destroy-protocol") }}
+ })
+ t.Cleanup(func() {
+ extension.UnregisterProtocol(testProtocolName)
+ })
+
+ originalCallback, callbackExists :=
extension.LookupGracefulShutdownCallback(testProtocolName)
+ extension.RegisterGracefulShutdownCallback(testProtocolName, func(ctx
context.Context) error {
+ record("notify-protocol")
+ return nil
+ })
+ t.Cleanup(func() {
+ extension.UnregisterGracefulShutdownCallback(testProtocolName)
+ if callbackExists {
+
extension.RegisterGracefulShutdownCallback(testProtocolName, originalCallback)
+ }
+ })
+
+ RegisterProtocol(testProtocolName)
+
+ config := global.DefaultShutdownConfig()
+ config.ConsumerUpdateWaitTime = "0s"
+ config.StepTimeout = "100ms"
+ config.OfflineRequestWindowTimeout = "0s"
+ config.ProviderActiveCount.Store(0)
+ config.ConsumerActiveCount.Store(0)
+
+ beforeShutdown(config)
+
+ assert.Equal(t, []string{"unregister-registry", "notify-protocol",
"destroy-protocol"}, events)
+}
+
+func TestUnregisterRegistriesSkipsMissingRegistryProtocol(t *testing.T) {
+ originalRegistryProtocol, registryProtocolExists :=
getProtocolIfPresent(constant.RegistryProtocol)
+ extension.UnregisterProtocol(constant.RegistryProtocol)
+ t.Cleanup(func() {
+ if registryProtocolExists {
+ extension.SetProtocol(constant.RegistryProtocol, func()
base.Protocol { return originalRegistryProtocol })
+ }
+ })
+
+ assert.NotPanics(t, func() {
+ unregisterRegistries()
+ })
+}
+
+func TestUnregisterRegistriesPrefersUnregisterOnlyCapability(t *testing.T) {
+ originalRegistryProtocol, registryProtocolExists :=
getProtocolIfPresent(constant.RegistryProtocol)
+ defer func() {
+ if registryProtocolExists {
+ extension.SetProtocol(constant.RegistryProtocol, func()
base.Protocol { return originalRegistryProtocol })
+ return
+ }
+ extension.UnregisterProtocol(constant.RegistryProtocol)
+ }()
+
+ called := make([]string, 0, 2)
+ extension.SetProtocol(constant.RegistryProtocol, func() base.Protocol {
+ return &testRegistryProtocol{
+ testProtocol: testProtocol{
+ destroy: func() {
+ called = append(called, "destroy")
+ },
+ },
+ unregister: func() {
+ called = append(called, "unregister")
+ },
+ }
+ })
+
+ unregisterRegistries()
+
+ assert.Equal(t, []string{"unregister"}, called)
+}
+
+func TestUnregisterRegistriesFallsBackToDestroy(t *testing.T) {
+ originalRegistryProtocol, registryProtocolExists :=
getProtocolIfPresent(constant.RegistryProtocol)
+ defer func() {
+ if registryProtocolExists {
+ extension.SetProtocol(constant.RegistryProtocol, func()
base.Protocol { return originalRegistryProtocol })
+ return
+ }
+ extension.UnregisterProtocol(constant.RegistryProtocol)
+ }()
+
+ called := false
+ extension.SetProtocol(constant.RegistryProtocol, func() base.Protocol {
+ return &testProtocol{
+ destroy: func() {
+ called = true
+ },
+ }
+ })
+
+ unregisterRegistries()
+
+ assert.True(t, called)
+}
+
+func TestDestroyProtocolsSkipsMissingProtocol(t *testing.T) {
+ protocols = map[string]struct{}{"missing-shutdown-protocol": {}}
+ proMu = sync.Mutex{}
+
+ assert.NotPanics(t, func() {
+ destroyProtocols()
+ })
+}
+
+func TestInvokeCustomShutdownCallbackDoesNotBlockForever(t *testing.T) {
+ block := make(chan struct{})
+ callback := func() {
+ <-block
+ }
+
+ start := time.Now()
+ invokeCustomShutdownCallback(100*time.Millisecond, callback)
+ elapsed := time.Since(start)
+
+ assert.Less(t, elapsed, time.Second)
+}
diff --git a/internal/internal.go b/internal/internal.go
index 93aafad61..52d3e40d8 100644
--- a/internal/internal.go
+++ b/internal/internal.go
@@ -28,6 +28,9 @@ var (
// HealthSetServingStatusServing is used to set service serving status
// the initialization place is in
/protocol/triple/health/healthServer.go
HealthSetServingStatusServing = func(service string) {}
+ // HealthSetServingStatusNotServing is used to publish a NOT_SERVING
health status
+ // the initialization place is in
/protocol/triple/health/healthServer.go
+ HealthSetServingStatusNotServing = func(service string) {}
// ReflectionRegister is used to register reflection service provider
// the initialization place is in
/protocol/triple/reflection/serverreflection.go
ReflectionRegister = func(reflection reflection.ServiceInfoProvider) {}
diff --git a/protocol/base/base_invoker.go b/protocol/base/base_invoker.go
index c7071a926..bed90b362 100644
--- a/protocol/base/base_invoker.go
+++ b/protocol/base/base_invoker.go
@@ -58,6 +58,11 @@ type Invoker interface {
Invoke(context.Context, Invocation) result.Result
}
+// AvailabilitySetter is implemented by invokers that support toggling
availability.
+type AvailabilitySetter interface {
+ SetAvailable(bool)
+}
+
// BaseInvoker provides default invoker implements Invoker
type BaseInvoker struct {
url uatomic.Pointer[common.URL]
@@ -89,6 +94,11 @@ func (bi *BaseInvoker) IsAvailable() bool {
return bi.available.Load()
}
+// SetAvailable sets available flag
+func (bi *BaseInvoker) SetAvailable(available bool) {
+ bi.available.Store(available)
+}
+
// IsDestroyed gets destroyed flag
func (bi *BaseInvoker) IsDestroyed() bool {
return bi.destroyed.Load()
diff --git a/protocol/base/base_protocol.go b/protocol/base/base_protocol.go
index 811048d47..964e10620 100644
--- a/protocol/base/base_protocol.go
+++ b/protocol/base/base_protocol.go
@@ -35,6 +35,13 @@ type Protocol interface {
Destroy()
}
+// RegistryUnregisterer is an optional protocol capability used during
graceful shutdown.
+// Implementations should only unregister exported services from the registry
and must not
+// destroy protocol servers or unexport providers as part of this step.
+type RegistryUnregisterer interface {
+ UnregisterRegistries()
+}
+
// BaseProtocol is default protocol implement.
type BaseProtocol struct {
exporterMap *sync.Map
diff --git a/protocol/grpc/active_notify_test.go
b/protocol/grpc/active_notify_test.go
new file mode 100644
index 000000000..8afb6d915
--- /dev/null
+++ b/protocol/grpc/active_notify_test.go
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package grpc
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ grpcgo "google.golang.org/grpc"
+ "google.golang.org/grpc/credentials/insecure"
+ grpc_health "google.golang.org/grpc/health"
+ grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/common"
+ "dubbo.apache.org/dubbo-go/v3/common/constant"
+ gracefulshutdown "dubbo.apache.org/dubbo-go/v3/graceful_shutdown"
+ "dubbo.apache.org/dubbo-go/v3/protocol/base"
+)
+
+type testClosingEventHandler struct {
+ events []gracefulshutdown.ClosingEvent
+}
+
+func (h *testClosingEventHandler) HandleClosingEvent(event
gracefulshutdown.ClosingEvent) bool {
+ h.events = append(h.events, event)
+ return true
+}
+
+func TestGrpcInvokerHandleHealthStatusNotServing(t *testing.T) {
+ url, err := common.NewURL(helloworldURL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ invoker := &GrpcInvoker{BaseInvoker: *base.NewBaseInvoker(url)}
+ handler := &testClosingEventHandler{}
+
+ handled :=
invoker.handleHealthStatus(grpc_health_v1.HealthCheckResponse_NOT_SERVING,
handler)
+
+ assert.True(t, handled)
+ if assert.Len(t, handler.events, 1) {
+ assert.Equal(t, "grpc-health-watch", handler.events[0].Source)
+ assert.Equal(t, url.GetCacheInvokerMapKey(),
handler.events[0].InstanceKey)
+ assert.Equal(t, url.ServiceKey(), handler.events[0].ServiceKey)
+ }
+}
+
+func TestGrpcServerSetAllServicesNotServing(t *testing.T) {
+ server := NewServer()
+ server.SetServingStatus("svc-a",
grpc_health_v1.HealthCheckResponse_SERVING)
+ server.SetServingStatus("svc-b",
grpc_health_v1.HealthCheckResponse_SERVING)
+
+ server.SetAllServicesNotServing()
+
+ respA, errA := server.healthServer.Check(context.TODO(),
&grpc_health_v1.HealthCheckRequest{Service: "svc-a"})
+ respB, errB := server.healthServer.Check(context.TODO(),
&grpc_health_v1.HealthCheckRequest{Service: "svc-b"})
+ assert.NoError(t, errA)
+ assert.NoError(t, errB)
+ assert.Equal(t, grpc_health_v1.HealthCheckResponse_NOT_SERVING,
respA.GetStatus())
+ assert.Equal(t, grpc_health_v1.HealthCheckResponse_NOT_SERVING,
respB.GetStatus())
+}
+
+func TestGrpcHealthWatchEmitsClosingEvent(t *testing.T) {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ defer listener.Close()
+
+ serviceKey := common.ServiceKey(constant.HealthCheckServiceInterface,
"group", "1.0.0")
+ healthServer := grpc_health.NewServer()
+ healthServer.SetServingStatus(serviceKey,
grpc_health_v1.HealthCheckResponse_SERVING)
+
+ server := grpcgo.NewServer()
+ grpc_health_v1.RegisterHealthServer(server, healthServer)
+ go func() {
+ _ = server.Serve(listener)
+ }()
+ defer server.Stop()
+
+ url, err := common.NewURL(fmt.Sprintf(
+ "grpc://%s/%s?interface=%s&group=group&version=1.0.0",
+ listener.Addr().String(),
+ constant.HealthCheckServiceInterface,
+ constant.HealthCheckServiceInterface,
+ ))
+ require.NoError(t, err)
+
+ conn, err := grpcgo.Dial(
+ listener.Addr().String(),
+ grpcgo.WithTransportCredentials(insecure.NewCredentials()),
+ grpcgo.WithBlock(),
+ grpcgo.WithTimeout(3*time.Second),
+ )
+ require.NoError(t, err)
+ defer conn.Close()
+
+ invoker := &GrpcInvoker{
+ BaseInvoker: *base.NewBaseInvoker(url),
+ clientGuard: &sync.RWMutex{},
+ client: &Client{ClientConn: conn},
+ }
+ handler := &testClosingEventHandler{}
+ invoker.startHealthWatch(handler)
+ defer invoker.Destroy()
+
+ healthServer.SetServingStatus(serviceKey,
grpc_health_v1.HealthCheckResponse_NOT_SERVING)
+
+ require.Eventually(t, func() bool {
+ return len(handler.events) == 1
+ }, 3*time.Second, 20*time.Millisecond)
+
+ assert.Equal(t, "grpc-health-watch", handler.events[0].Source)
+ assert.Equal(t, url.GetCacheInvokerMapKey(),
handler.events[0].InstanceKey)
+ assert.Equal(t, serviceKey, handler.events[0].ServiceKey)
+}
diff --git a/protocol/grpc/grpc_invoker.go b/protocol/grpc/grpc_invoker.go
index 5e53c82bd..25b76aac4 100644
--- a/protocol/grpc/grpc_invoker.go
+++ b/protocol/grpc/grpc_invoker.go
@@ -32,10 +32,13 @@ import (
"github.com/pkg/errors"
"google.golang.org/grpc/connectivity"
+
+ grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
)
import (
"dubbo.apache.org/dubbo-go/v3/common"
+ gracefulshutdown "dubbo.apache.org/dubbo-go/v3/graceful_shutdown"
"dubbo.apache.org/dubbo-go/v3/protocol/base"
"dubbo.apache.org/dubbo-go/v3/protocol/result"
)
@@ -49,15 +52,18 @@ type GrpcInvoker struct {
quitOnce sync.Once
clientGuard *sync.RWMutex
client *Client
+ watchCancel context.CancelFunc
}
// NewGrpcInvoker returns a Grpc invoker instance
func NewGrpcInvoker(url *common.URL, client *Client) *GrpcInvoker {
- return &GrpcInvoker{
+ invoker := &GrpcInvoker{
BaseInvoker: *base.NewBaseInvoker(url),
clientGuard: &sync.RWMutex{},
client: client,
}
+ invoker.startHealthWatch(gracefulshutdown.DefaultClosingEventHandler())
+ return invoker
}
func (gi *GrpcInvoker) setClient(client *Client) {
@@ -148,6 +154,9 @@ func (gi *GrpcInvoker) IsDestroyed() bool {
// Destroy will destroy gRPC's invoker and client, so it is only called once
func (gi *GrpcInvoker) Destroy() {
gi.quitOnce.Do(func() {
+ if gi.watchCancel != nil {
+ gi.watchCancel()
+ }
gi.BaseInvoker.Destroy()
client := gi.getClient()
if client != nil {
@@ -156,3 +165,55 @@ func (gi *GrpcInvoker) Destroy() {
}
})
}
+
+func (gi *GrpcInvoker) startHealthWatch(handler
gracefulshutdown.ClosingEventHandler) {
+ if handler == nil || gi.GetURL() == nil || gi.GetURL().ServiceKey() ==
"" || gi.GetURL().Interface() == "" {
+ return
+ }
+
+ client := gi.getClient()
+ if client == nil {
+ return
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ gi.watchCancel = cancel
+
+ go func() {
+ healthClient :=
grpc_health_v1.NewHealthClient(client.ClientConn)
+ stream, err := healthClient.Watch(ctx,
&grpc_health_v1.HealthCheckRequest{Service: gi.GetURL().ServiceKey()})
+ if err != nil {
+ logger.Debugf("[GRPC Protocol] health watch start
failed for %s: %v", gi.GetURL().String(), err)
+ return
+ }
+
+ for {
+ resp, recvErr := stream.Recv()
+ if recvErr != nil {
+ if ctx.Err() == nil {
+ logger.Debugf("[GRPC Protocol] health
watch recv failed for %s: %v", gi.GetURL().String(), recvErr)
+ }
+ return
+ }
+ if gi.handleHealthStatus(resp.GetStatus(), handler) {
+ return
+ }
+ }
+ }()
+}
+
+func (gi *GrpcInvoker) handleHealthStatus(status
grpc_health_v1.HealthCheckResponse_ServingStatus, handler
gracefulshutdown.ClosingEventHandler) bool {
+ if handler == nil || gi.GetURL() == nil {
+ return false
+ }
+ if status != grpc_health_v1.HealthCheckResponse_NOT_SERVING {
+ return false
+ }
+
+ return handler.HandleClosingEvent(gracefulshutdown.ClosingEvent{
+ Source: "grpc-health-watch",
+ InstanceKey: gi.GetURL().GetCacheInvokerMapKey(),
+ ServiceKey: gi.GetURL().ServiceKey(),
+ Address: gi.GetURL().Location,
+ })
+}
diff --git a/protocol/grpc/grpc_protocol.go b/protocol/grpc/grpc_protocol.go
index 30f855cfc..0955bf2bb 100644
--- a/protocol/grpc/grpc_protocol.go
+++ b/protocol/grpc/grpc_protocol.go
@@ -18,11 +18,14 @@
package grpc
import (
+ "context"
"sync"
)
import (
"github.com/dubbogo/gost/log/logger"
+
+ grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
)
import (
@@ -38,10 +41,36 @@ const (
func init() {
extension.SetProtocol(GRPC, GetProtocol)
+
+ // register graceful shutdown callback
+ extension.RegisterGracefulShutdownCallback(GRPC, func(ctx
context.Context) error {
+ grpcProto := GetProtocol()
+ if grpcProto == nil {
+ return nil
+ }
+
+ gp, ok := grpcProto.(*GrpcProtocol)
+ if !ok {
+ return nil
+ }
+
+ gp.serverLock.Lock()
+ defer gp.serverLock.Unlock()
+
+ for _, server := range gp.serverMap {
+ server.SetAllServicesNotServing()
+ }
+
+ return nil
+ })
}
var grpcProtocol *GrpcProtocol
+var grpcServerGracefulStop = func(server *Server) {
+ server.GracefulStop()
+}
+
// GrpcProtocol is gRPC protocol
type GrpcProtocol struct {
base.BaseProtocol
@@ -64,16 +93,17 @@ func (gp *GrpcProtocol) Export(invoker base.Invoker)
base.Exporter {
exporter := NewGrpcExporter(serviceKey, invoker, gp.ExporterMap())
gp.SetExporterMap(serviceKey, exporter)
logger.Infof("[GRPC Protocol] Export service: %s", url.String())
- gp.openServer(url)
+ srv := gp.openServer(url)
+ srv.SetServingStatus(serviceKey,
grpc_health_v1.HealthCheckResponse_SERVING)
return exporter
}
-func (gp *GrpcProtocol) openServer(url *common.URL) {
+func (gp *GrpcProtocol) openServer(url *common.URL) *Server {
gp.serverLock.Lock()
defer gp.serverLock.Unlock()
- if _, ok := gp.serverMap[url.Location]; ok {
- return
+ if srv, ok := gp.serverMap[url.Location]; ok {
+ return srv
}
if _, ok := gp.ExporterMap().Load(url.ServiceKey()); !ok {
@@ -83,6 +113,7 @@ func (gp *GrpcProtocol) openServer(url *common.URL) {
srv := NewServer()
gp.serverMap[url.Location] = srv
srv.Start(url)
+ return srv
}
// Refer a remote gRPC service
@@ -102,14 +133,23 @@ func (gp *GrpcProtocol) Refer(url *common.URL)
base.Invoker {
func (gp *GrpcProtocol) Destroy() {
logger.Infof("GrpcProtocol destroy.")
+ for _, server := range gp.drainServers() {
+ grpcServerGracefulStop(server)
+ }
+
+ gp.BaseProtocol.Destroy()
+}
+
+func (gp *GrpcProtocol) drainServers() []*Server {
gp.serverLock.Lock()
defer gp.serverLock.Unlock()
+
+ servers := make([]*Server, 0, len(gp.serverMap))
for key, server := range gp.serverMap {
delete(gp.serverMap, key)
- server.GracefulStop()
+ servers = append(servers, server)
}
-
- gp.BaseProtocol.Destroy()
+ return servers
}
// GetProtocol gets gRPC protocol, will create if null.
diff --git a/protocol/grpc/grpc_protocol_test.go
b/protocol/grpc/grpc_protocol_test.go
index 93bd64ea1..0f6644360 100644
--- a/protocol/grpc/grpc_protocol_test.go
+++ b/protocol/grpc/grpc_protocol_test.go
@@ -18,7 +18,9 @@
package grpc
import (
+ "sync/atomic"
"testing"
+ "time"
)
import (
@@ -56,3 +58,59 @@ func TestGrpcProtocolRefer(t *testing.T) {
invokersLen = len(proto.(*GrpcProtocol).Invokers())
assert.Equal(t, 0, invokersLen)
}
+
+func TestGrpcProtocolDestroyDoesNotHoldServerLockWhileGracefulStopping(t
*testing.T) {
+ proto := NewGRPCProtocol()
+ proto.serverMap["127.0.0.1:20000"] = &Server{}
+
+ originalStop := grpcServerGracefulStop
+ t.Cleanup(func() {
+ grpcServerGracefulStop = originalStop
+ })
+
+ entered := make(chan struct{})
+ release := make(chan struct{})
+ var stopCalls atomic.Int32
+ grpcServerGracefulStop = func(server *Server) {
+ stopCalls.Add(1)
+ close(entered)
+ <-release
+ }
+
+ done := make(chan struct{})
+ go func() {
+ proto.Destroy()
+ close(done)
+ }()
+
+ select {
+ case <-entered:
+ case <-time.After(time.Second):
+ t.Fatal("Destroy did not reach graceful stop")
+ }
+
+ lockAcquired := make(chan struct{})
+ go func() {
+ proto.serverLock.Lock()
+ _ = proto.serverMap
+ proto.serverLock.Unlock()
+ close(lockAcquired)
+ }()
+
+ select {
+ case <-lockAcquired:
+ case <-time.After(time.Second):
+ t.Fatal("serverLock remained held during graceful stop")
+ }
+
+ close(release)
+
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("Destroy did not finish")
+ }
+
+ assert.Equal(t, int32(1), stopCalls.Load())
+ assert.Empty(t, proto.serverMap)
+}
diff --git a/protocol/grpc/server.go b/protocol/grpc/server.go
index 354a732d7..9c1b31e89 100644
--- a/protocol/grpc/server.go
+++ b/protocol/grpc/server.go
@@ -36,6 +36,8 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
+ grpc_health "google.golang.org/grpc/health"
+ grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
)
@@ -59,13 +61,19 @@ type DubboGrpcService interface {
// Server is a gRPC server
type Server struct {
- grpcServer *grpc.Server
- bufferSize int
+ grpcServer *grpc.Server
+ healthServer *grpc_health.Server
+ bufferSize int
+ serviceLock sync.Mutex
+ services map[string]struct{}
}
// NewServer creates a new server
func NewServer() *Server {
- return &Server{}
+ return &Server{
+ healthServer: grpc_health.NewServer(),
+ services: make(map[string]struct{}),
+ }
}
func (s *Server) SetBufferSize(n int) {
@@ -134,6 +142,7 @@ func (s *Server) Start(url *common.URL) {
server := grpc.NewServer(serverOpts...)
s.grpcServer = server
+ grpc_health_v1.RegisterHealthServer(server, s.healthServer)
success = true
go func() {
@@ -152,6 +161,29 @@ func (s *Server) Start(url *common.URL) {
}()
}
+func (s *Server) SetServingStatus(service string, status
grpc_health_v1.HealthCheckResponse_ServingStatus) {
+ if s.healthServer == nil || service == "" {
+ return
+ }
+
+ s.serviceLock.Lock()
+ s.services[service] = struct{}{}
+ s.serviceLock.Unlock()
+ s.healthServer.SetServingStatus(service, status)
+}
+
+func (s *Server) SetAllServicesNotServing() {
+ if s.healthServer == nil {
+ return
+ }
+
+ s.serviceLock.Lock()
+ defer s.serviceLock.Unlock()
+ for service := range s.services {
+ s.healthServer.SetServingStatus(service,
grpc_health_v1.HealthCheckResponse_NOT_SERVING)
+ }
+}
+
// getProviderServices retrieves provider services from URL attributes.
func getProviderServices(url *common.URL) map[string]*global.ServiceConfig {
if providerConfRaw, ok := url.GetAttribute(constant.ProviderConfigKey);
ok {
diff --git a/protocol/triple/active_notify_test.go
b/protocol/triple/active_notify_test.go
new file mode 100644
index 000000000..83730b221
--- /dev/null
+++ b/protocol/triple/active_notify_test.go
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package triple
+
+import (
+ "context"
+ "fmt"
+ "net"
+ "sync"
+ "testing"
+ "time"
+)
+
+import (
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+
+ grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
+)
+
+import (
+ "dubbo.apache.org/dubbo-go/v3/common"
+ "dubbo.apache.org/dubbo-go/v3/common/constant"
+ gracefulshutdown "dubbo.apache.org/dubbo-go/v3/graceful_shutdown"
+ "dubbo.apache.org/dubbo-go/v3/protocol/base"
+ tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol"
+)
+
+type testTripleClosingEventHandler struct {
+ events []gracefulshutdown.ClosingEvent
+}
+
+func (h *testTripleClosingEventHandler) HandleClosingEvent(event
gracefulshutdown.ClosingEvent) bool {
+ h.events = append(h.events, event)
+ return true
+}
+
+func TestTripleInvokerHandleHealthStatusNotServing(t *testing.T) {
+ url, err :=
common.NewURL("tri://127.0.0.1:20000/org.apache.dubbo-go.mockService?group=group&version=1.0.0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ invoker := &TripleInvoker{BaseInvoker: *base.NewBaseInvoker(url)}
+ handler := &testTripleClosingEventHandler{}
+
+ handled :=
invoker.handleHealthStatus(grpc_health_v1.HealthCheckResponse_NOT_SERVING,
handler)
+
+ assert.True(t, handled)
+ if assert.Len(t, handler.events, 1) {
+ assert.Equal(t, "triple-health-watch", handler.events[0].Source)
+ assert.Equal(t, url.GetCacheInvokerMapKey(),
handler.events[0].InstanceKey)
+ assert.Equal(t, url.ServiceKey(), handler.events[0].ServiceKey)
+ }
+}
+
+func TestTripleHealthWatchEmitsClosingEvent(t *testing.T) {
+ listener, err := net.Listen("tcp", "127.0.0.1:0")
+ require.NoError(t, err)
+ addr := listener.Addr().String()
+ _ = listener.Close()
+
+ serviceKey := common.ServiceKey(constant.HealthCheckServiceInterface,
"group", "1.0.0")
+ notServing := make(chan struct{})
+
+ server := tri.NewServer(addr, nil)
+ err = server.RegisterServerStreamHandler(
+ "/grpc.health.v1.Health/Watch",
+ func() any { return new(grpc_health_v1.HealthCheckRequest) },
+ func(ctx context.Context, req *tri.Request, stream
*tri.ServerStream) error {
+ request, ok :=
req.Msg.(*grpc_health_v1.HealthCheckRequest)
+ if !ok {
+ return fmt.Errorf("unexpected request type %T",
req.Msg)
+ }
+ if request.GetService() != serviceKey {
+ return fmt.Errorf("unexpected service %s",
request.GetService())
+ }
+ if sendErr :=
stream.Send(&grpc_health_v1.HealthCheckResponse{
+ Status:
grpc_health_v1.HealthCheckResponse_SERVING,
+ }); sendErr != nil {
+ return sendErr
+ }
+ select {
+ case <-notServing:
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+ return stream.Send(&grpc_health_v1.HealthCheckResponse{
+ Status:
grpc_health_v1.HealthCheckResponse_NOT_SERVING,
+ })
+ },
+ )
+ require.NoError(t, err)
+
+ go func() {
+ _ = server.Run(constant.CallHTTP2, nil)
+ }()
+ defer func() {
+ _ = server.Stop()
+ }()
+
+ require.Eventually(t, func() bool {
+ conn, dialErr := net.DialTimeout("tcp", addr,
100*time.Millisecond)
+ if dialErr != nil {
+ return false
+ }
+ _ = conn.Close()
+ return true
+ }, 3*time.Second, 20*time.Millisecond)
+
+ url, err := common.NewURL(fmt.Sprintf(
+ "tri://%s/%s?interface=%s&group=group&version=1.0.0&timeout=1s",
+ addr,
+ constant.HealthCheckServiceInterface,
+ constant.HealthCheckServiceInterface,
+ ))
+ require.NoError(t, err)
+
+ cm, err := newClientManager(url)
+ require.NoError(t, err)
+ invoker := &TripleInvoker{
+ BaseInvoker: *base.NewBaseInvoker(url),
+ clientGuard: &sync.RWMutex{},
+ clientManager: cm,
+ }
+ defer invoker.Destroy()
+
+ handler := &testTripleClosingEventHandler{}
+ invoker.startHealthWatch(handler)
+
+ close(notServing)
+
+ require.Eventually(t, func() bool {
+ return len(handler.events) == 1
+ }, 3*time.Second, 20*time.Millisecond)
+
+ assert.Equal(t, "triple-health-watch", handler.events[0].Source)
+ assert.Equal(t, url.GetCacheInvokerMapKey(),
handler.events[0].InstanceKey)
+ assert.Equal(t, serviceKey, handler.events[0].ServiceKey)
+}
diff --git a/protocol/triple/client.go b/protocol/triple/client.go
index b235ee4c0..40014d468 100644
--- a/protocol/triple/client.go
+++ b/protocol/triple/client.go
@@ -37,6 +37,8 @@ import (
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
+
+ grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
)
import (
@@ -56,8 +58,9 @@ const (
// callUnary, callClientStream, callServerStream, callBidiStream.
// A Reference has a clientManager.
type clientManager struct {
- isIDL bool
- triClient *tri.Client
+ isIDL bool
+ triClient *tri.Client
+ healthClient *tri.Client
}
// TODO: code a triple client between clientManager and triple_protocol client
@@ -265,13 +268,31 @@ func newClientManager(url *common.URL) (*clientManager,
error) {
}
triClient := tri.NewClient(httpClient, triURL, cliOpts...)
+ healthURL, err := joinPath(baseTriURL,
constant.HealthCheckServiceInterface)
+ if err != nil {
+ return nil, fmt.Errorf("JoinPath failed for base %s, health
interface %s", baseTriURL, constant.HealthCheckServiceInterface)
+ }
+ healthClient := tri.NewClient(httpClient, healthURL,
tri.WithTimeout(timeout))
return &clientManager{
- isIDL: isIDL,
- triClient: triClient,
+ isIDL: isIDL,
+ triClient: triClient,
+ healthClient: healthClient,
}, nil
}
+func (cm *clientManager) callHealthWatch(ctx context.Context, service string)
(*tri.ServerStreamForClient, error) {
+ if cm.healthClient == nil {
+ return nil, errors.New("triple health client is not
initialized")
+ }
+ req := tri.NewRequest(&grpc_health_v1.HealthCheckRequest{Service:
service})
+ stream, err := cm.healthClient.CallServerStream(ctx, req, "Watch")
+ if err != nil {
+ return nil, err
+ }
+ return stream, nil
+}
+
func genKeepAliveOptions(url *common.URL, tripleConf *global.TripleConfig)
([]tri.ClientOption, time.Duration, time.Duration, error) {
var cliKeepAliveOpts []tri.ClientOption
diff --git a/protocol/triple/health/healthServer.go
b/protocol/triple/health/healthServer.go
index d2791fa1c..ec4715b4b 100644
--- a/protocol/triple/health/healthServer.go
+++ b/protocol/triple/health/healthServer.go
@@ -172,6 +172,7 @@ func (srv *HealthTripleServer) Resume() {
func init() {
healthServer = NewServer()
internal.HealthSetServingStatusServing = SetServingStatusServing
+ internal.HealthSetServingStatusNotServing = SetServingStatusNotServing
server.SetProviderServices(&server.InternalService{
Name: "healthCheck",
Init: func(options *server.ServiceOptions)
(*server.ServiceDefinition, bool) {
diff --git a/protocol/triple/triple.go b/protocol/triple/triple.go
index 731dcce41..2b394a9ca 100644
--- a/protocol/triple/triple.go
+++ b/protocol/triple/triple.go
@@ -45,8 +45,36 @@ var (
tripleProtocol *TripleProtocol
)
+var tripleServerGracefulStop = func(server *Server) {
+ server.GracefulStop()
+}
+
func init() {
extension.SetProtocol(TRIPLE, GetProtocol)
+
+ // register graceful shutdown callback
+ extension.RegisterGracefulShutdownCallback(TRIPLE, func(ctx
context.Context) error {
+ p := GetProtocol()
+ if p == nil {
+ return nil
+ }
+
+ tp, ok := p.(*TripleProtocol)
+ if !ok {
+ return nil
+ }
+
+ tp.ExporterMap().Range(func(key, value any) bool {
+ serviceKey, ok := key.(string)
+ if !ok || serviceKey == "" {
+ return true
+ }
+ internal.HealthSetServingStatusNotServing(serviceKey)
+ return true
+ })
+
+ return nil
+ })
}
type TripleProtocol struct {
@@ -68,7 +96,7 @@ func (tp *TripleProtocol) Export(invoker base.Invoker)
base.Exporter {
tp.SetExporterMap(serviceKey, exporter)
logger.Infof("[TRIPLE Protocol] Export service: %s", url.String())
tp.openServer(invoker, info)
- internal.HealthSetServingStatusServing(url.Service())
+ internal.HealthSetServingStatusServing(serviceKey)
return exporter
}
@@ -139,14 +167,23 @@ func (tp *TripleProtocol) Refer(url *common.URL)
base.Invoker {
func (tp *TripleProtocol) Destroy() {
logger.Infof("TripleProtocol destroy.")
+ for _, server := range tp.drainServers() {
+ tripleServerGracefulStop(server)
+ }
+
+ tp.BaseProtocol.Destroy()
+}
+
+func (tp *TripleProtocol) drainServers() []*Server {
tp.serverLock.Lock()
defer tp.serverLock.Unlock()
+
+ servers := make([]*Server, 0, len(tp.serverMap))
for key, server := range tp.serverMap {
delete(tp.serverMap, key)
- server.GracefulStop()
+ servers = append(servers, server)
}
-
- tp.BaseProtocol.Destroy()
+ return servers
}
// isGenericCall checks if the generic parameter indicates a generic call
diff --git a/protocol/triple/triple_invoker.go
b/protocol/triple/triple_invoker.go
index 11d81ae3a..53faaf067 100644
--- a/protocol/triple/triple_invoker.go
+++ b/protocol/triple/triple_invoker.go
@@ -26,11 +26,14 @@ import (
import (
"github.com/dubbogo/gost/log/logger"
+
+ grpc_health_v1 "google.golang.org/grpc/health/grpc_health_v1"
)
import (
"dubbo.apache.org/dubbo-go/v3/common"
"dubbo.apache.org/dubbo-go/v3/common/constant"
+ gracefulshutdown "dubbo.apache.org/dubbo-go/v3/graceful_shutdown"
"dubbo.apache.org/dubbo-go/v3/protocol/base"
"dubbo.apache.org/dubbo-go/v3/protocol/result"
tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol"
@@ -45,6 +48,7 @@ type TripleInvoker struct {
quitOnce sync.Once
clientGuard *sync.RWMutex
clientManager *clientManager
+ watchCancel context.CancelFunc
}
func (ti *TripleInvoker) setClientManager(cm *clientManager) {
@@ -238,6 +242,9 @@ func (ti *TripleInvoker) IsDestroyed() bool {
// Destroy will destroy Triple's invoker and client, so it is only called once
func (ti *TripleInvoker) Destroy() {
ti.quitOnce.Do(func() {
+ if ti.watchCancel != nil {
+ ti.watchCancel()
+ }
ti.BaseInvoker.Destroy()
if cm := ti.getClientManager(); cm != nil {
ti.setClientManager(nil)
@@ -252,10 +259,63 @@ func NewTripleInvoker(url *common.URL) (*TripleInvoker,
error) {
if err != nil {
return nil, err
}
- return &TripleInvoker{
+ invoker := &TripleInvoker{
BaseInvoker: *base.NewBaseInvoker(url),
quitOnce: sync.Once{},
clientGuard: &sync.RWMutex{},
clientManager: cm,
- }, nil
+ }
+ invoker.startHealthWatch(gracefulshutdown.DefaultClosingEventHandler())
+ return invoker, nil
+}
+
+func (ti *TripleInvoker) startHealthWatch(handler
gracefulshutdown.ClosingEventHandler) {
+ if handler == nil || ti.GetURL() == nil || ti.GetURL().ServiceKey() ==
"" {
+ return
+ }
+
+ cm := ti.getClientManager()
+ if cm == nil {
+ return
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ ti.watchCancel = cancel
+
+ go func() {
+ stream, err := cm.callHealthWatch(ctx, ti.GetURL().ServiceKey())
+ if err != nil {
+ logger.Debugf("[TRIPLE Protocol] health watch start
failed for %s: %v", ti.GetURL().String(), err)
+ return
+ }
+
+ for {
+ resp := new(grpc_health_v1.HealthCheckResponse)
+ if ok := stream.Receive(resp); !ok {
+ if ctx.Err() == nil {
+ logger.Debugf("[TRIPLE Protocol] health
watch recv failed for %s: %v", ti.GetURL().String(), stream.Err())
+ }
+ return
+ }
+ if ti.handleHealthStatus(resp.GetStatus(), handler) {
+ return
+ }
+ }
+ }()
+}
+
+func (ti *TripleInvoker) handleHealthStatus(status
grpc_health_v1.HealthCheckResponse_ServingStatus, handler
gracefulshutdown.ClosingEventHandler) bool {
+ if handler == nil || ti.GetURL() == nil {
+ return false
+ }
+ if status != grpc_health_v1.HealthCheckResponse_NOT_SERVING {
+ return false
+ }
+
+ return handler.HandleClosingEvent(gracefulshutdown.ClosingEvent{
+ Source: "triple-health-watch",
+ InstanceKey: ti.GetURL().GetCacheInvokerMapKey(),
+ ServiceKey: ti.GetURL().ServiceKey(),
+ Address: ti.GetURL().Location,
+ })
}
diff --git a/protocol/triple/triple_test.go b/protocol/triple/triple_test.go
index 3fa918a79..ab561f7bf 100644
--- a/protocol/triple/triple_test.go
+++ b/protocol/triple/triple_test.go
@@ -18,7 +18,10 @@
package triple
import (
+ "context"
+ "sync/atomic"
"testing"
+ "time"
)
import (
@@ -27,6 +30,7 @@ import (
import (
"dubbo.apache.org/dubbo-go/v3/common/extension"
+ tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol"
)
func TestNewTripleProtocol(t *testing.T) {
@@ -59,6 +63,24 @@ func TestTripleConstant(t *testing.T) {
assert.Equal(t, "tri", TRIPLE)
}
+func TestTripleGracefulShutdownCallbackRegistration(t *testing.T) {
+ cb, ok := extension.LookupGracefulShutdownCallback(TRIPLE)
+ assert.True(t, ok)
+ assert.NotNil(t, cb)
+
+ original := tripleProtocol
+ tp := NewTripleProtocol()
+ tp.serverMap["graceful-test"] = &Server{triServer: tri.NewServer("",
nil)}
+ tripleProtocol = tp
+ t.Cleanup(func() {
+ tripleProtocol = original
+ })
+
+ assert.NotPanics(t, func() {
+ assert.NoError(t, cb(context.Background()))
+ })
+}
+
func TestTripleProtocol_Destroy_EmptyServerMap(t *testing.T) {
tp := NewTripleProtocol()
@@ -68,6 +90,62 @@ func TestTripleProtocol_Destroy_EmptyServerMap(t *testing.T)
{
})
}
+func TestTripleProtocolDestroyDoesNotHoldServerLockWhileGracefulStopping(t
*testing.T) {
+ tp := NewTripleProtocol()
+ tp.serverMap["127.0.0.1:20000"] = &Server{}
+
+ originalStop := tripleServerGracefulStop
+ t.Cleanup(func() {
+ tripleServerGracefulStop = originalStop
+ })
+
+ entered := make(chan struct{})
+ release := make(chan struct{})
+ var stopCalls atomic.Int32
+ tripleServerGracefulStop = func(server *Server) {
+ stopCalls.Add(1)
+ close(entered)
+ <-release
+ }
+
+ done := make(chan struct{})
+ go func() {
+ tp.Destroy()
+ close(done)
+ }()
+
+ select {
+ case <-entered:
+ case <-time.After(time.Second):
+ t.Fatal("Destroy did not reach graceful stop")
+ }
+
+ lockAcquired := make(chan struct{})
+ go func() {
+ tp.serverLock.Lock()
+ _ = tp.serverMap
+ tp.serverLock.Unlock()
+ close(lockAcquired)
+ }()
+
+ select {
+ case <-lockAcquired:
+ case <-time.After(time.Second):
+ t.Fatal("serverLock remained held during graceful stop")
+ }
+
+ close(release)
+
+ select {
+ case <-done:
+ case <-time.After(time.Second):
+ t.Fatal("Destroy did not finish")
+ }
+
+ assert.Equal(t, int32(1), stopCalls.Load())
+ assert.Empty(t, tp.serverMap)
+}
+
// Test isGenericCall checks if the generic parameter indicates a generic call
func Test_isGenericCall(t *testing.T) {
tests := []struct {
diff --git a/registry/directory/directory.go b/registry/directory/directory.go
index f2e80705d..7689b34c4 100644
--- a/registry/directory/directory.go
+++ b/registry/directory/directory.go
@@ -43,6 +43,7 @@ import (
"dubbo.apache.org/dubbo-go/v3/config_center"
_ "dubbo.apache.org/dubbo-go/v3/config_center/configurator"
"dubbo.apache.org/dubbo-go/v3/global"
+ "dubbo.apache.org/dubbo-go/v3/graceful_shutdown"
"dubbo.apache.org/dubbo-go/v3/metrics"
metricsRegistry "dubbo.apache.org/dubbo-go/v3/metrics/registry"
protocolbase "dubbo.apache.org/dubbo-go/v3/protocol/base"
@@ -74,8 +75,25 @@ type RegistryDirectory struct {
registerLock sync.Mutex // this lock if for register
SubscribedUrl *common.URL
RegisteredUrl *common.URL
+ closingTombstones *sync.Map // map[string]closingTombstone
+ closingTombstoneTTL time.Duration
}
+type closingTombstone struct {
+ InstanceKey string
+ ServiceKey string
+ Address string
+ Source string
+ ExpireAt time.Time
+}
+
+var defaultClosingTombstoneTTL = func() time.Duration {
+ if duration, err :=
time.ParseDuration(global.DefaultShutdownConfig().ClosingInvokerExpireTime);
err == nil && duration > 0 {
+ return duration
+ }
+ return 30 * time.Second
+}()
+
// NewRegistryDirectory will create a new RegistryDirectory
func NewRegistryDirectory(url *common.URL, registry registry.Registry)
(directory.Directory, error) {
if url.SubURL == nil {
@@ -132,11 +150,13 @@ func NewRegistryDirectory(url *common.URL, registry
registry.Registry) (director
}
dir := &RegistryDirectory{
- Directory: base.NewDirectory(url),
- cacheInvokers: []protocolbase.Invoker{},
- cacheInvokersMap: &sync.Map{},
- serviceType: url.SubURL.Service(),
- registry: registry,
+ Directory: base.NewDirectory(url),
+ cacheInvokers: []protocolbase.Invoker{},
+ cacheInvokersMap: &sync.Map{},
+ serviceType: url.SubURL.Service(),
+ registry: registry,
+ closingTombstones: &sync.Map{},
+ closingTombstoneTTL: defaultClosingTombstoneTTL,
}
dir.consumerURL = dir.getConsumerUrl(url.SubURL)
@@ -150,6 +170,7 @@ func NewRegistryDirectory(url *common.URL, registry
registry.Registry) (director
dir.consumerConfigurationListener =
newConsumerConfigurationListener(dir, url)
dir.consumerConfigurationListener.addNotifyListener(dir)
dir.referenceConfigurationListener =
newReferenceConfigurationListener(dir, url)
+
graceful_shutdown.DefaultClosingDirectoryRegistry().Register(dir.closingServiceKey(),
dir)
if err := dir.registry.LoadSubscribeInstances(url.SubURL, dir); err !=
nil {
return nil, err
@@ -446,6 +467,7 @@ func (dir *RegistryDirectory)
uncacheInvokerWithClusterID(clusterID string) []pr
// uncacheInvoker will return abandoned Invoker, if no Invoker to be
abandoned, return nil
func (dir *RegistryDirectory) uncacheInvoker(event *registry.ServiceEvent)
[]protocolbase.Invoker {
defer
metrics.Publish(metricsRegistry.NewDirectoryEvent(metricsRegistry.NumDisableTotal))
+ dir.clearClosingTombstone(event.Key())
if clusterID := event.Service.GetParam(constant.MeshClusterIDKey, "");
event.Service.Location == constant.MeshAnyAddrMatcher && clusterID != "" {
dir.uncacheInvokerWithClusterID(clusterID)
}
@@ -462,6 +484,86 @@ func (dir *RegistryDirectory) uncacheInvokerWithKey(key
string) protocolbase.Inv
return nil
}
+// RemoveClosingInstance removes a single service instance from the directory
by instanceKey.
+// It is intended to be called by graceful shutdown logic before registry
updates converge.
+func (dir *RegistryDirectory) RemoveClosingInstance(instanceKey string) bool {
+ if instanceKey == "" {
+ return false
+ }
+
+ var removed protocolbase.Invoker
+ func() {
+ dir.registerLock.Lock()
+ defer dir.registerLock.Unlock()
+
+ if cacheInvoker, ok := dir.cacheInvokersMap.Load(instanceKey);
ok {
+ removed = cacheInvoker.(protocolbase.Invoker)
+ }
+ dir.markClosingTombstone(instanceKey, removed, "closing-event")
+ removed = dir.uncacheInvokerWithKey(instanceKey)
+ if removed != nil {
+ dir.setNewInvokers()
+ }
+ }()
+
+ if removed != nil {
+ removed.Destroy()
+ return true
+ }
+ return false
+}
+
+func (dir *RegistryDirectory) markClosingTombstone(instanceKey string, invoker
protocolbase.Invoker, source string) {
+ if instanceKey == "" {
+ return
+ }
+
+ tombstone := closingTombstone{
+ InstanceKey: instanceKey,
+ Source: source,
+ ExpireAt: time.Now().Add(dir.closingTombstoneTTL),
+ }
+ if invoker != nil && invoker.GetURL() != nil {
+ tombstone.ServiceKey = invoker.GetURL().ServiceKey()
+ tombstone.Address = invoker.GetURL().Location
+ }
+ dir.closingTombstones.Store(instanceKey, tombstone)
+}
+
+func (dir *RegistryDirectory) hasActiveClosingTombstone(instanceKey string)
bool {
+ if instanceKey == "" {
+ return false
+ }
+ tombstoneValue, ok := dir.closingTombstones.Load(instanceKey)
+ if !ok {
+ return false
+ }
+ tombstone := tombstoneValue.(closingTombstone)
+ if time.Now().After(tombstone.ExpireAt) {
+ dir.closingTombstones.Delete(instanceKey)
+ return false
+ }
+ return true
+}
+
+func (dir *RegistryDirectory) clearClosingTombstone(instanceKey string) {
+ if instanceKey == "" {
+ return
+ }
+ dir.closingTombstones.Delete(instanceKey)
+}
+
+func (dir *RegistryDirectory) cleanupExpiredClosingTombstones() {
+ now := time.Now()
+ dir.closingTombstones.Range(func(key, value any) bool {
+ tombstone := value.(closingTombstone)
+ if now.After(tombstone.ExpireAt) {
+ dir.closingTombstones.Delete(key)
+ }
+ return true
+ })
+}
+
// cacheInvoker will return abandoned Invoker,if no Invoker to be
abandoned,return nil
func (dir *RegistryDirectory) cacheInvoker(url *common.URL, event
*registry.ServiceEvent) protocolbase.Invoker {
dir.overrideUrl(dir.GetDirectoryUrl())
@@ -490,6 +592,11 @@ func (dir *RegistryDirectory) cacheInvoker(url
*common.URL, event *registry.Serv
func (dir *RegistryDirectory) doCacheInvoker(newUrl *common.URL, event
*registry.ServiceEvent) (protocolbase.Invoker, bool) {
key := event.Key()
+ dir.cleanupExpiredClosingTombstones()
+ if dir.hasActiveClosingTombstone(key) {
+ logger.Infof("[Registry Directory] skip rebuilding closing
instance due to tombstone, instance key: %s", key)
+ return nil, true
+ }
if cacheInvoker, ok := dir.cacheInvokersMap.Load(key); !ok {
logger.Debugf("service will be added in cache invokers:
invokers url is %s!", newUrl)
newInvoker :=
extension.GetProtocol(protocolwrapper.FILTER).Refer(newUrl)
@@ -549,6 +656,7 @@ func (dir *RegistryDirectory) IsAvailable() bool {
func (dir *RegistryDirectory) Destroy() {
// TODO:unregister & unsubscribe
dir.DoDestroy(func() {
+
graceful_shutdown.DefaultClosingDirectoryRegistry().Unregister(dir.closingServiceKey(),
dir)
if dir.RegisteredUrl != nil {
err := dir.registry.UnRegister(dir.RegisteredUrl)
if err != nil {
@@ -572,6 +680,17 @@ func (dir *RegistryDirectory) Destroy() {
metrics.Publish(metricsRegistry.NewDirectoryEvent(metricsRegistry.NumAllDec))
}
+func (dir *RegistryDirectory) closingServiceKey() string {
+ if dir.GetURL() == nil {
+ return ""
+ }
+ serviceKey := dir.GetURL().ServiceKey()
+ if serviceKey == "" && dir.GetURL().SubURL != nil {
+ serviceKey = dir.GetURL().SubURL.ServiceKey()
+ }
+ return serviceKey
+}
+
func (dir *RegistryDirectory) overrideUrl(targetUrl *common.URL) {
doOverrideUrl(dir.configurators, targetUrl)
doOverrideUrl(dir.consumerConfigurationListener.Configurators(),
targetUrl)
diff --git a/registry/directory/directory_test.go
b/registry/directory/directory_test.go
index 03d321ca7..4038a6e1e 100644
--- a/registry/directory/directory_test.go
+++ b/registry/directory/directory_test.go
@@ -151,6 +151,110 @@ func Test_RefreshUrl(t *testing.T) {
assert.Empty(t, registryDirectory.cacheInvokers)
}
+func TestRemoveClosingInstanceRemovesExactInstanceKey(t *testing.T) {
+ registryDirectory, mockRegistry := normalRegistryDir(true)
+
+ providerURL1, _ :=
common.NewURL("dubbo://0.0.0.0:20000/org.apache.dubbo-go.mockService",
+ common.WithParamsValue(constant.ClusterKey, "mock1"),
+ common.WithParamsValue(constant.GroupKey, "group"),
+ common.WithParamsValue(constant.VersionKey, "1.0.0"))
+ providerURL2, _ :=
common.NewURL("dubbo://0.0.0.0:20001/org.apache.dubbo-go.mockService",
+ common.WithParamsValue(constant.ClusterKey, "mock1"),
+ common.WithParamsValue(constant.GroupKey, "group"),
+ common.WithParamsValue(constant.VersionKey, "1.0.0"))
+
+ event1 := ®istry.ServiceEvent{Action: remoting.EventTypeAdd,
Service: providerURL1}
+ event2 := ®istry.ServiceEvent{Action: remoting.EventTypeAdd,
Service: providerURL2}
+ key1 := registryDirectory.invokerCacheKey(event1)
+ key2 := registryDirectory.invokerCacheKey(event2)
+
+ mockRegistry.MockEvent(event1)
+ mockRegistry.MockEvent(event2)
+ time.Sleep(1e9)
+
+ assert.Len(t, registryDirectory.cacheInvokers, 2)
+ assert.NotEqual(t, key1, key2)
+
+ removed := registryDirectory.RemoveClosingInstance(key1)
+ require.True(t, removed)
+
+ assert.Len(t, registryDirectory.cacheInvokers, 1)
+ assert.Len(t, registryDirectory.List(&invocation.RPCInvocation{}), 1)
+
+ _, stillExists := registryDirectory.cacheInvokersMap.Load(key1)
+ assert.False(t, stillExists)
+
+ remaining, ok := registryDirectory.cacheInvokersMap.Load(key2)
+ require.True(t, ok)
+ assert.Equal(t, key2, remaining.(interface{ GetURL() *common.URL
}).GetURL().GetCacheInvokerMapKey())
+}
+
+func TestRemoveClosingInstanceReturnsFalseForUnknownKey(t *testing.T) {
+ registryDirectory, _ := normalRegistryDir(true)
+
+ removed :=
registryDirectory.RemoveClosingInstance("missing-instance-key")
+ assert.False(t, removed)
+ assert.Empty(t, registryDirectory.cacheInvokers)
+}
+
+func TestClosingTombstonePreventsRebuildUntilDeleteEvent(t *testing.T) {
+ registryDirectory, mockRegistry := normalRegistryDir(true)
+
+ providerURL, _ :=
common.NewURL("dubbo://0.0.0.0:20000/org.apache.dubbo-go.mockService",
+ common.WithParamsValue(constant.ClusterKey, "mock1"),
+ common.WithParamsValue(constant.GroupKey, "group"),
+ common.WithParamsValue(constant.VersionKey, "1.0.0"))
+ event := ®istry.ServiceEvent{Action: remoting.EventTypeAdd, Service:
providerURL}
+ key := registryDirectory.invokerCacheKey(event)
+
+ mockRegistry.MockEvent(event)
+ time.Sleep(1e9)
+ assert.Len(t, registryDirectory.cacheInvokers, 1)
+
+ removed := registryDirectory.RemoveClosingInstance(key)
+ require.True(t, removed)
+ assert.Empty(t, registryDirectory.cacheInvokers)
+ assert.True(t, registryDirectory.hasActiveClosingTombstone(key))
+
+ mockRegistry.MockEvent(®istry.ServiceEvent{Action:
remoting.EventTypeAdd, Service: providerURL})
+ time.Sleep(1e9)
+ assert.Empty(t, registryDirectory.cacheInvokers)
+
+ mockRegistry.MockEvent(®istry.ServiceEvent{Action:
remoting.EventTypeDel, Service: providerURL})
+ time.Sleep(1e9)
+ assert.False(t, registryDirectory.hasActiveClosingTombstone(key))
+
+ mockRegistry.MockEvent(®istry.ServiceEvent{Action:
remoting.EventTypeAdd, Service: providerURL})
+ time.Sleep(1e9)
+ assert.Len(t, registryDirectory.cacheInvokers, 1)
+}
+
+func TestExpiredClosingTombstoneAllowsRebuild(t *testing.T) {
+ registryDirectory, mockRegistry := normalRegistryDir(true)
+ registryDirectory.closingTombstoneTTL = 20 * time.Millisecond
+
+ providerURL, _ :=
common.NewURL("dubbo://0.0.0.0:20000/org.apache.dubbo-go.mockService",
+ common.WithParamsValue(constant.ClusterKey, "mock1"),
+ common.WithParamsValue(constant.GroupKey, "group"),
+ common.WithParamsValue(constant.VersionKey, "1.0.0"))
+ event := ®istry.ServiceEvent{Action: remoting.EventTypeAdd, Service:
providerURL}
+ key := registryDirectory.invokerCacheKey(event)
+
+ mockRegistry.MockEvent(event)
+ time.Sleep(1e9)
+ require.Len(t, registryDirectory.cacheInvokers, 1)
+
+ require.True(t, registryDirectory.RemoveClosingInstance(key))
+ assert.Empty(t, registryDirectory.cacheInvokers)
+
+ time.Sleep(40 * time.Millisecond)
+ assert.False(t, registryDirectory.hasActiveClosingTombstone(key))
+
+ mockRegistry.MockEvent(®istry.ServiceEvent{Action:
remoting.EventTypeAdd, Service: providerURL})
+ time.Sleep(1e9)
+ assert.Len(t, registryDirectory.cacheInvokers, 1)
+}
+
func normalRegistryDir(noMockEvent ...bool) (*RegistryDirectory,
*registry.MockRegistry) {
extension.SetProtocol(protocolwrapper.FILTER,
protocolwrapper.NewMockProtocolFilter)
diff --git a/registry/protocol/protocol.go b/registry/protocol/protocol.go
index 3944f3fb4..eee3b4534 100644
--- a/registry/protocol/protocol.go
+++ b/registry/protocol/protocol.go
@@ -495,6 +495,19 @@ func (proto *registryProtocol) Destroy() {
})
}
+// UnregisterRegistries only unregisters exported services from registries
during graceful shutdown.
+// Protocol servers keep running until the later destroy phase.
+func (proto *registryProtocol) UnregisterRegistries() {
+ proto.bounds.Range(func(_, value any) bool {
+ exporter := value.(*exporterChangeableWrapper)
+ reg := proto.getRegistry(getRegistryUrl(exporter.originInvoker))
+ if err := reg.UnRegister(exporter.registerUrl); err != nil {
+ logger.Warnf("Unregister consumer url failed, %s,
error: %w", exporter.registerUrl.String(), err)
+ }
+ return true
+ })
+}
+
func getRegistryUrl(invoker base.Invoker) *common.URL {
// here add * for return a new url
url := invoker.GetURL()