feat(MCP): support gemini and claude models (#2936)

This commit is contained in:
LiuVaayne 2025-03-06 19:32:34 +08:00 committed by GitHub
parent f24177d5c4
commit e5664048d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 316 additions and 135 deletions

View File

@ -270,7 +270,7 @@ export default class MCPService extends EventEmitter {
const { tools } = await this.clients[serverName].listTools()
return tools.map((tool: any) => {
tool.serverName = serverName
tool.id = uuidv4()
tool.id = 'f' + uuidv4().replace(/-/g, '')
return tool
})
} else {
@ -282,7 +282,7 @@ export default class MCPService extends EventEmitter {
allTools = allTools.concat(
tools.map((tool: MCPTool) => {
tool.serverName = clientName
tool.id = uuidv4()
tool.id = 'f' + uuidv4().replace(/-/g, '')
return tool
})
)

View File

@ -1,5 +1,10 @@
import Anthropic from '@anthropic-ai/sdk'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
import {
MessageCreateParamsNonStreaming,
MessageParam,
ToolResultBlockParam,
ToolUseBlock
} from '@anthropic-ai/sdk/resources'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { isReasoningModel } from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
@ -14,6 +19,7 @@ import OpenAI from 'openai'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
import { anthropicToolUseToMcpTool, callMCPTool, mcpToolsToAnthropicTools } from './mcpToolUtils'
type ReasoningEffort = 'high' | 'medium' | 'low'
@ -118,7 +124,7 @@ export default class AnthropicProvider extends BaseProvider {
}
}
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -133,10 +139,12 @@ export default class AnthropicProvider extends BaseProvider {
}
const userMessages = flatten(userMessagesParams)
const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined
const body: MessageCreateParamsNonStreaming = {
model: model.id,
messages: userMessages,
tools: tools,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
@ -186,74 +194,118 @@ export default class AnthropicProvider extends BaseProvider {
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
return new Promise<void>((resolve, reject) => {
let hasThinkingContent = false
const stream = this.sdk.messages
.stream({ ...body, stream: true }, { signal })
.on('text', (text) => {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
stream.controller.abort()
return resolve()
}
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
const processStream = async (body: MessageCreateParamsNonStreaming) => {
new Promise<void>((resolve, reject) => {
const toolCalls: ToolUseBlock[] = []
let hasThinkingContent = false
const stream = this.sdk.messages
.stream({ ...body, stream: true }, { signal })
.on('text', (text) => {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
stream.controller.abort()
return resolve()
}
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
if (hasThinkingContent && time_first_content_millsec === 0) {
time_first_content_millsec = new Date().getTime()
}
if (hasThinkingContent && time_first_content_millsec === 0) {
time_first_content_millsec = new Date().getTime()
}
const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0
const time_thinking_millsec = time_first_content_millsec
? time_first_content_millsec - start_time_millsec
: 0
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
text,
metrics: {
completion_tokens: undefined,
time_completion_millsec,
time_first_token_millsec,
time_thinking_millsec
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
text,
metrics: {
completion_tokens: undefined,
time_completion_millsec,
time_first_token_millsec,
time_thinking_millsec
}
})
})
.on('thinking', (thinking) => {
hasThinkingContent = true
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
reasoning_content: thinking,
text: '',
metrics: {
completion_tokens: undefined,
time_completion_millsec,
time_first_token_millsec
}
})
})
.on('contentBlock', (content) => {
if (content.type == 'tool_use') {
toolCalls.push(content)
}
})
})
.on('thinking', (thinking) => {
hasThinkingContent = true
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
.on('finalMessage', async (message) => {
if (toolCalls.length > 0) {
const toolCallResults: ToolResultBlockParam[] = []
for (const toolCall of toolCalls) {
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
if (mcpTool) {
const resp = await callMCPTool(mcpTool)
toolCallResults.push({
type: 'tool_result',
tool_use_id: toolCall.id,
content: resp.content
})
}
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
reasoning_content: thinking,
text: '',
metrics: {
completion_tokens: undefined,
time_completion_millsec,
time_first_token_millsec
if (toolCallResults.length > 0) {
userMessages.push({
role: message.role,
content: message.content
})
userMessages.push({
role: 'user',
content: toolCallResults
})
const newBody = body
body.messages = userMessages
await processStream(newBody)
}
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
const time_thinking_millsec = time_first_content_millsec
? time_first_content_millsec - start_time_millsec
: 0
onChunk({
text: '',
usage: {
prompt_tokens: message.usage.input_tokens,
completion_tokens: message.usage.output_tokens,
total_tokens: sum(Object.values(message.usage))
},
metrics: {
completion_tokens: message.usage.output_tokens,
time_completion_millsec,
time_first_token_millsec,
time_thinking_millsec
}
})
resolve()
})
})
.on('finalMessage', (message) => {
const time_completion_millsec = new Date().getTime() - start_time_millsec
const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0
onChunk({
text: '',
usage: {
prompt_tokens: message.usage.input_tokens,
completion_tokens: message.usage.output_tokens,
total_tokens: sum(Object.values(message.usage))
},
metrics: {
completion_tokens: message.usage.output_tokens,
time_completion_millsec,
time_first_token_millsec,
time_thinking_millsec
}
})
resolve()
})
.on('error', (error) => reject(error))
}).finally(cleanup)
.on('error', (error) => reject(error))
}).finally(cleanup)
}
await processStream(body)
}
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {

View File

@ -1,6 +1,9 @@
import {
Content,
FileDataPart,
FunctionCallPart,
FunctionResponsePart,
GenerateContentStreamResult,
GoogleGenerativeAI,
HarmBlockThreshold,
HarmCategory,
@ -24,6 +27,8 @@ import OpenAI from 'openai'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
import { callMCPTool, geminiFunctionCallToMcpTool, mcpToolsToGeminiTools } from './mcpToolUtils'
export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenerativeAI
private requestOptions: RequestOptions
@ -141,7 +146,7 @@ export default class GeminiProvider extends BaseProvider {
]
}
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -157,12 +162,19 @@ export default class GeminiProvider extends BaseProvider {
history.push(await this.getMessageContents(message))
}
const tools = mcpToolsToGeminiTools(mcpTools)
if (assistant.enableWebSearch && isWebSearchModel(model)) {
tools.push({
// @ts-ignore googleSearch is not a valid tool for Gemini
googleSearch: {}
})
}
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
systemInstruction: assistant.prompt,
// @ts-ignore googleSearch is not a valid tool for Gemini
tools: assistant.enableWebSearch && isWebSearchModel(model) ? [{ googleSearch: {} }] : undefined,
tools: tools.length > 0 ? tools : undefined,
safetySettings: this.getSafetySettings(model.id),
generationConfig: {
maxOutputTokens: maxTokens,
@ -206,27 +218,62 @@ export default class GeminiProvider extends BaseProvider {
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(cleanup)
let time_first_token_millsec = 0
for await (const chunk of userMessagesStream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
const processStream = async (stream: GenerateContentStreamResult) => {
for await (const chunk of stream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
const functionCalls = chunk.functionCalls()
if (functionCalls) {
const fcallParts: FunctionCallPart[] = []
const fcRespParts: FunctionResponsePart[] = []
for (const call of functionCalls) {
console.log('Function call:', call)
fcallParts.push({ functionCall: call } as FunctionCallPart)
const mcpTool = geminiFunctionCallToMcpTool(mcpTools, call)
if (mcpTool) {
const toolCallResponse = await callMCPTool(mcpTool)
fcRespParts.push({
functionResponse: {
name: mcpTool.id,
response: toolCallResponse
}
})
}
}
if (fcRespParts) {
history.push(messageContents)
history.push({
role: 'model',
parts: fcallParts
})
const newChat = geminiModel.startChat({ history })
const newStream = await newChat.sendMessageStream(fcRespParts, { signal }).finally(cleanup)
await processStream(newStream)
}
}
onChunk({
text: chunk.text(),
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
},
metrics: {
completion_tokens: chunk.usageMetadata?.candidatesTokenCount,
time_completion_millsec,
time_first_token_millsec
},
search: chunk.candidates?.[0]?.groundingMetadata
})
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
text: chunk.text(),
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
},
metrics: {
completion_tokens: chunk.usageMetadata?.candidatesTokenCount,
time_completion_millsec,
time_first_token_millsec
},
search: chunk.candidates?.[0]?.groundingMetadata
})
}
await processStream(userMessagesStream)
}
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {

View File

@ -11,16 +11,7 @@ import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
import {
Assistant,
FileTypes,
GenerateImageParams,
MCPTool,
Message,
Model,
Provider,
Suggestion
} from '@renderer/types'
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
import { removeSpecialCharacters } from '@renderer/utils'
import { takeRight } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai'
@ -30,12 +21,12 @@ import {
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionTool,
ChatCompletionToolMessageParam
} from 'openai/resources'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
import { callMCPTool, mcpToolsToOpenAITools, openAIToolsToMcpTool } from './mcpToolUtils'
type ReasoningEffort = 'high' | 'medium' | 'low'
@ -226,34 +217,6 @@ export default class OpenAIProvider extends BaseProvider {
return model.id.startsWith('o1')
}
private mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
return mcpTools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: {
type: 'object',
properties: tool.inputSchema.properties,
required: tool.inputSchema.required
}
}
}))
}
private openAIToolsToMcpTool(
mcpTools: MCPTool[] | undefined,
llmTool: ChatCompletionMessageToolCall
): MCPTool | undefined {
if (!mcpTools) return undefined
const tool = mcpTools.find((tool) => tool.id === llmTool.function.name)
if (!tool) {
return undefined
}
tool.inputSchema = JSON.parse(llmTool.function.arguments)
return tool
}
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
@ -326,7 +289,7 @@ export default class OpenAIProvider extends BaseProvider {
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
const tools = mcpTools && mcpTools.length > 0 ? this.mcpToolsToOpenAITools(mcpTools) : undefined
const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined
const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter(
Boolean
@ -399,21 +362,16 @@ export default class OpenAIProvider extends BaseProvider {
} as ChatCompletionAssistantMessageParam)
for (const toolCall of toolCalls) {
const mcpTool = this.openAIToolsToMcpTool(mcpTools, toolCall)
const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall)
if (!mcpTool) {
continue
}
const toolCallResponse = await window.api.mcp.callTool({
client: mcpTool.serverName,
name: mcpTool.name,
args: mcpTool.inputSchema
})
const toolCallResponse = await callMCPTool(mcpTool)
console.log(toolCallResponse)
reqMessages.push({
role: 'tool',
content: JSON.stringify(toolCallResponse, null, 2),
content: toolCallResponse.content,
tool_call_id: toolCall.id
} as ChatCompletionToolMessageParam)
}

