refactor(MCP): enhance schema validation for gemini (#4153)

This commit is contained in:
SuYao 2025-03-31 21:13:59 +08:00 committed by GitHub
parent 8c5f61d407
commit ba640d4070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 160 additions and 58 deletions

View File

@ -63,7 +63,7 @@
"@cherrystudio/embedjs-openai": "^0.1.28",
"@electron-toolkit/utils": "^3.0.0",
"@electron/notarize": "^2.5.0",
"@google/generative-ai": "^0.21.0",
"@google/generative-ai": "^0.24.0",
"@langchain/community": "^0.3.36",
"@notionhq/client": "^2.2.15",
"@tryfabric/martian": "^1.2.4",

View File

@ -1,5 +1,17 @@
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai'
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiTool } from '@google/generative-ai'
import {
ArraySchema,
BaseSchema,
BooleanSchema,
EnumStringSchema,
FunctionDeclarationSchema,
FunctionDeclarationSchemaProperty,
IntegerSchema,
NumberSchema,
ObjectSchema,
SimpleStringSchema
} from '@google/generative-ai'
import { nanoid } from '@reduxjs/toolkit'
import store from '@renderer/store'
import { addMCPServer } from '@renderer/store/mcp'
@ -8,54 +20,143 @@ import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resour
import { ChunkCallbackData } from '../providers'
const supportedAttributes = [
'type',
'nullable',
'required',
// 'format',
'description',
'properties',
'items',
'enum',
'anyOf'
]
const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => {
// Filter out unsupported keys for Gemini
const filteredObj = filterUnsupportedKeys(obj)
function filterPropertieAttributes(tool: MCPTool, filterNestedObj = false) {
const properties = tool.inputSchema.properties
if (!properties) {
return {}
}
const getSubMap = (obj: Record<string, any>, keys: string[]) => {
const filtered = Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key)))
// Handle base schema properties
const baseSchema = {
description: filteredObj.description,
nullable: filteredObj.nullable
} as BaseSchema
if (filterNestedObj) {
// Handle string type
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
return {
...filtered,
...(obj.type === 'object' && obj.properties
? {
properties: Object.fromEntries(
Object.entries(obj.properties).map(([k, v]) => [
k,
(v as any).type === 'object' ? getSubMap(v as Record<string, any>, keys) : v
...baseSchema,
type: SchemaType.STRING,
format: 'enum',
enum: filteredObj.enum as string[]
} as EnumStringSchema
}
return {
...baseSchema,
type: SchemaType.STRING,
format: filteredObj.format === 'date-time' ? 'date-time' : undefined
} as SimpleStringSchema
}
// Handle number type
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
return {
...baseSchema,
type: SchemaType.NUMBER,
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
} as NumberSchema
}
// Handle integer type
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
return {
...baseSchema,
type: SchemaType.INTEGER,
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
} as IntegerSchema
}
// Handle boolean type
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
return {
...baseSchema,
type: SchemaType.BOOLEAN
} as BooleanSchema
}
// Handle array type
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
return {
...baseSchema,
type: SchemaType.ARRAY,
items: filteredObj.items
? ensureValidSchema(filteredObj.items as Record<string, any>)
: ({ type: SchemaType.STRING } as SimpleStringSchema),
minItems: filteredObj.minItems,
maxItems: filteredObj.maxItems
} as ArraySchema
}
// Handle object type (default)
const properties = filteredObj.properties
? Object.fromEntries(
Object.entries(filteredObj.properties).map(([key, value]) => [
key,
ensureValidSchema(value as Record<string, any>)
])
)
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
return {
...baseSchema,
type: SchemaType.OBJECT,
properties,
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
} as ObjectSchema
}
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
const supportedBaseKeys = ['description', 'nullable']
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
const supportedBooleanKeys = [...supportedBaseKeys, 'type']
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
const filtered: Record<string, any> = {}
let keysToKeep: string[]
if (obj.type?.toLowerCase() === SchemaType.STRING) {
keysToKeep = supportedStringKeys
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
keysToKeep = supportedBooleanKeys
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
keysToKeep = supportedArrayKeys
} else {
// Default to object type
keysToKeep = supportedObjectKeys
}
: {}),
...(obj.type === 'array' && obj.items?.type === 'object'
? {
items: getSubMap(obj.items, keys)
}
: {})
// copy supported keys
for (const key of keysToKeep) {
if (obj[key] !== undefined) {
filtered[key] = obj[key]
}
}
return filtered
}
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
const properties = tool.inputSchema.properties
if (!properties) {
return {}
}
for (const [key, val] of Object.entries(properties)) {
properties[key] = getSubMap(val, supportedAttributes)
}
// For OpenAI, we don't need to validate as strictly
if (!filterNestedObj) {
return properties
}
const processedProperties = Object.fromEntries(
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
)
return processedProperties
}
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
@ -132,7 +233,7 @@ export async function callMCPTool(tool: MCPTool): Promise<any> {
if (tool.serverName === 'mcp-auto-install') {
if (resp.data) {
const mcpServer: MCPServer = {
id: nanoid(),
id: `f${nanoid()}`,
name: resp.data.name,
description: resp.data.description,
baseUrl: resp.data.baseUrl,
@ -183,7 +284,7 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
return tool
}
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] {
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
if (!mcpTools || mcpTools.length === 0) {
// No tools available
return []
@ -195,18 +296,19 @@ export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTo
const functionDeclaration: FunctionDeclaration = {
name: tool.id,
description: tool.description,
...(Object.keys(properties).length > 0
? {
parameters: {
type: SchemaType.OBJECT,
properties
}
}
: {})
properties:
Object.keys(properties).length > 0
? Object.fromEntries(
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
)
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
} as FunctionDeclarationSchema
}
functions.push(functionDeclaration)
}
const tool: geminiToool = {
const tool: geminiTool = {
functionDeclarations: functions
}
return [tool]

View File

@ -1212,10 +1212,10 @@ __metadata:
languageName: node
linkType: hard
"@google/generative-ai@npm:^0.21.0":
version: 0.21.0
resolution: "@google/generative-ai@npm:0.21.0"
checksum: 10c0/cff5946c5964f2380e5097d82bd563d79be27a1a5ac604aaaad3f9ba3382992e4f0a371bd255baabfba4e5bdf296d8ce1410cbd65424afa98e64b2590fe49f3b
"@google/generative-ai@npm:^0.24.0":
version: 0.24.0
resolution: "@google/generative-ai@npm:0.24.0"
checksum: 10c0/31452bf2653cdee7fd61eb209f16ac0ef82c94c4175909ba40e1088e938e3e19e01f628dfb80d429dae3338fc8487e9a0fd8a6ff0164189f2722211175690b0b
languageName: node
linkType: hard
@ -3767,7 +3767,7 @@ __metadata:
"@eslint-react/eslint-plugin": "npm:^1.36.1"
"@eslint/js": "npm:^9.22.0"
"@google/genai": "npm:^0.4.0"
"@google/generative-ai": "npm:^0.21.0"
"@google/generative-ai": "npm:^0.24.0"
"@hello-pangea/dnd": "npm:^16.6.0"
"@kangfenmao/keyv-storage": "npm:^0.1.0"
"@langchain/community": "npm:^0.3.36"