This is an automated email from the ASF dual-hosted git repository. xuetaoli pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/dubbo-go.git
commit fb8eff2f090851261e2d1acfc55414ede91ca9c0 Author: 承潜 <[email protected]> AuthorDate: Wed Mar 25 10:49:35 2026 +0800 fix(triple): case-insensitive method routing without metadata polluti… (#3277) * fix(triple): case-insensitive method routing without metadata pollution (#3162) Introduce methodRouteMux in the Triple transport layer to support case-insensitive method name routing (Go's GetUser <-> Java's getUser). Remove the SwapCase method duplication from enhanceServiceInfo, which was polluting ServiceInfo.Methods with duplicate entries visible to registry and gRPC reflection. * style(triple): fix test import formatting * fix(ci): address triple sonar warnings * style(triple): modernize attachment map type * test(triple): split method handler tests and fix lint assertion * fix(triple): stabilize method fallback and host-unavailable tests * test(triple): extract repeated streaming literals * test(triple): make lowerFirstRune empty-input path coverable --- protocol/triple/server.go | 221 +++++---- protocol/triple/server_test.go | 527 +++++++++++++++++++++ .../triple/triple_protocol/method_route_mux.go | 118 +++++ .../triple_protocol/method_route_mux_test.go | 208 ++++++++ protocol/triple/triple_protocol/server.go | 4 +- protocol/triple/triple_protocol/triple_ext_test.go | 12 +- server/server.go | 34 +- server/server_test.go | 55 ++- server/triple_case_route_integration_test.go | 255 ++++++++++ 9 files changed, 1304 insertions(+), 130 deletions(-) diff --git a/protocol/triple/server.go b/protocol/triple/server.go index 244377b63..7c46ac7fc 100644 --- a/protocol/triple/server.go +++ b/protocol/triple/server.go @@ -320,108 +320,133 @@ func (s *Server) compatRegisterHandler(interfaceName string, svc dubbo3.Dubbo3Gr } } -// handleServiceWithInfo injects invoker and create handler based on ServiceInfo +// handleServiceWithInfo injects invoker and creates handlers based on ServiceInfo. +// Each method is registered once under its canonical procedure path. Triple's +// transport-layer route mux performs case-insensitive fallback matching. func (s *Server) handleServiceWithInfo(interfaceName string, invoker base.Invoker, info *common.ServiceInfo, opts ...tri.HandlerOption) { for _, method := range info.Methods { m := method procedure := joinProcedure(interfaceName, method.Name) - switch m.Type { - case constant.CallUnary: - _ = s.triServer.RegisterUnaryHandler( - procedure, - m.ReqInitFunc, - func(ctx context.Context, req *tri.Request) (*tri.Response, error) { - var args []any - if argsRaw, ok := req.Msg.([]any); ok { - // non-idl mode, req.Msg consists of many arguments - for _, argRaw := range argsRaw { - // refer to createServiceInfoWithReflection, in ReqInitFunc, argRaw is a pointer to real arg. - // so we have to invoke Elem to get the real arg. - args = append(args, reflect.ValueOf(argRaw).Elem().Interface()) - } - } else { - // triple idl mode and old triple idl mode - args = append(args, req.Msg) - } - attachments := generateAttachments(req.Header()) - // inject attachments - ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) - invo := invocation.NewRPCInvocation(m.Name, args, attachments) - res := invoker.Invoke(ctx, invo) - // todo(DMwangnima): modify InfoInvoker to get a unified processing logic - // please refer to server/InfoInvoker.Invoke() - var triResp *tri.Response - if existingResp, ok := res.Result().(*tri.Response); ok { - triResp = existingResp - } else { - // please refer to proxy/proxy_factory/ProxyInvoker.Invoke - triResp = tri.NewResponse([]any{res.Result()}) - } - for k, v := range res.Attachments() { - switch val := v.(type) { - case string: - tri.AppendToOutgoingContext(ctx, k, val) - case []string: - for _, v := range val { - tri.AppendToOutgoingContext(ctx, k, v) - } - } - } - return triResp, res.Error() - }, - opts..., - ) - case constant.CallClientStream: - _ = s.triServer.RegisterClientStreamHandler( - procedure, - func(ctx context.Context, stream *tri.ClientStream) (*tri.Response, error) { - var args []any - args = append(args, m.StreamInitFunc(stream)) - attachments := generateAttachments(stream.RequestHeader()) - // inject attachments - ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) - invo := invocation.NewRPCInvocation(m.Name, args, attachments) - res := invoker.Invoke(ctx, invo) - if triResp, ok := res.Result().(*tri.Response); ok { - return triResp, res.Error() - } - // please refer to proxy/proxy_factory/ProxyInvoker.Invoke - triResp := tri.NewResponse([]any{res.Result()}) - return triResp, res.Error() - }, - opts..., - ) - case constant.CallServerStream: - _ = s.triServer.RegisterServerStreamHandler( - procedure, - m.ReqInitFunc, - func(ctx context.Context, req *tri.Request, stream *tri.ServerStream) error { - var args []any - args = append(args, req.Msg, m.StreamInitFunc(stream)) - attachments := generateAttachments(req.Header()) - // inject attachments - ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) - invo := invocation.NewRPCInvocation(m.Name, args, attachments) - res := invoker.Invoke(ctx, invo) - return res.Error() - }, - opts..., - ) - case constant.CallBidiStream: - _ = s.triServer.RegisterBidiStreamHandler( - procedure, - func(ctx context.Context, stream *tri.BidiStream) error { - var args []any - args = append(args, m.StreamInitFunc(stream)) - attachments := generateAttachments(stream.RequestHeader()) - // inject attachments - ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) - invo := invocation.NewRPCInvocation(m.Name, args, attachments) - res := invoker.Invoke(ctx, invo) - return res.Error() - }, - opts..., - ) + s.registerMethodHandler(procedure, m, invoker, opts...) + } +} + +// registerMethodHandler registers a single method handler for the given procedure path. +func (s *Server) registerMethodHandler(procedure string, m common.MethodInfo, invoker base.Invoker, opts ...tri.HandlerOption) { + switch m.Type { + case constant.CallUnary: + s.registerUnaryMethodHandler(procedure, m, invoker, opts...) + case constant.CallClientStream: + s.registerClientStreamMethodHandler(procedure, m, invoker, opts...) + case constant.CallServerStream: + s.registerServerStreamMethodHandler(procedure, m, invoker, opts...) + case constant.CallBidiStream: + s.registerBidiStreamMethodHandler(procedure, m, invoker, opts...) + } +} + +func (s *Server) registerUnaryMethodHandler(procedure string, m common.MethodInfo, invoker base.Invoker, opts ...tri.HandlerOption) { + _ = s.triServer.RegisterUnaryHandler( + procedure, + m.ReqInitFunc, + func(ctx context.Context, req *tri.Request) (*tri.Response, error) { + args := extractUnaryInvocationArgs(req.Msg) + attachments := generateAttachments(req.Header()) + // inject attachments + ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) + invo := invocation.NewRPCInvocation(m.Name, args, attachments) + res := invoker.Invoke(ctx, invo) + // todo(DMwangnima): modify InfoInvoker to get a unified processing logic + // please refer to server/InfoInvoker.Invoke() + triResp := wrapTripleResponse(res.Result()) + appendTripleOutgoingAttachments(ctx, res.Attachments()) + return triResp, res.Error() + }, + opts..., + ) +} + +func (s *Server) registerClientStreamMethodHandler(procedure string, m common.MethodInfo, invoker base.Invoker, opts ...tri.HandlerOption) { + _ = s.triServer.RegisterClientStreamHandler( + procedure, + func(ctx context.Context, stream *tri.ClientStream) (*tri.Response, error) { + args := []any{m.StreamInitFunc(stream)} + attachments := generateAttachments(stream.RequestHeader()) + // inject attachments + ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) + invo := invocation.NewRPCInvocation(m.Name, args, attachments) + res := invoker.Invoke(ctx, invo) + return wrapTripleResponse(res.Result()), res.Error() + }, + opts..., + ) +} + +func (s *Server) registerServerStreamMethodHandler(procedure string, m common.MethodInfo, invoker base.Invoker, opts ...tri.HandlerOption) { + _ = s.triServer.RegisterServerStreamHandler( + procedure, + m.ReqInitFunc, + func(ctx context.Context, req *tri.Request, stream *tri.ServerStream) error { + args := []any{req.Msg, m.StreamInitFunc(stream)} + attachments := generateAttachments(req.Header()) + // inject attachments + ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) + invo := invocation.NewRPCInvocation(m.Name, args, attachments) + res := invoker.Invoke(ctx, invo) + return res.Error() + }, + opts..., + ) +} + +func (s *Server) registerBidiStreamMethodHandler(procedure string, m common.MethodInfo, invoker base.Invoker, opts ...tri.HandlerOption) { + _ = s.triServer.RegisterBidiStreamHandler( + procedure, + func(ctx context.Context, stream *tri.BidiStream) error { + args := []any{m.StreamInitFunc(stream)} + attachments := generateAttachments(stream.RequestHeader()) + // inject attachments + ctx = context.WithValue(ctx, constant.AttachmentKey, attachments) + invo := invocation.NewRPCInvocation(m.Name, args, attachments) + res := invoker.Invoke(ctx, invo) + return res.Error() + }, + opts..., + ) +} + +func extractUnaryInvocationArgs(msg any) []any { + if argsRaw, ok := msg.([]any); ok { + args := make([]any, 0, len(argsRaw)) + // non-idl mode, req.Msg consists of many arguments + for _, argRaw := range argsRaw { + // refer to createServiceInfoWithReflection, in ReqInitFunc, argRaw is a pointer to real arg. + // so we have to invoke Elem to get the real arg. + args = append(args, reflect.ValueOf(argRaw).Elem().Interface()) + } + return args + } + // triple idl mode and old triple idl mode + return []any{msg} +} + +func wrapTripleResponse(result any) *tri.Response { + if existingResp, ok := result.(*tri.Response); ok { + return existingResp + } + // please refer to proxy/proxy_factory/ProxyInvoker.Invoke + return tri.NewResponse([]any{result}) +} + +func appendTripleOutgoingAttachments(ctx context.Context, attachments map[string]any) { + for k, v := range attachments { + switch val := v.(type) { + case string: + tri.AppendToOutgoingContext(ctx, k, val) + case []string: + for _, item := range val { + tri.AppendToOutgoingContext(ctx, k, item) + } } } } diff --git a/protocol/triple/server_test.go b/protocol/triple/server_test.go index f6d801954..99da82179 100644 --- a/protocol/triple/server_test.go +++ b/protocol/triple/server_test.go @@ -21,6 +21,7 @@ import ( "context" "fmt" "net/http" + "net/http/httptest" "reflect" "sync" "testing" @@ -29,6 +30,7 @@ import ( import ( "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "google.golang.org/grpc" ) @@ -37,6 +39,9 @@ import ( "dubbo.apache.org/dubbo-go/v3/common" "dubbo.apache.org/dubbo-go/v3/common/constant" "dubbo.apache.org/dubbo-go/v3/global" + "dubbo.apache.org/dubbo-go/v3/protocol/base" + "dubbo.apache.org/dubbo-go/v3/protocol/result" + tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" ) func Test_generateAttachments(t *testing.T) { @@ -542,3 +547,525 @@ func Test_isReflectValueNil_UnsafePointer(t *testing.T) { assert.False(t, isReflectValueNil(v)) }) } + +// TestHandleServiceWithInfoSaveServiceInfoOnlyOriginalMethods verifies that +// saveServiceInfo only records the original method names so registry and +// gRPC-reflection metadata stay clean. +func TestHandleServiceWithInfoSaveServiceInfoOnlyOriginalMethods(t *testing.T) { + server := NewServer(nil) + info := &common.ServiceInfo{ + Methods: []common.MethodInfo{ + {Name: "GetUser", Type: constant.CallUnary}, + {Name: "ListUsers", Type: constant.CallServerStream}, + }, + } + server.saveServiceInfo("com.example.UserService", info) + + svcInfo := server.GetServiceInfo() + svc, ok := svcInfo["com.example.UserService"] + assert.True(t, ok) + // Only the two original methods; no "getUser" / "listUsers" aliases. + assert.Len(t, svc.Methods, 2) + names := make([]string, 0, len(svc.Methods)) + for _, m := range svc.Methods { + names = append(names, m.Name) + } + assert.Contains(t, names, "GetUser") + assert.Contains(t, names, "ListUsers") + assert.NotContains(t, names, "getUser") + assert.NotContains(t, names, "listUsers") +} + +type tripleServerTestInvoker struct { + invokeFn func(context.Context, base.Invocation) result.Result +} + +func (m *tripleServerTestInvoker) GetURL() *common.URL { + return common.NewURLWithOptions() +} + +func (m *tripleServerTestInvoker) IsAvailable() bool { + return true +} + +func (m *tripleServerTestInvoker) Destroy() { + // No-op: this test double does not own lifecycle resources. +} + +func (m *tripleServerTestInvoker) Invoke(ctx context.Context, invocation base.Invocation) result.Result { + if m.invokeFn == nil { + return &result.RPCResult{} + } + return m.invokeFn(ctx, invocation) +} + +type tripleServerTestConn struct { + reqHeader http.Header + respHeader http.Header + respTrailer http.Header + receiveFn func(any) + sent []any +} + +const ( + tripleServerTestFromClientMessage = "from-client" + tripleServerTestServerToken = "server-token" + tripleServerTestBidiToken = "bidi-token" + tripleServerTestBidiValue = "bidi-v" +) + +func newTripleServerTestConn() *tripleServerTestConn { + return &tripleServerTestConn{ + reqHeader: make(http.Header), + respHeader: make(http.Header), + respTrailer: make(http.Header), + } +} + +func (c *tripleServerTestConn) Spec() tri.Spec { + return tri.Spec{} +} + +func (c *tripleServerTestConn) Peer() tri.Peer { + return tri.Peer{} +} + +func (c *tripleServerTestConn) Receive(msg any) error { + if c.receiveFn != nil { + c.receiveFn(msg) + } + return nil +} + +func (c *tripleServerTestConn) RequestHeader() http.Header { + return c.reqHeader +} + +func (c *tripleServerTestConn) ExportableHeader() http.Header { + return c.reqHeader +} + +func (c *tripleServerTestConn) Send(msg any) error { + c.sent = append(c.sent, msg) + return nil +} + +func (c *tripleServerTestConn) ResponseHeader() http.Header { + return c.respHeader +} + +func (c *tripleServerTestConn) ResponseTrailer() http.Header { + return c.respTrailer +} + +func TestServerRegisterUnaryMethodHandler(t *testing.T) { + server := newServerForMethodHandlerTest() + invoker := &tripleServerTestInvoker{ + invokeFn: func(ctx context.Context, invocation base.Invocation) result.Result { + assert.Equal(t, "UnaryMethod", invocation.MethodName()) + assert.Equal(t, []any{"alice", 7}, invocation.Arguments()) + assert.Equal(t, []string{"v1", "v2"}, invocation.Attachments()["x-test"]) + + ctxAttachments, ok := ctx.Value(constant.AttachmentKey).(map[string]any) + require.True(t, ok) + assert.Equal(t, []string{"v1", "v2"}, ctxAttachments["x-test"]) + + res := &result.RPCResult{} + res.SetResult("unary-ok") + res.SetAttachments(map[string]any{ + "resp-one": "val", + "resp-multi": []string{"a", "b"}, + "resp-omit": 123, + }) + return res + }, + } + method := common.MethodInfo{ + Name: "UnaryMethod", + Type: constant.CallUnary, + ReqInitFunc: func() any { + var name string + var age int + return []any{&name, &age} + }, + } + procedure := "/svc/UnaryMethod" + server.registerMethodHandler(procedure, method, invoker) + + conn := newTripleServerTestConn() + conn.reqHeader.Add("X-Test", "v1") + conn.reqHeader.Add("X-Test", "v2") + conn.receiveFn = func(msg any) { + args, ok := msg.([]any) + require.True(t, ok) + *(args[0].(*string)) = "alice" + *(args[1].(*int)) = 7 + } + require.NoError(t, invokeRegisteredHandlerImplementation(server.triServer, procedure, conn)) + require.Len(t, conn.sent, 1) + assert.Equal(t, []any{"unary-ok"}, conn.sent[0]) + assert.Equal(t, []string{"val"}, conn.respTrailer.Values("resp-one")) + assert.Equal(t, []string{"a", "b"}, conn.respTrailer.Values("resp-multi")) + assert.Empty(t, conn.respTrailer.Values("resp-omit")) +} + +func TestServerRegisterClientStreamMethodHandler(t *testing.T) { + server := newServerForMethodHandlerTest() + invoker := &tripleServerTestInvoker{ + invokeFn: func(ctx context.Context, invocation base.Invocation) result.Result { + assert.Equal(t, "ClientStreamMethod", invocation.MethodName()) + assert.Equal(t, []any{"client-token"}, invocation.Arguments()) + assert.Equal(t, []string{"trace-123"}, invocation.Attachments()["trace-id"]) + _, ok := ctx.Value(constant.AttachmentKey).(map[string]any) + assert.True(t, ok) + + res := &result.RPCResult{} + res.SetResult("client-ok") + return res + }, + } + method := common.MethodInfo{ + Name: "ClientStreamMethod", + Type: constant.CallClientStream, + StreamInitFunc: func(baseStream any) any { + _, ok := baseStream.(*tri.ClientStream) + require.True(t, ok) + return "client-token" + }, + } + procedure := "/svc/ClientStreamMethod" + server.registerMethodHandler(procedure, method, invoker) + + conn := newTripleServerTestConn() + conn.reqHeader.Set("Trace-Id", "trace-123") + require.NoError(t, invokeRegisteredHandlerImplementation(server.triServer, procedure, conn)) + require.Len(t, conn.sent, 1) + assert.Equal(t, []any{"client-ok"}, conn.sent[0]) +} + +func TestServerRegisterServerStreamMethodHandler(t *testing.T) { + type serverStreamReq struct { + Message string + } + + server := newServerForMethodHandlerTest() + invoker := &tripleServerTestInvoker{ + invokeFn: func(ctx context.Context, invocation base.Invocation) result.Result { + assert.Equal(t, "ServerStreamMethod", invocation.MethodName()) + require.Len(t, invocation.Arguments(), 2) + req, ok := invocation.Arguments()[0].(*serverStreamReq) + require.True(t, ok) + assert.Equal(t, tripleServerTestFromClientMessage, req.Message) + assert.Equal(t, tripleServerTestServerToken, invocation.Arguments()[1]) + assert.Equal(t, []string{"v"}, invocation.Attachments()["x-stream"]) + _, ok = ctx.Value(constant.AttachmentKey).(map[string]any) + assert.True(t, ok) + return &result.RPCResult{} + }, + } + method := common.MethodInfo{ + Name: "ServerStreamMethod", + Type: constant.CallServerStream, + ReqInitFunc: func() any { + return &serverStreamReq{} + }, + StreamInitFunc: func(baseStream any) any { + _, ok := baseStream.(*tri.ServerStream) + require.True(t, ok) + return tripleServerTestServerToken + }, + } + procedure := "/svc/ServerStreamMethod" + server.registerMethodHandler(procedure, method, invoker) + + conn := newTripleServerTestConn() + conn.reqHeader.Set("X-Stream", "v") + conn.receiveFn = func(msg any) { + req, ok := msg.(*serverStreamReq) + require.True(t, ok) + req.Message = tripleServerTestFromClientMessage + } + require.NoError(t, invokeRegisteredHandlerImplementation(server.triServer, procedure, conn)) + assert.Empty(t, conn.sent) +} + +func TestServerRegisterBidiStreamMethodHandler(t *testing.T) { + server := newServerForMethodHandlerTest() + invoker := &tripleServerTestInvoker{ + invokeFn: func(ctx context.Context, invocation base.Invocation) result.Result { + assert.Equal(t, "BidiStreamMethod", invocation.MethodName()) + assert.Equal(t, []any{tripleServerTestBidiToken}, invocation.Arguments()) + assert.Equal(t, []string{tripleServerTestBidiValue}, invocation.Attachments()["x-bidi"]) + _, ok := ctx.Value(constant.AttachmentKey).(map[string]any) + assert.True(t, ok) + return &result.RPCResult{} + }, + } + method := common.MethodInfo{ + Name: "BidiStreamMethod", + Type: constant.CallBidiStream, + StreamInitFunc: func(baseStream any) any { + _, ok := baseStream.(*tri.BidiStream) + require.True(t, ok) + return tripleServerTestBidiToken + }, + } + procedure := "/svc/BidiStreamMethod" + server.registerMethodHandler(procedure, method, invoker) + + conn := newTripleServerTestConn() + conn.reqHeader.Set("X-Bidi", tripleServerTestBidiValue) + require.NoError(t, invokeRegisteredHandlerImplementation(server.triServer, procedure, conn)) + assert.Empty(t, conn.sent) +} + +func TestServerRegisterMethodHandlerUnknownType(t *testing.T) { + server := newServerForMethodHandlerTest() + procedure := "/svc/Unknown" + server.registerMethodHandler(procedure, common.MethodInfo{ + Name: "UnknownMethod", + Type: "unknown", + }, &tripleServerTestInvoker{}) + + _, ok := getServerHandler(server.triServer, procedure) + assert.False(t, ok) +} + +func TestServerHandleServiceWithInfoFallbackHitsStreamingHandlers(t *testing.T) { + type streamReq struct { + Message string + } + + server := newServerForMethodHandlerTest() + calledMethods := make([]string, 0, 2) + invoker := &tripleServerTestInvoker{ + invokeFn: func(ctx context.Context, invocation base.Invocation) result.Result { + calledMethods = append(calledMethods, invocation.MethodName()) + switch invocation.MethodName() { + case "CountUp": + require.Len(t, invocation.Arguments(), 2) + req, ok := invocation.Arguments()[0].(*streamReq) + require.True(t, ok) + assert.Equal(t, tripleServerTestFromClientMessage, req.Message) + assert.Equal(t, tripleServerTestServerToken, invocation.Arguments()[1]) + assert.Equal(t, []string{"stream-v"}, invocation.Attachments()["x-stream"]) + case "CumSum": + require.Len(t, invocation.Arguments(), 1) + assert.Equal(t, tripleServerTestBidiToken, invocation.Arguments()[0]) + assert.Equal(t, []string{tripleServerTestBidiValue}, invocation.Attachments()["x-bidi"]) + default: + t.Fatalf("unexpected method: %s", invocation.MethodName()) + } + + ctxAttachments, ok := ctx.Value(constant.AttachmentKey).(map[string]any) + require.True(t, ok) + assert.NotEmpty(t, ctxAttachments) + return &result.RPCResult{} + }, + } + info := &common.ServiceInfo{ + Methods: []common.MethodInfo{ + { + Name: "CountUp", + Type: constant.CallServerStream, + ReqInitFunc: func() any { + return &streamReq{} + }, + StreamInitFunc: func(baseStream any) any { + _, ok := baseStream.(*tri.ServerStream) + require.True(t, ok) + return tripleServerTestServerToken + }, + }, + { + Name: "CumSum", + Type: constant.CallBidiStream, + StreamInitFunc: func(baseStream any) any { + _, ok := baseStream.(*tri.BidiStream) + require.True(t, ok) + return tripleServerTestBidiToken + }, + }, + }, + } + server.handleServiceWithInfo("svc.Fallback", invoker, info) + + serverStreamConn := newTripleServerTestConn() + serverStreamConn.reqHeader.Set("X-Stream", "stream-v") + serverStreamConn.receiveFn = func(msg any) { + req, ok := msg.(*streamReq) + require.True(t, ok) + req.Message = tripleServerTestFromClientMessage + } + pattern, err := invokeRegisteredHandlerImplementationByRequestPath( + server.triServer, + "/svc.Fallback/countUp", + serverStreamConn, + ) + require.NoError(t, err) + assert.Equal(t, "/svc.Fallback/CountUp", pattern) + + bidiConn := newTripleServerTestConn() + bidiConn.reqHeader.Set("X-Bidi", tripleServerTestBidiValue) + pattern, err = invokeRegisteredHandlerImplementationByRequestPath( + server.triServer, + "/svc.Fallback/cumSum", + bidiConn, + ) + require.NoError(t, err) + assert.Equal(t, "/svc.Fallback/CumSum", pattern) + + assert.Equal(t, []string{"CountUp", "CumSum"}, calledMethods) +} + +func newServerForMethodHandlerTest() *Server { + return &Server{triServer: tri.NewServer("127.0.0.1:0", nil)} +} + +func TestExtractUnaryInvocationArgs(t *testing.T) { + t.Run("from non-idl argument slice", func(t *testing.T) { + name := "alice" + age := 18 + args := extractUnaryInvocationArgs([]any{&name, &age}) + assert.Equal(t, []any{"alice", 18}, args) + }) + + t.Run("from single message in idl mode", func(t *testing.T) { + msg := struct{ Name string }{Name: "idl"} + args := extractUnaryInvocationArgs(msg) + assert.Equal(t, []any{msg}, args) + }) +} + +func TestWrapTripleResponse(t *testing.T) { + resp := tri.NewResponse("already-wrapped") + assert.Same(t, resp, wrapTripleResponse(resp)) + + wrapped := wrapTripleResponse("plain-result") + assert.Equal(t, []any{"plain-result"}, wrapped.Msg) +} + +func TestAppendTripleOutgoingAttachments(t *testing.T) { + ctx := tri.NewOutgoingContext(context.Background(), make(http.Header)) + appendTripleOutgoingAttachments(ctx, map[string]any{ + "one": "1", + "multi": []string{"a", "b"}, + "omit": 100, + }) + + outgoing := tri.ExtractFromOutgoingContext(ctx) + require.NotNil(t, outgoing) + assert.Equal(t, []string{"1"}, outgoing.Values("one")) + assert.Equal(t, []string{"a", "b"}, outgoing.Values("multi")) + assert.Empty(t, outgoing.Values("omit")) +} + +const tripleServerDefaultImplementationKey = "/" + +// These helpers execute the default registered implementation directly so the +// tests can verify registerMethodHandler's invocation wiring without depending +// on protocol-specific HTTP framing details. +func invokeRegisteredHandlerImplementation(triServer *tri.Server, procedure string, conn tri.StreamingHandlerConn) error { + handler, ok := getServerHandler(triServer, procedure) + if !ok { + return fmt.Errorf("handler for procedure %s not found", procedure) + } + implementation, ok := getDefaultHandlerImplementation(handler) + if !ok { + return fmt.Errorf("default implementation for procedure %s not found", procedure) + } + return implementation(context.Background(), conn) +} + +func invokeRegisteredHandlerImplementationByRequestPath( + triServer *tri.Server, + requestPath string, + conn tri.StreamingHandlerConn, +) (string, error) { + handler, pattern, ok := getServerHandlerByRequestPath(triServer, requestPath) + if !ok { + return "", fmt.Errorf("handler for request path %s not found", requestPath) + } + implementation, ok := getDefaultHandlerImplementation(handler) + if !ok { + return "", fmt.Errorf("default implementation for request path %s not found", requestPath) + } + return pattern, implementation(context.Background(), conn) +} + +func getServerHandler(triServer *tri.Server, procedure string) (*tri.Handler, bool) { + if triServer == nil { + return nil, false + } + handlersField := reflect.ValueOf(triServer).Elem().FieldByName("handlers") + handlersValue, ok := extractUnexportedValue(handlersField) + if !ok { + return nil, false + } + handlers, ok := handlersValue.Interface().(map[string]*tri.Handler) + if !ok { + return nil, false + } + handler, ok := handlers[procedure] + return handler, ok +} + +func getServerHandlerByRequestPath(triServer *tri.Server, requestPath string) (*tri.Handler, string, bool) { + if triServer == nil { + return nil, "", false + } + muxField := reflect.ValueOf(triServer).Elem().FieldByName("mux") + muxValue, ok := extractUnexportedValue(muxField) + if !ok || !muxValue.IsValid() || muxValue.IsNil() { + return nil, "", false + } + + handlerMethod := muxValue.MethodByName("Handler") + if !handlerMethod.IsValid() { + return nil, "", false + } + req := httptest.NewRequest(http.MethodPost, requestPath, nil) + results := handlerMethod.Call([]reflect.Value{reflect.ValueOf(req)}) + if len(results) != 2 { + return nil, "", false + } + + handler, ok := results[0].Interface().(http.Handler) + if !ok || handler == nil { + return nil, "", false + } + pattern, ok := results[1].Interface().(string) + if !ok || pattern == "" { + return nil, "", false + } + triHandler, ok := handler.(*tri.Handler) + if !ok { + return nil, "", false + } + + return triHandler, pattern, true +} + +func getDefaultHandlerImplementation(handler *tri.Handler) (tri.StreamingHandlerFunc, bool) { + if handler == nil { + return nil, false + } + implField := reflect.ValueOf(handler).Elem().FieldByName("implementations") + implValue, ok := extractUnexportedValue(implField) + if !ok { + return nil, false + } + implementations, ok := implValue.Interface().(map[string]tri.StreamingHandlerFunc) + if !ok { + return nil, false + } + implementation, ok := implementations[tripleServerDefaultImplementationKey] + return implementation, ok +} + +func extractUnexportedValue(field reflect.Value) (reflect.Value, bool) { + if !field.IsValid() || !field.CanAddr() { + return reflect.Value{}, false + } + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem(), true +} diff --git a/protocol/triple/triple_protocol/method_route_mux.go b/protocol/triple/triple_protocol/method_route_mux.go new file mode 100644 index 000000000..b4be12669 --- /dev/null +++ b/protocol/triple/triple_protocol/method_route_mux.go @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package triple_protocol + +import ( + "net/http" + "sync" + "unicode" + "unicode/utf8" +) + +// methodRouteMux wraps http.ServeMux to provide case-insensitive procedure +// routing without duplicating method metadata in higher layers. +// +// Lookup order: +// 1. exact match via http.ServeMux +// 2. lowercase-first-method fallback via an internal index +type methodRouteMux struct { + exact *http.ServeMux + + mu sync.RWMutex + lower map[string]methodRouteEntry +} + +type methodRouteEntry struct { + pattern string + handler http.Handler +} + +func newMethodRouteMux() *methodRouteMux { + return &methodRouteMux{ + exact: http.NewServeMux(), + lower: make(map[string]methodRouteEntry), + } +} + +// Handle registers the handler in the exact mux and the lowercase-first-method +// fallback index. If two patterns collide after fallback normalization, the +// first registration wins while exact matching keeps the original behavior. +func (m *methodRouteMux) Handle(pattern string, handler http.Handler) { + m.exact.Handle(pattern, handler) + + lowerKey := normalizeMethodRouteKey(pattern) + m.mu.Lock() + defer m.mu.Unlock() + // Keep the first registration for a collided fallback key so mixed-case + // registrations from different generators (for example Go/Java stubs) keep + // deterministic behavior. + if _, exists := m.lower[lowerKey]; !exists { + m.lower[lowerKey] = methodRouteEntry{ + pattern: pattern, + handler: handler, + } + } +} + +func (m *methodRouteMux) Handler(r *http.Request) (http.Handler, string) { + if handler, pattern := m.exact.Handler(r); pattern != "" { + return handler, pattern + } + + lowerKey := normalizeMethodRouteKey(r.URL.Path) + m.mu.RLock() + entry, ok := m.lower[lowerKey] + m.mu.RUnlock() + if ok { + return entry.handler, entry.pattern + } + + return http.NotFoundHandler(), "" +} + +func (m *methodRouteMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + handler, pattern := m.Handler(r) + if pattern != "" { + handler.ServeHTTP(w, r) + return + } + + http.NotFound(w, r) +} + +func normalizeMethodRouteKey(path string) string { + lastSlash := -1 + for i := len(path) - 1; i >= 0; i-- { + if path[i] == '/' { + lastSlash = i + break + } + } + if lastSlash < 0 || lastSlash == len(path)-1 { + return path + } + return path[:lastSlash+1] + lowerFirstRune(path[lastSlash+1:]) +} + +func lowerFirstRune(s string) string { + r, size := utf8.DecodeRuneInString(s) + if r == utf8.RuneError && size == 0 { + return s + } + return string(unicode.ToLower(r)) + s[size:] +} diff --git a/protocol/triple/triple_protocol/method_route_mux_test.go b/protocol/triple/triple_protocol/method_route_mux_test.go new file mode 100644 index 000000000..331d7a8c9 --- /dev/null +++ b/protocol/triple/triple_protocol/method_route_mux_test.go @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package triple_protocol + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +import ( + "github.com/stretchr/testify/assert" +) + +const ( + methodRouteMuxGetUserPath = "/Service/GetUser" + methodRouteMuxGetUserLowerPath = "/Service/getUser" + methodRouteMuxNotFoundBody = "404 page not found\n" +) + +func TestMethodRouteMux(t *testing.T) { + tests := []struct { + name string + register func(*methodRouteMux) + requestPath string + wantStatusCode int + wantBody string + }{ + { + name: "exact match", + register: func(m *methodRouteMux) { + m.Handle(methodRouteMuxGetUserPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("exact")) + })) + }, + requestPath: methodRouteMuxGetUserPath, + wantStatusCode: http.StatusOK, + wantBody: "exact", + }, + { + name: "lowercase first rune fallback", + register: func(m *methodRouteMux) { + m.Handle(methodRouteMuxGetUserPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("fallback")) + })) + }, + requestPath: methodRouteMuxGetUserLowerPath, + wantStatusCode: http.StatusOK, + wantBody: "fallback", + }, + { + name: "uppercase first rune fallback", + register: func(m *methodRouteMux) { + m.Handle(methodRouteMuxGetUserLowerPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("fallback")) + })) + }, + requestPath: methodRouteMuxGetUserPath, + wantStatusCode: http.StatusOK, + wantBody: "fallback", + }, + { + name: "full lowercase fallback", + register: func(m *methodRouteMux) { + m.Handle(methodRouteMuxGetUserPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("fallback")) + })) + }, + requestPath: "/service/getuser", + wantStatusCode: http.StatusNotFound, + wantBody: methodRouteMuxNotFoundBody, + }, + { + name: "not found", + register: func(m *methodRouteMux) { + m.Handle(methodRouteMuxGetUserPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("unexpected")) + })) + }, + requestPath: "/Service/Delete", + wantStatusCode: http.StatusNotFound, + wantBody: methodRouteMuxNotFoundBody, + }, + { + name: "methods differing beyond first rune remain distinct", + register: func(m *methodRouteMux) { + m.Handle("/S/GetUser", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("first")) + })) + m.Handle("/S/Getuser", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("second")) + })) + }, + requestPath: "/S/getuser", + wantStatusCode: http.StatusOK, + wantBody: "second", + }, + { + name: "service path remains case sensitive", + register: func(m *methodRouteMux) { + m.Handle(methodRouteMuxGetUserPath, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("unexpected")) + })) + }, + requestPath: "/service/getUser", + wantStatusCode: http.StatusNotFound, + wantBody: methodRouteMuxNotFoundBody, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mux := newMethodRouteMux() + tt.register(mux) + + req := httptest.NewRequest(http.MethodPost, tt.requestPath, nil) + resp := httptest.NewRecorder() + mux.ServeHTTP(resp, req) + + assert.Equal(t, tt.wantStatusCode, resp.Code) + assert.Equal(t, tt.wantBody, resp.Body.String()) + }) + } +} + +func TestNormalizeMethodRouteKeyEdgeCases(t *testing.T) { + tests := []struct { + name string + path string + want string + }{ + { + name: "path without slash remains unchanged", + path: "GetUser", + want: "GetUser", + }, + { + name: "path ending with slash remains unchanged", + path: "/Service/", + want: "/Service/", + }, + { + name: "root path remains unchanged", + path: "/", + want: "/", + }, + { + name: "normal path lowercases first rune of method only", + path: methodRouteMuxGetUserPath, + want: methodRouteMuxGetUserLowerPath, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, normalizeMethodRouteKey(tt.path)) + }) + } +} + +func TestLowerFirstRuneEdgeCases(t *testing.T) { + assert.Empty(t, lowerFirstRune("")) + assert.Equal(t, "abc", lowerFirstRune("Abc")) + assert.Equal(t, "äbc", lowerFirstRune("Äbc")) +} + +func TestMethodRouteMuxFallbackCollisionFirstRegistrationWins(t *testing.T) { + mux := newMethodRouteMux() + firstPattern := methodRouteMuxGetUserPath + secondPattern := methodRouteMuxGetUserLowerPath + + mux.Handle(firstPattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("first")) + })) + mux.Handle(secondPattern, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("second")) + })) + + lowerKey := normalizeMethodRouteKey(firstPattern) + mux.mu.RLock() + entry, ok := mux.lower[lowerKey] + mux.mu.RUnlock() + + if !assert.True(t, ok) { + return + } + assert.Equal(t, firstPattern, entry.pattern) + + req := httptest.NewRequest(http.MethodPost, firstPattern, nil) + resp := httptest.NewRecorder() + entry.handler.ServeHTTP(resp, req) + assert.Equal(t, "first", resp.Body.String()) +} diff --git a/protocol/triple/triple_protocol/server.go b/protocol/triple/triple_protocol/server.go index 1346f9b0c..ef162ff8c 100644 --- a/protocol/triple/triple_protocol/server.go +++ b/protocol/triple/triple_protocol/server.go @@ -45,7 +45,7 @@ import ( type Server struct { addr string - mux *http.ServeMux + mux *methodRouteMux handlers map[string]*Handler httpSrv *http.Server http3Srv *http3.Server @@ -338,7 +338,7 @@ func (s *Server) GracefulStop(ctx context.Context) error { func NewServer(addr string, tripleConf *global.TripleConfig) *Server { return &Server{ - mux: http.NewServeMux(), + mux: newMethodRouteMux(), addr: addr, handlers: make(map[string]*Handler), tripleConfig: tripleConf, diff --git a/protocol/triple/triple_protocol/triple_ext_test.go b/protocol/triple/triple_protocol/triple_ext_test.go index bcc506342..e26a17de9 100644 --- a/protocol/triple/triple_protocol/triple_ext_test.go +++ b/protocol/triple/triple_protocol/triple_ext_test.go @@ -23,6 +23,7 @@ import ( "fmt" "io" "math/rand" + "net" "net/http" "net/http/httptest" "strings" @@ -790,8 +791,17 @@ func TestGRPCMissingTrailersError(t *testing.T) { func TestUnavailableIfHostInvalid(t *testing.T) { t.Parallel() + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = nil + transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) { + return nil, &net.DNSError{ + Err: "no such host", + Name: "api.invalid", + IsNotFound: true, + } + } client := pingv1connect.NewPingServiceClient( - http.DefaultClient, + &http.Client{Transport: transport}, "https://api.invalid/", ) err := client.Ping( diff --git a/server/server.go b/server/server.go index c4bba3950..eacb0c391 100644 --- a/server/server.go +++ b/server/server.go @@ -23,7 +23,6 @@ import ( "reflect" "sort" "strconv" - "strings" "sync" ) @@ -237,32 +236,33 @@ func createReflectionMethodFunc(method reflect.Method) func(ctx context.Context, } } -// Add a method with a name of a different first-letter case -// to achieve interoperability with java -// TODO: The method name case sensitivity in Dubbo-java should be addressed. -// We ought to make changes to handle this issue. +// enhanceServiceInfo fills in missing MethodFunc entries via reflection. +// Case-insensitive Triple routing is handled in the transport-layer route mux, +// but lowercase-first ServiceInfo method names still need MethodFunc backfill so +// reflection-based invocation can reach the exported Go method. func enhanceServiceInfo(info *common.ServiceInfo) *common.ServiceInfo { if info == nil { return info } - // Get service type for reflection-based method calls var svcType reflect.Type if info.ServiceType != nil { svcType = reflect.TypeOf(info.ServiceType) } - // Build method map for reflection lookup + // Build method map for reflection lookup. + // Keep the first-rune-swapped alias for lowercase-first ServiceInfo names + // (for example "sayHello" -> "SayHello") without duplicating metadata. methodMap := make(map[string]reflect.Method) if svcType != nil { for i := 0; i < svcType.NumMethod(); i++ { m := svcType.Method(i) methodMap[m.Name] = m - methodMap[strings.ToLower(m.Name)] = m + methodMap[dubboutil.SwapCaseFirstRune(m.Name)] = m } } - // Add MethodFunc to methods that don't have it + // Fill in MethodFunc for methods that don't already have one. for i := range info.Methods { if info.Methods[i].MethodFunc == nil && svcType != nil { if reflectMethod, ok := methodMap[info.Methods[i].Name]; ok { @@ -271,22 +271,6 @@ func enhanceServiceInfo(info *common.ServiceInfo) *common.ServiceInfo { } } - // Create additional methods with swapped-case names for Java interoperability - var additionalMethods []common.MethodInfo - for _, method := range info.Methods { - newMethod := method - newMethod.Name = dubboutil.SwapCaseFirstRune(method.Name) - if method.MethodFunc != nil { - newMethod.MethodFunc = method.MethodFunc - } else if svcType != nil { - if reflectMethod, ok := methodMap[dubboutil.SwapCaseFirstRune(method.Name)]; ok { - newMethod.MethodFunc = createReflectionMethodFunc(reflectMethod) - } - } - additionalMethods = append(additionalMethods, newMethod) - } - info.Methods = append(info.Methods, additionalMethods...) - return info } diff --git a/server/server_test.go b/server/server_test.go index a12daa36c..edced826a 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -18,6 +18,7 @@ package server import ( + "context" "reflect" "strconv" "sync" @@ -299,11 +300,57 @@ func TestEnhanceServiceInfo(t *testing.T) { result := enhanceServiceInfo(info) assert.NotNil(t, result) - // Should have doubled methods (original + case-swapped) - assert.Len(t, result.Methods, 2) + // ServiceInfo.Methods must only contain original method names. + // Swapped-case aliases are registered at the transport layer, not here. + assert.Len(t, result.Methods, 1) assert.Equal(t, "sayHello", result.Methods[0].Name) - // The swapped version should have capitalized first letter - assert.Equal(t, "SayHello", result.Methods[1].Name) +} + +// greetServiceForTest is a minimal service used to verify MethodFunc backfill. +type greetServiceForTest struct{} + +func (g *greetServiceForTest) Greet(ctx context.Context, req string) (string, error) { + return req, nil +} + +func (g *greetServiceForTest) Reference() string { return "greetServiceForTest" } + +// TestEnhanceServiceInfoMethodFuncBackfillExactName verifies that +// enhanceServiceInfo fills in MethodFunc when the ServiceInfo method name +// matches the Go exported method name exactly (PascalCase). +func TestEnhanceServiceInfoMethodFuncBackfillExactName(t *testing.T) { + svc := &greetServiceForTest{} + info := &common.ServiceInfo{ + ServiceType: svc, + Methods: []common.MethodInfo{ + {Name: "Greet"}, // exact Go name — must be found + }, + } + + result := enhanceServiceInfo(info) + assert.NotNil(t, result) + assert.Len(t, result.Methods, 1) + assert.NotNil(t, result.Methods[0].MethodFunc, + "MethodFunc must be filled in for exact-name match to avoid nil-func panic") +} + +// TestEnhanceServiceInfoMethodFuncBackfillJavaStyleName verifies that +// enhanceServiceInfo still fills in MethodFunc for lowercase-first method names +// so reflection-based invocation can reach the exported Go method. +func TestEnhanceServiceInfoMethodFuncBackfillJavaStyleName(t *testing.T) { + svc := &greetServiceForTest{} + info := &common.ServiceInfo{ + ServiceType: svc, + Methods: []common.MethodInfo{ + {Name: "greet"}, // Java/Dubbo-style lowercase-first name + }, + } + + result := enhanceServiceInfo(info) + assert.NotNil(t, result) + assert.Len(t, result.Methods, 1) + assert.NotNil(t, result.Methods[0].MethodFunc, + "MethodFunc must be found via swapped-case lookup to avoid nil-func panic on lowercase-first method names") } // Test getMetadataPort with default protocol diff --git a/server/triple_case_route_integration_test.go b/server/triple_case_route_integration_test.go new file mode 100644 index 000000000..db8604ca4 --- /dev/null +++ b/server/triple_case_route_integration_test.go @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "context" + "io" + "net" + "net/http" + "strconv" + "strings" + "testing" + "time" +) + +import ( + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +import ( + "dubbo.apache.org/dubbo-go/v3/common" + "dubbo.apache.org/dubbo-go/v3/common/constant" + "dubbo.apache.org/dubbo-go/v3/common/extension" + _ "dubbo.apache.org/dubbo-go/v3/filter/accesslog" // Register default provider filters for exportServices in this integration test. + _ "dubbo.apache.org/dubbo-go/v3/filter/echo" // Register default provider filters for exportServices in this integration test. + _ "dubbo.apache.org/dubbo-go/v3/filter/exec_limit" // Register default provider filters for exportServices in this integration test. + _ "dubbo.apache.org/dubbo-go/v3/filter/generic" // Register default provider filters for exportServices in this integration test. + _ "dubbo.apache.org/dubbo-go/v3/filter/graceful_shutdown" // Register default provider filters for exportServices in this integration test. + _ "dubbo.apache.org/dubbo-go/v3/filter/token" // Register default provider filters for exportServices in this integration test. + _ "dubbo.apache.org/dubbo-go/v3/filter/tps" // Register default provider filters for exportServices in this integration test. + "dubbo.apache.org/dubbo-go/v3/protocol" + _ "dubbo.apache.org/dubbo-go/v3/protocol/protocolwrapper" // Register protocol wrappers used by exportServices in this integration test. + _ "dubbo.apache.org/dubbo-go/v3/protocol/triple" // Register the Triple protocol used by exportServices in this integration test. + tri "dubbo.apache.org/dubbo-go/v3/protocol/triple/triple_protocol" + _ "dubbo.apache.org/dubbo-go/v3/proxy/proxy_factory" // Register the default proxy factory used during export in this integration test. +) + +const ( + tripleCaseRouteHelloBody = "\"Hello test\"" + tripleCaseRouteNotFoundBody = "404 page not found\n" +) + +type TripleCaseRouteService struct{} + +func (s *TripleCaseRouteService) Reference() string { + return "com.example.GreetService" +} + +func (s *TripleCaseRouteService) SayHello(context.Context, *emptypb.Empty) (*wrapperspb.StringValue, error) { + return wrapperspb.String("Hello test"), nil +} + +func TestTripleCaseInsensitiveRoute(t *testing.T) { + t.Run("pascal case registration accepts lowercase fallback", func(t *testing.T) { + runTripleCaseRouteIntegration( + t, + "com.example.GreetService", + "SayHello", + "/com.example.GreetService/SayHello", + []tripleRouteExpectation{ + { + name: "canonical method name", + path: "/com.example.GreetService/SayHello", + wantStatus: http.StatusOK, + wantBody: tripleCaseRouteHelloBody, + }, + { + name: "lowercase method name fallback", + path: "/com.example.GreetService/sayHello", + wantStatus: http.StatusOK, + wantBody: tripleCaseRouteHelloBody, + }, + { + name: "unknown method remains not found", + path: "/com.example.GreetService/DeleteUser", + wantStatus: http.StatusNotFound, + wantBody: tripleCaseRouteNotFoundBody, + }, + }, + ) + }) + + t.Run("camel case registration accepts uppercase fallback", func(t *testing.T) { + runTripleCaseRouteIntegration( + t, + "com.example.JavaStyleGreetService", + "sayHello", + "/com.example.JavaStyleGreetService/sayHello", + []tripleRouteExpectation{ + { + name: "registered camel case method name", + path: "/com.example.JavaStyleGreetService/sayHello", + wantStatus: http.StatusOK, + wantBody: tripleCaseRouteHelloBody, + }, + { + name: "uppercase method name fallback", + path: "/com.example.JavaStyleGreetService/SayHello", + wantStatus: http.StatusOK, + wantBody: tripleCaseRouteHelloBody, + }, + { + name: "unknown method remains not found", + path: "/com.example.JavaStyleGreetService/DeleteUser", + wantStatus: http.StatusNotFound, + wantBody: tripleCaseRouteNotFoundBody, + }, + }, + ) + }) +} + +type tripleRouteExpectation struct { + name string + path string + wantStatus int + wantBody string +} + +func runTripleCaseRouteIntegration( + t *testing.T, + interfaceName string, + methodName string, + readyPath string, + expectations []tripleRouteExpectation, +) { + t.Helper() + + port := testFreePort(t) + srv, err := NewServer( + WithServerProtocol( + protocol.WithTriple(), + protocol.WithIp("127.0.0.1"), + protocol.WithPort(port), + ), + ) + require.NoError(t, err) + + service := &TripleCaseRouteService{} + info := &common.ServiceInfo{ + InterfaceName: interfaceName, + ServiceType: service, + Methods: []common.MethodInfo{ + { + Name: methodName, + Type: constant.CallUnary, + ReqInitFunc: func() any { + return &emptypb.Empty{} + }, + MethodFunc: func(ctx context.Context, args []any, handler any) (any, error) { + req := args[0].(*emptypb.Empty) + res, callErr := handler.(*TripleCaseRouteService).SayHello(ctx, req) + if callErr != nil { + return nil, callErr + } + return tri.NewResponse(res), nil + }, + }, + }, + } + + err = srv.Register(service, info, WithInterface(info.InterfaceName), WithNotRegister()) + require.NoError(t, err) + err = srv.exportServices() + require.NoError(t, err) + + t.Cleanup(func() { + extension.GetProtocol(constant.TriProtocol).Destroy() + }) + + client := &http.Client{Timeout: 2 * time.Second} + baseURL := "http://127.0.0.1:" + strconv.Itoa(port) + waitTripleRouteReady(t, client, baseURL+readyPath) + + for _, tt := range expectations { + t.Run(tt.name, func(t *testing.T) { + status, body, reqErr := tripleRouteRequest(client, baseURL+tt.path) + require.NoError(t, reqErr) + assert.Equal(t, tt.wantStatus, status) + assert.Equal(t, tt.wantBody, body) + }) + } +} + +func testFreePort(t *testing.T) int { + t.Helper() + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + _ = listener.Close() + }() + + return listener.Addr().(*net.TCPAddr).Port +} + +func tripleRouteRequest(client *http.Client, url string) (int, string, error) { + req, err := http.NewRequest(http.MethodPost, url, strings.NewReader("{}")) + if err != nil { + return 0, "", err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + return 0, "", err + } + defer func() { + _ = resp.Body.Close() + }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return 0, "", err + } + return resp.StatusCode, string(body), nil +} + +func waitTripleRouteReady(t *testing.T, client *http.Client, url string) { + t.Helper() + + deadline := time.Now().Add(5 * time.Second) + var lastStatus int + var lastBody string + var lastErr error + + for time.Now().Before(deadline) { + lastStatus, lastBody, lastErr = tripleRouteRequest(client, url) + if lastErr == nil && lastStatus == http.StatusOK && lastBody == tripleCaseRouteHelloBody { + return + } + time.Sleep(50 * time.Millisecond) + } + + t.Fatalf("triple route not ready: status=%d body=%q err=%v", lastStatus, lastBody, lastErr) +}
