From 24e28b86cf78bb9f83639ce81dd65b5e1d965bda Mon Sep 17 00:00:00 2001 From: LiuVaayne <10231735+vaayne@users.noreply.github.com> Date: Wed, 9 Apr 2025 11:22:14 +0800 Subject: [PATCH] feat(mcp): support MCP by prompt (#4476) * feat: implement tool usage handling and system prompt building for AI providers * refactor: streamline tool usage handling and remove unused code in OpenAIProvider and formats * refactor: simplify tool usage handling in Anthropic and Gemini providers, and update prompt instructions * refactor: remove unused function calling model checks and simplify MCP tools handling in Inputbar * hidden tool use in message * revert import * Add idx parameter to parseAndCallTools for unique tool IDs --- .../src/pages/home/Inputbar/Inputbar.tsx | 32 +-- .../pages/home/Messages/MessageContent.tsx | 4 +- .../providers/AiProvider/AnthropicProvider.ts | 62 ++--- .../providers/AiProvider/GeminiProvider.ts | 96 +++---- .../providers/AiProvider/OpenAIProvider.ts | 238 ++++-------------- src/renderer/src/services/ApiService.ts | 2 +- src/renderer/src/utils/mcp-tools.ts | 86 ++++++- src/renderer/src/utils/prompt.ts | 158 ++++++++++++ 8 files changed, 365 insertions(+), 313 deletions(-) create mode 100644 src/renderer/src/utils/prompt.ts diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 63f23bc7..899f2380 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -15,7 +15,7 @@ import { } from '@ant-design/icons' import { QuickPanelListItem, QuickPanelView, useQuickPanel } from '@renderer/components/QuickPanel' import TranslateButton from '@renderer/components/TranslateButton' -import { isFunctionCallingModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' +import { isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' import db from '@renderer/databases' import { useAssistant } from '@renderer/hooks/useAssistant' import { useKnowledgeBases } from '@renderer/hooks/useKnowledge' @@ -118,7 +118,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = const quickPanel = useQuickPanel() const showKnowledgeIcon = useSidebarIconShow('knowledge') - const showMCPToolsIcon = isFunctionCallingModel(model) + // const showMCPToolsIcon = isFunctionCallingModel(model) const [tokenCount, setTokenCount] = useState(0) @@ -198,10 +198,8 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = userMessage.mentions = mentionModels } - if (isFunctionCallingModel(model)) { - if (!isEmpty(enabledMCPs) && !isEmpty(activedMcpServers)) { - userMessage.enabledMCPs = activedMcpServers.filter((server) => enabledMCPs?.some((s) => s.id === server.id)) - } + if (!isEmpty(enabledMCPs) && !isEmpty(activedMcpServers)) { + userMessage.enabledMCPs = activedMcpServers.filter((server) => enabledMCPs?.some((s) => s.id === server.id)) } userMessage.usage = await estimateMessageUsage(userMessage) @@ -230,7 +228,6 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = inputEmpty, loading, mentionModels, - model, resizeTextArea, selectedKnowledgeBases, text, @@ -346,17 +343,16 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = description: '', icon: , isMenu: true, - disabled: !showKnowledgeIcon || files.length > 0, + disabled: files.length > 0, action: () => { knowledgeBaseButtonRef.current?.openQuickPanel() } }, { label: t('settings.mcp.title'), - description: showMCPToolsIcon ? '' : t('settings.mcp.not_support'), + description: t('settings.mcp.not_support'), icon: , isMenu: true, - disabled: !showMCPToolsIcon, action: () => { mcpToolsButtonRef.current?.openQuickPanel() } @@ -378,7 +374,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = } } ] - }, [files.length, model, openSelectFileMenu, showKnowledgeIcon, showMCPToolsIcon, t, text, translate]) + }, [files.length, model, openSelectFileMenu, t, text, translate]) const handleKeyDown = (event: React.KeyboardEvent) => { const isEnterPressed = event.keyCode == 13 @@ -954,14 +950,12 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = disabled={files.length > 0} /> )} - {showMCPToolsIcon && ( - - )} + = ({ message: _message, model }) => { const content = `[@${model.name}](#) ${getBriefInfo(message.content)}` return } - + const toolUseRegex = /([\s\S]*?)<\/tool_use>/g return ( @@ -205,7 +205,7 @@ const MessageContent: React.FC = ({ message: _message, model }) => { - + {message.metadata?.generateImage && } {message.translatedContent && ( diff --git a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts index f495646d..d9f72f71 100644 --- a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts +++ b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts @@ -1,10 +1,5 @@ import Anthropic from '@anthropic-ai/sdk' -import { - MessageCreateParamsNonStreaming, - MessageParam, - ToolResultBlockParam, - ToolUseBlock -} from '@anthropic-ai/sdk/resources' +import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { isReasoningModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' @@ -17,13 +12,9 @@ import { } from '@renderer/services/MessagesService' import { Assistant, FileTypes, MCPToolResponse, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { - anthropicToolUseToMcpTool, - callMCPTool, - mcpToolsToAnthropicTools, - upsertMCPToolResponse -} from '@renderer/utils/mcp-tools' -import { first, flatten, isEmpty, sum, takeRight } from 'lodash' +import { parseAndCallTools } from '@renderer/utils/mcp-tools' +import { buildSystemPrompt } from '@renderer/utils/prompt' +import { first, flatten, sum, takeRight } from 'lodash' import OpenAI from 'openai' import { CompletionsParams } from '.' @@ -182,16 +173,21 @@ export default class AnthropicProvider extends BaseProvider { const userMessages = flatten(userMessagesParams) const lastUserMessage = _messages.findLast((m) => m.role === 'user') - const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined + // const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined + + let systemPrompt = assistant.prompt + if (mcpTools && mcpTools.length > 0) { + systemPrompt = buildSystemPrompt(systemPrompt, mcpTools) + } const body: MessageCreateParamsNonStreaming = { model: model.id, messages: userMessages, - tools: isEmpty(tools) ? undefined : tools, + // tools: isEmpty(tools) ? undefined : tools, max_tokens: maxTokens || DEFAULT_MAX_TOKENS, temperature: this.getTemperature(assistant, model), top_p: this.getTopP(assistant, model), - system: assistant.prompt, + system: systemPrompt, // @ts-ignore thinking thinking: this.getReasoningEffort(assistant, model), ...this.getCustomParameters(assistant) @@ -239,7 +235,6 @@ export default class AnthropicProvider extends BaseProvider { const processStream = (body: MessageCreateParamsNonStreaming, idx: number) => { return new Promise((resolve, reject) => { - const toolCalls: ToolUseBlock[] = [] let hasThinkingContent = false this.sdk.messages .stream({ ...body, stream: true }, { signal }) @@ -292,30 +287,11 @@ export default class AnthropicProvider extends BaseProvider { } }) }) - .on('contentBlock', (content) => { - if (content.type == 'tool_use') { - toolCalls.push(content) - } - }) .on('finalMessage', async (message) => { - if (toolCalls.length > 0) { - const toolCallResults: ToolResultBlockParam[] = [] - - for (const toolCall of toolCalls) { - const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) - if (mcpTool) { - upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'invoking', id: toolCall.id }, onChunk) - const resp = await callMCPTool(mcpTool) - toolCallResults.push({ type: 'tool_result', tool_use_id: toolCall.id, content: resp.content }) - upsertMCPToolResponse( - toolResponses, - { tool: mcpTool, status: 'done', response: resp, id: toolCall.id }, - onChunk - ) - } - } - - if (toolCallResults.length > 0) { + const content = message.content[0] + if (content && content.type === 'text') { + const toolResults = await parseAndCallTools(content.text, toolResponses, onChunk, idx, mcpTools) + if (toolResults.length > 0) { userMessages.push({ role: message.role, content: message.content @@ -323,12 +299,10 @@ export default class AnthropicProvider extends BaseProvider { userMessages.push({ role: 'user', - content: toolCallResults + content: toolResults.join('\n') }) - const newBody = body - body.messages = userMessages - + newBody.messages = userMessages await processStream(newBody, idx + 1) } } diff --git a/src/renderer/src/providers/AiProvider/GeminiProvider.ts b/src/renderer/src/providers/AiProvider/GeminiProvider.ts index 7717d6cb..be40686f 100644 --- a/src/renderer/src/providers/AiProvider/GeminiProvider.ts +++ b/src/renderer/src/providers/AiProvider/GeminiProvider.ts @@ -8,8 +8,6 @@ import { import { Content, FileDataPart, - FunctionCallPart, - FunctionResponsePart, GenerateContentStreamResult, GoogleGenerativeAI, HarmBlockThreshold, @@ -18,7 +16,8 @@ import { Part, RequestOptions, SafetySetting, - TextPart + TextPart, + Tool } from '@google/generative-ai' import { isGemmaModel, isWebSearchModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' @@ -33,12 +32,8 @@ import { import WebSearchService from '@renderer/services/WebSearchService' import { Assistant, FileType, FileTypes, MCPToolResponse, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { - callMCPTool, - geminiFunctionCallToMcpTool, - mcpToolsToGeminiTools, - upsertMCPToolResponse -} from '@renderer/utils/mcp-tools' +import { parseAndCallTools } from '@renderer/utils/mcp-tools' +import { buildSystemPrompt } from '@renderer/utils/prompt' import { MB } from '@shared/config/constant' import axios from 'axios' import { isEmpty, takeRight } from 'lodash' @@ -230,7 +225,14 @@ export default class GeminiProvider extends BaseProvider { history.push(await this.getMessageContents(message)) } - const tools = mcpToolsToGeminiTools(mcpTools) + let systemInstruction = assistant.prompt + + if (mcpTools && mcpTools.length > 0) { + systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools) + } + + // const tools = mcpToolsToGeminiTools(mcpTools) + const tools: Tool[] = [] const toolResponses: MCPToolResponse[] = [] if (!WebSearchService.isOverwriteEnabled() && assistant.enableWebSearch && isWebSearchModel(model)) { @@ -243,7 +245,7 @@ export default class GeminiProvider extends BaseProvider { const geminiModel = this.sdk.getGenerativeModel( { model: model.id, - ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }), + ...(isGemmaModel(model) ? {} : { systemInstruction: systemInstruction }), safetySettings: this.getSafetySettings(model.id), tools: tools, generationConfig: { @@ -268,7 +270,7 @@ export default class GeminiProvider extends BaseProvider { { text: 'user\n' + - assistant.prompt + + systemInstruction + '\n' + 'user\n' + messageContents.parts[0].text + @@ -307,7 +309,25 @@ export default class GeminiProvider extends BaseProvider { const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) let time_first_token_millsec = 0 + const processToolUses = async (content: string, idx: number) => { + const toolResults = await parseAndCallTools(content, toolResponses, onChunk, idx, mcpTools) + if (toolResults && toolResults.length > 0) { + history.push(messageContents) + const newChat = geminiModel.startChat({ history }) + const newStream = await newChat.sendMessageStream( + [ + { + text: toolResults.join('\n') + } + ], + { signal } + ) + await processStream(newStream, idx + 1) + } + } + const processStream = async (stream: GenerateContentStreamResult, idx: number) => { + let content = '' for await (const chunk of stream.stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break @@ -317,56 +337,8 @@ export default class GeminiProvider extends BaseProvider { const time_completion_millsec = new Date().getTime() - start_time_millsec - const functionCalls = chunk.functionCalls() - - if (functionCalls) { - const fcallParts: FunctionCallPart[] = [] - const fcRespParts: FunctionResponsePart[] = [] - for (const call of functionCalls) { - console.log('Function call:', call) - fcallParts.push({ functionCall: call } as FunctionCallPart) - const mcpTool = geminiFunctionCallToMcpTool(mcpTools, call) - if (mcpTool) { - upsertMCPToolResponse( - toolResponses, - { - tool: mcpTool, - status: 'invoking', - id: `${call.name}-${idx}` - }, - onChunk - ) - const toolCallResponse = await callMCPTool(mcpTool) - fcRespParts.push({ - functionResponse: { - name: mcpTool.id, - response: toolCallResponse - } - }) - upsertMCPToolResponse( - toolResponses, - { - tool: mcpTool, - status: 'done', - response: toolCallResponse, - id: `${call.name}-${idx}` - }, - onChunk - ) - } - } - - if (fcRespParts) { - history.push(messageContents) - history.push({ - role: 'model', - parts: fcallParts - }) - const newChat = geminiModel.startChat({ history }) - const newStream = await newChat.sendMessageStream(fcRespParts, { signal }) - await processStream(newStream, idx + 1) - } - } + content += chunk.text() + processToolUses(content, idx) onChunk({ text: chunk.text(), diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index 97f639b1..1ece14e9 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -31,21 +31,14 @@ import { } from '@renderer/types' import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { addImageFileToContents } from '@renderer/utils/formats' -import { - callMCPTool, - mcpToolsToOpenAITools, - openAIToolsToMcpTool, - upsertMCPToolResponse -} from '@renderer/utils/mcp-tools' +import { parseAndCallTools } from '@renderer/utils/mcp-tools' +import { buildSystemPrompt } from '@renderer/utils/prompt' import { isEmpty, takeRight } from 'lodash' import OpenAI, { AzureOpenAI } from 'openai' import { - ChatCompletionAssistantMessageParam, ChatCompletionContentPart, ChatCompletionCreateParamsNonStreaming, - ChatCompletionMessageParam, - ChatCompletionMessageToolCall, - ChatCompletionToolMessageParam + ChatCompletionMessageParam } from 'openai/resources' import { CompletionsParams } from '.' @@ -296,55 +289,6 @@ export default class OpenAIProvider extends BaseProvider { return model.id.startsWith('o1') || model.id.startsWith('o3') } - /** - * Check if the model is a Glm-4-alltools - * @param model - The model - * @returns True if the model is a Glm-4-alltools, false otherwise - */ - private isZhipuTool(model: Model) { - return model.id.includes('glm-4-alltools') - } - - /** - * Clean the tool call arguments - * @param toolCall - The tool call - * @returns The cleaned tool call - */ - private cleanToolCallArgs(toolCall: ChatCompletionMessageToolCall): ChatCompletionMessageToolCall { - if (toolCall.function.arguments) { - let args = toolCall.function.arguments - const codeBlockRegex = /```(?:\w*\n)?([\s\S]*?)```/ - const match = args.match(codeBlockRegex) - if (match) { - // Extract content from code block - let extractedArgs = match[1].trim() - // Clean function call format like tool_call(name1=value1,name2=value2) - const functionCallRegex = /^\s*\w+\s*\(([\s\S]*?)\)\s*$/ - const functionMatch = extractedArgs.match(functionCallRegex) - if (functionMatch) { - // Try to convert parameters to JSON format - const params = functionMatch[1].split(',').filter(Boolean) - const paramsObj = {} - params.forEach((param) => { - const [name, value] = param.split('=').map((p) => p.trim()) - if (name && value !== undefined) { - paramsObj[name] = value - } - }) - extractedArgs = JSON.stringify(paramsObj) - } - toolCall.function.arguments = extractedArgs - } - args = toolCall.function.arguments - const firstBraceIndex = args.indexOf('{') - const lastBraceIndex = args.lastIndexOf('}') - if (firstBraceIndex !== -1 && lastBraceIndex !== -1 && firstBraceIndex < lastBraceIndex) { - toolCall.function.arguments = args.substring(firstBraceIndex, lastBraceIndex + 1) - } - } - return toolCall - } - /** * Generate completions for the assistant * @param messages - The messages @@ -359,14 +303,16 @@ export default class OpenAIProvider extends BaseProvider { const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) messages = addImageFileToContents(messages) - let systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined - + let systemMessage = { role: 'system', content: assistant.prompt || '' } if (isOpenAIoSeries(model)) { systemMessage = { role: 'developer', content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` } } + if (mcpTools && mcpTools.length > 0) { + systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools) + } const userMessages: ChatCompletionMessageParam[] = [] const _messages = filterUserRoleStartMessages( @@ -429,14 +375,51 @@ export default class OpenAIProvider extends BaseProvider { const { signal } = abortController await this.checkIsCopilot() - const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined - const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter( Boolean ) as ChatCompletionMessageParam[] const toolResponses: MCPToolResponse[] = [] let firstChunk = true + + const processToolUses = async (content: string, idx: number) => { + const toolResults = await parseAndCallTools(content, toolResponses, onChunk, idx, mcpTools) + + if (toolResults.length > 0) { + reqMessages.push({ + role: 'assistant', + content: content + } as ChatCompletionMessageParam) + reqMessages.push({ + role: 'user', + content: toolResults.join('\n') + } as ChatCompletionMessageParam) + + const newStream = await this.sdk.chat.completions + // @ts-ignore key is not typed + .create( + { + model: model.id, + messages: reqMessages, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_tokens: maxTokens, + keep_alive: this.keepAliveTime, + stream: isSupportStreamOutput(), + // tools: tools, + ...getOpenAIWebSearchParams(assistant, model), + ...this.getReasoningEffort(assistant, model), + ...this.getProviderSpecificParameters(assistant, model), + ...this.getCustomParameters(assistant) + }, + { + signal + } + ) + await processStream(newStream, idx + 1) + } + } + const processStream = async (stream: any, idx: number) => { if (!isSupportStreamOutput()) { const time_completion_millsec = new Date().getTime() - start_time_millsec @@ -450,14 +433,17 @@ export default class OpenAIProvider extends BaseProvider { } }) } - const final_tool_calls = {} as Record + let content = '' for await (const chunk of stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { break } const delta = chunk.choices[0]?.delta + if (delta?.content) { + content += delta.content + } if (delta?.reasoning_content || delta?.reasoning) { hasReasoningContent = true @@ -479,29 +465,6 @@ export default class OpenAIProvider extends BaseProvider { const finishReason = chunk.choices[0]?.finish_reason - if (delta?.tool_calls?.length) { - const chunkToolCalls = delta.tool_calls - for (const t of chunkToolCalls) { - const { index, id, function: fn, type } = t - const args = fn && typeof fn.arguments === 'string' ? fn.arguments : '' - if (!(index in final_tool_calls)) { - final_tool_calls[index] = { - id, - function: { - name: fn?.name, - arguments: args - }, - type - } as ChatCompletionMessageToolCall - } else { - final_tool_calls[index].function.arguments += args - } - } - if (finishReason !== 'tool_calls') { - continue - } - } - let webSearch: any[] | undefined = undefined if (assistant.enableWebSearch && isZhipuModel(model) && finishReason === 'stop') { webSearch = chunk?.web_search @@ -510,102 +473,6 @@ export default class OpenAIProvider extends BaseProvider { webSearch = chunk?.search_info?.search_results firstChunk = true } - - if (finishReason === 'tool_calls' || (finishReason === 'stop' && Object.keys(final_tool_calls).length > 0)) { - const toolCalls = Object.values(final_tool_calls).map(this.cleanToolCallArgs) - console.log('start invoke tools', toolCalls) - if (this.isZhipuTool(model)) { - reqMessages.push({ - role: 'assistant', - content: `argments=${JSON.stringify(toolCalls[0].function.arguments)}` - }) - } else { - reqMessages.push({ - role: 'assistant', - tool_calls: toolCalls - } as ChatCompletionAssistantMessageParam) - } - - for (const toolCall of toolCalls) { - const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall) - - if (!mcpTool) { - continue - } - - upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'invoking', id: toolCall.id }, onChunk) - - const toolCallResponse = await callMCPTool(mcpTool) - const toolResponsContent: { type: string; text?: string; image_url?: { url: string } }[] = [] - for (const content of toolCallResponse.content) { - if (content.type === 'text') { - toolResponsContent.push({ - type: 'text', - text: content.text - }) - } else if (content.type === 'image') { - toolResponsContent.push({ - type: 'image_url', - image_url: { url: `data:${content.mimeType};base64,${content.data}` } - }) - } else { - console.warn('Unsupported content type:', content.type) - toolResponsContent.push({ - type: 'text', - text: 'unsupported content type: ' + content.type - }) - } - } - - const provider = lastUserMessage?.model?.provider - const modelName = lastUserMessage?.model?.name - - if ( - modelName?.toLocaleLowerCase().includes('gpt') || - (provider === 'dashscope' && modelName?.toLocaleLowerCase().includes('qwen')) - ) { - reqMessages.push({ - role: 'tool', - content: toolResponsContent, - tool_call_id: toolCall.id - } as ChatCompletionToolMessageParam) - } else { - reqMessages.push({ - role: 'tool', - content: JSON.stringify(toolResponsContent), - tool_call_id: toolCall.id - } as ChatCompletionToolMessageParam) - } - upsertMCPToolResponse( - toolResponses, - { tool: mcpTool, status: 'done', response: toolCallResponse, id: toolCall.id }, - onChunk - ) - } - const newStream = await this.sdk.chat.completions - // @ts-ignore key is not typed - .create( - { - model: model.id, - messages: reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_tokens: maxTokens, - keep_alive: this.keepAliveTime, - stream: isSupportStreamOutput(), - tools: tools, - ...getOpenAIWebSearchParams(assistant, model), - ...this.getReasoningEffort(assistant, model), - ...this.getProviderSpecificParameters(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal - } - ) - await processStream(newStream, idx + 1) - } - onChunk({ text: delta?.content || '', reasoning_content: delta?.reasoning_content || delta?.reasoning || '', @@ -622,7 +489,10 @@ export default class OpenAIProvider extends BaseProvider { mcpToolResponse: toolResponses }) } + + await processToolUses(content, idx) } + const stream = await this.sdk.chat.completions // @ts-ignore key is not typed .create( @@ -634,7 +504,7 @@ export default class OpenAIProvider extends BaseProvider { max_tokens: maxTokens, keep_alive: this.keepAliveTime, stream: isSupportStreamOutput(), - tools: tools, + // tools: tools, ...getOpenAIWebSearchParams(assistant, model), ...this.getReasoningEffort(assistant, model), ...this.getProviderSpecificParameters(assistant, model), diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index c87b7856..d5447ec9 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -250,7 +250,7 @@ export async function fetchChatCompletion({ } } } - console.log('message', message) + // console.log('message', message) } catch (error: any) { if (isAbortError(error)) { message.status = 'paused' diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index a7573c3a..a5a6caec 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -21,7 +21,7 @@ import { addMCPServer } from '@renderer/store/mcp' import { MCPServer, MCPTool, MCPToolResponse } from '@renderer/types' import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resources' -import { ChunkCallbackData } from '../providers/AiProvider' +import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider' const ensureValidSchema = (obj: Record): FunctionDeclarationSchemaProperty => { // Filter out unsupported keys for Gemini @@ -375,3 +375,87 @@ export function getMcpServerByTool(tool: MCPTool) { const servers = store.getState().mcp.servers return servers.find((s) => s.id === tool.serverId) } + +export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolResponse[] { + if (!content || !mcpTools || mcpTools.length === 0) { + return [] + } + const toolUsePattern = + /([\s\S]*?)([\s\S]*?)<\/name>([\s\S]*?)([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g + const tools: MCPToolResponse[] = [] + let match + let idx = 0 + // Find all tool use blocks + while ((match = toolUsePattern.exec(content)) !== null) { + // const fullMatch = match[0] + const toolName = match[2].trim() + const toolArgs = match[4].trim() + + // Try to parse the arguments as JSON + let parsedArgs + try { + parsedArgs = JSON.parse(toolArgs) + } catch (error) { + // If parsing fails, use the string as is + parsedArgs = toolArgs + } + // console.log(`Parsed arguments for tool "${toolName}":`, parsedArgs) + const mcpTool = mcpTools.find((tool) => tool.id === toolName) + if (!mcpTool) { + console.error(`Tool "${toolName}" not found in MCP tools`) + continue + } + + // Add to tools array + tools.push({ + id: `${toolName}-${idx++}`, // Unique ID for each tool use + tool: { + ...mcpTool, + inputSchema: parsedArgs + }, + status: 'pending' + }) + + // Remove the tool use block from the content + // content = content.replace(fullMatch, '') + } + return tools +} + +export async function parseAndCallTools( + content: string, + toolResponses: MCPToolResponse[], + onChunk: CompletionsParams['onChunk'], + idx: number, + mcpTools?: MCPTool[] +): Promise { + const toolResults: string[] = [] + // process tool use + const tools = parseToolUse(content, mcpTools || []) + if (!tools || tools.length === 0) { + return toolResults + } + for (let i = 0; i < tools.length; i++) { + const tool = tools[i] + upsertMCPToolResponse(toolResponses, { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'invoking' }, onChunk) + } + + const toolPromises = tools.map(async (tool, i) => { + const toolCallResponse = await callMCPTool(tool.tool) + const result = ` + + ${tool.id} + ${JSON.stringify(toolCallResponse)} + + `.trim() + upsertMCPToolResponse( + toolResponses, + { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'done', response: toolCallResponse }, + onChunk + ) + return result + }) + + toolResults.push(...(await Promise.all(toolPromises))) + return toolResults +} diff --git a/src/renderer/src/utils/prompt.ts b/src/renderer/src/utils/prompt.ts new file mode 100644 index 00000000..698fdb11 --- /dev/null +++ b/src/renderer/src/utils/prompt.ts @@ -0,0 +1,158 @@ +import { MCPTool } from '@renderer/types' + +export const SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \ +You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use. + +## Tool Use Formatting + +Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure: + + + {tool_name} + {json_arguments} + + +The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example: + + python_interpreter + {"code": "5 + 3 + 1294.678"} + + +The user will respond with the result of the tool use, which should be formatted as follows: + + + {tool_name} + {result} + + +The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action. +For example, if the result of the tool use is an image file, you can use it in the next action like this: + + + image_transformer + {"image": "image_1.jpg"} + + +Always adhere to this format for the tool use to ensure proper parsing and execution. + +## Tool Use Examples +{{ TOOL_USE_EXAMPLES }} + +## Tool Use Available Tools +Above example were using notional tools that might not exist for you. You only have access to these tools: +{{ AVAILABLE_TOOLS }} + +## Tool Use Rules +Here are the rules you should always follow to solve your task: +1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead. +2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. +3. If no tool call is needed, just answer the question directly. +4. Never re-do a tool call that you previously did with the exact same parameters. +5. For tool use, MARK SURE use XML tag format as shown in the examples above. Do not use any other format. + +# User Instructions +{{ USER_SYSTEM_PROMPT }} + +Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. +` + +export const ToolUseExamples = ` +Here are a few examples using notional tools: +--- +User: Generate an image of the oldest person in this document. + +Assistant: I can use the document_qa tool to find out who the oldest person is in the document. + + document_qa + {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} + + +User: + document_qa + John Doe, a 55 year old lumberjack living in Newfoundland. + + +Assistant: I can use the image_generator tool to create a portrait of John Doe. + + image_generator + {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} + + +User: + image_generator + image.png + + +Assistant: the image is generated as image.png + +--- +User: "What is the result of the following operation: 5 + 3 + 1294.678?" + +Assistant: I can use the python_interpreter tool to calculate the result of the operation. + + python_interpreter + {"code": "5 + 3 + 1294.678"} + + +User: + python_interpreter + 1302.678 + + +Assistant: The result of the operation is 1302.678. + +--- +User: "Which city has the highest population , Guangzhou or Shanghai?" + +Assistant: I can use the search tool to find the population of Guangzhou. + + search + {"query": "Population Guangzhou"} + + +User: + search + Guangzhou has a population of 15 million inhabitants as of 2021. + + +Assistant: I can use the search tool to find the population of Shanghai. + + search + {"query": "Population Shanghai"} + + +User: + search + 26 million (2019) + +Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population. +` + +export const AvailableTools = (tools: MCPTool[]) => { + const availableTools = tools + .map((tool) => { + return ` + + ${tool.id} + ${tool.description} + + ${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''} + + +` + }) + .join('\n') + return ` +${availableTools} +` +} + +export const buildSystemPrompt = (userSystemPrompt: string, tools: MCPTool[]): string => { + if (tools && tools.length > 0) { + return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt) + .replace('{{ TOOL_USE_EXAMPLES }}', ToolUseExamples) + .replace('{{ AVAILABLE_TOOLS }}', AvailableTools(tools)) + } + + return userSystemPrompt +}