feat(MCP): add support for enabling/disabling MCPServers per message (#2989)

*  feat: add MCP servers in chat input

- Introduce MCPToolsButton component for managing MCP servers
- Add new icon for MCP server tools in iconfont.css
- Update Inputbar to include MCP tools functionality
- Add toggle functionality for enabling/disabling MCP servers
- Implement styled dropdown menu for server selection
- Add necessary type imports and useState for MCP server management

*  feat: add support for enabling/disabling MCPServers per message (main)

- Added `enabledMCPs` property to the `Message` type to track enabled MCPServers.
- Modified `MCPToolsButton` to enable all active MCPServers by default using a new `enableAll` state.
- Introduced `filterMCPTools` utility to filter tools based on enabled MCPServers.
- Updated `AnthropicProvider`, `GeminiProvider`, and `OpenAIProvider` to filter tools using `filterMCPTools`.
- Enhanced `Inputbar` to include `enabledMCPs` in the message payload when set.
This commit is contained in:
LiuVaayne 2025-03-07 19:17:29 +08:00 committed by GitHub
parent a0351fb5ad
commit a8451b7c3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 269 additions and 8 deletions

View File

@ -19,6 +19,10 @@
content: '\e623';
}
.icon-mcp:before {
content: '\e78e';
}
.icon-icon-adaptive-width:before {
content: '\e87a';
}

View File

