This is an automated email from the ASF dual-hosted git repository. robertlazarski pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/axis-axis2-java-core.git
commit d9e125d1808a4f68a342d671cc7d10aecfe85bc3 Author: Robert Lazarski <[email protected]> AuthorDate: Tue Apr 7 06:15:06 2026 -1000 Address Gemini review findings: Java generator + tests + Python codegen OpenApiSpecGenerator.java: - Critical F1: x-axis2-payloadTemplate built via Jackson tree API (not string concat) — opName values with quotes/backslash/control chars are now escaped - Critical F2: mcpTickerResolveService value validated against [\\w.\\-]+/[\\w.\\-]+ before emission; WARN + skip on invalid format - Major F3: requiresAuth heuristic changed from substring to exact match on "loginservice"/"adminconsole"; new mcpRequiresAuth service param overrides heuristic entirely for explicit control - Major F4: ObjectMapper replaced with io.swagger.v3.core.util.Json.mapper() — reuses the swagger-core instance already on classpath, avoids expensive per-request construction - Major F6: error fallback now {"tools":[],"_error":"..."} to distinguish generation failure from a legitimate empty catalog - getMcpStringParam: null-safe for operation arg (supports service-level-only lookup) - getMcpBoolParam: WARN log for unrecognised non-empty values (e.g. "yes", "1") McpCatalogGeneratorTest.java: - F11: testOperationNameWithQuoteIsEscaped and testControlCharactersInOperationNameAreEscaped now also parse x-axis2-payloadTemplate to catch template escaping regressions - F13: testEmptyStringMcpInputSchemaProducesBaselineSchema - F14: testGenerateMcpCatalogWithNullRequestDoesNotThrow no longer accepts NPE - F15: testLoginHistoryServiceRequiresAuth + testExactLoginServiceNameDoesNotRequireAuth - F18: testMcpOpenWorldParamSetsOpenWorldHint gen_mcp_schema.py (all Critical + Major findings): - F20/Critical: JSON content XML-escaped via xml.sax.saxutils.escape() before inserting into <parameter> tag - F21/Critical: nested struct warning added; _strip_comments() removes block comments before field parsing (F23 fix bundled) - F22/Major: patches collected as (start,end,replacement) triples, applied in reverse position order — no offset corruption for multi-operation files - F24/Major: atomic write via tempfile+os.replace; no partial-write corruption - F25/Major: dead companion_size_re n_* branch removed; simplified to _count/_len/_size suffix skip only - F26/Minor: hard-coded finbench_ prefix removed; --prefix CLI argument added - F27/Minor: always_required uses re.search(r'_id$|^n_') — correct anchoring - F28/Minor: json.dumps indent=16 → indent=2 - F29/Minor: --encoding CLI argument; UnicodeDecodeError gives actionable message Co-Authored-By: Claude Sonnet 4.6 <[email protected]> --- .../apache/axis2/openapi/OpenApiSpecGenerator.java | 63 +++++- .../axis2/openapi/McpCatalogGeneratorTest.java | 90 ++++++-- tools/gen_mcp_schema.py | 227 +++++++++++++++------ 3 files changed, 295 insertions(+), 85 deletions(-) diff --git a/modules/openapi/src/main/java/org/apache/axis2/openapi/OpenApiSpecGenerator.java b/modules/openapi/src/main/java/org/apache/axis2/openapi/OpenApiSpecGenerator.java index 5824367b2c..7bdd451be2 100644 --- a/modules/openapi/src/main/java/org/apache/axis2/openapi/OpenApiSpecGenerator.java +++ b/modules/openapi/src/main/java/org/apache/axis2/openapi/OpenApiSpecGenerator.java @@ -646,7 +646,8 @@ public class OpenApiSpecGenerator { */ private String getMcpStringParam(AxisOperation operation, AxisService service, String paramName, String defaultValue) { - org.apache.axis2.description.Parameter p = operation.getParameter(paramName); + org.apache.axis2.description.Parameter p = + (operation != null) ? operation.getParameter(paramName) : null; if (p == null) p = service.getParameter(paramName); if (p != null && p.getValue() != null) { String v = p.getValue().toString().trim(); @@ -670,6 +671,10 @@ public class OpenApiSpecGenerator { String v = p.getValue().toString().trim().toLowerCase(java.util.Locale.ROOT); if ("true".equals(v)) return true; if ("false".equals(v)) return false; + if (!v.isEmpty()) { + log.warn("[MCP] Unrecognised boolean value '" + p.getValue().toString().trim() + + "' for parameter '" + paramName + "' — using default " + defaultValue); + } } return defaultValue; } @@ -687,8 +692,11 @@ public class OpenApiSpecGenerator { try { AxisConfiguration axisConfig = configurationContext.getAxisConfiguration(); - com.fasterxml.jackson.databind.ObjectMapper jackson = - new com.fasterxml.jackson.databind.ObjectMapper(); + // Re-use the swagger-core Jackson instance: already configured with + // NON_NULL, WRITE_DATES_AS_TIMESTAMPS=false, FAIL_ON_EMPTY_BEANS=false + // and available on the classpath — avoids allocating a new ObjectMapper + // per request (ObjectMapper construction is expensive due to module scanning). + com.fasterxml.jackson.databind.ObjectMapper jackson = io.swagger.v3.core.util.Json.mapper(); com.fasterxml.jackson.databind.node.ObjectNode root = jackson.createObjectNode(); // Catalog-level metadata so MCP clients understand the transport layer. @@ -715,8 +723,16 @@ public class OpenApiSpecGenerator { axisConfig.getParameter("mcpTickerResolveService"); if (tickerParam != null && tickerParam.getValue() != null) { String tickerSvcOp = tickerParam.getValue().toString().trim(); + // Validate format: must be "ServiceName/operationName" — each segment + // is an XML NCName (word chars, dots, hyphens). Reject anything else + // to prevent a misconfigured path-traversal value reaching MCP clients. if (!tickerSvcOp.isEmpty()) { - meta.put("tickerResolveEndpoint", "POST /services/" + tickerSvcOp); + if (tickerSvcOp.matches("[\\w.\\-]+/[\\w.\\-]+")) { + meta.put("tickerResolveEndpoint", "POST /services/" + tickerSvcOp); + } else { + log.warn("[MCP] Ignoring invalid mcpTickerResolveService value '" + + tickerSvcOp + "' — expected ServiceName/operationName"); + } } } @@ -728,9 +744,20 @@ public class OpenApiSpecGenerator { if (isSystemService(service)) continue; if (!shouldIncludeService(service)) continue; - // loginService is the unauthenticated token endpoint; all others require auth. + // Prefer explicit mcpRequiresAuth parameter; fall back to name heuristic + // only when the parameter is absent. The heuristic uses exact match on + // "loginservice" (case-insensitive) and "adminconsole" to avoid false + // positives for services like "LoginHistoryService" or "CatalogLoginRecords". String svcLower = service.getName().toLowerCase(java.util.Locale.ROOT); - boolean requiresAuth = !svcLower.contains("login") && !svcLower.equals("adminconsole"); + boolean requiresAuth; + String mcpRequiresAuthParam = getMcpStringParam( + null, service, "mcpRequiresAuth", null); + if (mcpRequiresAuthParam != null) { + requiresAuth = !"false".equalsIgnoreCase(mcpRequiresAuthParam); + } else { + requiresAuth = !svcLower.equals("loginservice") + && !svcLower.equals("adminconsole"); + } Iterator<AxisOperation> operations = service.getOperations(); while (operations.hasNext()) { @@ -805,8 +832,24 @@ public class OpenApiSpecGenerator { // body in this envelope — the bare {"field":value} object goes inside // "arg0". Example for portfolioVariance: // {"portfolioVariance":[{"arg0":{"nAssets":2,"weights":[0.6,0.4],...}}]} - toolNode.put("x-axis2-payloadTemplate", - "{\"" + opName + "\":[{\"arg0\":{}}]}"); + // + // Built via Jackson tree API (not string concatenation) so that opName + // values containing JSON-special chars (", \, control chars) are + // correctly escaped. jackson.writeValueAsString() cannot throw here + // because the tree is well-formed by construction. + try { + com.fasterxml.jackson.databind.node.ObjectNode tmpl = + jackson.createObjectNode(); + tmpl.putArray(opName).addObject().putObject("arg0"); + toolNode.put("x-axis2-payloadTemplate", + jackson.writeValueAsString(tmpl)); + } catch (com.fasterxml.jackson.core.JsonProcessingException jpe) { + // Cannot happen for a well-formed Jackson node tree; fall back + // to a safe static placeholder so the tool node is still usable. + log.warn("[MCP] Failed to serialize payloadTemplate for '" + + opName + "': " + jpe.getMessage()); + toolNode.put("x-axis2-payloadTemplate", "{}"); + } // Whether the caller must supply a Bearer token (from doLogin). toolNode.put("x-requiresAuth", requiresAuth); @@ -833,7 +876,9 @@ public class OpenApiSpecGenerator { } catch (Exception e) { log.error("Failed to generate MCP catalog JSON", e); - return "{\"tools\":[]}"; + // Return a distinct error shape so callers can distinguish generation failure + // from a legitimate empty catalog (no services deployed). + return "{\"tools\":[],\"_error\":\"catalog generation failed — see server log\"}"; } } diff --git a/modules/openapi/src/test/java/org/apache/axis2/openapi/McpCatalogGeneratorTest.java b/modules/openapi/src/test/java/org/apache/axis2/openapi/McpCatalogGeneratorTest.java index 167ccdf3f9..633127e7d3 100644 --- a/modules/openapi/src/test/java/org/apache/axis2/openapi/McpCatalogGeneratorTest.java +++ b/modules/openapi/src/test/java/org/apache/axis2/openapi/McpCatalogGeneratorTest.java @@ -252,6 +252,12 @@ public class McpCatalogGeneratorTest extends TestCase { String json = generator.generateMcpCatalogJson(mockRequest); JsonNode root = MAPPER.readTree(json); assertNotNull("JSON with escaped quotes must be parseable", root); + + // x-axis2-payloadTemplate must also be parseable JSON (Critical F1 regression guard) + String template = root.path("tools").get(0).path("x-axis2-payloadTemplate").asText(); + JsonNode parsedTemplate = MAPPER.readTree(template); + assertNotNull("x-axis2-payloadTemplate must be valid JSON even when op name has quotes", + parsedTemplate); } public void testServiceNameWithBackslashIsEscaped() throws Exception { @@ -282,6 +288,12 @@ public class McpCatalogGeneratorTest extends TestCase { json.contains("\t")); JsonNode root = MAPPER.readTree(json); // still parseable assertNotNull(root); + + // x-axis2-payloadTemplate must also survive control characters + String template = root.path("tools").get(0).path("x-axis2-payloadTemplate").asText(); + JsonNode parsedTemplate = MAPPER.readTree(template); + assertNotNull("x-axis2-payloadTemplate must be valid JSON when op name has control chars", + parsedTemplate); } // ── catalog request is null-safe ────────────────────────────────────────── @@ -289,18 +301,12 @@ public class McpCatalogGeneratorTest extends TestCase { public void testGenerateMcpCatalogWithNullRequestDoesNotThrow() throws Exception { addService("TestService", "testOp"); - // null request is passed — generator should handle it gracefully - // (request is not used by generateMcpCatalogJson; it only introspects AxisConfig) - try { - String json = generator.generateMcpCatalogJson(null); - assertNotNull(json); - JsonNode root = MAPPER.readTree(json); - assertTrue(root.has("tools")); - } catch (NullPointerException e) { - // Acceptable if the method does not guard against null — document behaviour - System.out.println("INFO: generateMcpCatalogJson(null) throws NPE — " + - "callers must ensure request is non-null"); - } + // generateMcpCatalogJson() does not use the HttpServletRequest parameter at all — + // it only introspects AxisConfig. Null must not throw; the catalog must be valid. + String json = generator.generateMcpCatalogJson(null); + assertNotNull("generateMcpCatalogJson(null) must return non-null", json); + JsonNode root = MAPPER.readTree(json); + assertTrue("generateMcpCatalogJson(null) result must have 'tools'", root.has("tools")); } // ── catalog _meta ───────────────────────────────────────────────────────── @@ -892,6 +898,66 @@ public class McpCatalogGeneratorTest extends TestCase { } } + /** + * F13: An empty-string mcpInputSchema parameter (trimmed to "") + * must be treated identically to an absent parameter — produces the + * empty baseline schema without throwing or logging a parse error. + */ + public void testEmptyStringMcpInputSchemaProducesBaselineSchema() throws Exception { + AxisService svc = new AxisService("TestService"); + AxisOperation op = new InOutAxisOperation(); + op.setName(QName.valueOf("doOp")); + op.addParameter(new org.apache.axis2.description.Parameter("mcpInputSchema", "")); + svc.addOperation(op); + axisConfig.addService(svc); + + JsonNode schema = getCatalogTools().get(0).path("inputSchema"); + // getMcpStringParam returns null for empty string → empty baseline + assertEquals("Empty mcpInputSchema must produce type=object", "object", + schema.path("type").asText()); + assertTrue("Empty mcpInputSchema must produce empty properties", + schema.path("properties").isObject()); + assertTrue("Empty mcpInputSchema must produce empty required array", + schema.path("required").isArray()); + } + + /** + * F15: A service whose name *contains* "login" but is NOT the login service + * must still require auth. The heuristic must use exact match, not substring. + */ + public void testLoginHistoryServiceRequiresAuth() throws Exception { + addService("LoginHistoryService", "getLoginHistory"); + JsonNode tool = getCatalogTools().get(0); + assertTrue("LoginHistoryService must require auth (exact-match heuristic, not substring)", + tool.path("x-requiresAuth").asBoolean()); + } + + /** + * F15 (inverse): exact match "loginService" (case-insensitive) still exempts. + */ + public void testExactLoginServiceNameDoesNotRequireAuth() throws Exception { + addService("loginService", "doLogin"); + JsonNode tool = getCatalogTools().get(0); + assertFalse("loginService (exact match) must not require auth", + tool.path("x-requiresAuth").asBoolean()); + } + + /** + * F18: mcpOpenWorld=true must set openWorldHint: true in annotations. + */ + public void testMcpOpenWorldParamSetsOpenWorldHint() throws Exception { + AxisService svc = new AxisService("NotificationService"); + AxisOperation op = new InOutAxisOperation(); + op.setName(QName.valueOf("sendAlert")); + op.addParameter(new org.apache.axis2.description.Parameter("mcpOpenWorld", "true")); + svc.addOperation(op); + axisConfig.addService(svc); + + JsonNode annotations = getCatalogTools().get(0).path("annotations"); + assertTrue("openWorldHint must be true when mcpOpenWorld=true", + annotations.path("openWorldHint").asBoolean()); + } + // ── helpers ───────────────────────────────────────────────────────────── private void addService(String serviceName, String operationName) throws Exception { diff --git a/tools/gen_mcp_schema.py b/tools/gen_mcp_schema.py index 1110925d2d..72e4158946 100644 --- a/tools/gen_mcp_schema.py +++ b/tools/gen_mcp_schema.py @@ -11,11 +11,23 @@ Usage python3 tools/gen_mcp_schema.py \\ --header path/to/service.h \\ --services path/to/services.xml \\ + [--prefix finbench_] \\ + [--encoding utf-8] \\ [--dry-run] The script writes in-place unless --dry-run is given, in which case it prints the updated XML to stdout. +Limitations +----------- +- Nested structs and anonymous union members are NOT supported. The struct + body regex stops at the first '}', so inner struct/union blocks will cause + field truncation. A WARNING is printed when a parsed struct body contains + a '{' character that suggests nesting. +- Only typedef struct { ... } name_t; patterns are detected. +- C preprocessor macros and conditional compilation (#if/#endif) are not + evaluated; fields inside #ifdef blocks may be included unconditionally. + C → JSON Schema type mapping ----------------------------- int / long / int32_t / int64_t / axis2_int32_t → "integer" @@ -23,10 +35,10 @@ double / float → "number" char * / axis2_char_t * → "string" axis2_bool_t / bool / int (named is_*/has_*) → "boolean" pointer-to-struct (foo_t *) → "object" -array + companion _count / n_ field → "array" +double * / float * (numeric array pointers) → "array" Required fields: any field without a "= 0" / "= NULL" / "= false" default in -the struct definition is treated as required. Fields named *_id, n_*, count_* +the struct definition is treated as required. Fields matching *_id or n_* are also always required. The script uses regex-only parsing (no libclang) so it works without a C @@ -36,10 +48,12 @@ unambiguously, it emits "type": "object" and logs a warning. import argparse import json +import os import re import sys -import textwrap +import tempfile from pathlib import Path +from xml.sax.saxutils import escape as xml_escape # --------------------------------------------------------------------------- # C type → JSON Schema type table @@ -60,11 +74,11 @@ def c_type_to_json_schema(c_type: str, field_name: str) -> dict: """Map a C type string to a minimal JSON Schema dict.""" c_type = c_type.strip() - # Boolean heuristic: field named is_*/has_* with int type + # Boolean heuristic: field named is_*/has_*/enable_*/use_* with int type if re.match(r'(is|has|enable|use)_', field_name) and re.search(r'\bint\b', c_type): return {"type": "boolean"} - # Pointer to array (double * / float * used for matrix/weight arrays) + # Pointer to numeric array (double * / float * used for matrix/weight arrays) if re.search(r'\bdouble\s*\*|\bfloat\s*\*', c_type): return {"type": "array", "items": {"type": "number"}} @@ -76,7 +90,7 @@ def c_type_to_json_schema(c_type: str, field_name: str) -> dict: if m: return {"type": "object"} - # Fallback + # Fallback — conservative print(f" WARNING: unmapped C type '{c_type}' for field '{field_name}' → object", file=sys.stderr) return {"type": "object"} @@ -93,27 +107,52 @@ _FIELD_RE = re.compile( r'^\s*(?P<type>(?:const\s+)?[\w\s\*]+?)\s+(?P<name>\w+)\s*(?:=\s*(?P<default>[^;]+))?\s*;', re.MULTILINE ) +_BLOCK_COMMENT_RE = re.compile(r'/\*.*?\*/', re.DOTALL) + + +def _strip_comments(text: str) -> str: + """Remove C block comments (/* ... */) and line comments (// ...).""" + # Block comments first (may span lines) + text = _BLOCK_COMMENT_RE.sub(' ', text) + # Line comments + text = re.sub(r'//[^\n]*', ' ', text) + return text def parse_structs(header_text: str) -> dict[str, dict]: """ Return {struct_name: {field_name: {"c_type": ..., "has_default": bool}}}. Only parses typedef struct { ... } name_t; blocks. + + Block and line comments are stripped from the body before field parsing + so that comment text containing ';' is not matched as a field. """ structs = {} for m in _STRUCT_RE.finditer(header_text): body = m.group(1) name = m.group(2) + + # Warn about potential nested struct/union — body regex stops at first '}' + # so any nested block would already be truncated, but alert the user. + if '{' in body: + print(f" WARNING: struct '{name}' body contains '{{' — nested struct/union " + f"members are not supported and may be missing from the schema.", + file=sys.stderr) + + # Strip comments before field parsing (F23 fix) + clean_body = _strip_comments(body) + fields = {} - for fm in _FIELD_RE.finditer(body): + for fm in _FIELD_RE.finditer(clean_body): field_name = fm.group("name") c_type = fm.group("type") default = fm.group("default") - # Skip comment-only or empty lines picked up by the regex - if c_type.strip().startswith("//") or c_type.strip().startswith("*"): + c_type_stripped = c_type.strip() + # Skip residual preprocessor or empty captures + if not c_type_stripped or c_type_stripped.startswith("#"): continue fields[field_name] = { - "c_type": c_type.strip(), + "c_type": c_type_stripped, "has_default": default is not None, } if fields: @@ -126,11 +165,7 @@ def build_json_schema(struct_fields: dict) -> dict: properties = {} required = [] - # Fields that are always array companions (paired with n_* / *_count) — skip them - # as array size information; they are implicit. - companion_size_re = re.compile(r'^n_|_count$|_len$|_size$') - - # First pass: collect array-indicator field names + # First pass: collect which fields are numeric array pointers array_fields = set() for fname, info in struct_fields.items(): c_type = info["c_type"] @@ -141,29 +176,26 @@ def build_json_schema(struct_fields: dict) -> dict: c_type = info["c_type"] has_default = info["has_default"] - # Skip size companion fields (n_assets accompanies weights[], etc.) - if companion_size_re.search(fname) and fname not in array_fields: - # Keep n_assets as it is the primary dimension parameter - if not fname.startswith("n_"): - continue + # Skip pure size-companion fields (_count, _len, _size suffixes) that + # exist only to carry the array length alongside a pointer field. + # n_* fields are intentionally kept — they are primary input parameters. + if re.search(r'_count$|_len$|_size$', fname) and fname not in array_fields: + continue schema_prop = c_type_to_json_schema(c_type, fname) - # Annotate array items for common financial arrays + # Ensure array items type is set for numeric arrays if schema_prop.get("type") == "array" and not schema_prop.get("items"): schema_prop["items"] = {"type": "number"} properties[fname] = schema_prop - # Required: no default AND not a companion size field - always_required = re.match(r'.+_id$|^n_', fname) + # Required heuristic: no default declared, or name matches *_id / n_* + always_required = bool(re.search(r'_id$|^n_', fname)) if always_required or not has_default: required.append(fname) - schema = { - "type": "object", - "properties": properties, - } + schema: dict = {"type": "object", "properties": properties} if required: schema["required"] = required return schema @@ -172,14 +204,21 @@ def build_json_schema(struct_fields: dict) -> dict: # --------------------------------------------------------------------------- # services.xml patcher # --------------------------------------------------------------------------- -def find_request_struct(structs: dict, op_name: str) -> str | None: +def find_request_struct(structs: dict, op_name: str, + prefix: str = "") -> str | None: """ Heuristically find the request struct for an operation name. - Tries: finbench_{op_name}_request_t, {op_name}_request_t, {op_name}_req_t + + Tries (in order): + {prefix}{op_name}_request_t + {op_name}_request_t + {op_name}_req_t + Falls back to a case-insensitive substring search on all struct names. """ - service_prefix = "finbench_" - candidates = [ - f"{service_prefix}{op_name}_request_t", + candidates = [] + if prefix: + candidates.append(f"{prefix}{op_name}_request_t") + candidates += [ f"{op_name}_request_t", f"{op_name}_req_t", ] @@ -204,50 +243,77 @@ _EXISTING_SCHEMA_RE = re.compile( ) -def patch_services_xml(xml_text: str, structs: dict) -> tuple[str, list[str]]: +def patch_services_xml(xml_text: str, structs: dict, + prefix: str = "") -> tuple[str, list[str]]: """ For each <operation name="..."> block, find the matching request struct and inject (or replace) a mcpInputSchema parameter. + Patches are collected and applied in reverse position order to avoid + offset corruption when multiple operations are in the same file (F22 fix). + + JSON inserted into XML is escaped with xml.sax.saxutils.escape() to + prevent malformed XML if struct field names contain &, <, or > (F20 fix). + Returns (patched_xml, list_of_change_messages). """ messages = [] - result = xml_text + + # Collect all patches as (start, end, replacement) triples, then apply + # in reverse order so earlier positions are not invalidated by later edits. + patches: list[tuple[int, int, str]] = [] for m in _OP_RE.finditer(xml_text): op_name = m.group("opname") - struct_name = find_request_struct(structs, op_name) + struct_name = find_request_struct(structs, op_name, prefix) if struct_name is None: messages.append(f" SKIP {op_name}: no matching *_request_t struct found") continue schema = build_json_schema(structs[struct_name]) - schema_json = json.dumps(schema, indent=16) + # indent=2 produces readable XML; xml_escape protects against + # JSON characters that are XML-special (&, <, >) (F20, F28 fix) + schema_json = xml_escape(json.dumps(schema, indent=2)) param_block = f'<parameter name="mcpInputSchema">{schema_json}</parameter>' - # Check if an mcpInputSchema already exists after this <operation ...> tag op_start = m.start() - # Find the closing </operation> - close_re = re.compile(r'</operation>', re.DOTALL) - close_m = close_re.search(result, op_start) + tag_end = m.end() # end of the <operation ...> opening tag + + # Find the closing </operation> tag from op_start in the ORIGINAL text + close_m = re.search(r'</operation>', xml_text[op_start:]) if close_m is None: + messages.append(f" SKIP {op_name}: no </operation> closing tag found") continue - op_block = result[op_start:close_m.end()] + + op_end = op_start + close_m.end() + op_block = xml_text[op_start:op_end] if '<parameter name="mcpInputSchema">' in op_block: - # Replace existing - new_op_block = _EXISTING_SCHEMA_RE.sub( - "\n " + param_block, op_block) - result = result[:op_start] + new_op_block + result[close_m.end():] - messages.append(f" UPDATE {op_name}: replaced mcpInputSchema from {struct_name}") + # Replace existing parameter — find its absolute span + existing_m = _EXISTING_SCHEMA_RE.search(xml_text, op_start, op_end) + if existing_m: + patches.append(( + existing_m.start(), + existing_m.end(), + "\n " + param_block + )) + messages.append( + f" UPDATE {op_name}: replaced mcpInputSchema from {struct_name}") else: - # Insert after the opening <operation ...> tag - tag_end = op_start + len(m.group(1)) - indent = "\n " - result = (result[:tag_end] - + indent + param_block - + result[tag_end:]) - messages.append(f" INSERT {op_name}: wrote mcpInputSchema from {struct_name}") + # Insert immediately after the opening <operation ...> tag + patches.append(( + tag_end, + tag_end, + "\n " + param_block + )) + messages.append( + f" INSERT {op_name}: wrote mcpInputSchema from {struct_name}") + + # Apply patches in reverse order (largest offset first) to preserve positions + patches.sort(key=lambda t: t[0], reverse=True) + result = xml_text + for start, end, replacement in patches: + result = result[:start] + replacement + result[end:] return result, messages @@ -255,35 +321,52 @@ def patch_services_xml(xml_text: str, structs: dict) -> tuple[str, list[str]]: # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- -def main(): +def main() -> None: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - p.add_argument("--header", required=True, help="Path to .h file") - p.add_argument("--services", required=True, help="Path to services.xml") + p.add_argument("--header", required=True, + help="Path to .h file containing *_request_t structs") + p.add_argument("--services", required=True, + help="Path to services.xml to patch in-place") + p.add_argument("--prefix", default="", + help="Application-specific struct name prefix (e.g. 'finbench_'). " + "Default: no prefix.") + p.add_argument("--encoding", default="utf-8", + help="File encoding for both header and services.xml. Default: utf-8") p.add_argument("--dry-run", action="store_true", - help="Print patched XML to stdout, do not write") + help="Print patched XML to stdout; do not write the file") args = p.parse_args() - header_path = Path(args.header) - services_path = Path(args.services) + header_path = Path(args.header).resolve() + services_path = Path(args.services).resolve() if not header_path.exists(): sys.exit(f"ERROR: header not found: {header_path}") if not services_path.exists(): sys.exit(f"ERROR: services.xml not found: {services_path}") - header_text = header_path.read_text(encoding="utf-8") - services_text = services_path.read_text(encoding="utf-8") + try: + header_text = header_path.read_text(encoding=args.encoding) + except UnicodeDecodeError as e: + sys.exit(f"ERROR: cannot decode {header_path} as {args.encoding}: {e}\n" + f" Try --encoding latin-1 or --encoding utf-8-sig") + + try: + services_text = services_path.read_text(encoding=args.encoding) + except UnicodeDecodeError as e: + sys.exit(f"ERROR: cannot decode {services_path} as {args.encoding}: {e}\n" + f" Try --encoding latin-1 or --encoding utf-8-sig") structs = parse_structs(header_text) if not structs: - sys.exit("ERROR: no typedef struct { } name_t; blocks found in header") + sys.exit("ERROR: no 'typedef struct { } name_t;' blocks found in header") print(f"Parsed {len(structs)} structs from {header_path.name}:", file=sys.stderr) for sname in structs: print(f" {sname} ({len(structs[sname])} fields)", file=sys.stderr) - patched, messages = patch_services_xml(services_text, structs) + patched, messages = patch_services_xml(services_text, structs, + prefix=args.prefix) print("Schema generation results:", file=sys.stderr) for msg in messages: @@ -292,7 +375,23 @@ def main(): if args.dry_run: print(patched) else: - services_path.write_text(patched, encoding="utf-8") + # Atomic write: write to a sibling temp file, then rename (F24 fix) + tmp_fd, tmp_path = tempfile.mkstemp( + dir=services_path.parent, + prefix=".gen_mcp_schema_", + suffix=".tmp" + ) + try: + with os.fdopen(tmp_fd, "w", encoding=args.encoding) as fh: + fh.write(patched) + os.replace(tmp_path, services_path) + except Exception: + # Clean up temp file if rename failed + try: + os.unlink(tmp_path) + except OSError: + pass + raise print(f"Written: {services_path}", file=sys.stderr)
