Feat/mcp oauth (#4837)

* feat: implement OAuth client provider and lockfile management

* feat: implement OAuth callback server and refactor authentication flow

* fix(McpService): restrict command handling to 'npx' for improved clarity

* refactor: make callbackPort optional in OAuthProviderOptions and clean up MCPService

* refactor: restructure OAuth handling by creating separate callback and provider classes, and remove unused utility functions
This commit is contained in:
LiuVaayne 2025-04-16 22:07:32 +08:00 committed by GitHub
parent ac0fe75078
commit 35c50b54a8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 425 additions and 14 deletions

View File

@ -1,3 +1,4 @@
import crypto from 'node:crypto'
import fs from 'node:fs' import fs from 'node:fs'
import os from 'node:os' import os from 'node:os'
import path from 'node:path' import path from 'node:path'
@ -22,9 +23,12 @@ import {
} from '@types' } from '@types'
import { app } from 'electron' import { app } from 'electron'
import Logger from 'electron-log' import Logger from 'electron-log'
import { EventEmitter } from 'events'
import { memoize } from 'lodash' import { memoize } from 'lodash'
import { CacheService } from './CacheService' import { CacheService } from './CacheService'
import { CallBackServer } from './mcp/oauth/callback'
import { McpOAuthClientProvider } from './mcp/oauth/provider'
import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient' import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient'
// Generic type for caching wrapped functions // Generic type for caching wrapped functions
@ -117,9 +121,17 @@ class McpService {
const args = [...(server.args || [])] const args = [...(server.args || [])]
let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport // let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
const authProvider = new McpOAuthClientProvider({
serverUrlHash: crypto
.createHash('md5')
.update(server.baseUrl || '')
.digest('hex')
})
try { const initTransport = async (): Promise<
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
> => {
// Create appropriate transport based on configuration // Create appropriate transport based on configuration
if (server.type === 'inMemory') { if (server.type === 'inMemory') {
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`) Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
@ -134,29 +146,31 @@ class McpService {
throw new Error(`Failed to start in-memory server: ${error.message}`) throw new Error(`Failed to start in-memory server: ${error.message}`)
} }
// set the client transport to the client // set the client transport to the client
transport = clientTransport return clientTransport
} else if (server.baseUrl) { } else if (server.baseUrl) {
if (server.type === 'streamableHttp') { if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = { const options: StreamableHTTPClientTransportOptions = {
requestInit: { requestInit: {
headers: server.headers || {} headers: server.headers || {}
} },
authProvider
} }
transport = new StreamableHTTPClientTransport(new URL(server.baseUrl!), options) return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
} else if (server.type === 'sse') { } else if (server.type === 'sse') {
const options: SSEClientTransportOptions = { const options: SSEClientTransportOptions = {
requestInit: { requestInit: {
headers: server.headers || {} headers: server.headers || {}
} },
authProvider
} }
transport = new SSEClientTransport(new URL(server.baseUrl!), options) return new SSEClientTransport(new URL(server.baseUrl!), options)
} else { } else {
throw new Error('Invalid server type') throw new Error('Invalid server type')
} }
} else if (server.command) { } else if (server.command) {
let cmd = server.command let cmd = server.command
if (server.command === 'npx' || server.command === 'bun' || server.command === 'bunx') { if (server.command === 'npx') {
cmd = await getBinaryPath('bun') cmd = await getBinaryPath('bun')
Logger.info(`[MCP] Using command: ${cmd}`) Logger.info(`[MCP] Using command: ${cmd}`)
@ -196,7 +210,7 @@ class McpService {
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`) Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
// Logger.info(`[MCP] Environment variables for server:`, server.env) // Logger.info(`[MCP] Environment variables for server:`, server.env)
transport = new StdioClientTransport({ const stdioTransport = new StdioClientTransport({
command: cmd, command: cmd,
args, args,
env: { env: {
@ -206,14 +220,72 @@ class McpService {
}, },
stderr: 'pipe' stderr: 'pipe'
}) })
transport.stderr?.on('data', (data) => stdioTransport.stderr?.on('data', (data) =>
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString()) Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
) )
return stdioTransport
} else { } else {
throw new Error('Either baseUrl or command must be provided') throw new Error('Either baseUrl or command must be provided')
} }
}
await client.connect(transport) const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
// Create an event emitter for the OAuth callback
const events = new EventEmitter()
// Create a callback server
const callbackServer = new CallBackServer({
port: authProvider.config.callbackPort,
path: authProvider.config.callbackPath || '/oauth/callback',
events
})
// Set a timeout to close the callback server
const timeoutId = setTimeout(() => {
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
callbackServer.close()
}, 300000) // 5 minutes timeout
try {
// Wait for the authorization code
const authCode = await callbackServer.waitForAuthCode()
Logger.info(`[MCP] Received auth code: ${authCode}`)
// Complete the OAuth flow
await transport.finishAuth(authCode)
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
const newTransport = await initTransport()
// Try to connect again
await client.connect(newTransport)
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
} catch (oauthError) {
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
throw new Error(
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
)
} finally {
// Clear the timeout and close the callback server
clearTimeout(timeoutId)
callbackServer.close()
}
}
try {
const transport = await initTransport()
try {
await client.connect(transport)
} catch (error: Error | any) {
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
} else {
throw error
}
}
// Store the new client in the cache // Store the new client in the cache
this.clients.set(serverKey, client) this.clients.set(serverKey, client)
@ -537,15 +609,15 @@ class McpService {
}) })
let path = '' let path = ''
child.stdout.on('data', (data) => { child.stdout.on('data', (data: Buffer) => {
path += data.toString() path += data.toString()
}) })
child.stderr.on('data', (data) => { child.stderr.on('data', (data: Buffer) => {
console.error('Error getting PATH:', data.toString()) console.error('Error getting PATH:', data.toString())
}) })
child.on('close', (code) => { child.on('close', (code: number) => {
if (code === 0) { if (code === 0) {
const trimmedPath = path.trim() const trimmedPath = path.trim()
resolve(trimmedPath) resolve(trimmedPath)

View File

@ -0,0 +1,76 @@
import Logger from 'electron-log'
import EventEmitter from 'events'
import http from 'http'
import { URL } from 'url'
import { OAuthCallbackServerOptions } from './types'
export class CallBackServer {
private server: Promise<http.Server>
private events: EventEmitter
constructor(options: OAuthCallbackServerOptions) {
const { port, path, events } = options
this.events = events
this.server = this.initialize(port, path)
}
initialize(port: number, path: string): Promise<http.Server> {
const server = http.createServer((req, res) => {
// Only handle requests to the callback path
if (req.url?.startsWith(path)) {
try {
// Parse the URL to extract the authorization code
const url = new URL(req.url, `http://localhost:${port}`)
const code = url.searchParams.get('code')
if (code) {
// Emit the code event
this.events.emit('auth-code-received', code)
}
} catch (error) {
Logger.error('Error processing OAuth callback:', error)
res.writeHead(500, { 'Content-Type': 'text/plain' })
res.end('Internal Server Error')
}
} else {
// Not a callback request
res.writeHead(404, { 'Content-Type': 'text/plain' })
res.end('Not Found')
}
})
// Handle server errors
server.on('error', (error) => {
Logger.error('OAuth callback server error:', error)
})
const runningServer = new Promise<http.Server>((resolve, reject) => {
server.listen(port, () => {
Logger.info(`OAuth callback server listening on port ${port}`)
resolve(server)
})
server.on('error', (error) => {
reject(error)
})
})
return runningServer
}
get getServer(): Promise<http.Server> {
return this.server
}
async close() {
const server = await this.server
server.close()
}
async waitForAuthCode(): Promise<string> {
return new Promise((resolve) => {
this.events.once('auth-code-received', (code) => {
resolve(code)
})
})
}
}

View File

@ -0,0 +1,78 @@
import path from 'node:path'
import { getConfigDir } from '@main/utils/file'
import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth'
import { OAuthClientInformation, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth'
import Logger from 'electron-log'
import open from 'open'
import { JsonFileStorage } from './storage'
import { OAuthProviderOptions } from './types'
export class McpOAuthClientProvider implements OAuthClientProvider {
private storage: JsonFileStorage
public readonly config: Required<OAuthProviderOptions>
constructor(options: OAuthProviderOptions) {
const configDir = path.join(getConfigDir(), 'mcp', 'oauth')
this.config = {
serverUrlHash: options.serverUrlHash,
callbackPort: options.callbackPort || 12346,
callbackPath: options.callbackPath || '/oauth/callback',
configDir: options.configDir || configDir,
clientName: options.clientName || 'Cherry Studio',
clientUri: options.clientUri || 'https://github.com/CherryHQ/cherry-studio'
}
this.storage = new JsonFileStorage(this.config.serverUrlHash, this.config.configDir)
}
get redirectUrl(): string {
return `http://localhost:${this.config.callbackPort}${this.config.callbackPath}`
}
get clientMetadata() {
return {
redirect_uris: [this.redirectUrl],
token_endpoint_auth_method: 'none',
grant_types: ['authorization_code', 'refresh_token'],
response_types: ['code'],
client_name: this.config.clientName,
client_uri: this.config.clientUri
}
}
async clientInformation(): Promise<OAuthClientInformation | undefined> {
return this.storage.getClientInformation()
}
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
await this.storage.saveClientInformation(info)
}
async tokens(): Promise<OAuthTokens | undefined> {
return this.storage.getTokens()
}
async saveTokens(tokens: OAuthTokens): Promise<void> {
await this.storage.saveTokens(tokens)
}
async redirectToAuthorization(authorizationUrl: URL): Promise<void> {
try {
// Open the browser to the authorization URL
await open(authorizationUrl.toString())
Logger.info('Browser opened automatically.')
} catch (error) {
Logger.error('Could not open browser automatically.')
throw error // Let caller handle the error
}
}
async saveCodeVerifier(codeVerifier: string): Promise<void> {
await this.storage.saveCodeVerifier(codeVerifier)
}
async codeVerifier(): Promise<string> {
return this.storage.getCodeVerifier()
}
}

View File

@ -0,0 +1,120 @@
import {
OAuthClientInformation,
OAuthClientInformationFull,
OAuthTokens
} from '@modelcontextprotocol/sdk/shared/auth.js'
import Logger from 'electron-log'
import fs from 'fs/promises'
import path from 'path'
import { IOAuthStorage, OAuthStorageData, OAuthStorageSchema } from './types'
export class JsonFileStorage implements IOAuthStorage {
private readonly filePath: string
private cache: OAuthStorageData | null = null
constructor(
readonly serverUrlHash: string,
configDir: string
) {
this.filePath = path.join(configDir, `${serverUrlHash}_oauth.json`)
}
private async readStorage(): Promise<OAuthStorageData> {
if (this.cache) {
return this.cache
}
try {
const data = await fs.readFile(this.filePath, 'utf-8')
const parsed = JSON.parse(data)
const validated = OAuthStorageSchema.parse(parsed)
this.cache = validated
return validated
} catch (error) {
if (error instanceof Error && 'code' in error && error.code === 'ENOENT') {
// File doesn't exist, return initial state
const initial: OAuthStorageData = { lastUpdated: Date.now() }
await this.writeStorage(initial)
return initial
}
Logger.error('Error reading OAuth storage:', error)
throw new Error(`Failed to read OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
}
}
private async writeStorage(data: OAuthStorageData): Promise<void> {
try {
// Ensure directory exists
await fs.mkdir(path.dirname(this.filePath), { recursive: true })
// Update timestamp
data.lastUpdated = Date.now()
// Write file atomically
const tempPath = `${this.filePath}.tmp`
await fs.writeFile(tempPath, JSON.stringify(data, null, 2))
await fs.rename(tempPath, this.filePath)
// Update cache
this.cache = data
} catch (error) {
Logger.error('Error writing OAuth storage:', error)
throw new Error(`Failed to write OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
}
}
async getClientInformation(): Promise<OAuthClientInformation | undefined> {
const data = await this.readStorage()
return data.clientInfo
}
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
const data = await this.readStorage()
await this.writeStorage({
...data,
clientInfo: info
})
}
async getTokens(): Promise<OAuthTokens | undefined> {
const data = await this.readStorage()
return data.tokens
}
async saveTokens(tokens: OAuthTokens): Promise<void> {
const data = await this.readStorage()
await this.writeStorage({
...data,
tokens
})
}
async getCodeVerifier(): Promise<string> {
const data = await this.readStorage()
if (!data.codeVerifier) {
throw new Error('No code verifier saved for session')
}
return data.codeVerifier
}
async saveCodeVerifier(codeVerifier: string): Promise<void> {
const data = await this.readStorage()
await this.writeStorage({
...data,
codeVerifier
})
}
async clear(): Promise<void> {
try {
await fs.unlink(this.filePath)
this.cache = null
} catch (error) {
if (error instanceof Error && 'code' in error && error.code !== 'ENOENT') {
Logger.error('Error clearing OAuth storage:', error)
throw new Error(`Failed to clear OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
}
}
}
}

View File

@ -0,0 +1,61 @@
import {
OAuthClientInformation,
OAuthClientInformationFull,
OAuthTokens
} from '@modelcontextprotocol/sdk/shared/auth.js'
import EventEmitter from 'events'
import { z } from 'zod'
export interface OAuthStorageData {
clientInfo?: OAuthClientInformation
tokens?: OAuthTokens
codeVerifier?: string
lastUpdated: number
}
export const OAuthStorageSchema = z.object({
clientInfo: z.any().optional(),
tokens: z.any().optional(),
codeVerifier: z.string().optional(),
lastUpdated: z.number()
})
export interface IOAuthStorage {
getClientInformation(): Promise<OAuthClientInformation | undefined>
saveClientInformation(info: OAuthClientInformationFull): Promise<void>
getTokens(): Promise<OAuthTokens | undefined>
saveTokens(tokens: OAuthTokens): Promise<void>
getCodeVerifier(): Promise<string>
saveCodeVerifier(codeVerifier: string): Promise<void>
clear(): Promise<void>
}
/**
* OAuth callback server setup options
*/
export interface OAuthCallbackServerOptions {
/** Port for the callback server */
port: number
/** Path for the callback endpoint */
path: string
/** Event emitter to signal when auth code is received */
events: EventEmitter
}
/**
* Options for creating an OAuth client provider
*/
export interface OAuthProviderOptions {
/** Server URL to connect to */
serverUrlHash: string
/** Port for the OAuth callback server */
callbackPort?: number
/** Path for the OAuth callback endpoint */
callbackPath?: string
/** Directory to store OAuth credentials */
configDir?: string
/** Client name to use for OAuth registration */
clientName?: string
/** Client URI to use for OAuth registration */
clientUri?: string
}

View File

@ -79,3 +79,7 @@ export function getFilesDir() {
export function getConfigDir() { export function getConfigDir() {
return path.join(os.homedir(), '.cherrystudio', 'config') return path.join(os.homedir(), '.cherrystudio', 'config')
} }
export function getAppConfigDir(name: string) {
return path.join(getConfigDir(), name)
}