View File

@ -0,0 +1,124 @@
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai'
import { MCPTool } from '@renderer/types'
import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resources'
const supportedAttributes = [
'type',
'nullable',
'required',
// 'format',
'description',
'properties',
'items',
'enum',
'anyOf'
]
function filterPropertieAttributes(tool: MCPTool) {
const roperties = tool.inputSchema.properties
const getSubMap = (obj: Record<string, any>, keys: string[]) => {
return Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key)))
}
for (const [key, val] of Object.entries(roperties)) {
roperties[key] = getSubMap(val, supportedAttributes)
}
return roperties
}
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
return mcpTools.map((tool) => ({
type: 'function',
function: {
name: tool.id,
description: tool.description,
parameters: {
type: 'object',
properties: filterPropertieAttributes(tool)
}
}
}))
}
export function openAIToolsToMcpTool(
mcpTools: MCPTool[] | undefined,
llmTool: ChatCompletionMessageToolCall
): MCPTool | undefined {
if (!mcpTools) return undefined
const tool = mcpTools.find((tool) => tool.id === llmTool.function.name)
if (!tool) {
return undefined
}
tool.inputSchema = JSON.parse(llmTool.function.arguments)
return tool
}
export async function callMCPTool(tool: MCPTool): Promise<any> {
return await window.api.mcp.callTool({
client: tool.serverName,
name: tool.name,
args: tool.inputSchema
})
}
export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array<ToolUnion> {
return mcpTools.map((tool) => {
const t: Tool = {
name: tool.id,
description: tool.description,
// @ts-ignore no check
input_schema: tool.inputSchema
}
return t
})
}
export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolUse: ToolUseBlock): MCPTool | undefined {
if (!mcpTools) return undefined
const tool = mcpTools.find((tool) => tool.id === toolUse.name)
if (!tool) {
return undefined
}
// @ts-ignore ignore type as it it unknow
tool.inputSchema = toolUse.input
return tool
}
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] {
if (!mcpTools) {
return []
}
const functions: FunctionDeclaration[] = []
for (const tool of mcpTools) {
const functionDeclaration: FunctionDeclaration = {
name: tool.id,
description: tool.description,
parameters: {
type: SchemaType.OBJECT,
properties: filterPropertieAttributes(tool)
}
}
functions.push(functionDeclaration)
}
const tool: geminiToool = {
functionDeclarations: functions
}
return [tool]
}
export function geminiFunctionCallToMcpTool(
mcpTools: MCPTool[] | undefined,
fcall: FunctionCall | undefined
): MCPTool | undefined {
if (!fcall) return undefined
if (!mcpTools) return undefined
const tool = mcpTools.find((tool) => tool.id === fcall.name)
if (!tool) {
return undefined
}
// @ts-ignore schema is not a valid property
tool.inputSchema = fcall.args
return tool
}