diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 749750fc..9165eb85 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -1,6 +1,6 @@ +import fs from 'node:fs' import os from 'node:os' import path from 'node:path' -import fs from 'node:fs' import { isLinux, isMac, isWin } from '@main/constant' import { createInMemoryMCPServer } from '@main/mcpServers/factory' @@ -11,10 +11,19 @@ import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js' import { getDefaultEnvironment, StdioClientTransport } from '@modelcontextprotocol/sdk/client/stdio.js' import { InMemoryTransport } from '@modelcontextprotocol/sdk/inMemory' import { nanoid } from '@reduxjs/toolkit' -import { GetMCPPromptResponse, GetResourceResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@types' +import { + GetMCPPromptResponse, + GetResourceResponse, + MCPCallToolResponse, + MCPPrompt, + MCPResource, + MCPServer, + MCPTool +} from '@types' import { app } from 'electron' import Logger from 'electron-log' import { memoize } from 'lodash' + import { CacheService } from './CacheService' import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient' @@ -297,12 +306,12 @@ class McpService { public async callTool( _: Electron.IpcMainInvokeEvent, { server, name, args }: { server: MCPServer; name: string; args: any } - ): Promise { + ): Promise { try { Logger.info('[MCP] Calling:', server.name, name, args) const client = await this.initClient(server) const result = await client.callTool({ name, arguments: args }) - return result + return result as MCPCallToolResponse } catch (error) { Logger.error(`[MCP] Error calling tool ${name} on ${server.name}:`, error) throw error diff --git a/src/preload/index.d.ts b/src/preload/index.d.ts index 18bb3b00..163780e5 100644 --- a/src/preload/index.d.ts +++ b/src/preload/index.d.ts @@ -151,7 +151,15 @@ declare global { restartServer: (server: MCPServer) => Promise stopServer: (server: MCPServer) => Promise listTools: (server: MCPServer) => Promise - callTool: ({ server, name, args }: { server: MCPServer; name: string; args: any }) => Promise + callTool: ({ + server, + name, + args + }: { + server: MCPServer + name: string + args: any + }) => Promise listPrompts: (server: MCPServer) => Promise getPrompt: ({ server, diff --git a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts index d9f72f71..0c936d69 100644 --- a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts +++ b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts @@ -1,7 +1,7 @@ import Anthropic from '@anthropic-ai/sdk' import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { isReasoningModel } from '@renderer/config/models' +import { isReasoningModel, isVisionModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' @@ -12,7 +12,7 @@ import { } from '@renderer/services/MessagesService' import { Assistant, FileTypes, MCPToolResponse, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { parseAndCallTools } from '@renderer/utils/mcp-tools' +import { mcpToolCallResponseToAnthropicMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' import { buildSystemPrompt } from '@renderer/utils/prompt' import { first, flatten, sum, takeRight } from 'lodash' import OpenAI from 'openai' @@ -290,17 +290,22 @@ export default class AnthropicProvider extends BaseProvider { .on('finalMessage', async (message) => { const content = message.content[0] if (content && content.type === 'text') { - const toolResults = await parseAndCallTools(content.text, toolResponses, onChunk, idx, mcpTools) + const toolResults = await parseAndCallTools( + content.text, + toolResponses, + onChunk, + idx, + mcpToolCallResponseToAnthropicMessage, + mcpTools, + isVisionModel(model) + ) if (toolResults.length > 0) { userMessages.push({ role: message.role, content: message.content }) - userMessages.push({ - role: 'user', - content: toolResults.join('\n') - }) + toolResults.forEach((ts) => userMessages.push(ts as MessageParam)) const newBody = body newBody.messages = userMessages await processStream(newBody, idx + 1) diff --git a/src/renderer/src/providers/AiProvider/GeminiProvider.ts b/src/renderer/src/providers/AiProvider/GeminiProvider.ts index be40686f..f337f802 100644 --- a/src/renderer/src/providers/AiProvider/GeminiProvider.ts +++ b/src/renderer/src/providers/AiProvider/GeminiProvider.ts @@ -19,7 +19,7 @@ import { TextPart, Tool } from '@google/generative-ai' -import { isGemmaModel, isWebSearchModel } from '@renderer/config/models' +import { isGemmaModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' @@ -32,11 +32,11 @@ import { import WebSearchService from '@renderer/services/WebSearchService' import { Assistant, FileType, FileTypes, MCPToolResponse, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { parseAndCallTools } from '@renderer/utils/mcp-tools' +import { mcpToolCallResponseToGeminiMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' import { buildSystemPrompt } from '@renderer/utils/prompt' import { MB } from '@shared/config/constant' import axios from 'axios' -import { isEmpty, takeRight } from 'lodash' +import { flatten, isEmpty, takeRight } from 'lodash' import OpenAI from 'openai' import { ChunkCallbackData, CompletionsParams } from '.' @@ -310,18 +310,21 @@ export default class GeminiProvider extends BaseProvider { let time_first_token_millsec = 0 const processToolUses = async (content: string, idx: number) => { - const toolResults = await parseAndCallTools(content, toolResponses, onChunk, idx, mcpTools) + const toolResults = await parseAndCallTools( + content, + toolResponses, + onChunk, + idx, + mcpToolCallResponseToGeminiMessage, + mcpTools, + isVisionModel(model) + ) if (toolResults && toolResults.length > 0) { history.push(messageContents) const newChat = geminiModel.startChat({ history }) - const newStream = await newChat.sendMessageStream( - [ - { - text: toolResults.join('\n') - } - ], - { signal } - ) + const newStream = await newChat.sendMessageStream(flatten(toolResults.map((ts) => (ts as Content).parts)), { + signal + }) await processStream(newStream, idx + 1) } } diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index 96469470..b78bf2c1 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -32,7 +32,7 @@ import { } from '@renderer/types' import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { addImageFileToContents } from '@renderer/utils/formats' -import { parseAndCallTools } from '@renderer/utils/mcp-tools' +import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' import { buildSystemPrompt } from '@renderer/utils/prompt' import { isEmpty, takeRight } from 'lodash' import OpenAI, { AzureOpenAI } from 'openai' @@ -390,17 +390,22 @@ export default class OpenAIProvider extends BaseProvider { let firstChunk = true const processToolUses = async (content: string, idx: number) => { - const toolResults = await parseAndCallTools(content, toolResponses, onChunk, idx, mcpTools) + const toolResults = await parseAndCallTools( + content, + toolResponses, + onChunk, + idx, + mcpToolCallResponseToOpenAIMessage, + mcpTools, + isVisionModel(model) + ) if (toolResults.length > 0) { reqMessages.push({ role: 'assistant', content: content } as ChatCompletionMessageParam) - reqMessages.push({ - role: 'user', - content: toolResults.join('\n') - } as ChatCompletionMessageParam) + toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam)) const newStream = await this.sdk.chat.completions // @ts-ignore key is not typed diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index a3800506..0d7489c0 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -444,6 +444,23 @@ export interface MCPToolResponse { response?: any } +export interface MCPToolResultContent { + type: 'text' | 'image' | 'audio' | 'resource' + text?: string + data?: string + mimeType?: string + resource?: { + uri?: string + text?: string + mimeType?: string + } +} + +export interface MCPCallToolResponse { + content: MCPToolResultContent[] + isError?: boolean +} + export interface MCPResource { serverId: string serverName: string diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index a5a6caec..941cc202 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -1,4 +1,5 @@ -import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' +import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' +import { MessageParam } from '@anthropic-ai/sdk/resources' import { ArraySchema, BaseSchema, @@ -15,11 +16,15 @@ import { SimpleStringSchema, Tool as geminiTool } from '@google/generative-ai' -import { nanoid } from '@reduxjs/toolkit' +import { Content, Part } from '@google/generative-ai' import store from '@renderer/store' -import { addMCPServer } from '@renderer/store/mcp' -import { MCPServer, MCPTool, MCPToolResponse } from '@renderer/types' -import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resources' +import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types' +import { + ChatCompletionContentPart, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionTool +} from 'openai/resources' import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider' @@ -218,7 +223,7 @@ export function openAIToolsToMcpTool( } } -export async function callMCPTool(tool: MCPTool): Promise { +export async function callMCPTool(tool: MCPTool): Promise { console.log(`[MCP] Calling Tool: ${tool.serverName} ${tool.name}`, tool) try { const server = getMcpServerByTool(tool) @@ -234,24 +239,6 @@ export async function callMCPTool(tool: MCPTool): Promise { }) console.log(`[MCP] Tool called: ${tool.serverName} ${tool.name}`, resp) - - if (tool.serverName === '@cherry/mcp-auto-install') { - if (resp.data) { - const mcpServer: MCPServer = { - id: `f${nanoid()}`, - name: resp.data.name, - description: resp.data.description, - baseUrl: resp.data.baseUrl, - command: resp.data.command, - args: resp.data.args, - env: resp.data.env, - registryUrl: '', - isActive: false - } - store.dispatch(addMCPServer(mcpServer)) - } - } - return resp } catch (e) { console.error(`[MCP] Error calling Tool: ${tool.serverName} ${tool.name}`, e) @@ -269,7 +256,7 @@ export async function callMCPTool(tool: MCPTool): Promise { export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array { return mcpTools.map((tool) => { - const t: Tool = { + const t: ToolUnion = { name: tool.id, description: tool.description, // @ts-ignore no check @@ -427,9 +414,15 @@ export async function parseAndCallTools( toolResponses: MCPToolResponse[], onChunk: CompletionsParams['onChunk'], idx: number, - mcpTools?: MCPTool[] -): Promise { - const toolResults: string[] = [] + convertToMessage: ( + toolCallId: string, + resp: MCPCallToolResponse, + isVisionModel: boolean + ) => ChatCompletionMessageParam | MessageParam | Content, + mcpTools?: MCPTool[], + isVisionModel: boolean = false +): Promise<(ChatCompletionMessageParam | MessageParam | Content)[]> { + const toolResults: (ChatCompletionMessageParam | MessageParam | Content)[] = [] // process tool use const tools = parseToolUse(content, mcpTools || []) if (!tools || tools.length === 0) { @@ -440,22 +433,228 @@ export async function parseAndCallTools( upsertMCPToolResponse(toolResponses, { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'invoking' }, onChunk) } + const images: string[] = [] const toolPromises = tools.map(async (tool, i) => { const toolCallResponse = await callMCPTool(tool.tool) - const result = ` - - ${tool.id} - ${JSON.stringify(toolCallResponse)} - - `.trim() upsertMCPToolResponse( toolResponses, { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'done', response: toolCallResponse }, onChunk ) - return result + + for (const content of toolCallResponse.content) { + if (content.type === 'image' && content.data) { + images.push(`data:${content.mimeType};base64,${content.data}`) + } + } + + onChunk({ + text: '\n', + generateImage: { + type: 'base64', + images: images + } + }) + + return convertToMessage(tool.tool.id, toolCallResponse, isVisionModel) }) toolResults.push(...(await Promise.all(toolPromises))) return toolResults } + +export function mcpToolCallResponseToOpenAIMessage( + toolCallId: string, + resp: MCPCallToolResponse, + isVisionModel: boolean = false +): ChatCompletionMessageParam { + const message = { + role: 'user' + } as ChatCompletionMessageParam + + if (resp.isError) { + message.content = JSON.stringify(resp.content) + } else { + const content: ChatCompletionContentPart[] = [ + { + type: 'text', + text: `Here is the result of tool call ${toolCallId}:` + } + ] + + if (isVisionModel) { + for (const item of resp.content) { + switch (item.type) { + case 'text': + content.push({ + type: 'text', + text: item.text || 'no content' + }) + break + case 'image': + content.push({ + type: 'image_url', + image_url: { + url: `data:${item.mimeType};base64,${item.data}`, + detail: 'auto' + } + }) + break + case 'audio': + content.push({ + type: 'input_audio', + input_audio: { + data: `data:${item.mimeType};base64,${item.data}`, + format: 'mp3' + } + }) + break + default: + content.push({ + type: 'text', + text: `Unsupported type: ${item.type}` + }) + break + } + } + } else { + content.push({ + type: 'text', + text: JSON.stringify(resp.content) + }) + } + + message.content = content + } + + return message +} + +export function mcpToolCallResponseToAnthropicMessage( + toolCallId: string, + resp: MCPCallToolResponse, + isVisionModel: boolean = false +): MessageParam { + const message = { + role: 'user' + } as MessageParam + if (resp.isError) { + message.content = JSON.stringify(resp.content) + } else { + const content: ContentBlockParam[] = [ + { + type: 'text', + text: `Here is the result of tool call ${toolCallId}:` + } + ] + if (isVisionModel) { + for (const item of resp.content) { + switch (item.type) { + case 'text': + content.push({ + type: 'text', + text: item.text || 'no content' + }) + break + case 'image': + if ( + item.mimeType === 'image/png' || + item.mimeType === 'image/jpeg' || + item.mimeType === 'image/webp' || + item.mimeType === 'image/gif' + ) { + content.push({ + type: 'image', + source: { + type: 'base64', + data: `data:${item.mimeType};base64,${item.data}`, + media_type: item.mimeType + } + }) + } else { + content.push({ + type: 'text', + text: `Unsupported image type: ${item.mimeType}` + }) + } + break + default: + content.push({ + type: 'text', + text: `Unsupported type: ${item.type}` + }) + break + } + } + } else { + content.push({ + type: 'text', + text: JSON.stringify(resp.content) + }) + } + message.content = content + } + + return message +} + +export function mcpToolCallResponseToGeminiMessage( + toolCallId: string, + resp: MCPCallToolResponse, + isVisionModel: boolean = false +): Content { + const message = { + role: 'user' + } as Content + + if (resp.isError) { + message.parts = [ + { + text: JSON.stringify(resp.content) + } + ] + } else { + const parts: Part[] = [ + { + text: `Here is the result of tool call ${toolCallId}:` + } + ] + if (isVisionModel) { + for (const item of resp.content) { + switch (item.type) { + case 'text': + parts.push({ + text: item.text || 'no content' + }) + break + case 'image': + if (!item.data) { + parts.push({ + text: 'No image data provided' + }) + } else { + parts.push({ + inlineData: { + data: item.data, + mimeType: item.mimeType || 'image/png' + } + }) + } + break + default: + parts.push({ + text: `Unsupported type: ${item.type}` + }) + break + } + } + } else { + parts.push({ + text: JSON.stringify(resp.content) + }) + } + message.parts = parts + } + + return message +}