This is an automated email from the ASF dual-hosted git repository.
wusheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-mcp.git
The following commit(s) were added to refs/heads/main by this push:
new 74320e3 refactor: remove the unnecessary tool and parameter,
implement unit tests (#39)
74320e3 is described below
commit 74320e315f4bbe268f20ebc220792b963bea7cc2
Author: Fine0830 <[email protected]>
AuthorDate: Tue Mar 31 14:25:05 2026 +0800
refactor: remove the unnecessary tool and parameter, implement unit tests
(#39)
---
.github/workflows/CI.yaml | 3 +
CLAUDE.md | 11 +-
Makefile | 12 +-
README.md | 8 +-
internal/swmcp/server.go | 53 +------
internal/swmcp/server_registry_test.go | 280 +++++++++++++++++++++++++++++++++
internal/swmcp/server_test.go | 171 ++++++++++++++++++++
internal/swmcp/session.go | 138 ----------------
internal/swmcp/sse.go | 2 +-
internal/swmcp/stdio.go | 2 +-
internal/swmcp/streamable.go | 2 +-
internal/tools/mqe.go | 142 ++++++++++++++++-
internal/tools/mqe_test.go | 119 ++++++++++++++
13 files changed, 740 insertions(+), 203 deletions(-)
diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml
index 6cd622e..1e7e362 100644
--- a/.github/workflows/CI.yaml
+++ b/.github/workflows/CI.yaml
@@ -67,6 +67,9 @@ jobs:
- name: Lint
run: make lint
+ - name: Test
+ run: make test
+
- name: Build Docker images
run: make docker
diff --git a/CLAUDE.md b/CLAUDE.md
index 5de6db1..dfaf53c 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -32,7 +32,7 @@ make build-image # Build Docker image
skywalking-mcp:latest
make clean # Remove build artifacts
```
-No unit tests exist yet. CI runs license checks, lint, and docker build.
+Unit tests exist for selected transport/context behavior. CI runs license
checks, lint, and docker build.
## Architecture
@@ -41,12 +41,11 @@ No unit tests exist yet. CI runs license checks, lint, and
docker build.
Three MCP transport modes as cobra subcommands: `stdio`, `sse`, `streamable`.
The SkyWalking OAP URL is resolved in priority order:
-- **stdio**: `set_skywalking_url` session tool > `--sw-url` flag >
`http://localhost:12800/graphql`
-- **SSE/HTTP**: `SW-URL` HTTP header > `--sw-url` flag >
`http://localhost:12800/graphql`
+- **All transports**: `--sw-url` flag > `http://localhost:12800/graphql`
-The `set_skywalking_url` tool is only available in stdio mode (single client,
well-defined session). SSE and HTTP transports use per-request headers instead.
+SSE and HTTP transports always use the configured server URL.
-Basic auth is configured via `--sw-username` / `--sw-password` flags. Both
flags (and the `set_skywalking_url` tool) support `${ENV_VAR}` syntax to
resolve credentials from environment variables (e.g. `--sw-password
${MY_SECRET}`).
+Basic auth is configured via `--sw-username` / `--sw-password` flags. The
startup flags support `${ENV_VAR}` syntax to resolve credentials from
environment variables (e.g. `--sw-password ${MY_SECRET}`).
Each transport injects the OAP URL and auth into the request context via
`WithSkyWalkingURLAndInsecure()` and `WithSkyWalkingAuth()`. Tools extract them
downstream using `skywalking-cli`'s `contextkey.BaseURL{}`,
`contextkey.Username{}`, and `contextkey.Password{}`.
@@ -99,4 +98,4 @@ Tool handlers should return `(mcp.NewToolResultError(...),
nil)` for expected qu
## CI & Merge Policy
-Squash-merge only. PRs to `main` require 1 approval and passing `Required`
status check (license + lint + docker build). Go 1.25.
\ No newline at end of file
+Squash-merge only. PRs to `main` require 1 approval and passing `Required`
status check (license + lint + docker build). Go 1.25.
diff --git a/Makefile b/Makefile
index 982d91a..a28039a 100644
--- a/Makefile
+++ b/Makefile
@@ -31,6 +31,8 @@ PLATFORMS ?= linux/amd64
MULTI_PLATFORMS ?= linux/amd64,linux/arm64
OUTPUT ?= --load
IMAGE_TAGS ?= -t $(IMAGE):$(VERSION) -t $(IMAGE):latest
+GO_TEST_FLAGS ?=
+GO_TEST_PKGS ?= ./...
.PHONY: all
all: build ;
@@ -48,6 +50,14 @@ build: ## Build the binary.
-X ${VERSION_PATH}.date=${BUILD_DATE}" \
-o bin/swmcp cmd/skywalking-mcp/main.go
+.PHONY: test
+test: ## Run unit tests.
+ go test $(GO_TEST_FLAGS) $(GO_TEST_PKGS)
+
+.PHONY: test-cover
+test-cover: ## Run unit tests with coverage output in coverage.txt.
+ go test $(GO_TEST_FLAGS) -coverprofile=coverage.txt $(GO_TEST_PKGS)
+
$(GO_LINT):
@$(GO_LINT) version > /dev/null 2>&1 || go install
github.com/golangci/golangci-lint/cmd/[email protected]
$(LICENSE_EYE):
@@ -139,7 +149,7 @@ PUSH_RELEASE_SCRIPTS := ./scripts/push-release.sh
release-push-candidate:
${PUSH_RELEASE_SCRIPTS}
-.PHONY: lint fix-lint
+.PHONY: lint fix-lint test test-cover
.PHONY: license-header fix-license-header dependency-license
fix-dependency-license
.PHONY: release-binary release-source release-sign release-assembly
.PHONY: release-push-candidate docker-build-multi
diff --git a/README.md b/README.md
index a269b79..af91904 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,11 @@ bin/swmcp stdio --sw-url http://localhost:12800
--sw-username admin --sw-passwor
bin/swmcp sse --sse-address localhost:8000 --base-path /mcp --sw-url
http://localhost:12800
```
+Transport URL behavior:
+
+- `stdio`, `sse`, and `streamable` all use the configured `--sw-url` value (or
the default `http://localhost:12800/graphql`).
+- `sse` and `streamable` ignore request-level URL override headers.
+
### Usage with Cursor, Copilot, Claude Code
```json
@@ -128,7 +133,6 @@ SkyWalking MCP provides the following tools to query and
analyze SkyWalking OAP
| Category | Tool Name | Description
|
|--------------|--------------------------------|---------------------------------------------------------------------------------------------------|
-| **Session** | `set_skywalking_url` | Set the SkyWalking OAP
server URL and optional basic auth credentials for the current session (stdio
mode only). Supports `${ENV_VAR}` syntax for credentials. |
| **Trace** | `query_traces` | Query traces with
multi-condition filtering (service, endpoint, state, tags, and time range via
start/end/step). Supports `full`, `summary`, and `errors_only` views with
performance insights. |
| **Log** | `query_logs` | Query logs with filters for
service, instance, endpoint, trace ID, tags, and time range. Supports cold
storage and pagination. |
| **MQE** | `execute_mqe_expression` | Execute MQE (Metrics Query
Expression) to query and calculate metrics data. Supports calculations,
aggregations, TopN, trend analysis, and multiple result types. |
@@ -176,4 +180,4 @@ SkyWalking MCP provides the following prompts for guided
analysis workflows:
[Apache 2.0 License.](/LICENSE)
-[mcp]: https://modelcontextprotocol.io/
\ No newline at end of file
+[mcp]: https://modelcontextprotocol.io/
diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go
index 8b4a740..0cdba97 100644
--- a/internal/swmcp/server.go
+++ b/internal/swmcp/server.go
@@ -37,18 +37,13 @@ import (
)
// newMCPServer creates a new MCP server with all tools, resources, and
prompts registered.
-// When stdio is true, session management tools (set_skywalking_url) are also
registered,
-// since stdio has a single client and session semantics are well-defined.
-func newMCPServer(stdio bool) *server.MCPServer {
+func newMCPServer() *server.MCPServer {
s := server.NewMCPServer(
"skywalking-mcp", "0.1.0",
server.WithResourceCapabilities(true, true),
server.WithPromptCapabilities(true),
server.WithLogging(),
)
- if stdio {
- AddSessionTools(s)
- }
tools.AddTraceTools(s)
tools.AddLogTools(s)
tools.AddMQETools(s)
@@ -131,63 +126,31 @@ func withConfiguredAuth(ctx context.Context)
context.Context {
return ctx
}
-// urlFromHeaders extracts URL for a request.
-// URL is sourced from Header > configured value > Default.
-func urlFromHeaders(req *http.Request) string {
- urlStr := req.Header.Get("SW-URL")
- if urlStr == "" {
- return configuredSkyWalkingURL()
- }
-
- return tools.FinalizeURL(urlStr)
-}
-
-// applySessionOverrides checks for a session in the context and applies any
-// URL or auth overrides that were set via the set_skywalking_url tool.
-func applySessionOverrides(ctx context.Context) context.Context {
- session := SessionFromContext(ctx)
- if session == nil {
- return ctx
- }
- if url := session.URL(); url != "" {
- ctx = context.WithValue(ctx, contextkey.BaseURL{}, url)
- }
- if username := session.Username(); username != "" {
- ctx = WithSkyWalkingAuth(ctx, username, session.Password())
- }
- return ctx
-}
-
// EnhanceStdioContextFunc returns a StdioContextFunc that enriches the context
-// with SkyWalking settings from the global configuration and a per-session
store.
+// with SkyWalking settings from the global configuration.
func EnhanceStdioContextFunc() server.StdioContextFunc {
- session := &Session{}
return func(ctx context.Context) context.Context {
- ctx = WithSession(ctx, session)
ctx = WithSkyWalkingURLAndInsecure(ctx,
configuredSkyWalkingURL(), false)
ctx = withConfiguredAuth(ctx)
- ctx = applySessionOverrides(ctx)
return ctx
}
}
// EnhanceSSEContextFunc returns a SSEContextFunc that enriches the context
-// with SkyWalking settings from SSE request headers and CLI-configured auth.
+// with SkyWalking settings from the CLI configuration and configured auth.
func EnhanceSSEContextFunc() server.SSEContextFunc {
- return func(ctx context.Context, req *http.Request) context.Context {
- urlStr := urlFromHeaders(req)
- ctx = WithSkyWalkingURLAndInsecure(ctx, urlStr, false)
+ return func(ctx context.Context, _ *http.Request) context.Context {
+ ctx = WithSkyWalkingURLAndInsecure(ctx,
configuredSkyWalkingURL(), false)
ctx = withConfiguredAuth(ctx)
return ctx
}
}
// EnhanceHTTPContextFunc returns a HTTPContextFunc that enriches the context
-// with SkyWalking settings from HTTP request headers and CLI-configured auth.
+// with SkyWalking settings from the CLI configuration and configured auth.
func EnhanceHTTPContextFunc() server.HTTPContextFunc {
- return func(ctx context.Context, req *http.Request) context.Context {
- urlStr := urlFromHeaders(req)
- ctx = WithSkyWalkingURLAndInsecure(ctx, urlStr, false)
+ return func(ctx context.Context, _ *http.Request) context.Context {
+ ctx = WithSkyWalkingURLAndInsecure(ctx,
configuredSkyWalkingURL(), false)
ctx = withConfiguredAuth(ctx)
return ctx
}
diff --git a/internal/swmcp/server_registry_test.go
b/internal/swmcp/server_registry_test.go
new file mode 100644
index 0000000..55d1384
--- /dev/null
+++ b/internal/swmcp/server_registry_test.go
@@ -0,0 +1,280 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package swmcp
+
+import (
+ "reflect"
+ "sort"
+ "testing"
+ "unsafe"
+
+ "github.com/mark3labs/mcp-go/mcp"
+ "github.com/mark3labs/mcp-go/server"
+)
+
+// These registry tests verify that newMCPServer wires up the expected tools,
+// prompts, and resources. mcp-go v0.45.0 does not expose a public inventory
API
+// for MCPServer, so the tests read server internals through a single helper
+// layer below. If mcp-go changes its internal field layout, update only the
+// helpers in this file rather than spreading reflect/unsafe access across
tests.
+
+func TestNewMCPServerRegistersExpectedTools(t *testing.T) {
+ srv := newMCPServer()
+
+ got := sortedToolNames(srv)
+ want := []string{
+ "execute_mqe_expression",
+ "get_mqe_metric_type",
+ "list_endpoints",
+ "list_instances",
+ "list_layers",
+ "list_mqe_metrics",
+ "list_processes",
+ "list_services",
+ "query_alarms",
+ "query_endpoints_topology",
+ "query_events",
+ "query_instances_topology",
+ "query_logs",
+ "query_processes_topology",
+ "query_services_topology",
+ "query_traces",
+ }
+
+ assertStringSlicesEqual(t, got, want)
+}
+
+func TestNewMCPServerRegistersExpectedPrompts(t *testing.T) {
+ srv := newMCPServer()
+
+ got := sortedPromptNames(srv)
+ want := []string{
+ "analyze-logs",
+ "analyze-performance",
+ "build-mqe-query",
+ "compare-services",
+ "explore-metrics",
+ "explore-service-topology",
+ "generate_duration",
+ "investigate-traces",
+ "top-services",
+ "trace-deep-dive",
+ }
+
+ assertStringSlicesEqual(t, got, want)
+}
+
+func TestNewMCPServerRegistersExpectedResources(t *testing.T) {
+ srv := newMCPServer()
+
+ resources := resourceMap(srv)
+ got := make([]string, 0, len(resources))
+ for uri := range resources {
+ got = append(got, uri)
+ }
+ sort.Strings(got)
+
+ want := []string{
+ "mqe://docs/ai_prompt",
+ "mqe://docs/examples",
+ "mqe://docs/syntax",
+ "mqe://metrics/available",
+ }
+
+ assertStringSlicesEqual(t, got, want)
+}
+
+func TestPromptMetadataIncludesExpectedArguments(t *testing.T) {
+ srv := newMCPServer()
+ prompts := promptMap(srv)
+
+ prompt, ok := prompts["generate_duration"]
+ if !ok {
+ t.Fatal("generate_duration prompt not registered")
+ }
+ if prompt.Description == "" {
+ t.Fatal("generate_duration prompt description is empty")
+ }
+ if len(prompt.Arguments) != 1 {
+ t.Fatalf("generate_duration prompt arguments = %d, want 1",
len(prompt.Arguments))
+ }
+ if prompt.Arguments[0].Name != "time_range" ||
!prompt.Arguments[0].Required {
+ t.Fatalf("unexpected generate_duration argument: %+v",
prompt.Arguments[0])
+ }
+
+ tracePrompt, ok := prompts["trace-deep-dive"]
+ if !ok {
+ t.Fatal("trace-deep-dive prompt not registered")
+ }
+ if len(tracePrompt.Arguments) != 2 {
+ t.Fatalf("trace-deep-dive prompt arguments = %d, want 2",
len(tracePrompt.Arguments))
+ }
+ if tracePrompt.Arguments[0].Name != "trace_id" ||
!tracePrompt.Arguments[0].Required {
+ t.Fatalf("unexpected first trace-deep-dive argument: %+v",
tracePrompt.Arguments[0])
+ }
+}
+
+func TestResourceMetadataIncludesExpectedMIMETypes(t *testing.T) {
+ srv := newMCPServer()
+ resources := resourceMap(srv)
+
+ tests := []struct {
+ uri string
+ name string
+ mimeType string
+ }{
+ {uri: "mqe://docs/syntax", name: "MQE Detailed Syntax Rules",
mimeType: "text/markdown"},
+ {uri: "mqe://docs/examples", name: "MQE Examples", mimeType:
"application/json"},
+ {uri: "mqe://metrics/available", name: "Available Metrics",
mimeType: "application/json"},
+ {uri: "mqe://docs/ai_prompt", name: "MQE AI Understanding
Guide", mimeType: "text/markdown"},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.uri, func(t *testing.T) {
+ resource, ok := resources[tc.uri]
+ if !ok {
+ t.Fatalf("resource %q not registered", tc.uri)
+ }
+ if resource.Name != tc.name {
+ t.Fatalf("resource name = %q, want %q",
resource.Name, tc.name)
+ }
+ if resource.MIMEType != tc.mimeType {
+ t.Fatalf("resource MIME type = %q, want %q",
resource.MIMEType, tc.mimeType)
+ }
+ })
+ }
+}
+
+func TestToolMetadataIncludesExpectedDescriptionsAndSchemas(t *testing.T) {
+ srv := newMCPServer()
+ tools := toolMap(srv)
+
+ tests := []struct {
+ name string
+ expectDesc bool
+ expectProperties []string
+ }{
+ {name: "query_traces", expectDesc: true, expectProperties:
[]string{"service_id", "trace_id", "view"}},
+ {name: "execute_mqe_expression", expectDesc: true,
expectProperties: []string{"expression", "service_name", "debug"}},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ tool, ok := tools[tc.name]
+ if !ok {
+ t.Fatalf("tool %q not registered", tc.name)
+ }
+ if tc.expectDesc && tool.Description == "" {
+ t.Fatalf("tool %q description is empty",
tc.name)
+ }
+ properties := tool.InputSchema.Properties
+ for _, property := range tc.expectProperties {
+ if _, ok := properties[property]; !ok {
+ t.Fatalf("tool %q missing input schema
property %q", tc.name, property)
+ }
+ }
+ })
+ }
+}
+
+func toolMap(srv *server.MCPServer) map[string]mcp.Tool {
+ serverTools := mustReadServerField(testedServerValue(srv), "tools")
+ result := make(map[string]mcp.Tool, serverTools.Len())
+
+ iter := serverTools.MapRange()
+ for iter.Next() {
+ name := iter.Key().String()
+ toolValue := copyReflectValue(iter.Value())
+ result[name] =
toolValue.FieldByName("Tool").Interface().(mcp.Tool)
+ }
+
+ return result
+}
+
+func promptMap(srv *server.MCPServer) map[string]mcp.Prompt {
+ serverPrompts := mustReadServerField(testedServerValue(srv), "prompts")
+ result := make(map[string]mcp.Prompt, serverPrompts.Len())
+
+ iter := serverPrompts.MapRange()
+ for iter.Next() {
+ result[iter.Key().String()] =
copyReflectValue(iter.Value()).Interface().(mcp.Prompt)
+ }
+
+ return result
+}
+
+func resourceMap(srv *server.MCPServer) map[string]mcp.Resource {
+ serverResources := mustReadServerField(testedServerValue(srv),
"resources")
+ result := make(map[string]mcp.Resource, serverResources.Len())
+
+ iter := serverResources.MapRange()
+ for iter.Next() {
+ resourceField :=
copyReflectValue(iter.Value()).FieldByName("resource")
+ result[iter.Key().String()] =
readPrivateField(resourceField).Interface().(mcp.Resource)
+ }
+
+ return result
+}
+
+func sortedToolNames(srv *server.MCPServer) []string {
+ tools := toolMap(srv)
+ names := make([]string, 0, len(tools))
+ for name := range tools {
+ names = append(names, name)
+ }
+ sort.Strings(names)
+ return names
+}
+
+func sortedPromptNames(srv *server.MCPServer) []string {
+ prompts := promptMap(srv)
+ names := make([]string, 0, len(prompts))
+ for name := range prompts {
+ names = append(names, name)
+ }
+ sort.Strings(names)
+ return names
+}
+
+func assertStringSlicesEqual(t *testing.T, got, want []string) {
+ t.Helper()
+ if !reflect.DeepEqual(got, want) {
+ t.Fatalf("values mismatch:\n got: %v\nwant: %v", got, want)
+ }
+}
+
+func readPrivateField(v reflect.Value) reflect.Value {
+ return reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem()
+}
+
+func testedServerValue(srv *server.MCPServer) reflect.Value {
+ return reflect.ValueOf(srv).Elem()
+}
+
+func mustReadServerField(srv reflect.Value, fieldName string) reflect.Value {
+ field := srv.FieldByName(fieldName)
+ if !field.IsValid() {
+ panic("mcp-go MCPServer no longer has field " + fieldName)
+ }
+ return readPrivateField(field)
+}
+
+func copyReflectValue(v reflect.Value) reflect.Value {
+ cloned := reflect.New(v.Type()).Elem()
+ cloned.Set(v)
+ return cloned
+}
diff --git a/internal/swmcp/server_test.go b/internal/swmcp/server_test.go
new file mode 100644
index 0000000..28056c8
--- /dev/null
+++ b/internal/swmcp/server_test.go
@@ -0,0 +1,171 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package swmcp
+
+import (
+ "context"
+ "net/http"
+ "testing"
+
+ "github.com/apache/skywalking-cli/pkg/contextkey"
+ "github.com/spf13/viper"
+
+ "github.com/apache/skywalking-mcp/internal/config"
+)
+
+const (
+ configuredHTTPOAPURL = "http://configured-oap:12800/graphql"
+ configuredHTTPSOAPURL = "https://configured-oap.example.com/graphql"
+)
+
+func TestConfiguredSkyWalkingURLUsesDefaultWhenUnset(t *testing.T) {
+ t.Cleanup(viper.Reset)
+
+ got := configuredSkyWalkingURL()
+ if got != config.DefaultSWURL {
+ t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got,
config.DefaultSWURL)
+ }
+}
+
+func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) {
+ t.Cleanup(viper.Reset)
+ viper.Set("url", "https://configured-oap.example.com:12800/")
+
+ got := configuredSkyWalkingURL()
+ want := "https://configured-oap.example.com:12800/graphql"
+ if got != want {
+ t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, want)
+ }
+}
+
+func TestResolveEnvVar(t *testing.T) {
+ t.Setenv("SW_TEST_SECRET", "resolved-secret")
+
+ tests := []struct {
+ name string
+ value string
+ want string
+ }{
+ {name: "raw", value: "raw-value", want: "raw-value"},
+ {name: "env", value: "${SW_TEST_SECRET}", want:
"resolved-secret"},
+ {name: "trimmed env", value: " ${SW_TEST_SECRET} ", want:
"resolved-secret"},
+ {name: "missing env", value: "${SW_TEST_MISSING}", want: ""},
+ }
+
+ for _, tc := range tests {
+ t.Run(tc.name, func(t *testing.T) {
+ if got := resolveEnvVar(tc.value); got != tc.want {
+ t.Fatalf("resolveEnvVar(%q) = %q, want %q",
tc.value, got, tc.want)
+ }
+ })
+ }
+}
+
+func TestWithConfiguredAuth(t *testing.T) {
+ t.Cleanup(viper.Reset)
+ t.Setenv("SW_TEST_USER", "env-user")
+ t.Setenv("SW_TEST_PASS", "env-pass")
+ viper.Set("username", "${SW_TEST_USER}")
+ viper.Set("password", "${SW_TEST_PASS}")
+
+ ctx := withConfiguredAuth(context.Background())
+
+ gotUser, _ := ctx.Value(contextkey.Username{}).(string)
+ if gotUser != "env-user" {
+ t.Fatalf("username = %q, want %q", gotUser, "env-user")
+ }
+
+ gotPass, _ := ctx.Value(contextkey.Password{}).(string)
+ if gotPass != "env-pass" {
+ t.Fatalf("password = %q, want %q", gotPass, "env-pass")
+ }
+}
+
+func TestWithConfiguredAuthSkipsEmptyUsername(t *testing.T) {
+ t.Cleanup(viper.Reset)
+ viper.Set("password", "password-only")
+
+ ctx := withConfiguredAuth(context.Background())
+
+ if got, ok := ctx.Value(contextkey.Username{}).(string); ok || got !=
"" {
+ t.Fatalf("username unexpectedly set to %q", got)
+ }
+ if got, ok := ctx.Value(contextkey.Password{}).(string); ok || got !=
"" {
+ t.Fatalf("password unexpectedly set to %q", got)
+ }
+}
+
+func TestEnhanceStdioContextFuncUsesConfiguredURLAndAuth(t *testing.T) {
+ t.Cleanup(viper.Reset)
+ t.Setenv("SW_STDIO_PASS", "stdio-pass")
+ viper.Set("url", "https://configured-oap.example.com")
+ viper.Set("username", "stdio-user")
+ viper.Set("password", "${SW_STDIO_PASS}")
+
+ ctx := EnhanceStdioContextFunc()(context.Background())
+
+ gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string)
+ if gotURL != configuredHTTPSOAPURL {
+ t.Fatalf("base URL = %q", gotURL)
+ }
+
+ gotUser, _ := ctx.Value(contextkey.Username{}).(string)
+ if gotUser != "stdio-user" {
+ t.Fatalf("username = %q", gotUser)
+ }
+
+ gotPass, _ := ctx.Value(contextkey.Password{}).(string)
+ if gotPass != "stdio-pass" {
+ t.Fatalf("password = %q", gotPass)
+ }
+}
+
+func TestEnhanceHTTPContextFuncDoesNotUseSWURLHeader(t *testing.T) {
+ t.Cleanup(viper.Reset)
+ viper.Set("url", "http://configured-oap:12800")
+
+ req, err := http.NewRequest(http.MethodPost, "http://client/request",
http.NoBody)
+ if err != nil {
+ t.Fatalf("create request: %v", err)
+ }
+ req.Header.Set("SW-URL", "http://attacker.invalid:8080")
+
+ ctx := EnhanceHTTPContextFunc()(context.Background(), req)
+
+ gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string)
+ if gotURL != configuredHTTPOAPURL {
+ t.Fatalf("base URL = %q", gotURL)
+ }
+}
+
+func TestEnhanceSSEContextFuncDoesNotUseSWURLHeader(t *testing.T) {
+ t.Cleanup(viper.Reset)
+ viper.Set("url", "https://configured-oap.example.com")
+
+ req, err := http.NewRequest(http.MethodGet, "http://client/events",
http.NoBody)
+ if err != nil {
+ t.Fatalf("create request: %v", err)
+ }
+ req.Header.Set("SW-URL", "https://attacker.invalid")
+
+ ctx := EnhanceSSEContextFunc()(context.Background(), req)
+
+ gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string)
+ if gotURL != configuredHTTPSOAPURL {
+ t.Fatalf("base URL = %q", gotURL)
+ }
+}
diff --git a/internal/swmcp/session.go b/internal/swmcp/session.go
deleted file mode 100644
index 5a684c1..0000000
--- a/internal/swmcp/session.go
+++ /dev/null
@@ -1,138 +0,0 @@
-// Licensed to Apache Software Foundation (ASF) under one or more contributor
-// license agreements. See the NOTICE file distributed with
-// this work for additional information regarding copyright
-// ownership. Apache Software Foundation (ASF) licenses this file to you under
-// the Apache License, Version 2.0 (the "License"); you may
-// not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package swmcp
-
-import (
- "context"
- "fmt"
- "sync"
-
- "github.com/mark3labs/mcp-go/mcp"
- "github.com/mark3labs/mcp-go/server"
-
- "github.com/apache/skywalking-mcp/internal/tools"
-)
-
-// sessionKey is the context key for looking up the session store.
-type sessionKey struct{}
-
-// Session holds per-session SkyWalking connection configuration.
-type Session struct {
- mu sync.RWMutex
- url string
- username string
- password string
-}
-
-// SetConnection updates the session's connection parameters.
-func (s *Session) SetConnection(url, username, password string) {
- s.mu.Lock()
- defer s.mu.Unlock()
- s.url = url
- s.username = username
- s.password = password
-}
-
-// URL returns the session's configured URL, or empty if not set.
-func (s *Session) URL() string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.url
-}
-
-// Username returns the session's configured username.
-func (s *Session) Username() string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.username
-}
-
-// Password returns the session's configured password.
-func (s *Session) Password() string {
- s.mu.RLock()
- defer s.mu.RUnlock()
- return s.password
-}
-
-// SessionFromContext retrieves the session from the context, or nil if not
present.
-func SessionFromContext(ctx context.Context) *Session {
- s, _ := ctx.Value(sessionKey{}).(*Session)
- return s
-}
-
-// WithSession attaches a session to the context.
-func WithSession(ctx context.Context, s *Session) context.Context {
- return context.WithValue(ctx, sessionKey{}, s)
-}
-
-// SetSkyWalkingURLRequest represents the request for the set_skywalking_url
tool.
-type SetSkyWalkingURLRequest struct {
- URL string `json:"url"`
- Username string `json:"username,omitempty"`
- Password string `json:"password,omitempty"`
-}
-
-func setSkyWalkingURL(ctx context.Context, req *SetSkyWalkingURLRequest)
(*mcp.CallToolResult, error) {
- if req.URL == "" {
- return mcp.NewToolResultError("url is required"), nil
- }
-
- session := SessionFromContext(ctx)
- if session == nil {
- return mcp.NewToolResultError("session not available"), nil
- }
-
- finalURL := tools.FinalizeURL(req.URL)
- session.SetConnection(finalURL, resolveEnvVar(req.Username),
resolveEnvVar(req.Password))
-
- msg := fmt.Sprintf("SkyWalking URL set to %s", finalURL)
- if req.Username != "" {
- msg += " with basic auth credentials"
- }
- return mcp.NewToolResultText(msg), nil
-}
-
-// AddSessionTools registers session management tools with the MCP server.
-func AddSessionTools(s *server.MCPServer) {
- tool := tools.NewTool(
- "set_skywalking_url",
- `Set the SkyWalking OAP server URL and optional basic auth
credentials for this session.
-This tool is only available in stdio transport mode.
-
-This tool configures the connection to SkyWalking OAP for all subsequent tool
calls in the current session.
-The URL and credentials persist for the lifetime of the session.
-
-Priority: session URL (set by this tool) > --sw-url flag > default
(http://localhost:12800/graphql)
-For SSE/HTTP transports, use the SW-URL HTTP header or --sw-url flag instead.
-
-Credentials support raw values or environment variable references using
${ENV_VAR} syntax.
-
-Examples:
-- {"url": "http://demo.skywalking.apache.org:12800"}: Connect without auth
-- {"url": "http://oap.internal:12800", "username": "admin", "password":
"admin"}: Connect with basic auth
-- {"url": "https://skywalking.example.com:443", "username": "${SW_USER}",
"password": "${SW_PASS}"}: Auth via env vars`,
- setSkyWalkingURL,
- mcp.WithString("url", mcp.Required(),
- mcp.Description("SkyWalking OAP server URL (required).
Example: http://localhost:12800")),
- mcp.WithString("username",
- mcp.Description("Username for basic auth (optional).
Supports ${ENV_VAR} syntax.")),
- mcp.WithString("password",
- mcp.Description("Password for basic auth (optional).
Supports ${ENV_VAR} syntax.")),
- )
- tool.Register(s)
-}
diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go
index 1e9a04e..14365a9 100644
--- a/internal/swmcp/sse.go
+++ b/internal/swmcp/sse.go
@@ -72,7 +72,7 @@ func runSSEServer(ctx context.Context, cfg
*config.SSEServerConfig) error {
}
sseServer := server.NewSSEServer(
- newMCPServer(false),
+ newMCPServer(),
server.WithStaticBasePath(cfg.BasePath),
server.WithSSEContextFunc(EnhanceSSEContextFunc()),
)
diff --git a/internal/swmcp/stdio.go b/internal/swmcp/stdio.go
index 9dd58df..02abb4a 100644
--- a/internal/swmcp/stdio.go
+++ b/internal/swmcp/stdio.go
@@ -60,7 +60,7 @@ func runStdioServer(ctx context.Context, cfg
*config.StdioServerConfig) error {
ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
defer stop()
- stdioServer := server.NewStdioServer(newMCPServer(true))
+ stdioServer := server.NewStdioServer(newMCPServer())
logger, err := initLogger(cfg.LogFilePath)
if err != nil {
diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go
index 8f41194..0500352 100644
--- a/internal/swmcp/streamable.go
+++ b/internal/swmcp/streamable.go
@@ -57,7 +57,7 @@ func NewStreamable() *cobra.Command {
// runStreamableServer starts the Streamable server with the provided
configuration.
func runStreamableServer(cfg *config.StreamableServerConfig) error {
httpServer := server.NewStreamableHTTPServer(
- newMCPServer(false),
+ newMCPServer(),
server.WithStateLess(true),
server.WithLogger(log.StandardLogger()),
server.WithHTTPContextFunc(EnhanceHTTPContextFunc()),
diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go
index 0f4189a..3a45396 100644
--- a/internal/tools/mqe.go
+++ b/internal/tools/mqe.go
@@ -25,8 +25,10 @@ import (
"fmt"
"io"
"net/http"
- "strings"
+ "regexp"
"time"
+ "unicode"
+ "unicode/utf8"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
@@ -41,6 +43,16 @@ func AddMQETools(mcp *server.MCPServer) {
MQEMetricsTypeTool.Register(mcp)
}
+const (
+ maxMQEExpressionLength = 2048
+ maxMQEExpressionDepth = 12
+ maxMQEEntityFieldLen = 256
+ maxMQERegexLength = 256
+ maxMetricNameLength = 128
+)
+
+var metricNamePattern = regexp.MustCompile(`^[A-Za-z0-9_.:-]+$`)
+
// GraphQLRequest represents a GraphQL request
type GraphQLRequest struct {
Query string `json:"query"`
@@ -101,8 +113,8 @@ func executeGraphQLWithContext(ctx context.Context, query
string, variables map[
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
- bodyBytes, _ := io.ReadAll(resp.Body)
- return nil, fmt.Errorf("HTTP request failed with status: %d,
body: %s", resp.StatusCode, string(bodyBytes))
+ _, _ = io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("GraphQL request failed with HTTP status
%d", resp.StatusCode)
}
var graphqlResp GraphQLResponse
@@ -111,11 +123,7 @@ func executeGraphQLWithContext(ctx context.Context, query
string, variables map[
}
if len(graphqlResp.Errors) > 0 {
- var errorMsgs []string
- for _, err := range graphqlResp.Errors {
- errorMsgs = append(errorMsgs, err.Message)
- }
- return nil, fmt.Errorf("GraphQL errors: %s",
strings.Join(errorMsgs, ", "))
+ return nil, fmt.Errorf("GraphQL query failed")
}
return &graphqlResp, nil
@@ -307,6 +315,9 @@ func executeMQEExpression(ctx context.Context, req
*MQEExpressionRequest) (*mcp.
if req.Expression == "" {
return mcp.NewToolResultError("expression is required"), nil
}
+ if err := validateMQEExpressionRequest(req); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
entity := buildMQEEntity(ctx, req)
timeCtx := GetTimeContext(ctx)
@@ -386,6 +397,10 @@ func executeMQEExpression(ctx context.Context, req
*MQEExpressionRequest) (*mcp.
// listMQEMetrics lists available metrics
func listMQEMetrics(ctx context.Context, req *MQEMetricsListRequest)
(*mcp.CallToolResult, error) {
+ if err := validateMQEMetricsListRequest(req); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
+
// GraphQL query for listing metrics
query := `
query listMetrics($regex: String) {
@@ -438,6 +453,9 @@ func getMQEMetricsType(ctx context.Context, req
*MQEMetricsTypeRequest) (*mcp.Ca
if req.MetricName == "" {
return mcp.NewToolResultError("metric_name must be provided"),
nil
}
+ if err := validateMetricName(req.MetricName); err != nil {
+ return mcp.NewToolResultError(err.Error()), nil
+ }
// GraphQL query for getting metric type
query := `
@@ -462,6 +480,114 @@ func getMQEMetricsType(ctx context.Context, req
*MQEMetricsTypeRequest) (*mcp.Ca
return mcp.NewToolResultText(string(jsonBytes)), nil
}
+func validateMQEExpressionRequest(req *MQEExpressionRequest) error {
+ if err := validateMQEExpression(req.Expression); err != nil {
+ return err
+ }
+
+ for fieldName, value := range map[string]string{
+ "service_name": req.ServiceName,
+ "layer": req.Layer,
+ "service_instance_name": req.ServiceInstanceName,
+ "endpoint_name": req.EndpointName,
+ "process_name": req.ProcessName,
+ "dest_service_name": req.DestServiceName,
+ "dest_layer": req.DestLayer,
+ "dest_service_instance_name": req.DestServiceInstanceName,
+ "dest_endpoint_name": req.DestEndpointName,
+ "dest_process_name": req.DestProcessName,
+ } {
+ if err := validateMQETextField(fieldName, value,
maxMQEEntityFieldLen); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func validateMQEMetricsListRequest(req *MQEMetricsListRequest) error {
+ if req == nil || req.Regex == "" {
+ return nil
+ }
+ if err := validateMQETextField("regex", req.Regex, maxMQERegexLength);
err != nil {
+ return err
+ }
+ if _, err := regexp.Compile(req.Regex); err != nil {
+ return fmt.Errorf("regex is invalid")
+ }
+ return nil
+}
+
+func validateMetricName(metricName string) error {
+ if err := validateMQETextField("metric_name", metricName,
maxMetricNameLength); err != nil {
+ return err
+ }
+ if !metricNamePattern.MatchString(metricName) {
+ return fmt.Errorf("metric_name contains invalid characters")
+ }
+ return nil
+}
+
+func validateMQEExpression(expression string) error {
+ if !utf8.ValidString(expression) {
+ return fmt.Errorf("expression must be valid UTF-8")
+ }
+ if len(expression) > maxMQEExpressionLength {
+ return fmt.Errorf("expression exceeds maximum length of %d
characters", maxMQEExpressionLength)
+ }
+ if containsUnsafeControlChars(expression) {
+ return fmt.Errorf("expression contains invalid control
characters")
+ }
+ if nestingDepth(expression) > maxMQEExpressionDepth {
+ return fmt.Errorf("expression exceeds maximum nesting depth of
%d", maxMQEExpressionDepth)
+ }
+ return nil
+}
+
+func validateMQETextField(fieldName, value string, maxLen int) error {
+ if value == "" {
+ return nil
+ }
+ if !utf8.ValidString(value) {
+ return fmt.Errorf("%s must be valid UTF-8", fieldName)
+ }
+ if len(value) > maxLen {
+ return fmt.Errorf("%s exceeds maximum length of %d characters",
fieldName, maxLen)
+ }
+ if containsUnsafeControlChars(value) {
+ return fmt.Errorf("%s contains invalid control characters",
fieldName)
+ }
+ return nil
+}
+
+func containsUnsafeControlChars(value string) bool {
+ for _, r := range value {
+ if unicode.IsControl(r) && r != '\n' && r != '\r' && r != '\t' {
+ return true
+ }
+ }
+ return false
+}
+
+func nestingDepth(value string) int {
+ depth := 0
+ maxDepth := 0
+ for _, r := range value {
+ switch r {
+ case '(', '{', '[':
+ depth++
+ if depth > maxDepth {
+ maxDepth = depth
+ }
+ case ')', '}', ']':
+ if depth > 0 {
+ depth--
+ }
+ }
+ }
+ return maxDepth
+}
+
var MQEExpressionTool = NewTool(
"execute_mqe_expression",
`Execute MQE (Metrics Query Expression) to query and calculate metrics
data.
diff --git a/internal/tools/mqe_test.go b/internal/tools/mqe_test.go
new file mode 100644
index 0000000..37d08af
--- /dev/null
+++ b/internal/tools/mqe_test.go
@@ -0,0 +1,119 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/apache/skywalking-cli/pkg/contextkey"
+ "github.com/mark3labs/mcp-go/mcp"
+)
+
+func TestValidateMQEExpressionRequestRejectsDeeplyNestedExpression(t
*testing.T) {
+ req := &MQEExpressionRequest{
+ Expression: strings.Repeat("(", maxMQEExpressionDepth+1) +
"service_cpm" + strings.Repeat(")", maxMQEExpressionDepth+1),
+ }
+
+ err := validateMQEExpressionRequest(req)
+ if err == nil || !strings.Contains(err.Error(), "maximum nesting
depth") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestValidateMQEMetricsListRequestRejectsInvalidRegex(t *testing.T) {
+ err := validateMQEMetricsListRequest(&MQEMetricsListRequest{Regex: "("})
+ if err == nil || err.Error() != "regex is invalid" {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestValidateMetricNameRejectsInvalidCharacters(t *testing.T) {
+ err := validateMetricName("service cpm")
+ if err == nil || err.Error() != "metric_name contains invalid
characters" {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func TestExecuteMQEExpressionRejectsOverlongEntityField(t *testing.T) {
+ req := &MQEExpressionRequest{
+ Expression: "service_cpm",
+ ServiceName: strings.Repeat("a", maxMQEEntityFieldLen+1),
+ }
+
+ result, err := executeMQEExpression(context.Background(), req)
+ if err != nil {
+ t.Fatalf("executeMQEExpression returned error: %v", err)
+ }
+ if !result.IsError {
+ t.Fatal("expected tool error result")
+ }
+ assertToolResultContains(t, result, "service_name exceeds maximum
length")
+}
+
+func TestExecuteGraphQLWithContextSanitizesHTTPErrorBody(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _
*http.Request) {
+ http.Error(w, "sensitive backend details",
http.StatusBadGateway)
+ }))
+ defer ts.Close()
+
+ ctx := context.WithValue(context.Background(), contextkey.BaseURL{},
ts.URL)
+ _, err := executeGraphQLWithContext(ctx, "query { ping }",
map[string]interface{}{})
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if !strings.Contains(err.Error(), "HTTP status 502") {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if strings.Contains(err.Error(), "sensitive backend details") {
+ t.Fatalf("backend body leaked in error: %v", err)
+ }
+}
+
+func TestExecuteGraphQLWithContextSanitizesGraphQLErrors(t *testing.T) {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _
*http.Request) {
+ w.Header().Set("Content-Type", "application/json")
+ _, _ = w.Write([]byte(`{"errors":[{"message":"database stack
trace"}]}`))
+ }))
+ defer ts.Close()
+
+ ctx := context.WithValue(context.Background(), contextkey.BaseURL{},
ts.URL)
+ _, err := executeGraphQLWithContext(ctx, "query { ping }",
map[string]interface{}{})
+ if err == nil {
+ t.Fatal("expected error")
+ }
+ if err.Error() != "GraphQL query failed" {
+ t.Fatalf("unexpected error: %v", err)
+ }
+}
+
+func assertToolResultContains(t *testing.T, result *mcp.CallToolResult, want
string) {
+ t.Helper()
+ if len(result.Content) == 0 {
+ t.Fatal("tool result had no content")
+ }
+ text, ok := result.Content[0].(mcp.TextContent)
+ if !ok {
+ t.Fatalf("unexpected content type: %T", result.Content[0])
+ }
+ if !strings.Contains(text.Text, want) {
+ t.Fatalf("tool result text %q does not contain %q", text.Text,
want)
+ }
+}