From 0fe7d559c8b2b1a979ee71479759afefe3e93ec6 Mon Sep 17 00:00:00 2001 From: LiuVaayne <10231735+vaayne@users.noreply.github.com> Date: Wed, 19 Mar 2025 20:09:05 +0800 Subject: [PATCH] feat[MCP]: Optimize list tool performance. (#3598) * refactor: remove unused filterMCPTools function calls from providers * fix: ensure enabledMCPs is checked for length before processing tools * feat: implement caching for tools retrieved from MCP server --- src/main/services/MCPService.ts | 24 +++++++++++++++++-- .../src/providers/AnthropicProvider.ts | 2 -- src/renderer/src/providers/GeminiProvider.ts | 2 -- src/renderer/src/providers/OpenAIProvider.ts | 2 -- src/renderer/src/services/ApiService.ts | 14 ++++++++--- 5 files changed, 33 insertions(+), 11 deletions(-) diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 48418aa0..f234d4d1 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -8,6 +8,7 @@ import log from 'electron-log' import { EventEmitter } from 'events' import { v4 as uuidv4 } from 'uuid' +import { CacheService } from './CacheService' import { windowService } from './WindowService' /** @@ -446,14 +447,33 @@ export default class MCPService extends EventEmitter { if (!this.clients[serverName]) { throw new Error(`MCP Client ${serverName} not found`) } + const cacheKey = `mcp:list_tool:${serverName}` + + if (CacheService.has(cacheKey)) { + log.info(`[MCP] Tools from ${serverName} loaded from cache`) + // Check if cache is still valid + const cachedTools = CacheService.get(cacheKey) + if (cachedTools && cachedTools.length > 0) { + return cachedTools + } + CacheService.remove(cacheKey) + } + const { tools } = await this.clients[serverName].listTools() - log.info(`[MCP] Tools from ${serverName}:`, tools) - return tools.map((tool: any) => ({ + const transformedTools = tools.map((tool: any) => ({ ...tool, serverName, id: 'f' + uuidv4().replace(/-/g, '') })) + + // Cache the tools for 5 minutes + if (transformedTools.length > 0) { + CacheService.set(cacheKey, transformedTools, 5 * 60 * 1000) + } + + log.info(`[MCP] Tools from ${serverName}:`, transformedTools) + return transformedTools } /** diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index 98cf044e..cad5df45 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -20,7 +20,6 @@ import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { anthropicToolUseToMcpTool, callMCPTool, - filterMCPTools, mcpToolsToAnthropicTools, upsertMCPToolResponse } from '@renderer/utils/mcp-tools' @@ -180,7 +179,6 @@ export default class AnthropicProvider extends BaseProvider { const userMessages = flatten(userMessagesParams) const lastUserMessage = _messages.findLast((m) => m.role === 'user') - mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs) const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined const body: MessageCreateParamsNonStreaming = { diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index b73c9d0d..653877e3 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -27,7 +27,6 @@ import { Assistant, FileType, FileTypes, MCPToolResponse, Message, Model, Provid import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { callMCPTool, - filterMCPTools, geminiFunctionCallToMcpTool, mcpToolsToGeminiTools, upsertMCPToolResponse @@ -197,7 +196,6 @@ export default class GeminiProvider extends BaseProvider { history.push(await this.getMessageContents(message)) } - mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs) const tools = mcpToolsToGeminiTools(mcpTools) const toolResponses: MCPToolResponse[] = [] diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 47828e77..802b39b8 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -29,7 +29,6 @@ import { import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { callMCPTool, - filterMCPTools, mcpToolsToOpenAITools, openAIToolsToMcpTool, upsertMCPToolResponse @@ -426,7 +425,6 @@ export default class OpenAIProvider extends BaseProvider { const { signal } = abortController await this.checkIsCopilot() - mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs) const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter( diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 80a24409..a8e2e3c2 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -3,7 +3,7 @@ import { SEARCH_SUMMARY_PROMPT } from '@renderer/config/prompts' import i18n from '@renderer/i18n' import store from '@renderer/store' import { setGenerating } from '@renderer/store/runtime' -import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types' +import { Assistant, MCPTool, Message, Model, Provider, Suggestion } from '@renderer/types' import { formatMessageError, isAbortError } from '@renderer/utils/error' import { cloneDeep, findLast, isEmpty } from 'lodash' @@ -97,7 +97,15 @@ export async function fetchChatCompletion({ } } - const allMCPTools = await window.api.mcp.listTools() + const lastUserMessage = findLast(messages, (m) => m.role === 'user') + // Get MCP tools + let mcpTools: MCPTool[] = [] + const enabledMCPs = lastUserMessage?.enabledMCPs + + if (enabledMCPs && enabledMCPs.length > 0) { + const allMCPTools = await window.api.mcp.listTools() + mcpTools = allMCPTools.filter((tool) => enabledMCPs.some((mcp) => mcp.name === tool.serverName)) + } await AI.completions({ messages: filterUsefulMessages(messages), @@ -131,7 +139,7 @@ export async function fetchChatCompletion({ onResponse({ ...message, status: 'pending' }) }, - mcpTools: allMCPTools + mcpTools: mcpTools }) message.status = 'success'