refactor(MCP): enhance schema validation for gemini (#4153)
This commit is contained in:
parent
8c5f61d407
commit
ba640d4070
@ -63,7 +63,7 @@
|
|||||||
"@cherrystudio/embedjs-openai": "^0.1.28",
|
"@cherrystudio/embedjs-openai": "^0.1.28",
|
||||||
"@electron-toolkit/utils": "^3.0.0",
|
"@electron-toolkit/utils": "^3.0.0",
|
||||||
"@electron/notarize": "^2.5.0",
|
"@electron/notarize": "^2.5.0",
|
||||||
"@google/generative-ai": "^0.21.0",
|
"@google/generative-ai": "^0.24.0",
|
||||||
"@langchain/community": "^0.3.36",
|
"@langchain/community": "^0.3.36",
|
||||||
"@notionhq/client": "^2.2.15",
|
"@notionhq/client": "^2.2.15",
|
||||||
"@tryfabric/martian": "^1.2.4",
|
"@tryfabric/martian": "^1.2.4",
|
||||||
|
|||||||
@ -1,5 +1,17 @@
|
|||||||
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
||||||
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai'
|
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiTool } from '@google/generative-ai'
|
||||||
|
import {
|
||||||
|
ArraySchema,
|
||||||
|
BaseSchema,
|
||||||
|
BooleanSchema,
|
||||||
|
EnumStringSchema,
|
||||||
|
FunctionDeclarationSchema,
|
||||||
|
FunctionDeclarationSchemaProperty,
|
||||||
|
IntegerSchema,
|
||||||
|
NumberSchema,
|
||||||
|
ObjectSchema,
|
||||||
|
SimpleStringSchema
|
||||||
|
} from '@google/generative-ai'
|
||||||
import { nanoid } from '@reduxjs/toolkit'
|
import { nanoid } from '@reduxjs/toolkit'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { addMCPServer } from '@renderer/store/mcp'
|
import { addMCPServer } from '@renderer/store/mcp'
|
||||||
@ -8,54 +20,143 @@ import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resour
|
|||||||
|
|
||||||
import { ChunkCallbackData } from '../providers'
|
import { ChunkCallbackData } from '../providers'
|
||||||
|
|
||||||
const supportedAttributes = [
|
const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => {
|
||||||
'type',
|
// Filter out unsupported keys for Gemini
|
||||||
'nullable',
|
const filteredObj = filterUnsupportedKeys(obj)
|
||||||
'required',
|
|
||||||
// 'format',
|
|
||||||
'description',
|
|
||||||
'properties',
|
|
||||||
'items',
|
|
||||||
'enum',
|
|
||||||
'anyOf'
|
|
||||||
]
|
|
||||||
|
|
||||||
function filterPropertieAttributes(tool: MCPTool, filterNestedObj = false) {
|
// Handle base schema properties
|
||||||
const properties = tool.inputSchema.properties
|
const baseSchema = {
|
||||||
if (!properties) {
|
description: filteredObj.description,
|
||||||
return {}
|
nullable: filteredObj.nullable
|
||||||
}
|
} as BaseSchema
|
||||||
const getSubMap = (obj: Record<string, any>, keys: string[]) => {
|
|
||||||
const filtered = Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key)))
|
|
||||||
|
|
||||||
if (filterNestedObj) {
|
// Handle string type
|
||||||
|
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
|
||||||
|
if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
|
||||||
return {
|
return {
|
||||||
...filtered,
|
...baseSchema,
|
||||||
...(obj.type === 'object' && obj.properties
|
type: SchemaType.STRING,
|
||||||
? {
|
format: 'enum',
|
||||||
properties: Object.fromEntries(
|
enum: filteredObj.enum as string[]
|
||||||
Object.entries(obj.properties).map(([k, v]) => [
|
} as EnumStringSchema
|
||||||
k,
|
}
|
||||||
(v as any).type === 'object' ? getSubMap(v as Record<string, any>, keys) : v
|
return {
|
||||||
|
...baseSchema,
|
||||||
|
type: SchemaType.STRING,
|
||||||
|
format: filteredObj.format === 'date-time' ? 'date-time' : undefined
|
||||||
|
} as SimpleStringSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle number type
|
||||||
|
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||||
|
return {
|
||||||
|
...baseSchema,
|
||||||
|
type: SchemaType.NUMBER,
|
||||||
|
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
|
||||||
|
} as NumberSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle integer type
|
||||||
|
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||||
|
return {
|
||||||
|
...baseSchema,
|
||||||
|
type: SchemaType.INTEGER,
|
||||||
|
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
|
||||||
|
} as IntegerSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle boolean type
|
||||||
|
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||||
|
return {
|
||||||
|
...baseSchema,
|
||||||
|
type: SchemaType.BOOLEAN
|
||||||
|
} as BooleanSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle array type
|
||||||
|
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||||
|
return {
|
||||||
|
...baseSchema,
|
||||||
|
type: SchemaType.ARRAY,
|
||||||
|
items: filteredObj.items
|
||||||
|
? ensureValidSchema(filteredObj.items as Record<string, any>)
|
||||||
|
: ({ type: SchemaType.STRING } as SimpleStringSchema),
|
||||||
|
minItems: filteredObj.minItems,
|
||||||
|
maxItems: filteredObj.maxItems
|
||||||
|
} as ArraySchema
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle object type (default)
|
||||||
|
const properties = filteredObj.properties
|
||||||
|
? Object.fromEntries(
|
||||||
|
Object.entries(filteredObj.properties).map(([key, value]) => [
|
||||||
|
key,
|
||||||
|
ensureValidSchema(value as Record<string, any>)
|
||||||
])
|
])
|
||||||
)
|
)
|
||||||
|
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
|
||||||
|
|
||||||
|
return {
|
||||||
|
...baseSchema,
|
||||||
|
type: SchemaType.OBJECT,
|
||||||
|
properties,
|
||||||
|
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
|
||||||
|
} as ObjectSchema
|
||||||
|
}
|
||||||
|
|
||||||
|
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
|
||||||
|
const supportedBaseKeys = ['description', 'nullable']
|
||||||
|
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
|
||||||
|
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
|
||||||
|
const supportedBooleanKeys = [...supportedBaseKeys, 'type']
|
||||||
|
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
|
||||||
|
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
|
||||||
|
|
||||||
|
const filtered: Record<string, any> = {}
|
||||||
|
|
||||||
|
let keysToKeep: string[]
|
||||||
|
|
||||||
|
if (obj.type?.toLowerCase() === SchemaType.STRING) {
|
||||||
|
keysToKeep = supportedStringKeys
|
||||||
|
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||||
|
keysToKeep = supportedNumberKeys
|
||||||
|
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||||
|
keysToKeep = supportedNumberKeys
|
||||||
|
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||||
|
keysToKeep = supportedBooleanKeys
|
||||||
|
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||||
|
keysToKeep = supportedArrayKeys
|
||||||
|
} else {
|
||||||
|
// Default to object type
|
||||||
|
keysToKeep = supportedObjectKeys
|
||||||
}
|
}
|
||||||
: {}),
|
|
||||||
...(obj.type === 'array' && obj.items?.type === 'object'
|
// copy supported keys
|
||||||
? {
|
for (const key of keysToKeep) {
|
||||||
items: getSubMap(obj.items, keys)
|
if (obj[key] !== undefined) {
|
||||||
}
|
filtered[key] = obj[key]
|
||||||
: {})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return filtered
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
|
||||||
|
const properties = tool.inputSchema.properties
|
||||||
|
if (!properties) {
|
||||||
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const [key, val] of Object.entries(properties)) {
|
// For OpenAI, we don't need to validate as strictly
|
||||||
properties[key] = getSubMap(val, supportedAttributes)
|
if (!filterNestedObj) {
|
||||||
}
|
|
||||||
return properties
|
return properties
|
||||||
|
}
|
||||||
|
|
||||||
|
const processedProperties = Object.fromEntries(
|
||||||
|
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||||
|
)
|
||||||
|
|
||||||
|
return processedProperties
|
||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
||||||
@ -132,7 +233,7 @@ export async function callMCPTool(tool: MCPTool): Promise<any> {
|
|||||||
if (tool.serverName === 'mcp-auto-install') {
|
if (tool.serverName === 'mcp-auto-install') {
|
||||||
if (resp.data) {
|
if (resp.data) {
|
||||||
const mcpServer: MCPServer = {
|
const mcpServer: MCPServer = {
|
||||||
id: nanoid(),
|
id: `f${nanoid()}`,
|
||||||
name: resp.data.name,
|
name: resp.data.name,
|
||||||
description: resp.data.description,
|
description: resp.data.description,
|
||||||
baseUrl: resp.data.baseUrl,
|
baseUrl: resp.data.baseUrl,
|
||||||
@ -183,7 +284,7 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
|
|||||||
return tool
|
return tool
|
||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] {
|
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
|
||||||
if (!mcpTools || mcpTools.length === 0) {
|
if (!mcpTools || mcpTools.length === 0) {
|
||||||
// No tools available
|
// No tools available
|
||||||
return []
|
return []
|
||||||
@ -195,18 +296,19 @@ export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTo
|
|||||||
const functionDeclaration: FunctionDeclaration = {
|
const functionDeclaration: FunctionDeclaration = {
|
||||||
name: tool.id,
|
name: tool.id,
|
||||||
description: tool.description,
|
description: tool.description,
|
||||||
...(Object.keys(properties).length > 0
|
|
||||||
? {
|
|
||||||
parameters: {
|
parameters: {
|
||||||
type: SchemaType.OBJECT,
|
type: SchemaType.OBJECT,
|
||||||
properties
|
properties:
|
||||||
}
|
Object.keys(properties).length > 0
|
||||||
}
|
? Object.fromEntries(
|
||||||
: {})
|
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||||
|
)
|
||||||
|
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
|
||||||
|
} as FunctionDeclarationSchema
|
||||||
}
|
}
|
||||||
functions.push(functionDeclaration)
|
functions.push(functionDeclaration)
|
||||||
}
|
}
|
||||||
const tool: geminiToool = {
|
const tool: geminiTool = {
|
||||||
functionDeclarations: functions
|
functionDeclarations: functions
|
||||||
}
|
}
|
||||||
return [tool]
|
return [tool]
|
||||||
|
|||||||
10
yarn.lock
10
yarn.lock
@ -1212,10 +1212,10 @@ __metadata:
|
|||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
"@google/generative-ai@npm:^0.21.0":
|
"@google/generative-ai@npm:^0.24.0":
|
||||||
version: 0.21.0
|
version: 0.24.0
|
||||||
resolution: "@google/generative-ai@npm:0.21.0"
|
resolution: "@google/generative-ai@npm:0.24.0"
|
||||||
checksum: 10c0/cff5946c5964f2380e5097d82bd563d79be27a1a5ac604aaaad3f9ba3382992e4f0a371bd255baabfba4e5bdf296d8ce1410cbd65424afa98e64b2590fe49f3b
|
checksum: 10c0/31452bf2653cdee7fd61eb209f16ac0ef82c94c4175909ba40e1088e938e3e19e01f628dfb80d429dae3338fc8487e9a0fd8a6ff0164189f2722211175690b0b
|
||||||
languageName: node
|
languageName: node
|
||||||
linkType: hard
|
linkType: hard
|
||||||
|
|
||||||
@ -3767,7 +3767,7 @@ __metadata:
|
|||||||
"@eslint-react/eslint-plugin": "npm:^1.36.1"
|
"@eslint-react/eslint-plugin": "npm:^1.36.1"
|
||||||
"@eslint/js": "npm:^9.22.0"
|
"@eslint/js": "npm:^9.22.0"
|
||||||
"@google/genai": "npm:^0.4.0"
|
"@google/genai": "npm:^0.4.0"
|
||||||
"@google/generative-ai": "npm:^0.21.0"
|
"@google/generative-ai": "npm:^0.24.0"
|
||||||
"@hello-pangea/dnd": "npm:^16.6.0"
|
"@hello-pangea/dnd": "npm:^16.6.0"
|
||||||
"@kangfenmao/keyv-storage": "npm:^0.1.0"
|
"@kangfenmao/keyv-storage": "npm:^0.1.0"
|
||||||
"@langchain/community": "npm:^0.3.36"
|
"@langchain/community": "npm:^0.3.36"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user