From 35c50b54a892806dddb49f5d08dc1d3519839c6b Mon Sep 17 00:00:00 2001 From: LiuVaayne <10231735+vaayne@users.noreply.github.com> Date: Wed, 16 Apr 2025 22:07:32 +0800 Subject: [PATCH] Feat/mcp oauth (#4837) * feat: implement OAuth client provider and lockfile management * feat: implement OAuth callback server and refactor authentication flow * fix(McpService): restrict command handling to 'npx' for improved clarity * refactor: make callbackPort optional in OAuthProviderOptions and clean up MCPService * refactor: restructure OAuth handling by creating separate callback and provider classes, and remove unused utility functions --- src/main/services/MCPService.ts | 100 +++++++++++++++++--- src/main/services/mcp/oauth/callback.ts | 76 +++++++++++++++ src/main/services/mcp/oauth/provider.ts | 78 +++++++++++++++ src/main/services/mcp/oauth/storage.ts | 120 ++++++++++++++++++++++++ src/main/services/mcp/oauth/types.ts | 61 ++++++++++++ src/main/utils/file.ts | 4 + 6 files changed, 425 insertions(+), 14 deletions(-) create mode 100644 src/main/services/mcp/oauth/callback.ts create mode 100644 src/main/services/mcp/oauth/provider.ts create mode 100644 src/main/services/mcp/oauth/storage.ts create mode 100644 src/main/services/mcp/oauth/types.ts diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 1e4e44dd..c2314ccf 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -1,3 +1,4 @@ +import crypto from 'node:crypto' import fs from 'node:fs' import os from 'node:os' import path from 'node:path' @@ -22,9 +23,12 @@ import { } from '@types' import { app } from 'electron' import Logger from 'electron-log' +import { EventEmitter } from 'events' import { memoize } from 'lodash' import { CacheService } from './CacheService' +import { CallBackServer } from './mcp/oauth/callback' +import { McpOAuthClientProvider } from './mcp/oauth/provider' import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient' // Generic type for caching wrapped functions @@ -117,9 +121,17 @@ class McpService { const args = [...(server.args || [])] - let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + const authProvider = new McpOAuthClientProvider({ + serverUrlHash: crypto + .createHash('md5') + .update(server.baseUrl || '') + .digest('hex') + }) - try { + const initTransport = async (): Promise< + StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport + > => { // Create appropriate transport based on configuration if (server.type === 'inMemory') { Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) @@ -134,29 +146,31 @@ class McpService { throw new Error(`Failed to start in-memory server: ${error.message}`) } // set the client transport to the client - transport = clientTransport + return clientTransport } else if (server.baseUrl) { if (server.type === 'streamableHttp') { const options: StreamableHTTPClientTransportOptions = { requestInit: { headers: server.headers || {} - } + }, + authProvider } - transport = new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) + return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) } else if (server.type === 'sse') { const options: SSEClientTransportOptions = { requestInit: { headers: server.headers || {} - } + }, + authProvider } - transport = new SSEClientTransport(new URL(server.baseUrl!), options) + return new SSEClientTransport(new URL(server.baseUrl!), options) } else { throw new Error('Invalid server type') } } else if (server.command) { let cmd = server.command - if (server.command === 'npx' || server.command === 'bun' || server.command === 'bunx') { + if (server.command === 'npx') { cmd = await getBinaryPath('bun') Logger.info(`[MCP] Using command: ${cmd}`) @@ -196,7 +210,7 @@ class McpService { Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) // Logger.info(`[MCP] Environment variables for server:`, server.env) - transport = new StdioClientTransport({ + const stdioTransport = new StdioClientTransport({ command: cmd, args, env: { @@ -206,14 +220,72 @@ class McpService { }, stderr: 'pipe' }) - transport.stderr?.on('data', (data) => + stdioTransport.stderr?.on('data', (data) => Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) ) + return stdioTransport } else { throw new Error('Either baseUrl or command must be provided') } + } - await client.connect(transport) + const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => { + Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`) + // Create an event emitter for the OAuth callback + const events = new EventEmitter() + + // Create a callback server + const callbackServer = new CallBackServer({ + port: authProvider.config.callbackPort, + path: authProvider.config.callbackPath || '/oauth/callback', + events + }) + + // Set a timeout to close the callback server + const timeoutId = setTimeout(() => { + Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`) + callbackServer.close() + }, 300000) // 5 minutes timeout + + try { + // Wait for the authorization code + const authCode = await callbackServer.waitForAuthCode() + Logger.info(`[MCP] Received auth code: ${authCode}`) + + // Complete the OAuth flow + await transport.finishAuth(authCode) + + Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`) + + const newTransport = await initTransport() + // Try to connect again + await client.connect(newTransport) + + Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`) + } catch (oauthError) { + Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError) + throw new Error( + `OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}` + ) + } finally { + // Clear the timeout and close the callback server + clearTimeout(timeoutId) + callbackServer.close() + } + } + + try { + const transport = await initTransport() + try { + await client.connect(transport) + } catch (error: Error | any) { + if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) { + Logger.info(`[MCP] Authentication required for server: ${server.name}`) + await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport) + } else { + throw error + } + } // Store the new client in the cache this.clients.set(serverKey, client) @@ -537,15 +609,15 @@ class McpService { }) let path = '' - child.stdout.on('data', (data) => { + child.stdout.on('data', (data: Buffer) => { path += data.toString() }) - child.stderr.on('data', (data) => { + child.stderr.on('data', (data: Buffer) => { console.error('Error getting PATH:', data.toString()) }) - child.on('close', (code) => { + child.on('close', (code: number) => { if (code === 0) { const trimmedPath = path.trim() resolve(trimmedPath) diff --git a/src/main/services/mcp/oauth/callback.ts b/src/main/services/mcp/oauth/callback.ts new file mode 100644 index 00000000..6884c530 --- /dev/null +++ b/src/main/services/mcp/oauth/callback.ts @@ -0,0 +1,76 @@ +import Logger from 'electron-log' +import EventEmitter from 'events' +import http from 'http' +import { URL } from 'url' + +import { OAuthCallbackServerOptions } from './types' + +export class CallBackServer { + private server: Promise + private events: EventEmitter + + constructor(options: OAuthCallbackServerOptions) { + const { port, path, events } = options + this.events = events + this.server = this.initialize(port, path) + } + + initialize(port: number, path: string): Promise { + const server = http.createServer((req, res) => { + // Only handle requests to the callback path + if (req.url?.startsWith(path)) { + try { + // Parse the URL to extract the authorization code + const url = new URL(req.url, `http://localhost:${port}`) + const code = url.searchParams.get('code') + if (code) { + // Emit the code event + this.events.emit('auth-code-received', code) + } + } catch (error) { + Logger.error('Error processing OAuth callback:', error) + res.writeHead(500, { 'Content-Type': 'text/plain' }) + res.end('Internal Server Error') + } + } else { + // Not a callback request + res.writeHead(404, { 'Content-Type': 'text/plain' }) + res.end('Not Found') + } + }) + + // Handle server errors + server.on('error', (error) => { + Logger.error('OAuth callback server error:', error) + }) + + const runningServer = new Promise((resolve, reject) => { + server.listen(port, () => { + Logger.info(`OAuth callback server listening on port ${port}`) + resolve(server) + }) + + server.on('error', (error) => { + reject(error) + }) + }) + return runningServer + } + + get getServer(): Promise { + return this.server + } + + async close() { + const server = await this.server + server.close() + } + + async waitForAuthCode(): Promise { + return new Promise((resolve) => { + this.events.once('auth-code-received', (code) => { + resolve(code) + }) + }) + } +} diff --git a/src/main/services/mcp/oauth/provider.ts b/src/main/services/mcp/oauth/provider.ts new file mode 100644 index 00000000..e56fada6 --- /dev/null +++ b/src/main/services/mcp/oauth/provider.ts @@ -0,0 +1,78 @@ +import path from 'node:path' + +import { getConfigDir } from '@main/utils/file' +import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth' +import { OAuthClientInformation, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth' +import Logger from 'electron-log' +import open from 'open' + +import { JsonFileStorage } from './storage' +import { OAuthProviderOptions } from './types' + +export class McpOAuthClientProvider implements OAuthClientProvider { + private storage: JsonFileStorage + public readonly config: Required + + constructor(options: OAuthProviderOptions) { + const configDir = path.join(getConfigDir(), 'mcp', 'oauth') + this.config = { + serverUrlHash: options.serverUrlHash, + callbackPort: options.callbackPort || 12346, + callbackPath: options.callbackPath || '/oauth/callback', + configDir: options.configDir || configDir, + clientName: options.clientName || 'Cherry Studio', + clientUri: options.clientUri || 'https://github.com/CherryHQ/cherry-studio' + } + this.storage = new JsonFileStorage(this.config.serverUrlHash, this.config.configDir) + } + + get redirectUrl(): string { + return `http://localhost:${this.config.callbackPort}${this.config.callbackPath}` + } + + get clientMetadata() { + return { + redirect_uris: [this.redirectUrl], + token_endpoint_auth_method: 'none', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + client_name: this.config.clientName, + client_uri: this.config.clientUri + } + } + + async clientInformation(): Promise { + return this.storage.getClientInformation() + } + + async saveClientInformation(info: OAuthClientInformationFull): Promise { + await this.storage.saveClientInformation(info) + } + + async tokens(): Promise { + return this.storage.getTokens() + } + + async saveTokens(tokens: OAuthTokens): Promise { + await this.storage.saveTokens(tokens) + } + + async redirectToAuthorization(authorizationUrl: URL): Promise { + try { + // Open the browser to the authorization URL + await open(authorizationUrl.toString()) + Logger.info('Browser opened automatically.') + } catch (error) { + Logger.error('Could not open browser automatically.') + throw error // Let caller handle the error + } + } + + async saveCodeVerifier(codeVerifier: string): Promise { + await this.storage.saveCodeVerifier(codeVerifier) + } + + async codeVerifier(): Promise { + return this.storage.getCodeVerifier() + } +} diff --git a/src/main/services/mcp/oauth/storage.ts b/src/main/services/mcp/oauth/storage.ts new file mode 100644 index 00000000..349fcf8b --- /dev/null +++ b/src/main/services/mcp/oauth/storage.ts @@ -0,0 +1,120 @@ +import { + OAuthClientInformation, + OAuthClientInformationFull, + OAuthTokens +} from '@modelcontextprotocol/sdk/shared/auth.js' +import Logger from 'electron-log' +import fs from 'fs/promises' +import path from 'path' + +import { IOAuthStorage, OAuthStorageData, OAuthStorageSchema } from './types' + +export class JsonFileStorage implements IOAuthStorage { + private readonly filePath: string + private cache: OAuthStorageData | null = null + + constructor( + readonly serverUrlHash: string, + configDir: string + ) { + this.filePath = path.join(configDir, `${serverUrlHash}_oauth.json`) + } + + private async readStorage(): Promise { + if (this.cache) { + return this.cache + } + + try { + const data = await fs.readFile(this.filePath, 'utf-8') + const parsed = JSON.parse(data) + const validated = OAuthStorageSchema.parse(parsed) + this.cache = validated + return validated + } catch (error) { + if (error instanceof Error && 'code' in error && error.code === 'ENOENT') { + // File doesn't exist, return initial state + const initial: OAuthStorageData = { lastUpdated: Date.now() } + await this.writeStorage(initial) + return initial + } + Logger.error('Error reading OAuth storage:', error) + throw new Error(`Failed to read OAuth storage: ${error instanceof Error ? error.message : String(error)}`) + } + } + + private async writeStorage(data: OAuthStorageData): Promise { + try { + // Ensure directory exists + await fs.mkdir(path.dirname(this.filePath), { recursive: true }) + + // Update timestamp + data.lastUpdated = Date.now() + + // Write file atomically + const tempPath = `${this.filePath}.tmp` + await fs.writeFile(tempPath, JSON.stringify(data, null, 2)) + await fs.rename(tempPath, this.filePath) + + // Update cache + this.cache = data + } catch (error) { + Logger.error('Error writing OAuth storage:', error) + throw new Error(`Failed to write OAuth storage: ${error instanceof Error ? error.message : String(error)}`) + } + } + + async getClientInformation(): Promise { + const data = await this.readStorage() + return data.clientInfo + } + + async saveClientInformation(info: OAuthClientInformationFull): Promise { + const data = await this.readStorage() + await this.writeStorage({ + ...data, + clientInfo: info + }) + } + + async getTokens(): Promise { + const data = await this.readStorage() + return data.tokens + } + + async saveTokens(tokens: OAuthTokens): Promise { + const data = await this.readStorage() + await this.writeStorage({ + ...data, + tokens + }) + } + + async getCodeVerifier(): Promise { + const data = await this.readStorage() + if (!data.codeVerifier) { + throw new Error('No code verifier saved for session') + } + return data.codeVerifier + } + + async saveCodeVerifier(codeVerifier: string): Promise { + const data = await this.readStorage() + await this.writeStorage({ + ...data, + codeVerifier + }) + } + + async clear(): Promise { + try { + await fs.unlink(this.filePath) + this.cache = null + } catch (error) { + if (error instanceof Error && 'code' in error && error.code !== 'ENOENT') { + Logger.error('Error clearing OAuth storage:', error) + throw new Error(`Failed to clear OAuth storage: ${error instanceof Error ? error.message : String(error)}`) + } + } + } +} diff --git a/src/main/services/mcp/oauth/types.ts b/src/main/services/mcp/oauth/types.ts new file mode 100644 index 00000000..de631c16 --- /dev/null +++ b/src/main/services/mcp/oauth/types.ts @@ -0,0 +1,61 @@ +import { + OAuthClientInformation, + OAuthClientInformationFull, + OAuthTokens +} from '@modelcontextprotocol/sdk/shared/auth.js' +import EventEmitter from 'events' +import { z } from 'zod' + +export interface OAuthStorageData { + clientInfo?: OAuthClientInformation + tokens?: OAuthTokens + codeVerifier?: string + lastUpdated: number +} + +export const OAuthStorageSchema = z.object({ + clientInfo: z.any().optional(), + tokens: z.any().optional(), + codeVerifier: z.string().optional(), + lastUpdated: z.number() +}) + +export interface IOAuthStorage { + getClientInformation(): Promise + saveClientInformation(info: OAuthClientInformationFull): Promise + getTokens(): Promise + saveTokens(tokens: OAuthTokens): Promise + getCodeVerifier(): Promise + saveCodeVerifier(codeVerifier: string): Promise + clear(): Promise +} + +/** + * OAuth callback server setup options + */ +export interface OAuthCallbackServerOptions { + /** Port for the callback server */ + port: number + /** Path for the callback endpoint */ + path: string + /** Event emitter to signal when auth code is received */ + events: EventEmitter +} + +/** + * Options for creating an OAuth client provider + */ +export interface OAuthProviderOptions { + /** Server URL to connect to */ + serverUrlHash: string + /** Port for the OAuth callback server */ + callbackPort?: number + /** Path for the OAuth callback endpoint */ + callbackPath?: string + /** Directory to store OAuth credentials */ + configDir?: string + /** Client name to use for OAuth registration */ + clientName?: string + /** Client URI to use for OAuth registration */ + clientUri?: string +} diff --git a/src/main/utils/file.ts b/src/main/utils/file.ts index 636b5999..de091555 100644 --- a/src/main/utils/file.ts +++ b/src/main/utils/file.ts @@ -79,3 +79,7 @@ export function getFilesDir() { export function getConfigDir() { return path.join(os.homedir(), '.cherrystudio', 'config') } + +export function getAppConfigDir(name: string) { + return path.join(getConfigDir(), name) +}