feat[MCP]: Optimize list tool performance. (#3598)

* refactor: remove unused filterMCPTools function calls from providers

* fix: ensure enabledMCPs is checked for length before processing tools

* feat: implement caching for tools retrieved from MCP server
This commit is contained in:
LiuVaayne 2025-03-19 20:09:05 +08:00 committed by GitHub
parent eef141cbe7
commit 0fe7d559c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 33 additions and 11 deletions

View File

@ -8,6 +8,7 @@ import log from 'electron-log'
import { EventEmitter } from 'events' import { EventEmitter } from 'events'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
import { CacheService } from './CacheService'
import { windowService } from './WindowService' import { windowService } from './WindowService'
/** /**
@ -446,14 +447,33 @@ export default class MCPService extends EventEmitter {
if (!this.clients[serverName]) { if (!this.clients[serverName]) {
throw new Error(`MCP Client ${serverName} not found`) throw new Error(`MCP Client ${serverName} not found`)
} }
const cacheKey = `mcp:list_tool:${serverName}`
if (CacheService.has(cacheKey)) {
log.info(`[MCP] Tools from ${serverName} loaded from cache`)
// Check if cache is still valid
const cachedTools = CacheService.get<MCPTool[]>(cacheKey)
if (cachedTools && cachedTools.length > 0) {
return cachedTools
}
CacheService.remove(cacheKey)
}
const { tools } = await this.clients[serverName].listTools() const { tools } = await this.clients[serverName].listTools()
log.info(`[MCP] Tools from ${serverName}:`, tools) const transformedTools = tools.map((tool: any) => ({
return tools.map((tool: any) => ({
...tool, ...tool,
serverName, serverName,
id: 'f' + uuidv4().replace(/-/g, '') id: 'f' + uuidv4().replace(/-/g, '')
})) }))
// Cache the tools for 5 minutes
if (transformedTools.length > 0) {
CacheService.set(cacheKey, transformedTools, 5 * 60 * 1000)
}
log.info(`[MCP] Tools from ${serverName}:`, transformedTools)
return transformedTools
} }
/** /**

View File

@ -20,7 +20,6 @@ import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { import {
anthropicToolUseToMcpTool, anthropicToolUseToMcpTool,
callMCPTool, callMCPTool,
filterMCPTools,
mcpToolsToAnthropicTools, mcpToolsToAnthropicTools,
upsertMCPToolResponse upsertMCPToolResponse
} from '@renderer/utils/mcp-tools' } from '@renderer/utils/mcp-tools'
@ -180,7 +179,6 @@ export default class AnthropicProvider extends BaseProvider {
const userMessages = flatten(userMessagesParams) const userMessages = flatten(userMessagesParams)
const lastUserMessage = _messages.findLast((m) => m.role === 'user') const lastUserMessage = _messages.findLast((m) => m.role === 'user')
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined
const body: MessageCreateParamsNonStreaming = { const body: MessageCreateParamsNonStreaming = {

View File

@ -27,7 +27,6 @@ import { Assistant, FileType, FileTypes, MCPToolResponse, Message, Model, Provid
import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { import {
callMCPTool, callMCPTool,
filterMCPTools,
geminiFunctionCallToMcpTool, geminiFunctionCallToMcpTool,
mcpToolsToGeminiTools, mcpToolsToGeminiTools,
upsertMCPToolResponse upsertMCPToolResponse
@ -197,7 +196,6 @@ export default class GeminiProvider extends BaseProvider {
history.push(await this.getMessageContents(message)) history.push(await this.getMessageContents(message))
} }
mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs)
const tools = mcpToolsToGeminiTools(mcpTools) const tools = mcpToolsToGeminiTools(mcpTools)
const toolResponses: MCPToolResponse[] = [] const toolResponses: MCPToolResponse[] = []

View File

@ -29,7 +29,6 @@ import {
import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { import {
callMCPTool, callMCPTool,
filterMCPTools,
mcpToolsToOpenAITools, mcpToolsToOpenAITools,
openAIToolsToMcpTool, openAIToolsToMcpTool,
upsertMCPToolResponse upsertMCPToolResponse
@ -426,7 +425,6 @@ export default class OpenAIProvider extends BaseProvider {
const { signal } = abortController const { signal } = abortController
await this.checkIsCopilot() await this.checkIsCopilot()
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined
const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter( const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter(

View File

@ -3,7 +3,7 @@ import { SEARCH_SUMMARY_PROMPT } from '@renderer/config/prompts'
import i18n from '@renderer/i18n' import i18n from '@renderer/i18n'
import store from '@renderer/store' import store from '@renderer/store'
import { setGenerating } from '@renderer/store/runtime' import { setGenerating } from '@renderer/store/runtime'
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types' import { Assistant, MCPTool, Message, Model, Provider, Suggestion } from '@renderer/types'
import { formatMessageError, isAbortError } from '@renderer/utils/error' import { formatMessageError, isAbortError } from '@renderer/utils/error'
import { cloneDeep, findLast, isEmpty } from 'lodash' import { cloneDeep, findLast, isEmpty } from 'lodash'
@ -97,7 +97,15 @@ export async function fetchChatCompletion({
} }
} }
const lastUserMessage = findLast(messages, (m) => m.role === 'user')
// Get MCP tools
let mcpTools: MCPTool[] = []
const enabledMCPs = lastUserMessage?.enabledMCPs
if (enabledMCPs && enabledMCPs.length > 0) {
const allMCPTools = await window.api.mcp.listTools() const allMCPTools = await window.api.mcp.listTools()
mcpTools = allMCPTools.filter((tool) => enabledMCPs.some((mcp) => mcp.name === tool.serverName))
}
await AI.completions({ await AI.completions({
messages: filterUsefulMessages(messages), messages: filterUsefulMessages(messages),
@ -131,7 +139,7 @@ export async function fetchChatCompletion({
onResponse({ ...message, status: 'pending' }) onResponse({ ...message, status: 'pending' })
}, },
mcpTools: allMCPTools mcpTools: mcpTools
}) })
message.status = 'success' message.status = 'success'