From c883fd85d802a91ce441b60a46661847082578bd Mon Sep 17 00:00:00 2001 From: Vaayne Date: Tue, 8 Apr 2025 20:41:18 +0800 Subject: [PATCH] feat(MCP): add StreamableHTTPClientTransport and update server type handling --- package.json | 2 +- src/main/services/MCPService.ts | 14 +- src/main/services/MCPStreamableHttpClient.ts | 365 ++++++++++++++++++ src/renderer/src/i18n/locales/en-us.json | 5 +- src/renderer/src/i18n/locales/ja-jp.json | 5 +- src/renderer/src/i18n/locales/ru-ru.json | 5 +- src/renderer/src/i18n/locales/zh-cn.json | 5 +- src/renderer/src/i18n/locales/zh-tw.json | 5 +- .../settings/MCPSettings/McpSettings.tsx | 14 +- src/renderer/src/types/index.ts | 2 +- yarn.lock | 27 +- 11 files changed, 416 insertions(+), 33 deletions(-) create mode 100644 src/main/services/MCPStreamableHttpClient.ts diff --git a/package.json b/package.json index 9719a9a5..915d68c1 100644 --- a/package.json +++ b/package.json @@ -111,7 +111,7 @@ "@google/genai": "^0.4.0", "@hello-pangea/dnd": "^16.6.0", "@kangfenmao/keyv-storage": "^0.1.0", - "@modelcontextprotocol/sdk": "^1.8.0", + "@modelcontextprotocol/sdk": "^1.9.0", "@notionhq/client": "^2.2.15", "@reduxjs/toolkit": "^2.2.5", "@tavily/core": "patch:@tavily/core@npm%3A0.3.1#~/.yarn/patches/@tavily-core-npm-0.3.1-fe69bf2bea.patch", diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 6f6199a5..aef46e9b 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -15,6 +15,7 @@ import { app } from 'electron' import Logger from 'electron-log' import { CacheService } from './CacheService' +import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient' class McpService { private clients: Map = new Map() @@ -67,7 +68,7 @@ class McpService { const args = [...(server.args || [])] - let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport + let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport try { // Create appropriate transport based on configuration @@ -86,7 +87,16 @@ class McpService { // set the client transport to the client transport = clientTransport } else if (server.baseUrl) { - transport = new SSEClientTransport(new URL(server.baseUrl)) + if (server.type === 'streamableHttp') { + transport = new StreamableHTTPClientTransport( + new URL(server.baseUrl!), + {} as StreamableHTTPClientTransportOptions + ) + } else if (server.type === 'sse') { + transport = new SSEClientTransport(new URL(server.baseUrl!)) + } else { + throw new Error('Invalid server type') + } } else if (server.command) { let cmd = server.command diff --git a/src/main/services/MCPStreamableHttpClient.ts b/src/main/services/MCPStreamableHttpClient.ts new file mode 100644 index 00000000..1e080d2a --- /dev/null +++ b/src/main/services/MCPStreamableHttpClient.ts @@ -0,0 +1,365 @@ +import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from '@modelcontextprotocol/sdk/client/auth.js' +import { Transport } from '@modelcontextprotocol/sdk/shared/transport.js' +import { JSONRPCMessage, JSONRPCMessageSchema } from '@modelcontextprotocol/sdk/types.js' + +export class StreamableHTTPError extends Error { + constructor( + public readonly code: number | undefined, + message: string | undefined, + public readonly event: ErrorEvent + ) { + super(`Streamable HTTP error: ${message}`) + } +} + +/** + * Configuration options for the `StreamableHTTPClientTransport`. + */ +export type StreamableHTTPClientTransportOptions = { + /** + * An OAuth client provider to use for authentication. + * + * When an `authProvider` is specified and the connection is started: + * 1. The connection is attempted with any existing access token from the `authProvider`. + * 2. If the access token has expired, the `authProvider` is used to refresh the token. + * 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`. + * + * After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `StreamableHTTPClientTransport.finishAuth` with the authorization code before retrying the connection. + * + * If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown. + * + * `UnauthorizedError` might also be thrown when sending any message over the transport, indicating that the session has expired, and needs to be re-authed and reconnected. + */ + authProvider?: OAuthClientProvider + + /** + * Customizes HTTP requests to the server. + */ + requestInit?: RequestInit +} + +/** + * Client transport for Streamable HTTP: this implements the MCP Streamable HTTP transport specification. + * It will connect to a server using HTTP POST for sending messages and HTTP GET with Server-Sent Events + * for receiving messages. + */ +export class StreamableHTTPClientTransport implements Transport { + private _activeStreams: Map> = new Map() + private _abortController?: AbortController + private _url: URL + private _requestInit?: RequestInit + private _authProvider?: OAuthClientProvider + private _sessionId?: string + private _lastEventId?: string + + onclose?: () => void + onerror?: (error: Error) => void + onmessage?: (message: JSONRPCMessage) => void + + constructor(url: URL, opts?: StreamableHTTPClientTransportOptions) { + this._url = url + this._requestInit = opts?.requestInit + this._authProvider = opts?.authProvider + } + + private async _authThenStart(): Promise { + if (!this._authProvider) { + throw new UnauthorizedError('No auth provider') + } + + let result: AuthResult + try { + result = await auth(this._authProvider, { serverUrl: this._url }) + } catch (error) { + this.onerror?.(error as Error) + throw error + } + + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError() + } + + return await this._startOrAuth() + } + + private async _commonHeaders(): Promise { + const headers: HeadersInit = {} + if (this._authProvider) { + const tokens = await this._authProvider.tokens() + if (tokens) { + headers['Authorization'] = `Bearer ${tokens.access_token}` + } + } + + if (this._sessionId) { + headers['mcp-session-id'] = this._sessionId + } + + return headers + } + + private async _startOrAuth(): Promise { + try { + // Try to open an initial SSE stream with GET to listen for server messages + // This is optional according to the spec - server may not support it + const commonHeaders = await this._commonHeaders() + const headers = new Headers(commonHeaders) + headers.set('Accept', 'text/event-stream') + + // Include Last-Event-ID header for resumable streams + if (this._lastEventId) { + headers.set('last-event-id', this._lastEventId) + } + + const response = await fetch(this._url, { + method: 'GET', + headers, + signal: this._abortController?.signal + }) + + if (response.status === 405) { + // Server doesn't support GET for SSE, which is allowed by the spec + // We'll rely on SSE responses to POST requests for communication + return + } + + if (!response.ok) { + if (response.status === 401 && this._authProvider) { + // Need to authenticate + return await this._authThenStart() + } + + const error = new Error(`Failed to open SSE stream: ${response.status} ${response.statusText}`) + this.onerror?.(error) + throw error + } + + // Successful connection, handle the SSE stream as a standalone listener + const streamId = `initial-${Date.now()}` + this._handleSseStream(response.body, streamId) + } catch (error) { + this.onerror?.(error as Error) + throw error + } + } + + async start() { + if (this._activeStreams.size > 0) { + throw new Error( + 'StreamableHTTPClientTransport already started! If using Client class, note that connect() calls start() automatically.' + ) + } + + this._abortController = new AbortController() + return await this._startOrAuth() + } + + /** + * Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth. + */ + async finishAuth(authorizationCode: string): Promise { + if (!this._authProvider) { + throw new UnauthorizedError('No auth provider') + } + + const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode }) + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError('Failed to authorize') + } + } + + async close(): Promise { + // Close all active streams + for (const reader of this._activeStreams.values()) { + try { + reader.cancel() + } catch (error) { + this.onerror?.(error as Error) + } + } + this._activeStreams.clear() + + // Abort any pending requests + this._abortController?.abort() + + // If we have a session ID, send a DELETE request to explicitly terminate the session + if (this._sessionId) { + try { + const commonHeaders = await this._commonHeaders() + const response = await fetch(this._url, { + method: 'DELETE', + headers: commonHeaders, + signal: this._abortController?.signal + }) + + if (!response.ok) { + // Server might respond with 405 if it doesn't support explicit session termination + // We don't throw an error in that case + if (response.status !== 405) { + const text = await response.text().catch(() => null) + throw new Error(`Error terminating session (HTTP ${response.status}): ${text}`) + } + } + } catch (error) { + // We still want to invoke onclose even if the session termination fails + this.onerror?.(error as Error) + } + } + + this.onclose?.() + } + + async send(message: JSONRPCMessage | JSONRPCMessage[]): Promise { + try { + const commonHeaders = await this._commonHeaders() + const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers }) + headers.set('content-type', 'application/json') + headers.set('accept', 'application/json, text/event-stream') + + const init = { + ...this._requestInit, + method: 'POST', + headers, + body: JSON.stringify(message), + signal: this._abortController?.signal + } + + const response = await fetch(this._url, init) + + // Handle session ID received during initialization + const sessionId = response.headers.get('mcp-session-id') + if (sessionId) { + this._sessionId = sessionId + } + + if (!response.ok) { + if (response.status === 401 && this._authProvider) { + const result = await auth(this._authProvider, { serverUrl: this._url }) + if (result !== 'AUTHORIZED') { + throw new UnauthorizedError() + } + + // Purposely _not_ awaited, so we don't call onerror twice + return this.send(message) + } + + const text = await response.text().catch(() => null) + throw new Error(`Error POSTing to endpoint (HTTP ${response.status}): ${text}`) + } + + // If the response is 202 Accepted, there's no body to process + if (response.status === 202) { + return + } + + // Get original message(s) for detecting request IDs + const messages = Array.isArray(message) ? message : [message] + + // Extract IDs from request messages for tracking responses + const requestIds = messages + .filter((msg) => 'method' in msg && 'id' in msg) + .map((msg) => ('id' in msg ? msg.id : undefined)) + .filter((id) => id !== undefined) + + // If we have request IDs and an SSE response, create a unique stream ID + const hasRequests = requestIds.length > 0 + + // Check the response type + const contentType = response.headers.get('content-type') + + if (hasRequests) { + if (contentType?.includes('text/event-stream')) { + // For streaming responses, create a unique stream ID based on request IDs + const streamId = `req-${requestIds.join('-')}-${Date.now()}` + this._handleSseStream(response.body, streamId) + } else if (contentType?.includes('application/json')) { + // For non-streaming servers, we might get direct JSON responses + const data = await response.json() + const responseMessages = Array.isArray(data) + ? data.map((msg) => JSONRPCMessageSchema.parse(msg)) + : [JSONRPCMessageSchema.parse(data)] + + for (const msg of responseMessages) { + this.onmessage?.(msg) + } + } + } + } catch (error) { + this.onerror?.(error as Error) + throw error + } + } + + private _handleSseStream(stream: ReadableStream | null, streamId: string): void { + if (!stream) { + return + } + + // Set up stream handling for server-sent events + const reader = stream.getReader() + this._activeStreams.set(streamId, reader) + const decoder = new TextDecoder() + let buffer = '' + + const processStream = async () => { + try { + while (true) { + const { done, value } = await reader.read() + if (done) { + // Stream closed by server + this._activeStreams.delete(streamId) + break + } + + buffer += decoder.decode(value, { stream: true }) + + // Process SSE messages in the buffer + const events = buffer.split('\n\n') + buffer = events.pop() || '' + + for (const event of events) { + const lines = event.split('\n') + let id: string | undefined + let eventType: string | undefined + let data: string | undefined + + // Parse SSE message according to the format + for (const line of lines) { + if (line.startsWith('id:')) { + id = line.slice(3).trim() + } else if (line.startsWith('event:')) { + eventType = line.slice(6).trim() + } else if (line.startsWith('data:')) { + data = line.slice(5).trim() + } + } + + // Update last event ID if provided by server + // As per spec: the ID MUST be globally unique across all streams within that session + if (id) { + this._lastEventId = id + } + + // Handle message event + if (data) { + // Default event type is 'message' per SSE spec if not specified + if (!eventType || eventType === 'message') { + try { + const message = JSONRPCMessageSchema.parse(JSON.parse(data)) + this.onmessage?.(message) + } catch (error) { + this.onerror?.(error as Error) + } + } + } + } + } + } catch (error) { + this._activeStreams.delete(streamId) + this.onerror?.(error as Error) + } + } + + processStream() + } +} diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 2d5727a3..73070fcc 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -1028,8 +1028,9 @@ "argsTooltip": "Each argument on a new line", "baseUrlTooltip": "Remote server base URL", "command": "Command", - "sse": "Server-Sent Events(sse)", - "stdio": "Standard Input/Output(stdio)", + "sse": "Server-Sent Events (sse)", + "streamableHttp": "Streamable HTTP (streamableHttp)", + "stdio": "Standard Input/Output (stdio)", "inMemory": "Memory", "config_description": "Configure Model Context Protocol servers", "deleteError": "Failed to delete server", diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index 0179a70c..1b78f598 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -1027,8 +1027,9 @@ "argsTooltip": "1行に1つの引数を入力してください", "baseUrlTooltip": "リモートURLアドレス", "command": "コマンド", - "sse": "サーバー送信イベント(sse)", - "stdio": "標準入力/出力(stdio)", + "sse": "サーバー送信イベント (sse)", + "streamableHttp": "ストリーミング可能なHTTP (streamable)", + "stdio": "標準入力/出力 (stdio)", "inMemory": "メモリ", "config_description": "モデルコンテキストプロトコルサーバーの設定", "deleteError": "サーバーの削除に失敗しました", diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 31f2268c..4402c736 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -1027,8 +1027,9 @@ "argsTooltip": "Каждый аргумент с новой строки", "baseUrlTooltip": "Адрес удаленного URL", "command": "Команда", - "sse": "События, отправляемые сервером(sse)", - "stdio": "Стандартный ввод/вывод(stdio)", + "sse": "События, отправляемые сервером (sse)", + "streamableHttp": "Потоковый HTTP (streamableHttp)", + "stdio": "Стандартный ввод/вывод (stdio)", "inMemory": "Память", "config_description": "Настройка серверов протокола контекста модели", "deleteError": "Не удалось удалить сервер", diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index b14951af..a0121456 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -1028,8 +1028,9 @@ "argsTooltip": "每个参数占一行", "baseUrlTooltip": "远程 URL 地址", "command": "命令", - "sse": "服务器发送事件(sse)", - "stdio": "标准输入/输出(stdio)", + "sse": "服务器发送事件 (sse)", + "streamableHttp": "可流式传输的HTTP (streamableHttp)", + "stdio": "标准输入/输出 (stdio)", "inMemory": "内存", "config_description": "配置模型上下文协议服务器", "deleteError": "删除服务器失败", diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 06e83d9d..b4a93998 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -1027,8 +1027,9 @@ "argsTooltip": "每個參數佔一行", "baseUrlTooltip": "遠端 URL 地址", "command": "指令", - "sse": "伺服器傳送事件(sse)", - "stdio": "標準輸入/輸出(stdio)", + "sse": "伺服器傳送事件 (sse)", + "streamableHttp": "可串流的HTTP (streamableHttp)", + "stdio": "標準輸入/輸出 (stdio)", "inMemory": "記憶體", "config_description": "設定模型上下文協議伺服器", "deleteError": "刪除伺服器失敗", diff --git a/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx b/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx index 59435675..8bc1b562 100644 --- a/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx +++ b/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx @@ -149,7 +149,7 @@ const McpSettings: React.FC = ({ server }) => { } // set stdio or sse server - if (values.serverType === 'sse') { + if (values.serverType === 'sse' || server.type === 'streamableHttp') { mcpServer.baseUrl = values.baseUrl } else { mcpServer.command = values.command @@ -358,7 +358,8 @@ const McpSettings: React.FC = ({ server }) => { onChange={(e) => setServerType(e.target.value)} options={[ { label: t('settings.mcp.stdio'), value: 'stdio' }, - { label: t('settings.mcp.sse'), value: 'sse' } + { label: t('settings.mcp.sse'), value: 'sse' }, + { label: t('settings.mcp.streamableHttp'), value: 'streamableHttp' } ]} /> @@ -372,6 +373,15 @@ const McpSettings: React.FC = ({ server }) => { )} + {serverType === 'streamableHttp' && ( + + + + )} {serverType === 'stdio' && ( <>