feat(MCP): support gemini and claude models (#2936)
This commit is contained in:
parent
f24177d5c4
commit
e5664048d9
@ -270,7 +270,7 @@ export default class MCPService extends EventEmitter {
|
|||||||
const { tools } = await this.clients[serverName].listTools()
|
const { tools } = await this.clients[serverName].listTools()
|
||||||
return tools.map((tool: any) => {
|
return tools.map((tool: any) => {
|
||||||
tool.serverName = serverName
|
tool.serverName = serverName
|
||||||
tool.id = uuidv4()
|
tool.id = 'f' + uuidv4().replace(/-/g, '')
|
||||||
return tool
|
return tool
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
@ -282,7 +282,7 @@ export default class MCPService extends EventEmitter {
|
|||||||
allTools = allTools.concat(
|
allTools = allTools.concat(
|
||||||
tools.map((tool: MCPTool) => {
|
tools.map((tool: MCPTool) => {
|
||||||
tool.serverName = clientName
|
tool.serverName = clientName
|
||||||
tool.id = uuidv4()
|
tool.id = 'f' + uuidv4().replace(/-/g, '')
|
||||||
return tool
|
return tool
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,10 @@
|
|||||||
import Anthropic from '@anthropic-ai/sdk'
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
|
import {
|
||||||
|
MessageCreateParamsNonStreaming,
|
||||||
|
MessageParam,
|
||||||
|
ToolResultBlockParam,
|
||||||
|
ToolUseBlock
|
||||||
|
} from '@anthropic-ai/sdk/resources'
|
||||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
import { isReasoningModel } from '@renderer/config/models'
|
import { isReasoningModel } from '@renderer/config/models'
|
||||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
@ -14,6 +19,7 @@ import OpenAI from 'openai'
|
|||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
import { anthropicToolUseToMcpTool, callMCPTool, mcpToolsToAnthropicTools } from './mcpToolUtils'
|
||||||
|
|
||||||
type ReasoningEffort = 'high' | 'medium' | 'low'
|
type ReasoningEffort = 'high' | 'medium' | 'low'
|
||||||
|
|
||||||
@ -118,7 +124,7 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
|
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
@ -133,10 +139,12 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const userMessages = flatten(userMessagesParams)
|
const userMessages = flatten(userMessagesParams)
|
||||||
|
const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined
|
||||||
|
|
||||||
const body: MessageCreateParamsNonStreaming = {
|
const body: MessageCreateParamsNonStreaming = {
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages: userMessages,
|
messages: userMessages,
|
||||||
|
tools: tools,
|
||||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||||
temperature: this.getTemperature(assistant, model),
|
temperature: this.getTemperature(assistant, model),
|
||||||
top_p: this.getTopP(assistant, model),
|
top_p: this.getTopP(assistant, model),
|
||||||
@ -186,74 +194,118 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||||
const { signal } = abortController
|
const { signal } = abortController
|
||||||
|
|
||||||
return new Promise<void>((resolve, reject) => {
|
const processStream = async (body: MessageCreateParamsNonStreaming) => {
|
||||||
let hasThinkingContent = false
|
new Promise<void>((resolve, reject) => {
|
||||||
const stream = this.sdk.messages
|
const toolCalls: ToolUseBlock[] = []
|
||||||
.stream({ ...body, stream: true }, { signal })
|
let hasThinkingContent = false
|
||||||
.on('text', (text) => {
|
const stream = this.sdk.messages
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
.stream({ ...body, stream: true }, { signal })
|
||||||
stream.controller.abort()
|
.on('text', (text) => {
|
||||||
return resolve()
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||||
}
|
stream.controller.abort()
|
||||||
if (time_first_token_millsec == 0) {
|
return resolve()
|
||||||
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
}
|
||||||
}
|
if (time_first_token_millsec == 0) {
|
||||||
|
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
}
|
||||||
|
|
||||||
if (hasThinkingContent && time_first_content_millsec === 0) {
|
if (hasThinkingContent && time_first_content_millsec === 0) {
|
||||||
time_first_content_millsec = new Date().getTime()
|
time_first_content_millsec = new Date().getTime()
|
||||||
}
|
}
|
||||||
|
|
||||||
const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0
|
const time_thinking_millsec = time_first_content_millsec
|
||||||
|
? time_first_content_millsec - start_time_millsec
|
||||||
|
: 0
|
||||||
|
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
onChunk({
|
onChunk({
|
||||||
text,
|
text,
|
||||||
metrics: {
|
metrics: {
|
||||||
completion_tokens: undefined,
|
completion_tokens: undefined,
|
||||||
time_completion_millsec,
|
time_completion_millsec,
|
||||||
time_first_token_millsec,
|
time_first_token_millsec,
|
||||||
time_thinking_millsec
|
time_thinking_millsec
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.on('thinking', (thinking) => {
|
||||||
|
hasThinkingContent = true
|
||||||
|
if (time_first_token_millsec == 0) {
|
||||||
|
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
}
|
||||||
|
|
||||||
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
onChunk({
|
||||||
|
reasoning_content: thinking,
|
||||||
|
text: '',
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: undefined,
|
||||||
|
time_completion_millsec,
|
||||||
|
time_first_token_millsec
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.on('contentBlock', (content) => {
|
||||||
|
if (content.type == 'tool_use') {
|
||||||
|
toolCalls.push(content)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
.on('finalMessage', async (message) => {
|
||||||
.on('thinking', (thinking) => {
|
if (toolCalls.length > 0) {
|
||||||
hasThinkingContent = true
|
const toolCallResults: ToolResultBlockParam[] = []
|
||||||
if (time_first_token_millsec == 0) {
|
for (const toolCall of toolCalls) {
|
||||||
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||||
}
|
if (mcpTool) {
|
||||||
|
const resp = await callMCPTool(mcpTool)
|
||||||
|
toolCallResults.push({
|
||||||
|
type: 'tool_result',
|
||||||
|
tool_use_id: toolCall.id,
|
||||||
|
content: resp.content
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
if (toolCallResults.length > 0) {
|
||||||
onChunk({
|
userMessages.push({
|
||||||
reasoning_content: thinking,
|
role: message.role,
|
||||||
text: '',
|
content: message.content
|
||||||
metrics: {
|
})
|
||||||
completion_tokens: undefined,
|
userMessages.push({
|
||||||
time_completion_millsec,
|
role: 'user',
|
||||||
time_first_token_millsec
|
content: toolCallResults
|
||||||
|
})
|
||||||
|
|
||||||
|
const newBody = body
|
||||||
|
body.messages = userMessages
|
||||||
|
await processStream(newBody)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
const time_thinking_millsec = time_first_content_millsec
|
||||||
|
? time_first_content_millsec - start_time_millsec
|
||||||
|
: 0
|
||||||
|
onChunk({
|
||||||
|
text: '',
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: message.usage.input_tokens,
|
||||||
|
completion_tokens: message.usage.output_tokens,
|
||||||
|
total_tokens: sum(Object.values(message.usage))
|
||||||
|
},
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: message.usage.output_tokens,
|
||||||
|
time_completion_millsec,
|
||||||
|
time_first_token_millsec,
|
||||||
|
time_thinking_millsec
|
||||||
|
}
|
||||||
|
})
|
||||||
|
resolve()
|
||||||
})
|
})
|
||||||
})
|
.on('error', (error) => reject(error))
|
||||||
.on('finalMessage', (message) => {
|
}).finally(cleanup)
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
}
|
||||||
const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0
|
|
||||||
onChunk({
|
await processStream(body)
|
||||||
text: '',
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: message.usage.input_tokens,
|
|
||||||
completion_tokens: message.usage.output_tokens,
|
|
||||||
total_tokens: sum(Object.values(message.usage))
|
|
||||||
},
|
|
||||||
metrics: {
|
|
||||||
completion_tokens: message.usage.output_tokens,
|
|
||||||
time_completion_millsec,
|
|
||||||
time_first_token_millsec,
|
|
||||||
time_thinking_millsec
|
|
||||||
}
|
|
||||||
})
|
|
||||||
resolve()
|
|
||||||
})
|
|
||||||
.on('error', (error) => reject(error))
|
|
||||||
}).finally(cleanup)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||||
|
|||||||
@ -1,6 +1,9 @@
|
|||||||
import {
|
import {
|
||||||
Content,
|
Content,
|
||||||
FileDataPart,
|
FileDataPart,
|
||||||
|
FunctionCallPart,
|
||||||
|
FunctionResponsePart,
|
||||||
|
GenerateContentStreamResult,
|
||||||
GoogleGenerativeAI,
|
GoogleGenerativeAI,
|
||||||
HarmBlockThreshold,
|
HarmBlockThreshold,
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
@ -24,6 +27,8 @@ import OpenAI from 'openai'
|
|||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
import { callMCPTool, geminiFunctionCallToMcpTool, mcpToolsToGeminiTools } from './mcpToolUtils'
|
||||||
|
|
||||||
export default class GeminiProvider extends BaseProvider {
|
export default class GeminiProvider extends BaseProvider {
|
||||||
private sdk: GoogleGenerativeAI
|
private sdk: GoogleGenerativeAI
|
||||||
private requestOptions: RequestOptions
|
private requestOptions: RequestOptions
|
||||||
@ -141,7 +146,7 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
|
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
@ -157,12 +162,19 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
history.push(await this.getMessageContents(message))
|
history.push(await this.getMessageContents(message))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const tools = mcpToolsToGeminiTools(mcpTools)
|
||||||
|
if (assistant.enableWebSearch && isWebSearchModel(model)) {
|
||||||
|
tools.push({
|
||||||
|
// @ts-ignore googleSearch is not a valid tool for Gemini
|
||||||
|
googleSearch: {}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
const geminiModel = this.sdk.getGenerativeModel(
|
const geminiModel = this.sdk.getGenerativeModel(
|
||||||
{
|
{
|
||||||
model: model.id,
|
model: model.id,
|
||||||
systemInstruction: assistant.prompt,
|
systemInstruction: assistant.prompt,
|
||||||
// @ts-ignore googleSearch is not a valid tool for Gemini
|
tools: tools.length > 0 ? tools : undefined,
|
||||||
tools: assistant.enableWebSearch && isWebSearchModel(model) ? [{ googleSearch: {} }] : undefined,
|
|
||||||
safetySettings: this.getSafetySettings(model.id),
|
safetySettings: this.getSafetySettings(model.id),
|
||||||
generationConfig: {
|
generationConfig: {
|
||||||
maxOutputTokens: maxTokens,
|
maxOutputTokens: maxTokens,
|
||||||
@ -206,27 +218,62 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(cleanup)
|
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(cleanup)
|
||||||
let time_first_token_millsec = 0
|
let time_first_token_millsec = 0
|
||||||
|
|
||||||
for await (const chunk of userMessagesStream.stream) {
|
const processStream = async (stream: GenerateContentStreamResult) => {
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
for await (const chunk of stream.stream) {
|
||||||
if (time_first_token_millsec == 0) {
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||||
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
if (time_first_token_millsec == 0) {
|
||||||
|
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
}
|
||||||
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
|
||||||
|
const functionCalls = chunk.functionCalls()
|
||||||
|
if (functionCalls) {
|
||||||
|
const fcallParts: FunctionCallPart[] = []
|
||||||
|
const fcRespParts: FunctionResponsePart[] = []
|
||||||
|
for (const call of functionCalls) {
|
||||||
|
console.log('Function call:', call)
|
||||||
|
fcallParts.push({ functionCall: call } as FunctionCallPart)
|
||||||
|
const mcpTool = geminiFunctionCallToMcpTool(mcpTools, call)
|
||||||
|
if (mcpTool) {
|
||||||
|
const toolCallResponse = await callMCPTool(mcpTool)
|
||||||
|
fcRespParts.push({
|
||||||
|
functionResponse: {
|
||||||
|
name: mcpTool.id,
|
||||||
|
response: toolCallResponse
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (fcRespParts) {
|
||||||
|
history.push(messageContents)
|
||||||
|
history.push({
|
||||||
|
role: 'model',
|
||||||
|
parts: fcallParts
|
||||||
|
})
|
||||||
|
const newChat = geminiModel.startChat({ history })
|
||||||
|
const newStream = await newChat.sendMessageStream(fcRespParts, { signal }).finally(cleanup)
|
||||||
|
await processStream(newStream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onChunk({
|
||||||
|
text: chunk.text(),
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||||
|
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
||||||
|
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||||
|
},
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: chunk.usageMetadata?.candidatesTokenCount,
|
||||||
|
time_completion_millsec,
|
||||||
|
time_first_token_millsec
|
||||||
|
},
|
||||||
|
search: chunk.candidates?.[0]?.groundingMetadata
|
||||||
|
})
|
||||||
}
|
}
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
|
||||||
onChunk({
|
|
||||||
text: chunk.text(),
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
|
||||||
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
|
||||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
|
||||||
},
|
|
||||||
metrics: {
|
|
||||||
completion_tokens: chunk.usageMetadata?.candidatesTokenCount,
|
|
||||||
time_completion_millsec,
|
|
||||||
time_first_token_millsec
|
|
||||||
},
|
|
||||||
search: chunk.candidates?.[0]?.groundingMetadata
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await processStream(userMessagesStream)
|
||||||
}
|
}
|
||||||
|
|
||||||
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||||
|
|||||||
@ -11,16 +11,7 @@ import i18n from '@renderer/i18n'
|
|||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService'
|
||||||
import {
|
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
Assistant,
|
|
||||||
FileTypes,
|
|
||||||
GenerateImageParams,
|
|
||||||
MCPTool,
|
|
||||||
Message,
|
|
||||||
Model,
|
|
||||||
Provider,
|
|
||||||
Suggestion
|
|
||||||
} from '@renderer/types'
|
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import { takeRight } from 'lodash'
|
import { takeRight } from 'lodash'
|
||||||
import OpenAI, { AzureOpenAI } from 'openai'
|
import OpenAI, { AzureOpenAI } from 'openai'
|
||||||
@ -30,12 +21,12 @@ import {
|
|||||||
ChatCompletionCreateParamsNonStreaming,
|
ChatCompletionCreateParamsNonStreaming,
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatCompletionMessageToolCall,
|
ChatCompletionMessageToolCall,
|
||||||
ChatCompletionTool,
|
|
||||||
ChatCompletionToolMessageParam
|
ChatCompletionToolMessageParam
|
||||||
} from 'openai/resources'
|
} from 'openai/resources'
|
||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
import { callMCPTool, mcpToolsToOpenAITools, openAIToolsToMcpTool } from './mcpToolUtils'
|
||||||
|
|
||||||
type ReasoningEffort = 'high' | 'medium' | 'low'
|
type ReasoningEffort = 'high' | 'medium' | 'low'
|
||||||
|
|
||||||
@ -226,34 +217,6 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
return model.id.startsWith('o1')
|
return model.id.startsWith('o1')
|
||||||
}
|
}
|
||||||
|
|
||||||
private mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
|
||||||
return mcpTools.map((tool) => ({
|
|
||||||
type: 'function',
|
|
||||||
function: {
|
|
||||||
name: tool.id,
|
|
||||||
description: tool.description,
|
|
||||||
parameters: {
|
|
||||||
type: 'object',
|
|
||||||
properties: tool.inputSchema.properties,
|
|
||||||
required: tool.inputSchema.required
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
private openAIToolsToMcpTool(
|
|
||||||
mcpTools: MCPTool[] | undefined,
|
|
||||||
llmTool: ChatCompletionMessageToolCall
|
|
||||||
): MCPTool | undefined {
|
|
||||||
if (!mcpTools) return undefined
|
|
||||||
const tool = mcpTools.find((tool) => tool.id === llmTool.function.name)
|
|
||||||
if (!tool) {
|
|
||||||
return undefined
|
|
||||||
}
|
|
||||||
tool.inputSchema = JSON.parse(llmTool.function.arguments)
|
|
||||||
return tool
|
|
||||||
}
|
|
||||||
|
|
||||||
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> {
|
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
@ -326,7 +289,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||||
const { signal } = abortController
|
const { signal } = abortController
|
||||||
|
|
||||||
const tools = mcpTools && mcpTools.length > 0 ? this.mcpToolsToOpenAITools(mcpTools) : undefined
|
const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined
|
||||||
|
|
||||||
const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter(
|
const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter(
|
||||||
Boolean
|
Boolean
|
||||||
@ -399,21 +362,16 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
} as ChatCompletionAssistantMessageParam)
|
} as ChatCompletionAssistantMessageParam)
|
||||||
|
|
||||||
for (const toolCall of toolCalls) {
|
for (const toolCall of toolCalls) {
|
||||||
const mcpTool = this.openAIToolsToMcpTool(mcpTools, toolCall)
|
const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall)
|
||||||
|
|
||||||
if (!mcpTool) {
|
if (!mcpTool) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
const toolCallResponse = await window.api.mcp.callTool({
|
const toolCallResponse = await callMCPTool(mcpTool)
|
||||||
client: mcpTool.serverName,
|
console.log(toolCallResponse)
|
||||||
name: mcpTool.name,
|
|
||||||
args: mcpTool.inputSchema
|
|
||||||
})
|
|
||||||
|
|
||||||
reqMessages.push({
|
reqMessages.push({
|
||||||
role: 'tool',
|
role: 'tool',
|
||||||
content: JSON.stringify(toolCallResponse, null, 2),
|
content: toolCallResponse.content,
|
||||||
tool_call_id: toolCall.id
|
tool_call_id: toolCall.id
|
||||||
} as ChatCompletionToolMessageParam)
|
} as ChatCompletionToolMessageParam)
|
||||||
}
|
}
|
||||||
|
|||||||
124
src/renderer/src/providers/mcpToolUtils.ts
Normal file
124
src/renderer/src/providers/mcpToolUtils.ts
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
||||||
|
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai'
|
||||||
|
import { MCPTool } from '@renderer/types'
|
||||||
|
import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resources'
|
||||||
|
|
||||||
|
const supportedAttributes = [
|
||||||
|
'type',
|
||||||
|
'nullable',
|
||||||
|
'required',
|
||||||
|
// 'format',
|
||||||
|
'description',
|
||||||
|
'properties',
|
||||||
|
'items',
|
||||||
|
'enum',
|
||||||
|
'anyOf'
|
||||||
|
]
|
||||||
|
|
||||||
|
function filterPropertieAttributes(tool: MCPTool) {
|
||||||
|
const roperties = tool.inputSchema.properties
|
||||||
|
const getSubMap = (obj: Record<string, any>, keys: string[]) => {
|
||||||
|
return Object.fromEntries(Object.entries(obj).filter(([key]) => keys.includes(key)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const [key, val] of Object.entries(roperties)) {
|
||||||
|
roperties[key] = getSubMap(val, supportedAttributes)
|
||||||
|
}
|
||||||
|
return roperties
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
||||||
|
return mcpTools.map((tool) => ({
|
||||||
|
type: 'function',
|
||||||
|
function: {
|
||||||
|
name: tool.id,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: {
|
||||||
|
type: 'object',
|
||||||
|
properties: filterPropertieAttributes(tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
export function openAIToolsToMcpTool(
|
||||||
|
mcpTools: MCPTool[] | undefined,
|
||||||
|
llmTool: ChatCompletionMessageToolCall
|
||||||
|
): MCPTool | undefined {
|
||||||
|
if (!mcpTools) return undefined
|
||||||
|
const tool = mcpTools.find((tool) => tool.id === llmTool.function.name)
|
||||||
|
if (!tool) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
tool.inputSchema = JSON.parse(llmTool.function.arguments)
|
||||||
|
return tool
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function callMCPTool(tool: MCPTool): Promise<any> {
|
||||||
|
return await window.api.mcp.callTool({
|
||||||
|
client: tool.serverName,
|
||||||
|
name: tool.name,
|
||||||
|
args: tool.inputSchema
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array<ToolUnion> {
|
||||||
|
return mcpTools.map((tool) => {
|
||||||
|
const t: Tool = {
|
||||||
|
name: tool.id,
|
||||||
|
description: tool.description,
|
||||||
|
// @ts-ignore no check
|
||||||
|
input_schema: tool.inputSchema
|
||||||
|
}
|
||||||
|
return t
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolUse: ToolUseBlock): MCPTool | undefined {
|
||||||
|
if (!mcpTools) return undefined
|
||||||
|
const tool = mcpTools.find((tool) => tool.id === toolUse.name)
|
||||||
|
if (!tool) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
// @ts-ignore ignore type as it it unknow
|
||||||
|
tool.inputSchema = toolUse.input
|
||||||
|
return tool
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiToool[] {
|
||||||
|
if (!mcpTools) {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
const functions: FunctionDeclaration[] = []
|
||||||
|
|
||||||
|
for (const tool of mcpTools) {
|
||||||
|
const functionDeclaration: FunctionDeclaration = {
|
||||||
|
name: tool.id,
|
||||||
|
description: tool.description,
|
||||||
|
parameters: {
|
||||||
|
type: SchemaType.OBJECT,
|
||||||
|
properties: filterPropertieAttributes(tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
functions.push(functionDeclaration)
|
||||||
|
}
|
||||||
|
const tool: geminiToool = {
|
||||||
|
functionDeclarations: functions
|
||||||
|
}
|
||||||
|
return [tool]
|
||||||
|
}
|
||||||
|
|
||||||
|
export function geminiFunctionCallToMcpTool(
|
||||||
|
mcpTools: MCPTool[] | undefined,
|
||||||
|
fcall: FunctionCall | undefined
|
||||||
|
): MCPTool | undefined {
|
||||||
|
if (!fcall) return undefined
|
||||||
|
if (!mcpTools) return undefined
|
||||||
|
const tool = mcpTools.find((tool) => tool.id === fcall.name)
|
||||||
|
if (!tool) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
// @ts-ignore schema is not a valid property
|
||||||
|
tool.inputSchema = fcall.args
|
||||||
|
return tool
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user