diff --git a/src/renderer/src/pages/home/Messages/MessageTools.tsx b/src/renderer/src/pages/home/Messages/MessageTools.tsx index d5aad2e3..73d73cfc 100644 --- a/src/renderer/src/pages/home/Messages/MessageTools.tsx +++ b/src/renderer/src/pages/home/Messages/MessageTools.tsx @@ -43,20 +43,18 @@ const MessageTools: FC = ({ message }) => { // Format tool responses for collapse items const getCollapseItems = () => { const items: { key: string; label: JSX.Element; children: React.ReactNode }[] = [] - // Add tool responses toolResponses.forEach((toolResponse: MCPToolResponse) => { - const { tool, status } = toolResponse - const toolId = tool.id + const { id, tool, status, response } = toolResponse const isInvoking = status === 'invoking' const isDone = status === 'done' - const response = { + const result = { params: tool.inputSchema, response: toolResponse.response } items.push({ - key: toolId, + key: id, label: ( @@ -89,11 +87,11 @@ const MessageTools: FC = ({ message }) => { className="message-action-button" onClick={(e) => { e.stopPropagation() - copyContent(JSON.stringify(response, null, 2), toolId) + copyContent(JSON.stringify(result, null, 2), id) }} aria-label={t('common.copy')}> - {!copiedMap[toolId] && } - {copiedMap[toolId] && } + {!copiedMap[id] && } + {copiedMap[id] && } @@ -101,9 +99,9 @@ const MessageTools: FC = ({ message }) => { ), - children: isDone && response && ( + children: isDone && result && ( -
{JSON.stringify(response, null, 2)}
+
{JSON.stringify(result, null, 2)}
) }) diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index b2e84761..f2ffc64c 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -208,7 +208,7 @@ export default class AnthropicProvider extends BaseProvider { const { signal } = abortController const toolResponses: MCPToolResponse[] = [] - const processStream = (body: MessageCreateParamsNonStreaming) => { + const processStream = (body: MessageCreateParamsNonStreaming, idx: number) => { return new Promise((resolve, reject) => { const toolCalls: ToolUseBlock[] = [] let hasThinkingContent = false @@ -274,10 +274,14 @@ export default class AnthropicProvider extends BaseProvider { for (const toolCall of toolCalls) { const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) if (mcpTool) { - upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'invoking' }, onChunk) + upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'invoking', id: toolCall.id }, onChunk) const resp = await callMCPTool(mcpTool) toolCallResults.push({ type: 'tool_result', tool_use_id: toolCall.id, content: resp.content }) - upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'done', response: resp }, onChunk) + upsertMCPToolResponse( + toolResponses, + { tool: mcpTool, status: 'done', response: resp, id: toolCall.id }, + onChunk + ) } } @@ -295,7 +299,7 @@ export default class AnthropicProvider extends BaseProvider { const newBody = body body.messages = userMessages - await processStream(newBody) + await processStream(newBody, idx + 1) } } @@ -326,7 +330,7 @@ export default class AnthropicProvider extends BaseProvider { }) } - await processStream(body) + await processStream(body, 0) .catch((error) => { // 不加这个错误抛不出来 throw error diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index fd90a4f5..9c3ec59e 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -221,7 +221,7 @@ export default class GeminiProvider extends BaseProvider { const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) let time_first_token_millsec = 0 - const processStream = async (stream: GenerateContentStreamResult) => { + const processStream = async (stream: GenerateContentStreamResult, idx: number) => { for await (const chunk of stream.stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break if (time_first_token_millsec == 0) { @@ -242,7 +242,8 @@ export default class GeminiProvider extends BaseProvider { toolResponses, { tool: mcpTool, - status: 'invoking' + status: 'invoking', + id: `${call.name}-${idx}` }, onChunk ) @@ -258,7 +259,8 @@ export default class GeminiProvider extends BaseProvider { { tool: mcpTool, status: 'done', - response: toolCallResponse + response: toolCallResponse, + id: `${call.name}-${idx}` }, onChunk ) @@ -272,7 +274,7 @@ export default class GeminiProvider extends BaseProvider { }) const newChat = geminiModel.startChat({ history }) const newStream = await newChat.sendMessageStream(fcRespParts, { signal }) - await processStream(newStream).finally(cleanup) + await processStream(newStream, idx + 1).finally(cleanup) } } @@ -293,7 +295,7 @@ export default class GeminiProvider extends BaseProvider { }) } } - await processStream(userMessagesStream).finally(cleanup) + await processStream(userMessagesStream, 0).finally(cleanup) } async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 3fd6e4a8..5ce2f28b 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -321,7 +321,7 @@ export default class OpenAIProvider extends BaseProvider { const toolResponses: MCPToolResponse[] = [] - const processStream = async (stream: any) => { + const processStream = async (stream: any, idx: number) => { if (!isSupportStreamOutput()) { const time_completion_millsec = new Date().getTime() - start_time_millsec return onChunk({ @@ -365,30 +365,29 @@ export default class OpenAIProvider extends BaseProvider { if (delta?.tool_calls) { const chunkToolCalls = delta.tool_calls - if (finishReason !== 'tool_calls') { - for (const t of chunkToolCalls) { - const { index, id, function: fn, type } = t - const args = fn && typeof fn.arguments === 'string' ? fn.arguments : '' - if (!(index in final_tool_calls)) { - final_tool_calls[index] = { - id, - function: { - name: fn?.name, - arguments: args - }, - type - } as ChatCompletionMessageToolCall - } else { - final_tool_calls[index].function.arguments += args - } + for (const t of chunkToolCalls) { + const { index, id, function: fn, type } = t + const args = fn && typeof fn.arguments === 'string' ? fn.arguments : '' + if (!(index in final_tool_calls)) { + final_tool_calls[index] = { + id, + function: { + name: fn?.name, + arguments: args + }, + type + } as ChatCompletionMessageToolCall + } else { + final_tool_calls[index].function.arguments += args } + } + if (finishReason !== 'tool_calls') { continue } } if (finishReason === 'tool_calls') { const toolCalls = Object.values(final_tool_calls) - console.log('start invoke tools', toolCalls) reqMessages.push({ role: 'assistant', tool_calls: toolCalls @@ -400,12 +399,10 @@ export default class OpenAIProvider extends BaseProvider { continue } - upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'invoking' }, onChunk) + upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'invoking', id: toolCall.id }, onChunk) const toolCallResponse = await callMCPTool(mcpTool) - console.log('[OpenAIProvider] toolCallResponse', toolCallResponse) - reqMessages.push({ role: 'tool', content: isString(toolCallResponse.content) @@ -414,9 +411,12 @@ export default class OpenAIProvider extends BaseProvider { tool_call_id: toolCall.id } as ChatCompletionToolMessageParam) - upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'done', response: toolCallResponse }, onChunk) + upsertMCPToolResponse( + toolResponses, + { tool: mcpTool, status: 'done', response: toolCallResponse, id: toolCall.id }, + onChunk + ) } - const newStream = await this.sdk.chat.completions // @ts-ignore key is not typed .create( @@ -438,7 +438,7 @@ export default class OpenAIProvider extends BaseProvider { signal } ) - await processStream(newStream) + await processStream(newStream, idx + 1) } onChunk({ @@ -479,7 +479,7 @@ export default class OpenAIProvider extends BaseProvider { } ) - await processStream(stream).finally(cleanup) + await processStream(stream, 0).finally(cleanup) } async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 458eaade..f8a52896 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -359,7 +359,8 @@ export interface MCPConfig { } export interface MCPToolResponse { - tool: MCPTool - status: string + id: string // tool call id, it should be unique + tool: MCPTool // tool info + status: string // 'invoking' | 'done' response?: any } diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index d5e50652..810bef0a 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -52,16 +52,38 @@ export function openAIToolsToMcpTool( if (!tool) { return undefined } - tool.inputSchema = JSON.parse(llmTool.function.arguments) - return tool + console.log( + `[MCP] OpenAI Tool to MCP Tool: ${tool.serverName} ${tool.name}`, + tool, + 'args', + llmTool.function.arguments + ) + // use this to parse the arguments and avoid parsing errors + let args: any = {} + try { + args = JSON.parse(llmTool.function.arguments) + } catch (e) { + console.error('Error parsing arguments', e) + } + + return { + id: tool.id, + serverName: tool.serverName, + name: tool.name, + description: tool.description, + inputSchema: args + } } export async function callMCPTool(tool: MCPTool): Promise { - return await window.api.mcp.callTool({ + console.log(`[MCP] Calling Tool: ${tool.serverName} ${tool.name}`, tool) + const resp = await window.api.mcp.callTool({ client: tool.serverName, name: tool.name, args: tool.inputSchema }) + console.log(`[MCP] Tool called: ${tool.serverName} ${tool.name}`, resp) + return resp } export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array { @@ -133,7 +155,7 @@ export function upsertMCPToolResponse( ) { try { for (const ret of results) { - if (ret.tool.id == resp.tool.id) { + if (ret.id === resp.id) { ret.response = resp.response ret.status = resp.status return