refactor(MCP): enhance schema validation for gemini (#4153)
This commit is contained in:
parent
8c5f61d407
commit
ba640d4070
@ -63,7 +63,7 @@
|
||||
"@cherrystudio/embedjs-openai": "^0.1.28",
|
||||
"@electron-toolkit/utils": "^3.0.0",
|
||||
"@electron/notarize": "^2.5.0",
|
||||
"@google/generative-ai": "^0.21.0",
|
||||
"@google/generative-ai": "^0.24.0",
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@tryfabric/martian": "^1.2.4",
|
||||
|
||||
@ -1,5 +1,17 @@
|
||||
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
||||
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai'
|
||||
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiTool } from '@google/generative-ai'
|
||||
import {
|
||||
ArraySchema,
|
||||
BaseSchema,
|
||||
BooleanSchema,
|
||||
EnumStringSchema,
|
||||
FunctionDeclarationSchema,
|
||||
FunctionDeclarationSchemaProperty,
|
||||
IntegerSchema,
|
||||
NumberSchema,
|
||||
ObjectSchema,
|
||||
SimpleStringSchema
|
||||
} from '@google/generative-ai'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import store from '@renderer/store'
|
||||
import { addMCPServer } from '@renderer/store/mcp'
|
||||
@ -8,56 +20,145 @@ import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resour
|
||||
|
||||
import { ChunkCallbackData } from '../providers'
|
||||
|
||||
const supportedAttributes = [
|
||||
'type',
|
||||
'nullable',
|
||||
'required',
|
||||
// 'format',
|
||||
'description',
|
||||
'properties',
|
||||
'items',
|
||||
'enum',
|
||||
'anyOf'
|
||||
]
|
||||
const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => {
|
||||
// Filter out unsupported keys for Gemini
|
||||
const filteredObj = filterUnsupportedKeys(obj)
|
||||
|
||||
function filterPropertieAttributes(tool: MCPTool, filterNestedObj = false) {
|
||||
const properties = tool.inputSchema.properties
|
||||
if (!properties) {
|
||||
return {}
|
||||
}
|
||||
const getSubMap = (obj: Record<string, any>, keys: string[]) => {
|
||||
const filtered = Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key)))
|
||||
// Handle base schema properties
|
||||
const baseSchema = {
|
||||
description: filteredObj.description,
|
||||
nullable: filteredObj.nullable
|
||||
} as BaseSchema
|
||||
|
||||
if (filterNestedObj) {
|
||||
// Handle string type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
|
||||
if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
|
||||
return {
|
||||
...filtered,
|
||||
...(obj.type === 'object' && obj.properties
|
||||
? {
|
||||
properties: Object.fromEntries(
|
||||
Object.entries(obj.properties).map(([k, v]) => [
|
||||
k,
|
||||
(v as any).type === 'object' ? getSubMap(v as Record<string, any>, keys) : v
|
||||
...baseSchema,
|
||||
type: SchemaType.STRING,
|
||||
format: 'enum',
|
||||
enum: filteredObj.enum as string[]
|
||||
} as EnumStringSchema
|
||||
}
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.STRING,
|
||||
format: filteredObj.format === 'date-time' ? 'date-time' : undefined
|
||||
} as SimpleStringSchema
|
||||
}
|
||||
|
||||
// Handle number type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.NUMBER,
|
||||
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
|
||||
} as NumberSchema
|
||||
}
|
||||
|
||||
// Handle integer type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.INTEGER,
|
||||
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
|
||||
} as IntegerSchema
|
||||
}
|
||||
|
||||
// Handle boolean type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.BOOLEAN
|
||||
} as BooleanSchema
|
||||
}
|
||||
|
||||
// Handle array type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.ARRAY,
|
||||
items: filteredObj.items
|
||||
? ensureValidSchema(filteredObj.items as Record<string, any>)
|
||||
: ({ type: SchemaType.STRING } as SimpleStringSchema),
|
||||
minItems: filteredObj.minItems,
|
||||
maxItems: filteredObj.maxItems
|
||||
} as ArraySchema
|
||||
}
|
||||
|
||||
// Handle object type (default)
|
||||
const properties = filteredObj.properties
|
||||
? Object.fromEntries(
|
||||
Object.entries(filteredObj.properties).map(([key, value]) => [
|
||||
key,
|
||||
ensureValidSchema(value as Record<string, any>)
|
||||
])
|
||||
)
|
||||
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
|
||||
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.OBJECT,
|
||||
properties,
|
||||
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
|
||||
} as ObjectSchema
|
||||
}
|
||||
: {}),
|
||||
...(obj.type === 'array' && obj.items?.type === 'object'
|
||||
? {
|
||||
items: getSubMap(obj.items, keys)
|
||||
|
||||
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
|
||||
const supportedBaseKeys = ['description', 'nullable']
|
||||
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
|
||||
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
|
||||
const supportedBooleanKeys = [...supportedBaseKeys, 'type']
|
||||
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
|
||||
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
|
||||
|
||||
const filtered: Record<string, any> = {}
|
||||
|
||||
let keysToKeep: string[]
|
||||
|
||||
if (obj.type?.toLowerCase() === SchemaType.STRING) {
|
||||
keysToKeep = supportedStringKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||
keysToKeep = supportedNumberKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||
keysToKeep = supportedNumberKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||
keysToKeep = supportedBooleanKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||
keysToKeep = supportedArrayKeys
|
||||
} else {
|
||||
// Default to object type
|
||||
keysToKeep = supportedObjectKeys
|
||||
}
|
||||
: {})
|
||||
|
||||
// copy supported keys
|
||||
for (const key of keysToKeep) {
|
||||
if (obj[key] !== undefined) {
|
||||
filtered[key] = obj[key]
|
||||
}
|
||||
}
|
||||
|
||||
return filtered
|
||||
}
|
||||
|
||||
for (const [key, val] of Object.entries(properties)) {
|
||||
properties[key] = getSubMap(val, supportedAttributes)
|
||||
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
|
||||
const properties = tool.inputSchema.properties
|
||||
if (!properties) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// For OpenAI, we don't need to validate as strictly
|
||||
if (!filterNestedObj) {
|
||||
return properties
|
||||
}
|
||||
|
||||
const processedProperties = Object.fromEntries(
|
||||
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||
)
|
||||
|
||||
return processedProperties
|
||||
}
|
||||
|
||||
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
||||
return mcpTools.map((tool) => ({
|
||||
type: 'function',
|
||||
@ -132,7 +233,7 @@ export async function callMCPTool(tool: MCPTool): Promise<any> {
|
||||
if (tool.serverName === 'mcp-auto-install') {
|
||||
if (resp.data) {
|
||||
const mcpServer: MCPServer = {
|
||||
id: nanoid(),
|
||||
id: `f${nanoid()}`,
|
||||
name: resp.data.name,
|
||||
description: resp.data.description,
|
||||
baseUrl: resp.data.baseUrl,
|
||||
@ -183,7 +284,7 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
|
||||
return tool
|
||||
}
|
||||
|
||||
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] {
|
||||
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
|
||||
if (!mcpTools || mcpTools.length === 0) {
|
||||
// No tools available
|
||||
return []
|
||||
@ -195,18 +296,19 @@ export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTo
|
||||
const functionDeclaration: FunctionDeclaration = {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
...(Object.keys(properties).length > 0
|
||||
? {
|
||||
parameters: {
|
||||
type: SchemaType.OBJECT,
|
||||
properties
|
||||
}
|
||||
}
|
||||
: {})
|
||||
properties:
|
||||
Object.keys(properties).length > 0
|
||||
? Object.fromEntries(
|
||||
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||
)
|
||||
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
|
||||
} as FunctionDeclarationSchema
|
||||
}
|
||||
functions.push(functionDeclaration)
|
||||
}
|
||||
const tool: geminiToool = {
|
||||
const tool: geminiTool = {
|
||||
functionDeclarations: functions
|
||||
}
|
||||
return [tool]
|
||||
|
||||
10
yarn.lock
10
yarn.lock
@ -1212,10 +1212,10 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@google/generative-ai@npm:^0.21.0":
|
||||
version: 0.21.0
|
||||
resolution: "@google/generative-ai@npm:0.21.0"
|
||||
checksum: 10c0/cff5946c5964f2380e5097d82bd563d79be27a1a5ac604aaaad3f9ba3382992e4f0a371bd255baabfba4e5bdf296d8ce1410cbd65424afa98e64b2590fe49f3b
|
||||
"@google/generative-ai@npm:^0.24.0":
|
||||
version: 0.24.0
|
||||
resolution: "@google/generative-ai@npm:0.24.0"
|
||||
checksum: 10c0/31452bf2653cdee7fd61eb209f16ac0ef82c94c4175909ba40e1088e938e3e19e01f628dfb80d429dae3338fc8487e9a0fd8a6ff0164189f2722211175690b0b
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -3767,7 +3767,7 @@ __metadata:
|
||||
"@eslint-react/eslint-plugin": "npm:^1.36.1"
|
||||
"@eslint/js": "npm:^9.22.0"
|
||||
"@google/genai": "npm:^0.4.0"
|
||||
"@google/generative-ai": "npm:^0.21.0"
|
||||
"@google/generative-ai": "npm:^0.24.0"
|
||||
"@hello-pangea/dnd": "npm:^16.6.0"
|
||||
"@kangfenmao/keyv-storage": "npm:^0.1.0"
|
||||
"@langchain/community": "npm:^0.3.36"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user