diff --git a/src/main/services/mcp.ts b/src/main/services/mcp.ts index 09cdb75a..f362ba10 100644 --- a/src/main/services/mcp.ts +++ b/src/main/services/mcp.ts @@ -270,7 +270,7 @@ export default class MCPService extends EventEmitter { const { tools } = await this.clients[serverName].listTools() return tools.map((tool: any) => { tool.serverName = serverName - tool.id = uuidv4() + tool.id = 'f' + uuidv4().replace(/-/g, '') return tool }) } else { @@ -282,7 +282,7 @@ export default class MCPService extends EventEmitter { allTools = allTools.concat( tools.map((tool: MCPTool) => { tool.serverName = clientName - tool.id = uuidv4() + tool.id = 'f' + uuidv4().replace(/-/g, '') return tool }) ) diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index f2962c6d..13f91c45 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -1,5 +1,10 @@ import Anthropic from '@anthropic-ai/sdk' -import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' +import { + MessageCreateParamsNonStreaming, + MessageParam, + ToolResultBlockParam, + ToolUseBlock +} from '@anthropic-ai/sdk/resources' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { isReasoningModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' @@ -14,6 +19,7 @@ import OpenAI from 'openai' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' +import { anthropicToolUseToMcpTool, callMCPTool, mcpToolsToAnthropicTools } from './mcpToolUtils' type ReasoningEffort = 'high' | 'medium' | 'low' @@ -118,7 +124,7 @@ export default class AnthropicProvider extends BaseProvider { } } - public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { + public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) @@ -133,10 +139,12 @@ export default class AnthropicProvider extends BaseProvider { } const userMessages = flatten(userMessagesParams) + const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined const body: MessageCreateParamsNonStreaming = { model: model.id, messages: userMessages, + tools: tools, max_tokens: maxTokens || DEFAULT_MAX_TOKENS, temperature: this.getTemperature(assistant, model), top_p: this.getTopP(assistant, model), @@ -186,74 +194,118 @@ export default class AnthropicProvider extends BaseProvider { const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) const { signal } = abortController - return new Promise((resolve, reject) => { - let hasThinkingContent = false - const stream = this.sdk.messages - .stream({ ...body, stream: true }, { signal }) - .on('text', (text) => { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - stream.controller.abort() - return resolve() - } - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - start_time_millsec - } + const processStream = async (body: MessageCreateParamsNonStreaming) => { + new Promise((resolve, reject) => { + const toolCalls: ToolUseBlock[] = [] + let hasThinkingContent = false + const stream = this.sdk.messages + .stream({ ...body, stream: true }, { signal }) + .on('text', (text) => { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { + stream.controller.abort() + return resolve() + } + if (time_first_token_millsec == 0) { + time_first_token_millsec = new Date().getTime() - start_time_millsec + } - if (hasThinkingContent && time_first_content_millsec === 0) { - time_first_content_millsec = new Date().getTime() - } + if (hasThinkingContent && time_first_content_millsec === 0) { + time_first_content_millsec = new Date().getTime() + } - const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0 + const time_thinking_millsec = time_first_content_millsec + ? time_first_content_millsec - start_time_millsec + : 0 - const time_completion_millsec = new Date().getTime() - start_time_millsec - onChunk({ - text, - metrics: { - completion_tokens: undefined, - time_completion_millsec, - time_first_token_millsec, - time_thinking_millsec + const time_completion_millsec = new Date().getTime() - start_time_millsec + onChunk({ + text, + metrics: { + completion_tokens: undefined, + time_completion_millsec, + time_first_token_millsec, + time_thinking_millsec + } + }) + }) + .on('thinking', (thinking) => { + hasThinkingContent = true + if (time_first_token_millsec == 0) { + time_first_token_millsec = new Date().getTime() - start_time_millsec + } + + const time_completion_millsec = new Date().getTime() - start_time_millsec + onChunk({ + reasoning_content: thinking, + text: '', + metrics: { + completion_tokens: undefined, + time_completion_millsec, + time_first_token_millsec + } + }) + }) + .on('contentBlock', (content) => { + if (content.type == 'tool_use') { + toolCalls.push(content) } }) - }) - .on('thinking', (thinking) => { - hasThinkingContent = true - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - start_time_millsec - } + .on('finalMessage', async (message) => { + if (toolCalls.length > 0) { + const toolCallResults: ToolResultBlockParam[] = [] + for (const toolCall of toolCalls) { + const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) + if (mcpTool) { + const resp = await callMCPTool(mcpTool) + toolCallResults.push({ + type: 'tool_result', + tool_use_id: toolCall.id, + content: resp.content + }) + } + } - const time_completion_millsec = new Date().getTime() - start_time_millsec - onChunk({ - reasoning_content: thinking, - text: '', - metrics: { - completion_tokens: undefined, - time_completion_millsec, - time_first_token_millsec + if (toolCallResults.length > 0) { + userMessages.push({ + role: message.role, + content: message.content + }) + userMessages.push({ + role: 'user', + content: toolCallResults + }) + + const newBody = body + body.messages = userMessages + await processStream(newBody) + } } + + const time_completion_millsec = new Date().getTime() - start_time_millsec + const time_thinking_millsec = time_first_content_millsec + ? time_first_content_millsec - start_time_millsec + : 0 + onChunk({ + text: '', + usage: { + prompt_tokens: message.usage.input_tokens, + completion_tokens: message.usage.output_tokens, + total_tokens: sum(Object.values(message.usage)) + }, + metrics: { + completion_tokens: message.usage.output_tokens, + time_completion_millsec, + time_first_token_millsec, + time_thinking_millsec + } + }) + resolve() }) - }) - .on('finalMessage', (message) => { - const time_completion_millsec = new Date().getTime() - start_time_millsec - const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0 - onChunk({ - text: '', - usage: { - prompt_tokens: message.usage.input_tokens, - completion_tokens: message.usage.output_tokens, - total_tokens: sum(Object.values(message.usage)) - }, - metrics: { - completion_tokens: message.usage.output_tokens, - time_completion_millsec, - time_first_token_millsec, - time_thinking_millsec - } - }) - resolve() - }) - .on('error', (error) => reject(error)) - }).finally(cleanup) + .on('error', (error) => reject(error)) + }).finally(cleanup) + } + + await processStream(body) } public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index be3a0304..84dddc3d 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -1,6 +1,9 @@ import { Content, FileDataPart, + FunctionCallPart, + FunctionResponsePart, + GenerateContentStreamResult, GoogleGenerativeAI, HarmBlockThreshold, HarmCategory, @@ -24,6 +27,8 @@ import OpenAI from 'openai' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' +import { callMCPTool, geminiFunctionCallToMcpTool, mcpToolsToGeminiTools } from './mcpToolUtils' + export default class GeminiProvider extends BaseProvider { private sdk: GoogleGenerativeAI private requestOptions: RequestOptions @@ -141,7 +146,7 @@ export default class GeminiProvider extends BaseProvider { ] } - public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { + public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) @@ -157,12 +162,19 @@ export default class GeminiProvider extends BaseProvider { history.push(await this.getMessageContents(message)) } + const tools = mcpToolsToGeminiTools(mcpTools) + if (assistant.enableWebSearch && isWebSearchModel(model)) { + tools.push({ + // @ts-ignore googleSearch is not a valid tool for Gemini + googleSearch: {} + }) + } + const geminiModel = this.sdk.getGenerativeModel( { model: model.id, systemInstruction: assistant.prompt, - // @ts-ignore googleSearch is not a valid tool for Gemini - tools: assistant.enableWebSearch && isWebSearchModel(model) ? [{ googleSearch: {} }] : undefined, + tools: tools.length > 0 ? tools : undefined, safetySettings: this.getSafetySettings(model.id), generationConfig: { maxOutputTokens: maxTokens, @@ -206,27 +218,62 @@ export default class GeminiProvider extends BaseProvider { const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(cleanup) let time_first_token_millsec = 0 - for await (const chunk of userMessagesStream.stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - start_time_millsec + const processStream = async (stream: GenerateContentStreamResult) => { + for await (const chunk of stream.stream) { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break + if (time_first_token_millsec == 0) { + time_first_token_millsec = new Date().getTime() - start_time_millsec + } + const time_completion_millsec = new Date().getTime() - start_time_millsec + + const functionCalls = chunk.functionCalls() + if (functionCalls) { + const fcallParts: FunctionCallPart[] = [] + const fcRespParts: FunctionResponsePart[] = [] + for (const call of functionCalls) { + console.log('Function call:', call) + fcallParts.push({ functionCall: call } as FunctionCallPart) + const mcpTool = geminiFunctionCallToMcpTool(mcpTools, call) + if (mcpTool) { + const toolCallResponse = await callMCPTool(mcpTool) + fcRespParts.push({ + functionResponse: { + name: mcpTool.id, + response: toolCallResponse + } + }) + } + } + if (fcRespParts) { + history.push(messageContents) + history.push({ + role: 'model', + parts: fcallParts + }) + const newChat = geminiModel.startChat({ history }) + const newStream = await newChat.sendMessageStream(fcRespParts, { signal }).finally(cleanup) + await processStream(newStream) + } + } + + onChunk({ + text: chunk.text(), + usage: { + prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, + completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, + total_tokens: chunk.usageMetadata?.totalTokenCount || 0 + }, + metrics: { + completion_tokens: chunk.usageMetadata?.candidatesTokenCount, + time_completion_millsec, + time_first_token_millsec + }, + search: chunk.candidates?.[0]?.groundingMetadata + }) } - const time_completion_millsec = new Date().getTime() - start_time_millsec - onChunk({ - text: chunk.text(), - usage: { - prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, - completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, - total_tokens: chunk.usageMetadata?.totalTokenCount || 0 - }, - metrics: { - completion_tokens: chunk.usageMetadata?.candidatesTokenCount, - time_completion_millsec, - time_first_token_millsec - }, - search: chunk.candidates?.[0]?.groundingMetadata - }) } + + await processStream(userMessagesStream) } async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 9cc3bef9..e365923d 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -11,16 +11,7 @@ import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService' -import { - Assistant, - FileTypes, - GenerateImageParams, - MCPTool, - Message, - Model, - Provider, - Suggestion -} from '@renderer/types' +import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import { takeRight } from 'lodash' import OpenAI, { AzureOpenAI } from 'openai' @@ -30,12 +21,12 @@ import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam, ChatCompletionMessageToolCall, - ChatCompletionTool, ChatCompletionToolMessageParam } from 'openai/resources' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' +import { callMCPTool, mcpToolsToOpenAITools, openAIToolsToMcpTool } from './mcpToolUtils' type ReasoningEffort = 'high' | 'medium' | 'low' @@ -226,34 +217,6 @@ export default class OpenAIProvider extends BaseProvider { return model.id.startsWith('o1') } - private mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array { - return mcpTools.map((tool) => ({ - type: 'function', - function: { - name: tool.id, - description: tool.description, - parameters: { - type: 'object', - properties: tool.inputSchema.properties, - required: tool.inputSchema.required - } - } - })) - } - - private openAIToolsToMcpTool( - mcpTools: MCPTool[] | undefined, - llmTool: ChatCompletionMessageToolCall - ): MCPTool | undefined { - if (!mcpTools) return undefined - const tool = mcpTools.find((tool) => tool.id === llmTool.function.name) - if (!tool) { - return undefined - } - tool.inputSchema = JSON.parse(llmTool.function.arguments) - return tool - } - async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel @@ -326,7 +289,7 @@ export default class OpenAIProvider extends BaseProvider { const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) const { signal } = abortController - const tools = mcpTools && mcpTools.length > 0 ? this.mcpToolsToOpenAITools(mcpTools) : undefined + const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter( Boolean @@ -399,21 +362,16 @@ export default class OpenAIProvider extends BaseProvider { } as ChatCompletionAssistantMessageParam) for (const toolCall of toolCalls) { - const mcpTool = this.openAIToolsToMcpTool(mcpTools, toolCall) - + const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall) if (!mcpTool) { continue } - const toolCallResponse = await window.api.mcp.callTool({ - client: mcpTool.serverName, - name: mcpTool.name, - args: mcpTool.inputSchema - }) - + const toolCallResponse = await callMCPTool(mcpTool) + console.log(toolCallResponse) reqMessages.push({ role: 'tool', - content: JSON.stringify(toolCallResponse, null, 2), + content: toolCallResponse.content, tool_call_id: toolCall.id } as ChatCompletionToolMessageParam) } diff --git a/src/renderer/src/providers/mcpToolUtils.ts b/src/renderer/src/providers/mcpToolUtils.ts new file mode 100644 index 00000000..a27d7d71 --- /dev/null +++ b/src/renderer/src/providers/mcpToolUtils.ts @@ -0,0 +1,124 @@ +import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' +import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai' +import { MCPTool } from '@renderer/types' +import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resources' + +const supportedAttributes = [ + 'type', + 'nullable', + 'required', + // 'format', + 'description', + 'properties', + 'items', + 'enum', + 'anyOf' +] + +function filterPropertieAttributes(tool: MCPTool) { + const roperties = tool.inputSchema.properties + const getSubMap = (obj: Record, keys: string[]) => { + return Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key))) + } + + for (const [key, val] of Object.entries(roperties)) { + roperties[key] = getSubMap(val, supportedAttributes) + } + return roperties +} + +export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array { + return mcpTools.map((tool) => ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: { + type: 'object', + properties: filterPropertieAttributes(tool) + } + } + })) +} + +export function openAIToolsToMcpTool( + mcpTools: MCPTool[] | undefined, + llmTool: ChatCompletionMessageToolCall +): MCPTool | undefined { + if (!mcpTools) return undefined + const tool = mcpTools.find((tool) => tool.id === llmTool.function.name) + if (!tool) { + return undefined + } + tool.inputSchema = JSON.parse(llmTool.function.arguments) + return tool +} + +export async function callMCPTool(tool: MCPTool): Promise { + return await window.api.mcp.callTool({ + client: tool.serverName, + name: tool.name, + args: tool.inputSchema + }) +} + +export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array { + return mcpTools.map((tool) => { + const t: Tool = { + name: tool.id, + description: tool.description, + // @ts-ignore no check + input_schema: tool.inputSchema + } + return t + }) +} + +export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolUse: ToolUseBlock): MCPTool | undefined { + if (!mcpTools) return undefined + const tool = mcpTools.find((tool) => tool.id === toolUse.name) + if (!tool) { + return undefined + } + // @ts-ignore ignore type as it it unknow + tool.inputSchema = toolUse.input + return tool +} + +export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] { + if (!mcpTools) { + return [] + } + const functions: FunctionDeclaration[] = [] + + for (const tool of mcpTools) { + const functionDeclaration: FunctionDeclaration = { + name: tool.id, + description: tool.description, + parameters: { + type: SchemaType.OBJECT, + properties: filterPropertieAttributes(tool) + } + } + functions.push(functionDeclaration) + } + const tool: geminiToool = { + functionDeclarations: functions + } + return [tool] +} + +export function geminiFunctionCallToMcpTool( + mcpTools: MCPTool[] | undefined, + fcall: FunctionCall | undefined +): MCPTool | undefined { + if (!fcall) return undefined + if (!mcpTools) return undefined + const tool = mcpTools.find((tool) => tool.id === fcall.name) + if (!tool) { + return undefined + } + // @ts-ignore schema is not a valid property + tool.inputSchema = fcall.args + return tool +}