@ -26,7 +26,7 @@ import { translateText } from '@renderer/services/TranslateService'
import WebSearchService from '@renderer/services/WebSearchService'
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { setGenerating, setSearching } from '@renderer/store/runtime'
import { Assistant, FileType, KnowledgeBase, Message, Model, Topic } from '@renderer/types'
import { Assistant, FileType, KnowledgeBase, MCPServer, Message, Model, Topic } from '@renderer/types'
import { classNames, delay, getFileExtension, uuid } from '@renderer/utils'
import { abortCompletion } from '@renderer/utils/abortController'
import { getFilesFromDropEvent } from '@renderer/utils/input'
@ -45,6 +45,7 @@ import NarrowLayout from '../Messages/NarrowLayout'
import AttachmentButton from './AttachmentButton'
import AttachmentPreview from './AttachmentPreview'
import KnowledgeBaseButton from './KnowledgeBaseButton'
import MCPToolsButton from './MCPToolsButton'
import MentionModelsButton from './MentionModelsButton'
import MentionModelsInput from './MentionModelsInput'
import SendMessageButton from './SendMessageButton'
@ -88,6 +89,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic }) => {
const [isTranslating, setIsTranslating] = useState(false)
const [selectedKnowledgeBases, setSelectedKnowledgeBases] = useState<KnowledgeBase[]>([])
const [mentionModels, setMentionModels] = useState<Model[]>([])
const [enabledMCPs, setEnabledMCPs] = useState<MCPServer[]>([])
const [isMentionPopupOpen, setIsMentionPopupOpen] = useState(false)
const [isDragging, setIsDragging] = useState(false)
const [textareaHeight, setTextareaHeight] = useState<number>()
@ -157,6 +159,11 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic }) => {
if (mentionModels.length > 0) {
message.mentions = mentionModels
}
if (enabledMCPs.length > 0) {
message.enabledMCPs = enabledMCPs
}
currentMessageId.current = message.id
EventEmitter.emit(EVENT_NAMES.SEND_MESSAGE, message)
@ -587,6 +594,17 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic }) => {
setMentionModels(mentionModels.filter((m) => m.id !== model.id))
}
const toggelEnableMCP = (mcp: MCPServer) => {
setEnabledMCPs((prev) => {
const exists = prev.some((item) => item.name === mcp.name)
if (exists) {
return prev.filter((item) => item.name !== mcp.name)
} else {
return [...prev, mcp]
}
})
}
const onEnableWebSearch = () => {
console.log(assistant)
if (!isWebSearchModel(model)) {
@ -682,6 +700,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic }) => {
onMentionModel={(model) => onMentionModel(model, mentionFromKeyboard)}
ToolbarButton={ToolbarButton}
/>
<MCPToolsButton enabledMCPs={enabledMCPs} onEnableMCP={toggelEnableMCP} ToolbarButton={ToolbarButton} />
<Tooltip placement="top" title={t('chat.input.web_search')} arrow>
<ToolbarButton type="text" onClick={onEnableWebSearch}>
<GlobalOutlined

View File

@ -0,0 +1,204 @@
import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { MCPServer } from '@renderer/types'
import { Dropdown, Switch, Tooltip } from 'antd'
import { FC, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { createGlobalStyle } from 'styled-components'
interface Props {
enabledMCPs: MCPServer[]
onEnableMCP: (server: MCPServer) => void
ToolbarButton: any
}
const MCPToolsButton: FC<Props> = ({ enabledMCPs, onEnableMCP, ToolbarButton }) => {
const { mcpServers } = useMCPServers()
const [isOpen, setIsOpen] = useState(false)
const [enableAll, setEnableAll] = useState(true)
const dropdownRef = useRef<any>(null)
const menuRef = useRef<HTMLDivElement>(null)
const { t } = useTranslation()
const truncateText = (text: string, maxLength: number = 50) => {
if (!text || text.length <= maxLength) return text
return text.substring(0, maxLength) + '...'
}
// Check if all active servers are enabled
const activeServers = mcpServers.filter((s) => s.isActive)
// Enable all active servers by default
useEffect(() => {
if (activeServers.length > 0) {
activeServers.forEach((server) => {
if (enableAll && !enabledMCPs.includes(server)) {
onEnableMCP(server)
}
if (!enableAll && enabledMCPs.includes(server)) {
onEnableMCP(server)
}
})
}
}, [enableAll])
const menu = (
<div ref={menuRef} className="ant-dropdown-menu">
<div className="dropdown-header">
<div className="header-content">
<h4>{t('settings.mcp.title')}</h4>
<div className="enable-all-container">
{/* <span className="enable-all-label">{t('mcp.enable_all')}</span> */}
<Switch size="small" checked={enableAll} onChange={setEnableAll} />
</div>
</div>
</div>
{mcpServers.length > 0 ? (
mcpServers
.filter((s) => s.isActive)
.map((server) => (
<div key={server.name} className="ant-dropdown-menu-item mcp-server-item">
<div className="server-info">
<div className="server-name">{server.name}</div>
{server.description && (
<Tooltip title={server.description} placement="bottom">
<div className="server-description">{truncateText(server.description)}</div>
</Tooltip>
)}
{server.baseUrl && <div className="server-url">{server.baseUrl}</div>}
</div>
<Switch size="small" checked={enabledMCPs.includes(server)} onChange={() => onEnableMCP(server)} />
</div>
))
) : (
<div className="ant-dropdown-menu-item-group">
<div className="ant-dropdown-menu-item no-results">{t('models.no_matches')}</div>
</div>
)}
</div>
)
return (
<>
<DropdownMenuStyle />
<Dropdown
dropdownRender={() => menu}
trigger={['click']}
open={isOpen}
onOpenChange={setIsOpen}
overlayClassName="mention-models-dropdown">
<Tooltip placement="top" title="MCP Servers" arrow>
<ToolbarButton type="text" ref={dropdownRef}>
<i className="iconfont icon-mcp" style={{ fontSize: 18 }}></i>
</ToolbarButton>
</Tooltip>
</Dropdown>
</>
)
}
const DropdownMenuStyle = createGlobalStyle`
.mention-models-dropdown {
.ant-dropdown-menu {
max-height: 400px;
overflow-y: auto;
overflow-x: hidden;
padding: 4px 0;
margin-bottom: 40px;
position: relative;
&::-webkit-scrollbar {
width: 6px;
height: 6px;
}
&::-webkit-scrollbar-thumb {
border-radius: 10px;
background: var(--color-scrollbar-thumb);
&:hover {
background: var(--color-scrollbar-thumb-hover);
}
}
&::-webkit-scrollbar-track {
background: transparent;
}
.no-results {
padding: 8px 12px;
color: var(--color-text-3);
cursor: default;
font-size: 14px;
&:hover {
background: none;
}
}
.dropdown-header {
padding: 8px 12px;
border-bottom: 1px solid var(--color-border);
margin-bottom: 4px;
.header-content {
display: flex;
justify-content: space-between;
align-items: center;
}
h4 {
margin: 0;
color: var(--color-text-1);
font-size: 14px;
font-weight: 500;
}
.enable-all-container {
display: flex;
align-items: center;
gap: 8px;
.enable-all-label {
font-size: 12px;
color: var(--color-text-3);
}
}
}
.mcp-server-item {
display: flex;
justify-content: space-between;
align-items: center;
padding: 8px 12px;
.server-info {
flex: 1;
overflow: hidden;
.server-name {
font-weight: 500;
font-size: 14px;
color: var(--color-text-1);
}
.server-description {
font-size: 12px;
color: var(--color-text-3);
margin-top: 2px;
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.server-url {
font-size: 11px;
color: var(--color-text-4);
margin-top: 2px;
}
}
}
}
}
`
export default MCPToolsButton

View File

@ -19,7 +19,13 @@ import OpenAI from 'openai'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
import { anthropicToolUseToMcpTool, callMCPTool, mcpToolsToAnthropicTools, upsertMCPToolResponse } from './mcpToolUtils'
import {
anthropicToolUseToMcpTool,
callMCPTool,
filterMCPTools,
mcpToolsToAnthropicTools,
upsertMCPToolResponse
} from './mcpToolUtils'
type ReasoningEffort = 'high' | 'medium' | 'low'
@ -139,6 +145,8 @@ export default class AnthropicProvider extends BaseProvider {
}
const userMessages = flatten(userMessagesParams)
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined
const body: MessageCreateParamsNonStreaming = {
@ -189,8 +197,6 @@ export default class AnthropicProvider extends BaseProvider {
})
}
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
const toolResponses: MCPToolResponse[] = []

View File

@ -27,7 +27,13 @@ import OpenAI from 'openai'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
import { callMCPTool, geminiFunctionCallToMcpTool, mcpToolsToGeminiTools, upsertMCPToolResponse } from './mcpToolUtils'
import {
callMCPTool,
filterMCPTools,
geminiFunctionCallToMcpTool,
mcpToolsToGeminiTools,
upsertMCPToolResponse
} from './mcpToolUtils'
export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenerativeAI
@ -161,7 +167,7 @@ export default class GeminiProvider extends BaseProvider {
for (const message of userMessages) {
history.push(await this.getMessageContents(message))
}
mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs)
const tools = mcpToolsToGeminiTools(mcpTools)
const toolResponses: MCPToolResponse[] = []
if (assistant.enableWebSearch && isWebSearchModel(model)) {

View File

@ -35,7 +35,13 @@ import {
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
import { callMCPTool, mcpToolsToOpenAITools, openAIToolsToMcpTool, upsertMCPToolResponse } from './mcpToolUtils'
import {
callMCPTool,
filterMCPTools,
mcpToolsToOpenAITools,
openAIToolsToMcpTool,
upsertMCPToolResponse
} from './mcpToolUtils'
type ReasoningEffort = 'high' | 'medium' | 'low'
@ -298,6 +304,7 @@ export default class OpenAIProvider extends BaseProvider {
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined
const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter(

View File

@ -1,6 +1,6 @@
import { Tool, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
import { FunctionCall, FunctionDeclaration, SchemaType, Tool as geminiToool } from '@google/generative-ai'
import { MCPTool, MCPToolResponse } from '@renderer/types'
import { MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
import { ChatCompletionMessageToolCall, ChatCompletionTool } from 'openai/resources'
import { ChunkCallbackData } from '.'
@ -146,3 +146,17 @@ export function upsertMCPToolResponse(
})
}
}
export function filterMCPTools(
mcpTools: MCPTool[] | undefined,
enabledServers: MCPServer[] | undefined
): MCPTool[] | undefined {
if (mcpTools) {
if (enabledServers) {
mcpTools = mcpTools.filter((t) => enabledServers.some((m) => m.name === t.serverName))
} else {
mcpTools = []
}
}
return mcpTools
}

View File

@ -67,6 +67,7 @@ export type Message = {
askId?: string
useful?: boolean
error?: Record<string, any>
enabledMCPs?: MCPServer[]
metadata?: {
// Gemini
groundingMetadata?: any