Feat/mcp oauth (#4837)
* feat: implement OAuth client provider and lockfile management * feat: implement OAuth callback server and refactor authentication flow * fix(McpService): restrict command handling to 'npx' for improved clarity * refactor: make callbackPort optional in OAuthProviderOptions and clean up MCPService * refactor: restructure OAuth handling by creating separate callback and provider classes, and remove unused utility functions
This commit is contained in:
parent
ac0fe75078
commit
35c50b54a8
@ -1,3 +1,4 @@
|
||||
import crypto from 'node:crypto'
|
||||
import fs from 'node:fs'
|
||||
import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
@ -22,9 +23,12 @@ import {
|
||||
} from '@types'
|
||||
import { app } from 'electron'
|
||||
import Logger from 'electron-log'
|
||||
import { EventEmitter } from 'events'
|
||||
import { memoize } from 'lodash'
|
||||
|
||||
import { CacheService } from './CacheService'
|
||||
import { CallBackServer } from './mcp/oauth/callback'
|
||||
import { McpOAuthClientProvider } from './mcp/oauth/provider'
|
||||
import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient'
|
||||
|
||||
// Generic type for caching wrapped functions
|
||||
@ -117,9 +121,17 @@ class McpService {
|
||||
|
||||
const args = [...(server.args || [])]
|
||||
|
||||
let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||
const authProvider = new McpOAuthClientProvider({
|
||||
serverUrlHash: crypto
|
||||
.createHash('md5')
|
||||
.update(server.baseUrl || '')
|
||||
.digest('hex')
|
||||
})
|
||||
|
||||
try {
|
||||
const initTransport = async (): Promise<
|
||||
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||
> => {
|
||||
// Create appropriate transport based on configuration
|
||||
if (server.type === 'inMemory') {
|
||||
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
|
||||
@ -134,29 +146,31 @@ class McpService {
|
||||
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
||||
}
|
||||
// set the client transport to the client
|
||||
transport = clientTransport
|
||||
return clientTransport
|
||||
} else if (server.baseUrl) {
|
||||
if (server.type === 'streamableHttp') {
|
||||
const options: StreamableHTTPClientTransportOptions = {
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
}
|
||||
transport = new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
||||
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
||||
} else if (server.type === 'sse') {
|
||||
const options: SSEClientTransportOptions = {
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
}
|
||||
transport = new SSEClientTransport(new URL(server.baseUrl!), options)
|
||||
return new SSEClientTransport(new URL(server.baseUrl!), options)
|
||||
} else {
|
||||
throw new Error('Invalid server type')
|
||||
}
|
||||
} else if (server.command) {
|
||||
let cmd = server.command
|
||||
|
||||
if (server.command === 'npx' || server.command === 'bun' || server.command === 'bunx') {
|
||||
if (server.command === 'npx') {
|
||||
cmd = await getBinaryPath('bun')
|
||||
Logger.info(`[MCP] Using command: ${cmd}`)
|
||||
|
||||
@ -196,7 +210,7 @@ class McpService {
|
||||
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
|
||||
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||
|
||||
transport = new StdioClientTransport({
|
||||
const stdioTransport = new StdioClientTransport({
|
||||
command: cmd,
|
||||
args,
|
||||
env: {
|
||||
@ -206,14 +220,72 @@ class McpService {
|
||||
},
|
||||
stderr: 'pipe'
|
||||
})
|
||||
transport.stderr?.on('data', (data) =>
|
||||
stdioTransport.stderr?.on('data', (data) =>
|
||||
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
|
||||
)
|
||||
return stdioTransport
|
||||
} else {
|
||||
throw new Error('Either baseUrl or command must be provided')
|
||||
}
|
||||
}
|
||||
|
||||
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
|
||||
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
|
||||
// Create an event emitter for the OAuth callback
|
||||
const events = new EventEmitter()
|
||||
|
||||
// Create a callback server
|
||||
const callbackServer = new CallBackServer({
|
||||
port: authProvider.config.callbackPort,
|
||||
path: authProvider.config.callbackPath || '/oauth/callback',
|
||||
events
|
||||
})
|
||||
|
||||
// Set a timeout to close the callback server
|
||||
const timeoutId = setTimeout(() => {
|
||||
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
|
||||
callbackServer.close()
|
||||
}, 300000) // 5 minutes timeout
|
||||
|
||||
try {
|
||||
// Wait for the authorization code
|
||||
const authCode = await callbackServer.waitForAuthCode()
|
||||
Logger.info(`[MCP] Received auth code: ${authCode}`)
|
||||
|
||||
// Complete the OAuth flow
|
||||
await transport.finishAuth(authCode)
|
||||
|
||||
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
|
||||
|
||||
const newTransport = await initTransport()
|
||||
// Try to connect again
|
||||
await client.connect(newTransport)
|
||||
|
||||
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
|
||||
} catch (oauthError) {
|
||||
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
|
||||
throw new Error(
|
||||
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
|
||||
)
|
||||
} finally {
|
||||
// Clear the timeout and close the callback server
|
||||
clearTimeout(timeoutId)
|
||||
callbackServer.close()
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
const transport = await initTransport()
|
||||
try {
|
||||
await client.connect(transport)
|
||||
} catch (error: Error | any) {
|
||||
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
|
||||
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
|
||||
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// Store the new client in the cache
|
||||
this.clients.set(serverKey, client)
|
||||
@ -537,15 +609,15 @@ class McpService {
|
||||
})
|
||||
|
||||
let path = ''
|
||||
child.stdout.on('data', (data) => {
|
||||
child.stdout.on('data', (data: Buffer) => {
|
||||
path += data.toString()
|
||||
})
|
||||
|
||||
child.stderr.on('data', (data) => {
|
||||
child.stderr.on('data', (data: Buffer) => {
|
||||
console.error('Error getting PATH:', data.toString())
|
||||
})
|
||||
|
||||
child.on('close', (code) => {
|
||||
child.on('close', (code: number) => {
|
||||
if (code === 0) {
|
||||
const trimmedPath = path.trim()
|
||||
resolve(trimmedPath)
|
||||
|
||||
76
src/main/services/mcp/oauth/callback.ts
Normal file
76
src/main/services/mcp/oauth/callback.ts
Normal file
@ -0,0 +1,76 @@
|
||||
import Logger from 'electron-log'
|
||||
import EventEmitter from 'events'
|
||||
import http from 'http'
|
||||
import { URL } from 'url'
|
||||
|
||||
import { OAuthCallbackServerOptions } from './types'
|
||||
|
||||
export class CallBackServer {
|
||||
private server: Promise<http.Server>
|
||||
private events: EventEmitter
|
||||
|
||||
constructor(options: OAuthCallbackServerOptions) {
|
||||
const { port, path, events } = options
|
||||
this.events = events
|
||||
this.server = this.initialize(port, path)
|
||||
}
|
||||
|
||||
initialize(port: number, path: string): Promise<http.Server> {
|
||||
const server = http.createServer((req, res) => {
|
||||
// Only handle requests to the callback path
|
||||
if (req.url?.startsWith(path)) {
|
||||
try {
|
||||
// Parse the URL to extract the authorization code
|
||||
const url = new URL(req.url, `http://localhost:${port}`)
|
||||
const code = url.searchParams.get('code')
|
||||
if (code) {
|
||||
// Emit the code event
|
||||
this.events.emit('auth-code-received', code)
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error('Error processing OAuth callback:', error)
|
||||
res.writeHead(500, { 'Content-Type': 'text/plain' })
|
||||
res.end('Internal Server Error')
|
||||
}
|
||||
} else {
|
||||
// Not a callback request
|
||||
res.writeHead(404, { 'Content-Type': 'text/plain' })
|
||||
res.end('Not Found')
|
||||
}
|
||||
})
|
||||
|
||||
// Handle server errors
|
||||
server.on('error', (error) => {
|
||||
Logger.error('OAuth callback server error:', error)
|
||||
})
|
||||
|
||||
const runningServer = new Promise<http.Server>((resolve, reject) => {
|
||||
server.listen(port, () => {
|
||||
Logger.info(`OAuth callback server listening on port ${port}`)
|
||||
resolve(server)
|
||||
})
|
||||
|
||||
server.on('error', (error) => {
|
||||
reject(error)
|
||||
})
|
||||
})
|
||||
return runningServer
|
||||
}
|
||||
|
||||
get getServer(): Promise<http.Server> {
|
||||
return this.server
|
||||
}
|
||||
|
||||
async close() {
|
||||
const server = await this.server
|
||||
server.close()
|
||||
}
|
||||
|
||||
async waitForAuthCode(): Promise<string> {
|
||||
return new Promise((resolve) => {
|
||||
this.events.once('auth-code-received', (code) => {
|
||||
resolve(code)
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
78
src/main/services/mcp/oauth/provider.ts
Normal file
78
src/main/services/mcp/oauth/provider.ts
Normal file
@ -0,0 +1,78 @@
|
||||
import path from 'node:path'
|
||||
|
||||
import { getConfigDir } from '@main/utils/file'
|
||||
import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth'
|
||||
import { OAuthClientInformation, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth'
|
||||
import Logger from 'electron-log'
|
||||
import open from 'open'
|
||||
|
||||
import { JsonFileStorage } from './storage'
|
||||
import { OAuthProviderOptions } from './types'
|
||||
|
||||
export class McpOAuthClientProvider implements OAuthClientProvider {
|
||||
private storage: JsonFileStorage
|
||||
public readonly config: Required<OAuthProviderOptions>
|
||||
|
||||
constructor(options: OAuthProviderOptions) {
|
||||
const configDir = path.join(getConfigDir(), 'mcp', 'oauth')
|
||||
this.config = {
|
||||
serverUrlHash: options.serverUrlHash,
|
||||
callbackPort: options.callbackPort || 12346,
|
||||
callbackPath: options.callbackPath || '/oauth/callback',
|
||||
configDir: options.configDir || configDir,
|
||||
clientName: options.clientName || 'Cherry Studio',
|
||||
clientUri: options.clientUri || 'https://github.com/CherryHQ/cherry-studio'
|
||||
}
|
||||
this.storage = new JsonFileStorage(this.config.serverUrlHash, this.config.configDir)
|
||||
}
|
||||
|
||||
get redirectUrl(): string {
|
||||
return `http://localhost:${this.config.callbackPort}${this.config.callbackPath}`
|
||||
}
|
||||
|
||||
get clientMetadata() {
|
||||
return {
|
||||
redirect_uris: [this.redirectUrl],
|
||||
token_endpoint_auth_method: 'none',
|
||||
grant_types: ['authorization_code', 'refresh_token'],
|
||||
response_types: ['code'],
|
||||
client_name: this.config.clientName,
|
||||
client_uri: this.config.clientUri
|
||||
}
|
||||
}
|
||||
|
||||
async clientInformation(): Promise<OAuthClientInformation | undefined> {
|
||||
return this.storage.getClientInformation()
|
||||
}
|
||||
|
||||
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
|
||||
await this.storage.saveClientInformation(info)
|
||||
}
|
||||
|
||||
async tokens(): Promise<OAuthTokens | undefined> {
|
||||
return this.storage.getTokens()
|
||||
}
|
||||
|
||||
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
||||
await this.storage.saveTokens(tokens)
|
||||
}
|
||||
|
||||
async redirectToAuthorization(authorizationUrl: URL): Promise<void> {
|
||||
try {
|
||||
// Open the browser to the authorization URL
|
||||
await open(authorizationUrl.toString())
|
||||
Logger.info('Browser opened automatically.')
|
||||
} catch (error) {
|
||||
Logger.error('Could not open browser automatically.')
|
||||
throw error // Let caller handle the error
|
||||
}
|
||||
}
|
||||
|
||||
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
||||
await this.storage.saveCodeVerifier(codeVerifier)
|
||||
}
|
||||
|
||||
async codeVerifier(): Promise<string> {
|
||||
return this.storage.getCodeVerifier()
|
||||
}
|
||||
}
|
||||
120
src/main/services/mcp/oauth/storage.ts
Normal file
120
src/main/services/mcp/oauth/storage.ts
Normal file
@ -0,0 +1,120 @@
|
||||
import {
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthTokens
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||
import Logger from 'electron-log'
|
||||
import fs from 'fs/promises'
|
||||
import path from 'path'
|
||||
|
||||
import { IOAuthStorage, OAuthStorageData, OAuthStorageSchema } from './types'
|
||||
|
||||
export class JsonFileStorage implements IOAuthStorage {
|
||||
private readonly filePath: string
|
||||
private cache: OAuthStorageData | null = null
|
||||
|
||||
constructor(
|
||||
readonly serverUrlHash: string,
|
||||
configDir: string
|
||||
) {
|
||||
this.filePath = path.join(configDir, `${serverUrlHash}_oauth.json`)
|
||||
}
|
||||
|
||||
private async readStorage(): Promise<OAuthStorageData> {
|
||||
if (this.cache) {
|
||||
return this.cache
|
||||
}
|
||||
|
||||
try {
|
||||
const data = await fs.readFile(this.filePath, 'utf-8')
|
||||
const parsed = JSON.parse(data)
|
||||
const validated = OAuthStorageSchema.parse(parsed)
|
||||
this.cache = validated
|
||||
return validated
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && error.code === 'ENOENT') {
|
||||
// File doesn't exist, return initial state
|
||||
const initial: OAuthStorageData = { lastUpdated: Date.now() }
|
||||
await this.writeStorage(initial)
|
||||
return initial
|
||||
}
|
||||
Logger.error('Error reading OAuth storage:', error)
|
||||
throw new Error(`Failed to read OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
|
||||
}
|
||||
}
|
||||
|
||||
private async writeStorage(data: OAuthStorageData): Promise<void> {
|
||||
try {
|
||||
// Ensure directory exists
|
||||
await fs.mkdir(path.dirname(this.filePath), { recursive: true })
|
||||
|
||||
// Update timestamp
|
||||
data.lastUpdated = Date.now()
|
||||
|
||||
// Write file atomically
|
||||
const tempPath = `${this.filePath}.tmp`
|
||||
await fs.writeFile(tempPath, JSON.stringify(data, null, 2))
|
||||
await fs.rename(tempPath, this.filePath)
|
||||
|
||||
// Update cache
|
||||
this.cache = data
|
||||
} catch (error) {
|
||||
Logger.error('Error writing OAuth storage:', error)
|
||||
throw new Error(`Failed to write OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
|
||||
}
|
||||
}
|
||||
|
||||
async getClientInformation(): Promise<OAuthClientInformation | undefined> {
|
||||
const data = await this.readStorage()
|
||||
return data.clientInfo
|
||||
}
|
||||
|
||||
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
|
||||
const data = await this.readStorage()
|
||||
await this.writeStorage({
|
||||
...data,
|
||||
clientInfo: info
|
||||
})
|
||||
}
|
||||
|
||||
async getTokens(): Promise<OAuthTokens | undefined> {
|
||||
const data = await this.readStorage()
|
||||
return data.tokens
|
||||
}
|
||||
|
||||
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
||||
const data = await this.readStorage()
|
||||
await this.writeStorage({
|
||||
...data,
|
||||
tokens
|
||||
})
|
||||
}
|
||||
|
||||
async getCodeVerifier(): Promise<string> {
|
||||
const data = await this.readStorage()
|
||||
if (!data.codeVerifier) {
|
||||
throw new Error('No code verifier saved for session')
|
||||
}
|
||||
return data.codeVerifier
|
||||
}
|
||||
|
||||
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
||||
const data = await this.readStorage()
|
||||
await this.writeStorage({
|
||||
...data,
|
||||
codeVerifier
|
||||
})
|
||||
}
|
||||
|
||||
async clear(): Promise<void> {
|
||||
try {
|
||||
await fs.unlink(this.filePath)
|
||||
this.cache = null
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && error.code !== 'ENOENT') {
|
||||
Logger.error('Error clearing OAuth storage:', error)
|
||||
throw new Error(`Failed to clear OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
61
src/main/services/mcp/oauth/types.ts
Normal file
61
src/main/services/mcp/oauth/types.ts
Normal file
@ -0,0 +1,61 @@
|
||||
import {
|
||||
OAuthClientInformation,
|
||||
OAuthClientInformationFull,
|
||||
OAuthTokens
|
||||
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||
import EventEmitter from 'events'
|
||||
import { z } from 'zod'
|
||||
|
||||
export interface OAuthStorageData {
|
||||
clientInfo?: OAuthClientInformation
|
||||
tokens?: OAuthTokens
|
||||
codeVerifier?: string
|
||||
lastUpdated: number
|
||||
}
|
||||
|
||||
export const OAuthStorageSchema = z.object({
|
||||
clientInfo: z.any().optional(),
|
||||
tokens: z.any().optional(),
|
||||
codeVerifier: z.string().optional(),
|
||||
lastUpdated: z.number()
|
||||
})
|
||||
|
||||
export interface IOAuthStorage {
|
||||
getClientInformation(): Promise<OAuthClientInformation | undefined>
|
||||
saveClientInformation(info: OAuthClientInformationFull): Promise<void>
|
||||
getTokens(): Promise<OAuthTokens | undefined>
|
||||
saveTokens(tokens: OAuthTokens): Promise<void>
|
||||
getCodeVerifier(): Promise<string>
|
||||
saveCodeVerifier(codeVerifier: string): Promise<void>
|
||||
clear(): Promise<void>
|
||||
}
|
||||
|
||||
/**
|
||||
* OAuth callback server setup options
|
||||
*/
|
||||
export interface OAuthCallbackServerOptions {
|
||||
/** Port for the callback server */
|
||||
port: number
|
||||
/** Path for the callback endpoint */
|
||||
path: string
|
||||
/** Event emitter to signal when auth code is received */
|
||||
events: EventEmitter
|
||||
}
|
||||
|
||||
/**
|
||||
* Options for creating an OAuth client provider
|
||||
*/
|
||||
export interface OAuthProviderOptions {
|
||||
/** Server URL to connect to */
|
||||
serverUrlHash: string
|
||||
/** Port for the OAuth callback server */
|
||||
callbackPort?: number
|
||||
/** Path for the OAuth callback endpoint */
|
||||
callbackPath?: string
|
||||
/** Directory to store OAuth credentials */
|
||||
configDir?: string
|
||||
/** Client name to use for OAuth registration */
|
||||
clientName?: string
|
||||
/** Client URI to use for OAuth registration */
|
||||
clientUri?: string
|
||||
}
|
||||
@ -79,3 +79,7 @@ export function getFilesDir() {
|
||||
export function getConfigDir() {
|
||||
return path.join(os.homedir(), '.cherrystudio', 'config')
|
||||
}
|
||||
|
||||
export function getAppConfigDir(name: string) {
|
||||
return path.join(getConfigDir(), name)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user