refactor(MCP): enhance schema validation for gemini (#4153)

This commit is contained in:
SuYao 2025-03-31 21:13:59 +08:00 committed by GitHub
parent 8c5f61d407
commit ba640d4070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 58 deletions

View File

@ -63,7 +63,7 @@
"@cherrystudio/embedjs-openai": "^0.1.28", "@cherrystudio/embedjs-openai": "^0.1.28",
"@electron-toolkit/utils": "^3.0.0", "@electron-toolkit/utils": "^3.0.0",
"@electron/notarize": "^2.5.0", "@electron/notarize": "^2.5.0",
"@google/generative-ai": "^0.21.0", "@google/generative-ai": "^0.24.0",
"@langchain/community": "^0.3.36", "@langchain/community": "^0.3.36",
"@notionhq/client": "^2.2.15", "@notionhq/client": "^2.2.15",
"@tryfabric/martian": "^1.2.4", "@tryfabric/martian": "^1.2.4",

View File

@ -1,5 +1,17 @@
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai' import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiTool } from '@google/generative-ai'
import {
ArraySchema,
BaseSchema,
BooleanSchema,
EnumStringSchema,
FunctionDeclarationSchema,
FunctionDeclarationSchemaProperty,
IntegerSchema,
NumberSchema,
ObjectSchema,
SimpleStringSchema
} from '@google/generative-ai'
import { nanoid } from '@reduxjs/toolkit' import { nanoid } from '@reduxjs/toolkit'
import store from '@renderer/store' import store from '@renderer/store'
import { addMCPServer } from '@renderer/store/mcp' import { addMCPServer } from '@renderer/store/mcp'
@ -8,54 +20,143 @@ import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resour
import { ChunkCallbackData } from '../providers' import { ChunkCallbackData } from '../providers'
const supportedAttributes = [ const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => {
'type', // Filter out unsupported keys for Gemini
'nullable', const filteredObj = filterUnsupportedKeys(obj)
'required',
// 'format',
'description',
'properties',
'items',
'enum',
'anyOf'
]
function filterPropertieAttributes(tool: MCPTool, filterNestedObj = false) { // Handle base schema properties
const baseSchema = {
description: filteredObj.description,
nullable: filteredObj.nullable
} as BaseSchema
// Handle string type
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
return {
...baseSchema,
type: SchemaType.STRING,
format: 'enum',
enum: filteredObj.enum as string[]
} as EnumStringSchema
}
return {
...baseSchema,
type: SchemaType.STRING,
format: filteredObj.format === 'date-time' ? 'date-time' : undefined
} as SimpleStringSchema
}
// Handle number type
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
return {
...baseSchema,
type: SchemaType.NUMBER,
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
} as NumberSchema
}
// Handle integer type
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
return {
...baseSchema,
type: SchemaType.INTEGER,
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
} as IntegerSchema
}
// Handle boolean type
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
return {
...baseSchema,
type: SchemaType.BOOLEAN
} as BooleanSchema
}
// Handle array type
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
return {
...baseSchema,
type: SchemaType.ARRAY,
items: filteredObj.items
? ensureValidSchema(filteredObj.items as Record<string, any>)
: ({ type: SchemaType.STRING } as SimpleStringSchema),
minItems: filteredObj.minItems,
maxItems: filteredObj.maxItems
} as ArraySchema
}
// Handle object type (default)
const properties = filteredObj.properties
? Object.fromEntries(
Object.entries(filteredObj.properties).map(([key, value]) => [
key,
ensureValidSchema(value as Record<string, any>)
])
)
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
return {
...baseSchema,
type: SchemaType.OBJECT,
properties,
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
} as ObjectSchema
}
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
const supportedBaseKeys = ['description', 'nullable']
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
const supportedBooleanKeys = [...supportedBaseKeys, 'type']
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
const filtered: Record<string, any> = {}
let keysToKeep: string[]
if (obj.type?.toLowerCase() === SchemaType.STRING) {
keysToKeep = supportedStringKeys
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
keysToKeep = supportedBooleanKeys
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
keysToKeep = supportedArrayKeys
} else {
// Default to object type
keysToKeep = supportedObjectKeys
}
// copy supported keys
for (const key of keysToKeep) {
if (obj[key] !== undefined) {
filtered[key] = obj[key]
}
}
return filtered
}
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
const properties = tool.inputSchema.properties const properties = tool.inputSchema.properties
if (!properties) { if (!properties) {
return {} return {}
} }
const getSubMap = (obj: Record<string, any>, keys: string[]) => {
const filtered = Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key)))
if (filterNestedObj) { // For OpenAI, we don't need to validate as strictly
return { if (!filterNestedObj) {
...filtered, return properties
...(obj.type === 'object' && obj.properties
? {
properties: Object.fromEntries(
Object.entries(obj.properties).map(([k, v]) => [
k,
(v as any).type === 'object' ? getSubMap(v as Record<string, any>, keys) : v
])
)
}
: {}),
...(obj.type === 'array' && obj.items?.type === 'object'
? {
items: getSubMap(obj.items, keys)
}
: {})
}
}
return filtered
} }
for (const [key, val] of Object.entries(properties)) { const processedProperties = Object.fromEntries(
properties[key] = getSubMap(val, supportedAttributes) Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
} )
return properties
return processedProperties
} }
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> { export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
@ -132,7 +233,7 @@ export async function callMCPTool(tool: MCPTool): Promise<any> {
if (tool.serverName === 'mcp-auto-install') { if (tool.serverName === 'mcp-auto-install') {
if (resp.data) { if (resp.data) {
const mcpServer: MCPServer = { const mcpServer: MCPServer = {
id: nanoid(), id: `f${nanoid()}`,
name: resp.data.name, name: resp.data.name,
description: resp.data.description, description: resp.data.description,
baseUrl: resp.data.baseUrl, baseUrl: resp.data.baseUrl,
@ -183,7 +284,7 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
return tool return tool
} }
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] { export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
if (!mcpTools || mcpTools.length === 0) { if (!mcpTools || mcpTools.length === 0) {
// No tools available // No tools available
return [] return []
@ -195,18 +296,19 @@ export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTo
const functionDeclaration: FunctionDeclaration = { const functionDeclaration: FunctionDeclaration = {
name: tool.id, name: tool.id,
description: tool.description, description: tool.description,
...(Object.keys(properties).length > 0 parameters: {
? { type: SchemaType.OBJECT,
parameters: { properties:
type: SchemaType.OBJECT, Object.keys(properties).length > 0
properties ? Object.fromEntries(
} Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
} )
: {}) : { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
} as FunctionDeclarationSchema
} }
functions.push(functionDeclaration) functions.push(functionDeclaration)
} }
const tool: geminiToool = { const tool: geminiTool = {
functionDeclarations: functions functionDeclarations: functions
} }
return [tool] return [tool]

View File

@ -1212,10 +1212,10 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@google/generative-ai@npm:^0.21.0": "@google/generative-ai@npm:^0.24.0":
version: 0.21.0 version: 0.24.0
resolution: "@google/generative-ai@npm:0.21.0" resolution: "@google/generative-ai@npm:0.24.0"
checksum: 10c0/cff5946c5964f2380e5097d82bd563d79be27a1a5ac604aaaad3f9ba3382992e4f0a371bd255baabfba4e5bdf296d8ce1410cbd65424afa98e64b2590fe49f3b checksum: 10c0/31452bf2653cdee7fd61eb209f16ac0ef82c94c4175909ba40e1088e938e3e19e01f628dfb80d429dae3338fc8487e9a0fd8a6ff0164189f2722211175690b0b
languageName: node languageName: node
linkType: hard linkType: hard
@ -3767,7 +3767,7 @@ __metadata:
"@eslint-react/eslint-plugin": "npm:^1.36.1" "@eslint-react/eslint-plugin": "npm:^1.36.1"
"@eslint/js": "npm:^9.22.0" "@eslint/js": "npm:^9.22.0"
"@google/genai": "npm:^0.4.0" "@google/genai": "npm:^0.4.0"
"@google/generative-ai": "npm:^0.21.0" "@google/generative-ai": "npm:^0.24.0"
"@hello-pangea/dnd": "npm:^16.6.0" "@hello-pangea/dnd": "npm:^16.6.0"
"@kangfenmao/keyv-storage": "npm:^0.1.0" "@kangfenmao/keyv-storage": "npm:^0.1.0"
"@langchain/community": "npm:^0.3.36" "@langchain/community": "npm:^0.3.36"