From ba640d4070792915f24636e95c1d60b430084ea8 Mon Sep 17 00:00:00 2001 From: SuYao Date: Mon, 31 Mar 2025 21:13:59 +0800 Subject: [PATCH] refactor(MCP): enhance schema validation for gemini (#4153) --- package.json | 2 +- src/renderer/src/utils/mcp-tools.ts | 206 +++++++++++++++++++++------- yarn.lock | 10 +- 3 files changed, 160 insertions(+), 58 deletions(-) diff --git a/package.json b/package.json index 8635696b..23d8c55f 100644 --- a/package.json +++ b/package.json @@ -63,7 +63,7 @@ "@cherrystudio/embedjs-openai": "^0.1.28", "@electron-toolkit/utils": "^3.0.0", "@electron/notarize": "^2.5.0", - "@google/generative-ai": "^0.21.0", + "@google/generative-ai": "^0.24.0", "@langchain/community": "^0.3.36", "@notionhq/client": "^2.2.15", "@tryfabric/martian": "^1.2.4", diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index 3b0110cb..22b78f11 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -1,5 +1,17 @@ import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' -import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai' +import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiTool } from '@google/generative-ai' +import { + ArraySchema, + BaseSchema, + BooleanSchema, + EnumStringSchema, + FunctionDeclarationSchema, + FunctionDeclarationSchemaProperty, + IntegerSchema, + NumberSchema, + ObjectSchema, + SimpleStringSchema +} from '@google/generative-ai' import { nanoid } from '@reduxjs/toolkit' import store from '@renderer/store' import { addMCPServer } from '@renderer/store/mcp' @@ -8,54 +20,143 @@ import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resour import { ChunkCallbackData } from '../providers' -const supportedAttributes = [ - 'type', - 'nullable', - 'required', - // 'format', - 'description', - 'properties', - 'items', - 'enum', - 'anyOf' -] +const ensureValidSchema = (obj: Record): FunctionDeclarationSchemaProperty => { + // Filter out unsupported keys for Gemini + const filteredObj = filterUnsupportedKeys(obj) -function filterPropertieAttributes(tool: MCPTool, filterNestedObj = false) { + // Handle base schema properties + const baseSchema = { + description: filteredObj.description, + nullable: filteredObj.nullable + } as BaseSchema + + // Handle string type + if (filteredObj.type?.toLowerCase() === SchemaType.STRING) { + if (filteredObj.enum && Array.isArray(filteredObj.enum)) { + return { + ...baseSchema, + type: SchemaType.STRING, + format: 'enum', + enum: filteredObj.enum as string[] + } as EnumStringSchema + } + return { + ...baseSchema, + type: SchemaType.STRING, + format: filteredObj.format === 'date-time' ? 'date-time' : undefined + } as SimpleStringSchema + } + + // Handle number type + if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) { + return { + ...baseSchema, + type: SchemaType.NUMBER, + format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined + } as NumberSchema + } + + // Handle integer type + if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) { + return { + ...baseSchema, + type: SchemaType.INTEGER, + format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined + } as IntegerSchema + } + + // Handle boolean type + if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) { + return { + ...baseSchema, + type: SchemaType.BOOLEAN + } as BooleanSchema + } + + // Handle array type + if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) { + return { + ...baseSchema, + type: SchemaType.ARRAY, + items: filteredObj.items + ? ensureValidSchema(filteredObj.items as Record) + : ({ type: SchemaType.STRING } as SimpleStringSchema), + minItems: filteredObj.minItems, + maxItems: filteredObj.maxItems + } as ArraySchema + } + + // Handle object type (default) + const properties = filteredObj.properties + ? Object.fromEntries( + Object.entries(filteredObj.properties).map(([key, value]) => [ + key, + ensureValidSchema(value as Record) + ]) + ) + : { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty + + return { + ...baseSchema, + type: SchemaType.OBJECT, + properties, + required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined + } as ObjectSchema +} + +function filterUnsupportedKeys(obj: Record): Record { + const supportedBaseKeys = ['description', 'nullable'] + const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum'] + const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format'] + const supportedBooleanKeys = [...supportedBaseKeys, 'type'] + const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems'] + const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required'] + + const filtered: Record = {} + + let keysToKeep: string[] + + if (obj.type?.toLowerCase() === SchemaType.STRING) { + keysToKeep = supportedStringKeys + } else if (obj.type?.toLowerCase() === SchemaType.NUMBER) { + keysToKeep = supportedNumberKeys + } else if (obj.type?.toLowerCase() === SchemaType.INTEGER) { + keysToKeep = supportedNumberKeys + } else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) { + keysToKeep = supportedBooleanKeys + } else if (obj.type?.toLowerCase() === SchemaType.ARRAY) { + keysToKeep = supportedArrayKeys + } else { + // Default to object type + keysToKeep = supportedObjectKeys + } + + // copy supported keys + for (const key of keysToKeep) { + if (obj[key] !== undefined) { + filtered[key] = obj[key] + } + } + + return filtered +} + +function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record { const properties = tool.inputSchema.properties if (!properties) { return {} } - const getSubMap = (obj: Record, keys: string[]) => { - const filtered = Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key))) - if (filterNestedObj) { - return { - ...filtered, - ...(obj.type === 'object' && obj.properties - ? { - properties: Object.fromEntries( - Object.entries(obj.properties).map(([k, v]) => [ - k, - (v as any).type === 'object' ? getSubMap(v as Record, keys) : v - ]) - ) - } - : {}), - ...(obj.type === 'array' && obj.items?.type === 'object' - ? { - items: getSubMap(obj.items, keys) - } - : {}) - } - } - - return filtered + // For OpenAI, we don't need to validate as strictly + if (!filterNestedObj) { + return properties } - for (const [key, val] of Object.entries(properties)) { - properties[key] = getSubMap(val, supportedAttributes) - } - return properties + const processedProperties = Object.fromEntries( + Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record)]) + ) + + return processedProperties } export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array { @@ -132,7 +233,7 @@ export async function callMCPTool(tool: MCPTool): Promise { if (tool.serverName === 'mcp-auto-install') { if (resp.data) { const mcpServer: MCPServer = { - id: nanoid(), + id: `f${nanoid()}`, name: resp.data.name, description: resp.data.description, baseUrl: resp.data.baseUrl, @@ -183,7 +284,7 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU return tool } -export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] { +export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] { if (!mcpTools || mcpTools.length === 0) { // No tools available return [] @@ -195,18 +296,19 @@ export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTo const functionDeclaration: FunctionDeclaration = { name: tool.id, description: tool.description, - ...(Object.keys(properties).length > 0 - ? { - parameters: { - type: SchemaType.OBJECT, - properties - } - } - : {}) + parameters: { + type: SchemaType.OBJECT, + properties: + Object.keys(properties).length > 0 + ? Object.fromEntries( + Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record)]) + ) + : { _empty: { type: SchemaType.STRING } as SimpleStringSchema } + } as FunctionDeclarationSchema } functions.push(functionDeclaration) } - const tool: geminiToool = { + const tool: geminiTool = { functionDeclarations: functions } return [tool] diff --git a/yarn.lock b/yarn.lock index 8c6ce606..5f1eb10f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1212,10 +1212,10 @@ __metadata: languageName: node linkType: hard -"@google/generative-ai@npm:^0.21.0": - version: 0.21.0 - resolution: "@google/generative-ai@npm:0.21.0" - checksum: 10c0/cff5946c5964f2380e5097d82bd563d79be27a1a5ac604aaaad3f9ba3382992e4f0a371bd255baabfba4e5bdf296d8ce1410cbd65424afa98e64b2590fe49f3b +"@google/generative-ai@npm:^0.24.0": + version: 0.24.0 + resolution: "@google/generative-ai@npm:0.24.0" + checksum: 10c0/31452bf2653cdee7fd61eb209f16ac0ef82c94c4175909ba40e1088e938e3e19e01f628dfb80d429dae3338fc8487e9a0fd8a6ff0164189f2722211175690b0b languageName: node linkType: hard @@ -3767,7 +3767,7 @@ __metadata: "@eslint-react/eslint-plugin": "npm:^1.36.1" "@eslint/js": "npm:^9.22.0" "@google/genai": "npm:^0.4.0" - "@google/generative-ai": "npm:^0.21.0" + "@google/generative-ai": "npm:^0.24.0" "@hello-pangea/dnd": "npm:^16.6.0" "@kangfenmao/keyv-storage": "npm:^0.1.0" "@langchain/community": "npm:^0.3.36"