This is an automated email from the ASF dual-hosted git repository.

wusheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/skywalking-mcp.git


The following commit(s) were added to refs/heads/main by this push:
     new 74320e3  refactor: remove the unnecessary tool and parameter, 
implement unit tests (#39)
74320e3 is described below

commit 74320e315f4bbe268f20ebc220792b963bea7cc2
Author: Fine0830 <[email protected]>
AuthorDate: Tue Mar 31 14:25:05 2026 +0800

    refactor: remove the unnecessary tool and parameter, implement unit tests 
(#39)
---
 .github/workflows/CI.yaml              |   3 +
 CLAUDE.md                              |  11 +-
 Makefile                               |  12 +-
 README.md                              |   8 +-
 internal/swmcp/server.go               |  53 +------
 internal/swmcp/server_registry_test.go | 280 +++++++++++++++++++++++++++++++++
 internal/swmcp/server_test.go          | 171 ++++++++++++++++++++
 internal/swmcp/session.go              | 138 ----------------
 internal/swmcp/sse.go                  |   2 +-
 internal/swmcp/stdio.go                |   2 +-
 internal/swmcp/streamable.go           |   2 +-
 internal/tools/mqe.go                  | 142 ++++++++++++++++-
 internal/tools/mqe_test.go             | 119 ++++++++++++++
 13 files changed, 740 insertions(+), 203 deletions(-)

diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml
index 6cd622e..1e7e362 100644
--- a/.github/workflows/CI.yaml
+++ b/.github/workflows/CI.yaml
@@ -67,6 +67,9 @@ jobs:
       - name: Lint
         run: make lint
 
+      - name: Test
+        run: make test
+
       - name: Build Docker images
         run: make docker
 
diff --git a/CLAUDE.md b/CLAUDE.md
index 5de6db1..dfaf53c 100644
--- a/CLAUDE.md
+++ b/CLAUDE.md
@@ -32,7 +32,7 @@ make build-image      # Build Docker image 
skywalking-mcp:latest
 make clean            # Remove build artifacts
 ```
 
-No unit tests exist yet. CI runs license checks, lint, and docker build.
+Unit tests exist for selected transport/context behavior. CI runs license 
checks, lint, and docker build.
 
 ## Architecture
 
@@ -41,12 +41,11 @@ No unit tests exist yet. CI runs license checks, lint, and 
docker build.
 Three MCP transport modes as cobra subcommands: `stdio`, `sse`, `streamable`.
 
 The SkyWalking OAP URL is resolved in priority order:
-- **stdio**: `set_skywalking_url` session tool > `--sw-url` flag > 
`http://localhost:12800/graphql`
-- **SSE/HTTP**: `SW-URL` HTTP header > `--sw-url` flag > 
`http://localhost:12800/graphql`
+- **All transports**: `--sw-url` flag > `http://localhost:12800/graphql`
 
-The `set_skywalking_url` tool is only available in stdio mode (single client, 
well-defined session). SSE and HTTP transports use per-request headers instead.
+SSE and HTTP transports always use the configured server URL.
 
-Basic auth is configured via `--sw-username` / `--sw-password` flags. Both 
flags (and the `set_skywalking_url` tool) support `${ENV_VAR}` syntax to 
resolve credentials from environment variables (e.g. `--sw-password 
${MY_SECRET}`).
+Basic auth is configured via `--sw-username` / `--sw-password` flags. The 
startup flags support `${ENV_VAR}` syntax to resolve credentials from 
environment variables (e.g. `--sw-password ${MY_SECRET}`).
 
 Each transport injects the OAP URL and auth into the request context via 
`WithSkyWalkingURLAndInsecure()` and `WithSkyWalkingAuth()`. Tools extract them 
downstream using `skywalking-cli`'s `contextkey.BaseURL{}`, 
`contextkey.Username{}`, and `contextkey.Password{}`.
 
@@ -99,4 +98,4 @@ Tool handlers should return `(mcp.NewToolResultError(...), 
nil)` for expected qu
 
 ## CI & Merge Policy
 
-Squash-merge only. PRs to `main` require 1 approval and passing `Required` 
status check (license + lint + docker build). Go 1.25.
\ No newline at end of file
+Squash-merge only. PRs to `main` require 1 approval and passing `Required` 
status check (license + lint + docker build). Go 1.25.
diff --git a/Makefile b/Makefile
index 982d91a..a28039a 100644
--- a/Makefile
+++ b/Makefile
@@ -31,6 +31,8 @@ PLATFORMS ?= linux/amd64
 MULTI_PLATFORMS ?= linux/amd64,linux/arm64
 OUTPUT ?= --load
 IMAGE_TAGS ?= -t $(IMAGE):$(VERSION) -t $(IMAGE):latest
+GO_TEST_FLAGS ?=
+GO_TEST_PKGS ?= ./...
 
 .PHONY: all
 all: build ;
@@ -48,6 +50,14 @@ build: ## Build the binary.
                -X ${VERSION_PATH}.date=${BUILD_DATE}" \
                -o bin/swmcp cmd/skywalking-mcp/main.go
 
+.PHONY: test
+test: ## Run unit tests.
+       go test $(GO_TEST_FLAGS) $(GO_TEST_PKGS)
+
+.PHONY: test-cover
+test-cover: ## Run unit tests with coverage output in coverage.txt.
+       go test $(GO_TEST_FLAGS) -coverprofile=coverage.txt $(GO_TEST_PKGS)
+
 $(GO_LINT):
        @$(GO_LINT) version > /dev/null 2>&1 || go install 
github.com/golangci/golangci-lint/cmd/[email protected]
 $(LICENSE_EYE):
@@ -139,7 +149,7 @@ PUSH_RELEASE_SCRIPTS := ./scripts/push-release.sh
 release-push-candidate:
        ${PUSH_RELEASE_SCRIPTS}
 
-.PHONY: lint fix-lint
+.PHONY: lint fix-lint test test-cover
 .PHONY: license-header fix-license-header dependency-license 
fix-dependency-license
 .PHONY: release-binary release-source release-sign release-assembly
 .PHONY: release-push-candidate docker-build-multi
diff --git a/README.md b/README.md
index a269b79..af91904 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,11 @@ bin/swmcp stdio --sw-url http://localhost:12800 
--sw-username admin --sw-passwor
 bin/swmcp sse --sse-address localhost:8000 --base-path /mcp --sw-url 
http://localhost:12800
 ```
 
+Transport URL behavior:
+
+- `stdio`, `sse`, and `streamable` all use the configured `--sw-url` value (or 
the default `http://localhost:12800/graphql`).
+- `sse` and `streamable` ignore request-level URL override headers.
+
 ### Usage with Cursor, Copilot, Claude Code
 
 ```json
@@ -128,7 +133,6 @@ SkyWalking MCP provides the following tools to query and 
analyze SkyWalking OAP
 
 | Category     | Tool Name                      | Description                  
                                                                     |
 
|--------------|--------------------------------|---------------------------------------------------------------------------------------------------|
-| **Session**  | `set_skywalking_url`           | Set the SkyWalking OAP 
server URL and optional basic auth credentials for the current session (stdio 
mode only). Supports `${ENV_VAR}` syntax for credentials. |
 | **Trace**    | `query_traces`                 | Query traces with 
multi-condition filtering (service, endpoint, state, tags, and time range via 
start/end/step). Supports `full`, `summary`, and `errors_only` views with 
performance insights. |
 | **Log**      | `query_logs`                   | Query logs with filters for 
service, instance, endpoint, trace ID, tags, and time range. Supports cold 
storage and pagination. |
 | **MQE**      | `execute_mqe_expression`       | Execute MQE (Metrics Query 
Expression) to query and calculate metrics data. Supports calculations, 
aggregations, TopN, trend analysis, and multiple result types. |
@@ -176,4 +180,4 @@ SkyWalking MCP provides the following prompts for guided 
analysis workflows:
 
 [Apache 2.0 License.](/LICENSE)
 
-[mcp]: https://modelcontextprotocol.io/
\ No newline at end of file
+[mcp]: https://modelcontextprotocol.io/
diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go
index 8b4a740..0cdba97 100644
--- a/internal/swmcp/server.go
+++ b/internal/swmcp/server.go
@@ -37,18 +37,13 @@ import (
 )
 
 // newMCPServer creates a new MCP server with all tools, resources, and 
prompts registered.
-// When stdio is true, session management tools (set_skywalking_url) are also 
registered,
-// since stdio has a single client and session semantics are well-defined.
-func newMCPServer(stdio bool) *server.MCPServer {
+func newMCPServer() *server.MCPServer {
        s := server.NewMCPServer(
                "skywalking-mcp", "0.1.0",
                server.WithResourceCapabilities(true, true),
                server.WithPromptCapabilities(true),
                server.WithLogging(),
        )
-       if stdio {
-               AddSessionTools(s)
-       }
        tools.AddTraceTools(s)
        tools.AddLogTools(s)
        tools.AddMQETools(s)
@@ -131,63 +126,31 @@ func withConfiguredAuth(ctx context.Context) 
context.Context {
        return ctx
 }
 
-// urlFromHeaders extracts URL for a request.
-// URL is sourced from Header > configured value > Default.
-func urlFromHeaders(req *http.Request) string {
-       urlStr := req.Header.Get("SW-URL")
-       if urlStr == "" {
-               return configuredSkyWalkingURL()
-       }
-
-       return tools.FinalizeURL(urlStr)
-}
-
-// applySessionOverrides checks for a session in the context and applies any
-// URL or auth overrides that were set via the set_skywalking_url tool.
-func applySessionOverrides(ctx context.Context) context.Context {
-       session := SessionFromContext(ctx)
-       if session == nil {
-               return ctx
-       }
-       if url := session.URL(); url != "" {
-               ctx = context.WithValue(ctx, contextkey.BaseURL{}, url)
-       }
-       if username := session.Username(); username != "" {
-               ctx = WithSkyWalkingAuth(ctx, username, session.Password())
-       }
-       return ctx
-}
-
 // EnhanceStdioContextFunc returns a StdioContextFunc that enriches the context
-// with SkyWalking settings from the global configuration and a per-session 
store.
+// with SkyWalking settings from the global configuration.
 func EnhanceStdioContextFunc() server.StdioContextFunc {
-       session := &Session{}
        return func(ctx context.Context) context.Context {
-               ctx = WithSession(ctx, session)
                ctx = WithSkyWalkingURLAndInsecure(ctx, 
configuredSkyWalkingURL(), false)
                ctx = withConfiguredAuth(ctx)
-               ctx = applySessionOverrides(ctx)
                return ctx
        }
 }
 
 // EnhanceSSEContextFunc returns a SSEContextFunc that enriches the context
-// with SkyWalking settings from SSE request headers and CLI-configured auth.
+// with SkyWalking settings from the CLI configuration and configured auth.
 func EnhanceSSEContextFunc() server.SSEContextFunc {
-       return func(ctx context.Context, req *http.Request) context.Context {
-               urlStr := urlFromHeaders(req)
-               ctx = WithSkyWalkingURLAndInsecure(ctx, urlStr, false)
+       return func(ctx context.Context, _ *http.Request) context.Context {
+               ctx = WithSkyWalkingURLAndInsecure(ctx, 
configuredSkyWalkingURL(), false)
                ctx = withConfiguredAuth(ctx)
                return ctx
        }
 }
 
 // EnhanceHTTPContextFunc returns a HTTPContextFunc that enriches the context
-// with SkyWalking settings from HTTP request headers and CLI-configured auth.
+// with SkyWalking settings from the CLI configuration and configured auth.
 func EnhanceHTTPContextFunc() server.HTTPContextFunc {
-       return func(ctx context.Context, req *http.Request) context.Context {
-               urlStr := urlFromHeaders(req)
-               ctx = WithSkyWalkingURLAndInsecure(ctx, urlStr, false)
+       return func(ctx context.Context, _ *http.Request) context.Context {
+               ctx = WithSkyWalkingURLAndInsecure(ctx, 
configuredSkyWalkingURL(), false)
                ctx = withConfiguredAuth(ctx)
                return ctx
        }
diff --git a/internal/swmcp/server_registry_test.go 
b/internal/swmcp/server_registry_test.go
new file mode 100644
index 0000000..55d1384
--- /dev/null
+++ b/internal/swmcp/server_registry_test.go
@@ -0,0 +1,280 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package swmcp
+
+import (
+       "reflect"
+       "sort"
+       "testing"
+       "unsafe"
+
+       "github.com/mark3labs/mcp-go/mcp"
+       "github.com/mark3labs/mcp-go/server"
+)
+
+// These registry tests verify that newMCPServer wires up the expected tools,
+// prompts, and resources. mcp-go v0.45.0 does not expose a public inventory 
API
+// for MCPServer, so the tests read server internals through a single helper
+// layer below. If mcp-go changes its internal field layout, update only the
+// helpers in this file rather than spreading reflect/unsafe access across 
tests.
+
+func TestNewMCPServerRegistersExpectedTools(t *testing.T) {
+       srv := newMCPServer()
+
+       got := sortedToolNames(srv)
+       want := []string{
+               "execute_mqe_expression",
+               "get_mqe_metric_type",
+               "list_endpoints",
+               "list_instances",
+               "list_layers",
+               "list_mqe_metrics",
+               "list_processes",
+               "list_services",
+               "query_alarms",
+               "query_endpoints_topology",
+               "query_events",
+               "query_instances_topology",
+               "query_logs",
+               "query_processes_topology",
+               "query_services_topology",
+               "query_traces",
+       }
+
+       assertStringSlicesEqual(t, got, want)
+}
+
+func TestNewMCPServerRegistersExpectedPrompts(t *testing.T) {
+       srv := newMCPServer()
+
+       got := sortedPromptNames(srv)
+       want := []string{
+               "analyze-logs",
+               "analyze-performance",
+               "build-mqe-query",
+               "compare-services",
+               "explore-metrics",
+               "explore-service-topology",
+               "generate_duration",
+               "investigate-traces",
+               "top-services",
+               "trace-deep-dive",
+       }
+
+       assertStringSlicesEqual(t, got, want)
+}
+
+func TestNewMCPServerRegistersExpectedResources(t *testing.T) {
+       srv := newMCPServer()
+
+       resources := resourceMap(srv)
+       got := make([]string, 0, len(resources))
+       for uri := range resources {
+               got = append(got, uri)
+       }
+       sort.Strings(got)
+
+       want := []string{
+               "mqe://docs/ai_prompt",
+               "mqe://docs/examples",
+               "mqe://docs/syntax",
+               "mqe://metrics/available",
+       }
+
+       assertStringSlicesEqual(t, got, want)
+}
+
+func TestPromptMetadataIncludesExpectedArguments(t *testing.T) {
+       srv := newMCPServer()
+       prompts := promptMap(srv)
+
+       prompt, ok := prompts["generate_duration"]
+       if !ok {
+               t.Fatal("generate_duration prompt not registered")
+       }
+       if prompt.Description == "" {
+               t.Fatal("generate_duration prompt description is empty")
+       }
+       if len(prompt.Arguments) != 1 {
+               t.Fatalf("generate_duration prompt arguments = %d, want 1", 
len(prompt.Arguments))
+       }
+       if prompt.Arguments[0].Name != "time_range" || 
!prompt.Arguments[0].Required {
+               t.Fatalf("unexpected generate_duration argument: %+v", 
prompt.Arguments[0])
+       }
+
+       tracePrompt, ok := prompts["trace-deep-dive"]
+       if !ok {
+               t.Fatal("trace-deep-dive prompt not registered")
+       }
+       if len(tracePrompt.Arguments) != 2 {
+               t.Fatalf("trace-deep-dive prompt arguments = %d, want 2", 
len(tracePrompt.Arguments))
+       }
+       if tracePrompt.Arguments[0].Name != "trace_id" || 
!tracePrompt.Arguments[0].Required {
+               t.Fatalf("unexpected first trace-deep-dive argument: %+v", 
tracePrompt.Arguments[0])
+       }
+}
+
+func TestResourceMetadataIncludesExpectedMIMETypes(t *testing.T) {
+       srv := newMCPServer()
+       resources := resourceMap(srv)
+
+       tests := []struct {
+               uri      string
+               name     string
+               mimeType string
+       }{
+               {uri: "mqe://docs/syntax", name: "MQE Detailed Syntax Rules", 
mimeType: "text/markdown"},
+               {uri: "mqe://docs/examples", name: "MQE Examples", mimeType: 
"application/json"},
+               {uri: "mqe://metrics/available", name: "Available Metrics", 
mimeType: "application/json"},
+               {uri: "mqe://docs/ai_prompt", name: "MQE AI Understanding 
Guide", mimeType: "text/markdown"},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.uri, func(t *testing.T) {
+                       resource, ok := resources[tc.uri]
+                       if !ok {
+                               t.Fatalf("resource %q not registered", tc.uri)
+                       }
+                       if resource.Name != tc.name {
+                               t.Fatalf("resource name = %q, want %q", 
resource.Name, tc.name)
+                       }
+                       if resource.MIMEType != tc.mimeType {
+                               t.Fatalf("resource MIME type = %q, want %q", 
resource.MIMEType, tc.mimeType)
+                       }
+               })
+       }
+}
+
+func TestToolMetadataIncludesExpectedDescriptionsAndSchemas(t *testing.T) {
+       srv := newMCPServer()
+       tools := toolMap(srv)
+
+       tests := []struct {
+               name             string
+               expectDesc       bool
+               expectProperties []string
+       }{
+               {name: "query_traces", expectDesc: true, expectProperties: 
[]string{"service_id", "trace_id", "view"}},
+               {name: "execute_mqe_expression", expectDesc: true, 
expectProperties: []string{"expression", "service_name", "debug"}},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       tool, ok := tools[tc.name]
+                       if !ok {
+                               t.Fatalf("tool %q not registered", tc.name)
+                       }
+                       if tc.expectDesc && tool.Description == "" {
+                               t.Fatalf("tool %q description is empty", 
tc.name)
+                       }
+                       properties := tool.InputSchema.Properties
+                       for _, property := range tc.expectProperties {
+                               if _, ok := properties[property]; !ok {
+                                       t.Fatalf("tool %q missing input schema 
property %q", tc.name, property)
+                               }
+                       }
+               })
+       }
+}
+
+func toolMap(srv *server.MCPServer) map[string]mcp.Tool {
+       serverTools := mustReadServerField(testedServerValue(srv), "tools")
+       result := make(map[string]mcp.Tool, serverTools.Len())
+
+       iter := serverTools.MapRange()
+       for iter.Next() {
+               name := iter.Key().String()
+               toolValue := copyReflectValue(iter.Value())
+               result[name] = 
toolValue.FieldByName("Tool").Interface().(mcp.Tool)
+       }
+
+       return result
+}
+
+func promptMap(srv *server.MCPServer) map[string]mcp.Prompt {
+       serverPrompts := mustReadServerField(testedServerValue(srv), "prompts")
+       result := make(map[string]mcp.Prompt, serverPrompts.Len())
+
+       iter := serverPrompts.MapRange()
+       for iter.Next() {
+               result[iter.Key().String()] = 
copyReflectValue(iter.Value()).Interface().(mcp.Prompt)
+       }
+
+       return result
+}
+
+func resourceMap(srv *server.MCPServer) map[string]mcp.Resource {
+       serverResources := mustReadServerField(testedServerValue(srv), 
"resources")
+       result := make(map[string]mcp.Resource, serverResources.Len())
+
+       iter := serverResources.MapRange()
+       for iter.Next() {
+               resourceField := 
copyReflectValue(iter.Value()).FieldByName("resource")
+               result[iter.Key().String()] = 
readPrivateField(resourceField).Interface().(mcp.Resource)
+       }
+
+       return result
+}
+
+func sortedToolNames(srv *server.MCPServer) []string {
+       tools := toolMap(srv)
+       names := make([]string, 0, len(tools))
+       for name := range tools {
+               names = append(names, name)
+       }
+       sort.Strings(names)
+       return names
+}
+
+func sortedPromptNames(srv *server.MCPServer) []string {
+       prompts := promptMap(srv)
+       names := make([]string, 0, len(prompts))
+       for name := range prompts {
+               names = append(names, name)
+       }
+       sort.Strings(names)
+       return names
+}
+
+func assertStringSlicesEqual(t *testing.T, got, want []string) {
+       t.Helper()
+       if !reflect.DeepEqual(got, want) {
+               t.Fatalf("values mismatch:\n got: %v\nwant: %v", got, want)
+       }
+}
+
+func readPrivateField(v reflect.Value) reflect.Value {
+       return reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem()
+}
+
+func testedServerValue(srv *server.MCPServer) reflect.Value {
+       return reflect.ValueOf(srv).Elem()
+}
+
+func mustReadServerField(srv reflect.Value, fieldName string) reflect.Value {
+       field := srv.FieldByName(fieldName)
+       if !field.IsValid() {
+               panic("mcp-go MCPServer no longer has field " + fieldName)
+       }
+       return readPrivateField(field)
+}
+
+func copyReflectValue(v reflect.Value) reflect.Value {
+       cloned := reflect.New(v.Type()).Elem()
+       cloned.Set(v)
+       return cloned
+}
diff --git a/internal/swmcp/server_test.go b/internal/swmcp/server_test.go
new file mode 100644
index 0000000..28056c8
--- /dev/null
+++ b/internal/swmcp/server_test.go
@@ -0,0 +1,171 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package swmcp
+
+import (
+       "context"
+       "net/http"
+       "testing"
+
+       "github.com/apache/skywalking-cli/pkg/contextkey"
+       "github.com/spf13/viper"
+
+       "github.com/apache/skywalking-mcp/internal/config"
+)
+
+const (
+       configuredHTTPOAPURL  = "http://configured-oap:12800/graphql";
+       configuredHTTPSOAPURL = "https://configured-oap.example.com/graphql";
+)
+
+func TestConfiguredSkyWalkingURLUsesDefaultWhenUnset(t *testing.T) {
+       t.Cleanup(viper.Reset)
+
+       got := configuredSkyWalkingURL()
+       if got != config.DefaultSWURL {
+               t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, 
config.DefaultSWURL)
+       }
+}
+
+func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) {
+       t.Cleanup(viper.Reset)
+       viper.Set("url", "https://configured-oap.example.com:12800/";)
+
+       got := configuredSkyWalkingURL()
+       want := "https://configured-oap.example.com:12800/graphql";
+       if got != want {
+               t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, want)
+       }
+}
+
+func TestResolveEnvVar(t *testing.T) {
+       t.Setenv("SW_TEST_SECRET", "resolved-secret")
+
+       tests := []struct {
+               name  string
+               value string
+               want  string
+       }{
+               {name: "raw", value: "raw-value", want: "raw-value"},
+               {name: "env", value: "${SW_TEST_SECRET}", want: 
"resolved-secret"},
+               {name: "trimmed env", value: " ${SW_TEST_SECRET} ", want: 
"resolved-secret"},
+               {name: "missing env", value: "${SW_TEST_MISSING}", want: ""},
+       }
+
+       for _, tc := range tests {
+               t.Run(tc.name, func(t *testing.T) {
+                       if got := resolveEnvVar(tc.value); got != tc.want {
+                               t.Fatalf("resolveEnvVar(%q) = %q, want %q", 
tc.value, got, tc.want)
+                       }
+               })
+       }
+}
+
+func TestWithConfiguredAuth(t *testing.T) {
+       t.Cleanup(viper.Reset)
+       t.Setenv("SW_TEST_USER", "env-user")
+       t.Setenv("SW_TEST_PASS", "env-pass")
+       viper.Set("username", "${SW_TEST_USER}")
+       viper.Set("password", "${SW_TEST_PASS}")
+
+       ctx := withConfiguredAuth(context.Background())
+
+       gotUser, _ := ctx.Value(contextkey.Username{}).(string)
+       if gotUser != "env-user" {
+               t.Fatalf("username = %q, want %q", gotUser, "env-user")
+       }
+
+       gotPass, _ := ctx.Value(contextkey.Password{}).(string)
+       if gotPass != "env-pass" {
+               t.Fatalf("password = %q, want %q", gotPass, "env-pass")
+       }
+}
+
+func TestWithConfiguredAuthSkipsEmptyUsername(t *testing.T) {
+       t.Cleanup(viper.Reset)
+       viper.Set("password", "password-only")
+
+       ctx := withConfiguredAuth(context.Background())
+
+       if got, ok := ctx.Value(contextkey.Username{}).(string); ok || got != 
"" {
+               t.Fatalf("username unexpectedly set to %q", got)
+       }
+       if got, ok := ctx.Value(contextkey.Password{}).(string); ok || got != 
"" {
+               t.Fatalf("password unexpectedly set to %q", got)
+       }
+}
+
+func TestEnhanceStdioContextFuncUsesConfiguredURLAndAuth(t *testing.T) {
+       t.Cleanup(viper.Reset)
+       t.Setenv("SW_STDIO_PASS", "stdio-pass")
+       viper.Set("url", "https://configured-oap.example.com";)
+       viper.Set("username", "stdio-user")
+       viper.Set("password", "${SW_STDIO_PASS}")
+
+       ctx := EnhanceStdioContextFunc()(context.Background())
+
+       gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string)
+       if gotURL != configuredHTTPSOAPURL {
+               t.Fatalf("base URL = %q", gotURL)
+       }
+
+       gotUser, _ := ctx.Value(contextkey.Username{}).(string)
+       if gotUser != "stdio-user" {
+               t.Fatalf("username = %q", gotUser)
+       }
+
+       gotPass, _ := ctx.Value(contextkey.Password{}).(string)
+       if gotPass != "stdio-pass" {
+               t.Fatalf("password = %q", gotPass)
+       }
+}
+
+func TestEnhanceHTTPContextFuncDoesNotUseSWURLHeader(t *testing.T) {
+       t.Cleanup(viper.Reset)
+       viper.Set("url", "http://configured-oap:12800";)
+
+       req, err := http.NewRequest(http.MethodPost, "http://client/request";, 
http.NoBody)
+       if err != nil {
+               t.Fatalf("create request: %v", err)
+       }
+       req.Header.Set("SW-URL", "http://attacker.invalid:8080";)
+
+       ctx := EnhanceHTTPContextFunc()(context.Background(), req)
+
+       gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string)
+       if gotURL != configuredHTTPOAPURL {
+               t.Fatalf("base URL = %q", gotURL)
+       }
+}
+
+func TestEnhanceSSEContextFuncDoesNotUseSWURLHeader(t *testing.T) {
+       t.Cleanup(viper.Reset)
+       viper.Set("url", "https://configured-oap.example.com";)
+
+       req, err := http.NewRequest(http.MethodGet, "http://client/events";, 
http.NoBody)
+       if err != nil {
+               t.Fatalf("create request: %v", err)
+       }
+       req.Header.Set("SW-URL", "https://attacker.invalid";)
+
+       ctx := EnhanceSSEContextFunc()(context.Background(), req)
+
+       gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string)
+       if gotURL != configuredHTTPSOAPURL {
+               t.Fatalf("base URL = %q", gotURL)
+       }
+}
diff --git a/internal/swmcp/session.go b/internal/swmcp/session.go
deleted file mode 100644
index 5a684c1..0000000
--- a/internal/swmcp/session.go
+++ /dev/null
@@ -1,138 +0,0 @@
-// Licensed to Apache Software Foundation (ASF) under one or more contributor
-// license agreements. See the NOTICE file distributed with
-// this work for additional information regarding copyright
-// ownership. Apache Software Foundation (ASF) licenses this file to you under
-// the Apache License, Version 2.0 (the "License"); you may
-// not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package swmcp
-
-import (
-       "context"
-       "fmt"
-       "sync"
-
-       "github.com/mark3labs/mcp-go/mcp"
-       "github.com/mark3labs/mcp-go/server"
-
-       "github.com/apache/skywalking-mcp/internal/tools"
-)
-
-// sessionKey is the context key for looking up the session store.
-type sessionKey struct{}
-
-// Session holds per-session SkyWalking connection configuration.
-type Session struct {
-       mu       sync.RWMutex
-       url      string
-       username string
-       password string
-}
-
-// SetConnection updates the session's connection parameters.
-func (s *Session) SetConnection(url, username, password string) {
-       s.mu.Lock()
-       defer s.mu.Unlock()
-       s.url = url
-       s.username = username
-       s.password = password
-}
-
-// URL returns the session's configured URL, or empty if not set.
-func (s *Session) URL() string {
-       s.mu.RLock()
-       defer s.mu.RUnlock()
-       return s.url
-}
-
-// Username returns the session's configured username.
-func (s *Session) Username() string {
-       s.mu.RLock()
-       defer s.mu.RUnlock()
-       return s.username
-}
-
-// Password returns the session's configured password.
-func (s *Session) Password() string {
-       s.mu.RLock()
-       defer s.mu.RUnlock()
-       return s.password
-}
-
-// SessionFromContext retrieves the session from the context, or nil if not 
present.
-func SessionFromContext(ctx context.Context) *Session {
-       s, _ := ctx.Value(sessionKey{}).(*Session)
-       return s
-}
-
-// WithSession attaches a session to the context.
-func WithSession(ctx context.Context, s *Session) context.Context {
-       return context.WithValue(ctx, sessionKey{}, s)
-}
-
-// SetSkyWalkingURLRequest represents the request for the set_skywalking_url 
tool.
-type SetSkyWalkingURLRequest struct {
-       URL      string `json:"url"`
-       Username string `json:"username,omitempty"`
-       Password string `json:"password,omitempty"`
-}
-
-func setSkyWalkingURL(ctx context.Context, req *SetSkyWalkingURLRequest) 
(*mcp.CallToolResult, error) {
-       if req.URL == "" {
-               return mcp.NewToolResultError("url is required"), nil
-       }
-
-       session := SessionFromContext(ctx)
-       if session == nil {
-               return mcp.NewToolResultError("session not available"), nil
-       }
-
-       finalURL := tools.FinalizeURL(req.URL)
-       session.SetConnection(finalURL, resolveEnvVar(req.Username), 
resolveEnvVar(req.Password))
-
-       msg := fmt.Sprintf("SkyWalking URL set to %s", finalURL)
-       if req.Username != "" {
-               msg += " with basic auth credentials"
-       }
-       return mcp.NewToolResultText(msg), nil
-}
-
-// AddSessionTools registers session management tools with the MCP server.
-func AddSessionTools(s *server.MCPServer) {
-       tool := tools.NewTool(
-               "set_skywalking_url",
-               `Set the SkyWalking OAP server URL and optional basic auth 
credentials for this session.
-This tool is only available in stdio transport mode.
-
-This tool configures the connection to SkyWalking OAP for all subsequent tool 
calls in the current session.
-The URL and credentials persist for the lifetime of the session.
-
-Priority: session URL (set by this tool) > --sw-url flag > default 
(http://localhost:12800/graphql)
-For SSE/HTTP transports, use the SW-URL HTTP header or --sw-url flag instead.
-
-Credentials support raw values or environment variable references using 
${ENV_VAR} syntax.
-
-Examples:
-- {"url": "http://demo.skywalking.apache.org:12800"}: Connect without auth
-- {"url": "http://oap.internal:12800";, "username": "admin", "password": 
"admin"}: Connect with basic auth
-- {"url": "https://skywalking.example.com:443";, "username": "${SW_USER}", 
"password": "${SW_PASS}"}: Auth via env vars`,
-               setSkyWalkingURL,
-               mcp.WithString("url", mcp.Required(),
-                       mcp.Description("SkyWalking OAP server URL (required). 
Example: http://localhost:12800";)),
-               mcp.WithString("username",
-                       mcp.Description("Username for basic auth (optional). 
Supports ${ENV_VAR} syntax.")),
-               mcp.WithString("password",
-                       mcp.Description("Password for basic auth (optional). 
Supports ${ENV_VAR} syntax.")),
-       )
-       tool.Register(s)
-}
diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go
index 1e9a04e..14365a9 100644
--- a/internal/swmcp/sse.go
+++ b/internal/swmcp/sse.go
@@ -72,7 +72,7 @@ func runSSEServer(ctx context.Context, cfg 
*config.SSEServerConfig) error {
        }
 
        sseServer := server.NewSSEServer(
-               newMCPServer(false),
+               newMCPServer(),
                server.WithStaticBasePath(cfg.BasePath),
                server.WithSSEContextFunc(EnhanceSSEContextFunc()),
        )
diff --git a/internal/swmcp/stdio.go b/internal/swmcp/stdio.go
index 9dd58df..02abb4a 100644
--- a/internal/swmcp/stdio.go
+++ b/internal/swmcp/stdio.go
@@ -60,7 +60,7 @@ func runStdioServer(ctx context.Context, cfg 
*config.StdioServerConfig) error {
        ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM)
        defer stop()
 
-       stdioServer := server.NewStdioServer(newMCPServer(true))
+       stdioServer := server.NewStdioServer(newMCPServer())
 
        logger, err := initLogger(cfg.LogFilePath)
        if err != nil {
diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go
index 8f41194..0500352 100644
--- a/internal/swmcp/streamable.go
+++ b/internal/swmcp/streamable.go
@@ -57,7 +57,7 @@ func NewStreamable() *cobra.Command {
 // runStreamableServer starts the Streamable server with the provided 
configuration.
 func runStreamableServer(cfg *config.StreamableServerConfig) error {
        httpServer := server.NewStreamableHTTPServer(
-               newMCPServer(false),
+               newMCPServer(),
                server.WithStateLess(true),
                server.WithLogger(log.StandardLogger()),
                server.WithHTTPContextFunc(EnhanceHTTPContextFunc()),
diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go
index 0f4189a..3a45396 100644
--- a/internal/tools/mqe.go
+++ b/internal/tools/mqe.go
@@ -25,8 +25,10 @@ import (
        "fmt"
        "io"
        "net/http"
-       "strings"
+       "regexp"
        "time"
+       "unicode"
+       "unicode/utf8"
 
        "github.com/mark3labs/mcp-go/mcp"
        "github.com/mark3labs/mcp-go/server"
@@ -41,6 +43,16 @@ func AddMQETools(mcp *server.MCPServer) {
        MQEMetricsTypeTool.Register(mcp)
 }
 
+const (
+       maxMQEExpressionLength = 2048
+       maxMQEExpressionDepth  = 12
+       maxMQEEntityFieldLen   = 256
+       maxMQERegexLength      = 256
+       maxMetricNameLength    = 128
+)
+
+var metricNamePattern = regexp.MustCompile(`^[A-Za-z0-9_.:-]+$`)
+
 // GraphQLRequest represents a GraphQL request
 type GraphQLRequest struct {
        Query     string                 `json:"query"`
@@ -101,8 +113,8 @@ func executeGraphQLWithContext(ctx context.Context, query 
string, variables map[
        defer resp.Body.Close()
 
        if resp.StatusCode != http.StatusOK {
-               bodyBytes, _ := io.ReadAll(resp.Body)
-               return nil, fmt.Errorf("HTTP request failed with status: %d, 
body: %s", resp.StatusCode, string(bodyBytes))
+               _, _ = io.ReadAll(resp.Body)
+               return nil, fmt.Errorf("GraphQL request failed with HTTP status 
%d", resp.StatusCode)
        }
 
        var graphqlResp GraphQLResponse
@@ -111,11 +123,7 @@ func executeGraphQLWithContext(ctx context.Context, query 
string, variables map[
        }
 
        if len(graphqlResp.Errors) > 0 {
-               var errorMsgs []string
-               for _, err := range graphqlResp.Errors {
-                       errorMsgs = append(errorMsgs, err.Message)
-               }
-               return nil, fmt.Errorf("GraphQL errors: %s", 
strings.Join(errorMsgs, ", "))
+               return nil, fmt.Errorf("GraphQL query failed")
        }
 
        return &graphqlResp, nil
@@ -307,6 +315,9 @@ func executeMQEExpression(ctx context.Context, req 
*MQEExpressionRequest) (*mcp.
        if req.Expression == "" {
                return mcp.NewToolResultError("expression is required"), nil
        }
+       if err := validateMQEExpressionRequest(req); err != nil {
+               return mcp.NewToolResultError(err.Error()), nil
+       }
 
        entity := buildMQEEntity(ctx, req)
        timeCtx := GetTimeContext(ctx)
@@ -386,6 +397,10 @@ func executeMQEExpression(ctx context.Context, req 
*MQEExpressionRequest) (*mcp.
 
 // listMQEMetrics lists available metrics
 func listMQEMetrics(ctx context.Context, req *MQEMetricsListRequest) 
(*mcp.CallToolResult, error) {
+       if err := validateMQEMetricsListRequest(req); err != nil {
+               return mcp.NewToolResultError(err.Error()), nil
+       }
+
        // GraphQL query for listing metrics
        query := `
                query listMetrics($regex: String) {
@@ -438,6 +453,9 @@ func getMQEMetricsType(ctx context.Context, req 
*MQEMetricsTypeRequest) (*mcp.Ca
        if req.MetricName == "" {
                return mcp.NewToolResultError("metric_name must be provided"), 
nil
        }
+       if err := validateMetricName(req.MetricName); err != nil {
+               return mcp.NewToolResultError(err.Error()), nil
+       }
 
        // GraphQL query for getting metric type
        query := `
@@ -462,6 +480,114 @@ func getMQEMetricsType(ctx context.Context, req 
*MQEMetricsTypeRequest) (*mcp.Ca
        return mcp.NewToolResultText(string(jsonBytes)), nil
 }
 
+func validateMQEExpressionRequest(req *MQEExpressionRequest) error {
+       if err := validateMQEExpression(req.Expression); err != nil {
+               return err
+       }
+
+       for fieldName, value := range map[string]string{
+               "service_name":               req.ServiceName,
+               "layer":                      req.Layer,
+               "service_instance_name":      req.ServiceInstanceName,
+               "endpoint_name":              req.EndpointName,
+               "process_name":               req.ProcessName,
+               "dest_service_name":          req.DestServiceName,
+               "dest_layer":                 req.DestLayer,
+               "dest_service_instance_name": req.DestServiceInstanceName,
+               "dest_endpoint_name":         req.DestEndpointName,
+               "dest_process_name":          req.DestProcessName,
+       } {
+               if err := validateMQETextField(fieldName, value, 
maxMQEEntityFieldLen); err != nil {
+                       return err
+               }
+       }
+
+       return nil
+}
+
+func validateMQEMetricsListRequest(req *MQEMetricsListRequest) error {
+       if req == nil || req.Regex == "" {
+               return nil
+       }
+       if err := validateMQETextField("regex", req.Regex, maxMQERegexLength); 
err != nil {
+               return err
+       }
+       if _, err := regexp.Compile(req.Regex); err != nil {
+               return fmt.Errorf("regex is invalid")
+       }
+       return nil
+}
+
+func validateMetricName(metricName string) error {
+       if err := validateMQETextField("metric_name", metricName, 
maxMetricNameLength); err != nil {
+               return err
+       }
+       if !metricNamePattern.MatchString(metricName) {
+               return fmt.Errorf("metric_name contains invalid characters")
+       }
+       return nil
+}
+
+func validateMQEExpression(expression string) error {
+       if !utf8.ValidString(expression) {
+               return fmt.Errorf("expression must be valid UTF-8")
+       }
+       if len(expression) > maxMQEExpressionLength {
+               return fmt.Errorf("expression exceeds maximum length of %d 
characters", maxMQEExpressionLength)
+       }
+       if containsUnsafeControlChars(expression) {
+               return fmt.Errorf("expression contains invalid control 
characters")
+       }
+       if nestingDepth(expression) > maxMQEExpressionDepth {
+               return fmt.Errorf("expression exceeds maximum nesting depth of 
%d", maxMQEExpressionDepth)
+       }
+       return nil
+}
+
+func validateMQETextField(fieldName, value string, maxLen int) error {
+       if value == "" {
+               return nil
+       }
+       if !utf8.ValidString(value) {
+               return fmt.Errorf("%s must be valid UTF-8", fieldName)
+       }
+       if len(value) > maxLen {
+               return fmt.Errorf("%s exceeds maximum length of %d characters", 
fieldName, maxLen)
+       }
+       if containsUnsafeControlChars(value) {
+               return fmt.Errorf("%s contains invalid control characters", 
fieldName)
+       }
+       return nil
+}
+
+func containsUnsafeControlChars(value string) bool {
+       for _, r := range value {
+               if unicode.IsControl(r) && r != '\n' && r != '\r' && r != '\t' {
+                       return true
+               }
+       }
+       return false
+}
+
+func nestingDepth(value string) int {
+       depth := 0
+       maxDepth := 0
+       for _, r := range value {
+               switch r {
+               case '(', '{', '[':
+                       depth++
+                       if depth > maxDepth {
+                               maxDepth = depth
+                       }
+               case ')', '}', ']':
+                       if depth > 0 {
+                               depth--
+                       }
+               }
+       }
+       return maxDepth
+}
+
 var MQEExpressionTool = NewTool(
        "execute_mqe_expression",
        `Execute MQE (Metrics Query Expression) to query and calculate metrics 
data.
diff --git a/internal/tools/mqe_test.go b/internal/tools/mqe_test.go
new file mode 100644
index 0000000..37d08af
--- /dev/null
+++ b/internal/tools/mqe_test.go
@@ -0,0 +1,119 @@
+// Licensed to Apache Software Foundation (ASF) under one or more contributor
+// license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright
+// ownership. Apache Software Foundation (ASF) licenses this file to you under
+// the Apache License, Version 2.0 (the "License"); you may
+// not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tools
+
+import (
+       "context"
+       "net/http"
+       "net/http/httptest"
+       "strings"
+       "testing"
+
+       "github.com/apache/skywalking-cli/pkg/contextkey"
+       "github.com/mark3labs/mcp-go/mcp"
+)
+
+func TestValidateMQEExpressionRequestRejectsDeeplyNestedExpression(t 
*testing.T) {
+       req := &MQEExpressionRequest{
+               Expression: strings.Repeat("(", maxMQEExpressionDepth+1) + 
"service_cpm" + strings.Repeat(")", maxMQEExpressionDepth+1),
+       }
+
+       err := validateMQEExpressionRequest(req)
+       if err == nil || !strings.Contains(err.Error(), "maximum nesting 
depth") {
+               t.Fatalf("unexpected error: %v", err)
+       }
+}
+
+func TestValidateMQEMetricsListRequestRejectsInvalidRegex(t *testing.T) {
+       err := validateMQEMetricsListRequest(&MQEMetricsListRequest{Regex: "("})
+       if err == nil || err.Error() != "regex is invalid" {
+               t.Fatalf("unexpected error: %v", err)
+       }
+}
+
+func TestValidateMetricNameRejectsInvalidCharacters(t *testing.T) {
+       err := validateMetricName("service cpm")
+       if err == nil || err.Error() != "metric_name contains invalid 
characters" {
+               t.Fatalf("unexpected error: %v", err)
+       }
+}
+
+func TestExecuteMQEExpressionRejectsOverlongEntityField(t *testing.T) {
+       req := &MQEExpressionRequest{
+               Expression:  "service_cpm",
+               ServiceName: strings.Repeat("a", maxMQEEntityFieldLen+1),
+       }
+
+       result, err := executeMQEExpression(context.Background(), req)
+       if err != nil {
+               t.Fatalf("executeMQEExpression returned error: %v", err)
+       }
+       if !result.IsError {
+               t.Fatal("expected tool error result")
+       }
+       assertToolResultContains(t, result, "service_name exceeds maximum 
length")
+}
+
+func TestExecuteGraphQLWithContextSanitizesHTTPErrorBody(t *testing.T) {
+       ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ 
*http.Request) {
+               http.Error(w, "sensitive backend details", 
http.StatusBadGateway)
+       }))
+       defer ts.Close()
+
+       ctx := context.WithValue(context.Background(), contextkey.BaseURL{}, 
ts.URL)
+       _, err := executeGraphQLWithContext(ctx, "query { ping }", 
map[string]interface{}{})
+       if err == nil {
+               t.Fatal("expected error")
+       }
+       if !strings.Contains(err.Error(), "HTTP status 502") {
+               t.Fatalf("unexpected error: %v", err)
+       }
+       if strings.Contains(err.Error(), "sensitive backend details") {
+               t.Fatalf("backend body leaked in error: %v", err)
+       }
+}
+
+func TestExecuteGraphQLWithContextSanitizesGraphQLErrors(t *testing.T) {
+       ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ 
*http.Request) {
+               w.Header().Set("Content-Type", "application/json")
+               _, _ = w.Write([]byte(`{"errors":[{"message":"database stack 
trace"}]}`))
+       }))
+       defer ts.Close()
+
+       ctx := context.WithValue(context.Background(), contextkey.BaseURL{}, 
ts.URL)
+       _, err := executeGraphQLWithContext(ctx, "query { ping }", 
map[string]interface{}{})
+       if err == nil {
+               t.Fatal("expected error")
+       }
+       if err.Error() != "GraphQL query failed" {
+               t.Fatalf("unexpected error: %v", err)
+       }
+}
+
+func assertToolResultContains(t *testing.T, result *mcp.CallToolResult, want 
string) {
+       t.Helper()
+       if len(result.Content) == 0 {
+               t.Fatal("tool result had no content")
+       }
+       text, ok := result.Content[0].(mcp.TextContent)
+       if !ok {
+               t.Fatalf("unexpected content type: %T", result.Content[0])
+       }
+       if !strings.Contains(text.Text, want) {
+               t.Fatalf("tool result text %q does not contain %q", text.Text, 
want)
+       }
+}


Reply via email to