Feat/mcp oauth (#4837)
* feat: implement OAuth client provider and lockfile management * feat: implement OAuth callback server and refactor authentication flow * fix(McpService): restrict command handling to 'npx' for improved clarity * refactor: make callbackPort optional in OAuthProviderOptions and clean up MCPService * refactor: restructure OAuth handling by creating separate callback and provider classes, and remove unused utility functions
This commit is contained in:
parent
ac0fe75078
commit
35c50b54a8
@ -1,3 +1,4 @@
|
|||||||
|
import crypto from 'node:crypto'
|
||||||
import fs from 'node:fs'
|
import fs from 'node:fs'
|
||||||
import os from 'node:os'
|
import os from 'node:os'
|
||||||
import path from 'node:path'
|
import path from 'node:path'
|
||||||
@ -22,9 +23,12 @@ import {
|
|||||||
} from '@types'
|
} from '@types'
|
||||||
import { app } from 'electron'
|
import { app } from 'electron'
|
||||||
import Logger from 'electron-log'
|
import Logger from 'electron-log'
|
||||||
|
import { EventEmitter } from 'events'
|
||||||
import { memoize } from 'lodash'
|
import { memoize } from 'lodash'
|
||||||
|
|
||||||
import { CacheService } from './CacheService'
|
import { CacheService } from './CacheService'
|
||||||
|
import { CallBackServer } from './mcp/oauth/callback'
|
||||||
|
import { McpOAuthClientProvider } from './mcp/oauth/provider'
|
||||||
import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient'
|
import { StreamableHTTPClientTransport, type StreamableHTTPClientTransportOptions } from './MCPStreamableHttpClient'
|
||||||
|
|
||||||
// Generic type for caching wrapped functions
|
// Generic type for caching wrapped functions
|
||||||
@ -117,9 +121,17 @@ class McpService {
|
|||||||
|
|
||||||
const args = [...(server.args || [])]
|
const args = [...(server.args || [])]
|
||||||
|
|
||||||
let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
// let transport: StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||||
|
const authProvider = new McpOAuthClientProvider({
|
||||||
|
serverUrlHash: crypto
|
||||||
|
.createHash('md5')
|
||||||
|
.update(server.baseUrl || '')
|
||||||
|
.digest('hex')
|
||||||
|
})
|
||||||
|
|
||||||
try {
|
const initTransport = async (): Promise<
|
||||||
|
StdioClientTransport | SSEClientTransport | InMemoryTransport | StreamableHTTPClientTransport
|
||||||
|
> => {
|
||||||
// Create appropriate transport based on configuration
|
// Create appropriate transport based on configuration
|
||||||
if (server.type === 'inMemory') {
|
if (server.type === 'inMemory') {
|
||||||
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
|
Logger.info(`[MCP] Using in-memory transport for server: ${server.name}`)
|
||||||
@ -134,29 +146,31 @@ class McpService {
|
|||||||
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
throw new Error(`Failed to start in-memory server: ${error.message}`)
|
||||||
}
|
}
|
||||||
// set the client transport to the client
|
// set the client transport to the client
|
||||||
transport = clientTransport
|
return clientTransport
|
||||||
} else if (server.baseUrl) {
|
} else if (server.baseUrl) {
|
||||||
if (server.type === 'streamableHttp') {
|
if (server.type === 'streamableHttp') {
|
||||||
const options: StreamableHTTPClientTransportOptions = {
|
const options: StreamableHTTPClientTransportOptions = {
|
||||||
requestInit: {
|
requestInit: {
|
||||||
headers: server.headers || {}
|
headers: server.headers || {}
|
||||||
}
|
},
|
||||||
|
authProvider
|
||||||
}
|
}
|
||||||
transport = new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
return new StreamableHTTPClientTransport(new URL(server.baseUrl!), options)
|
||||||
} else if (server.type === 'sse') {
|
} else if (server.type === 'sse') {
|
||||||
const options: SSEClientTransportOptions = {
|
const options: SSEClientTransportOptions = {
|
||||||
requestInit: {
|
requestInit: {
|
||||||
headers: server.headers || {}
|
headers: server.headers || {}
|
||||||
}
|
},
|
||||||
|
authProvider
|
||||||
}
|
}
|
||||||
transport = new SSEClientTransport(new URL(server.baseUrl!), options)
|
return new SSEClientTransport(new URL(server.baseUrl!), options)
|
||||||
} else {
|
} else {
|
||||||
throw new Error('Invalid server type')
|
throw new Error('Invalid server type')
|
||||||
}
|
}
|
||||||
} else if (server.command) {
|
} else if (server.command) {
|
||||||
let cmd = server.command
|
let cmd = server.command
|
||||||
|
|
||||||
if (server.command === 'npx' || server.command === 'bun' || server.command === 'bunx') {
|
if (server.command === 'npx') {
|
||||||
cmd = await getBinaryPath('bun')
|
cmd = await getBinaryPath('bun')
|
||||||
Logger.info(`[MCP] Using command: ${cmd}`)
|
Logger.info(`[MCP] Using command: ${cmd}`)
|
||||||
|
|
||||||
@ -196,7 +210,7 @@ class McpService {
|
|||||||
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
|
Logger.info(`[MCP] Starting server with command: ${cmd} ${args ? args.join(' ') : ''}`)
|
||||||
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
// Logger.info(`[MCP] Environment variables for server:`, server.env)
|
||||||
|
|
||||||
transport = new StdioClientTransport({
|
const stdioTransport = new StdioClientTransport({
|
||||||
command: cmd,
|
command: cmd,
|
||||||
args,
|
args,
|
||||||
env: {
|
env: {
|
||||||
@ -206,14 +220,72 @@ class McpService {
|
|||||||
},
|
},
|
||||||
stderr: 'pipe'
|
stderr: 'pipe'
|
||||||
})
|
})
|
||||||
transport.stderr?.on('data', (data) =>
|
stdioTransport.stderr?.on('data', (data) =>
|
||||||
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
|
Logger.info(`[MCP] Stdio stderr for server: ${server.name} `, data.toString())
|
||||||
)
|
)
|
||||||
|
return stdioTransport
|
||||||
} else {
|
} else {
|
||||||
throw new Error('Either baseUrl or command must be provided')
|
throw new Error('Either baseUrl or command must be provided')
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
await client.connect(transport)
|
const handleAuth = async (client: Client, transport: SSEClientTransport | StreamableHTTPClientTransport) => {
|
||||||
|
Logger.info(`[MCP] Starting OAuth flow for server: ${server.name}`)
|
||||||
|
// Create an event emitter for the OAuth callback
|
||||||
|
const events = new EventEmitter()
|
||||||
|
|
||||||
|
// Create a callback server
|
||||||
|
const callbackServer = new CallBackServer({
|
||||||
|
port: authProvider.config.callbackPort,
|
||||||
|
path: authProvider.config.callbackPath || '/oauth/callback',
|
||||||
|
events
|
||||||
|
})
|
||||||
|
|
||||||
|
// Set a timeout to close the callback server
|
||||||
|
const timeoutId = setTimeout(() => {
|
||||||
|
Logger.warn(`[MCP] OAuth flow timed out for server: ${server.name}`)
|
||||||
|
callbackServer.close()
|
||||||
|
}, 300000) // 5 minutes timeout
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Wait for the authorization code
|
||||||
|
const authCode = await callbackServer.waitForAuthCode()
|
||||||
|
Logger.info(`[MCP] Received auth code: ${authCode}`)
|
||||||
|
|
||||||
|
// Complete the OAuth flow
|
||||||
|
await transport.finishAuth(authCode)
|
||||||
|
|
||||||
|
Logger.info(`[MCP] OAuth flow completed for server: ${server.name}`)
|
||||||
|
|
||||||
|
const newTransport = await initTransport()
|
||||||
|
// Try to connect again
|
||||||
|
await client.connect(newTransport)
|
||||||
|
|
||||||
|
Logger.info(`[MCP] Successfully authenticated with server: ${server.name}`)
|
||||||
|
} catch (oauthError) {
|
||||||
|
Logger.error(`[MCP] OAuth authentication failed for server ${server.name}:`, oauthError)
|
||||||
|
throw new Error(
|
||||||
|
`OAuth authentication failed: ${oauthError instanceof Error ? oauthError.message : String(oauthError)}`
|
||||||
|
)
|
||||||
|
} finally {
|
||||||
|
// Clear the timeout and close the callback server
|
||||||
|
clearTimeout(timeoutId)
|
||||||
|
callbackServer.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const transport = await initTransport()
|
||||||
|
try {
|
||||||
|
await client.connect(transport)
|
||||||
|
} catch (error: Error | any) {
|
||||||
|
if (error instanceof Error && (error.name === 'UnauthorizedError' || error.message.includes('Unauthorized'))) {
|
||||||
|
Logger.info(`[MCP] Authentication required for server: ${server.name}`)
|
||||||
|
await handleAuth(client, transport as SSEClientTransport | StreamableHTTPClientTransport)
|
||||||
|
} else {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Store the new client in the cache
|
// Store the new client in the cache
|
||||||
this.clients.set(serverKey, client)
|
this.clients.set(serverKey, client)
|
||||||
@ -537,15 +609,15 @@ class McpService {
|
|||||||
})
|
})
|
||||||
|
|
||||||
let path = ''
|
let path = ''
|
||||||
child.stdout.on('data', (data) => {
|
child.stdout.on('data', (data: Buffer) => {
|
||||||
path += data.toString()
|
path += data.toString()
|
||||||
})
|
})
|
||||||
|
|
||||||
child.stderr.on('data', (data) => {
|
child.stderr.on('data', (data: Buffer) => {
|
||||||
console.error('Error getting PATH:', data.toString())
|
console.error('Error getting PATH:', data.toString())
|
||||||
})
|
})
|
||||||
|
|
||||||
child.on('close', (code) => {
|
child.on('close', (code: number) => {
|
||||||
if (code === 0) {
|
if (code === 0) {
|
||||||
const trimmedPath = path.trim()
|
const trimmedPath = path.trim()
|
||||||
resolve(trimmedPath)
|
resolve(trimmedPath)
|
||||||
|
|||||||
76
src/main/services/mcp/oauth/callback.ts
Normal file
76
src/main/services/mcp/oauth/callback.ts
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import Logger from 'electron-log'
|
||||||
|
import EventEmitter from 'events'
|
||||||
|
import http from 'http'
|
||||||
|
import { URL } from 'url'
|
||||||
|
|
||||||
|
import { OAuthCallbackServerOptions } from './types'
|
||||||
|
|
||||||
|
export class CallBackServer {
|
||||||
|
private server: Promise<http.Server>
|
||||||
|
private events: EventEmitter
|
||||||
|
|
||||||
|
constructor(options: OAuthCallbackServerOptions) {
|
||||||
|
const { port, path, events } = options
|
||||||
|
this.events = events
|
||||||
|
this.server = this.initialize(port, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
initialize(port: number, path: string): Promise<http.Server> {
|
||||||
|
const server = http.createServer((req, res) => {
|
||||||
|
// Only handle requests to the callback path
|
||||||
|
if (req.url?.startsWith(path)) {
|
||||||
|
try {
|
||||||
|
// Parse the URL to extract the authorization code
|
||||||
|
const url = new URL(req.url, `http://localhost:${port}`)
|
||||||
|
const code = url.searchParams.get('code')
|
||||||
|
if (code) {
|
||||||
|
// Emit the code event
|
||||||
|
this.events.emit('auth-code-received', code)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error('Error processing OAuth callback:', error)
|
||||||
|
res.writeHead(500, { 'Content-Type': 'text/plain' })
|
||||||
|
res.end('Internal Server Error')
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Not a callback request
|
||||||
|
res.writeHead(404, { 'Content-Type': 'text/plain' })
|
||||||
|
res.end('Not Found')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle server errors
|
||||||
|
server.on('error', (error) => {
|
||||||
|
Logger.error('OAuth callback server error:', error)
|
||||||
|
})
|
||||||
|
|
||||||
|
const runningServer = new Promise<http.Server>((resolve, reject) => {
|
||||||
|
server.listen(port, () => {
|
||||||
|
Logger.info(`OAuth callback server listening on port ${port}`)
|
||||||
|
resolve(server)
|
||||||
|
})
|
||||||
|
|
||||||
|
server.on('error', (error) => {
|
||||||
|
reject(error)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
return runningServer
|
||||||
|
}
|
||||||
|
|
||||||
|
get getServer(): Promise<http.Server> {
|
||||||
|
return this.server
|
||||||
|
}
|
||||||
|
|
||||||
|
async close() {
|
||||||
|
const server = await this.server
|
||||||
|
server.close()
|
||||||
|
}
|
||||||
|
|
||||||
|
async waitForAuthCode(): Promise<string> {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
this.events.once('auth-code-received', (code) => {
|
||||||
|
resolve(code)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
78
src/main/services/mcp/oauth/provider.ts
Normal file
78
src/main/services/mcp/oauth/provider.ts
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
import path from 'node:path'
|
||||||
|
|
||||||
|
import { getConfigDir } from '@main/utils/file'
|
||||||
|
import { OAuthClientProvider } from '@modelcontextprotocol/sdk/client/auth'
|
||||||
|
import { OAuthClientInformation, OAuthClientInformationFull, OAuthTokens } from '@modelcontextprotocol/sdk/shared/auth'
|
||||||
|
import Logger from 'electron-log'
|
||||||
|
import open from 'open'
|
||||||
|
|
||||||
|
import { JsonFileStorage } from './storage'
|
||||||
|
import { OAuthProviderOptions } from './types'
|
||||||
|
|
||||||
|
export class McpOAuthClientProvider implements OAuthClientProvider {
|
||||||
|
private storage: JsonFileStorage
|
||||||
|
public readonly config: Required<OAuthProviderOptions>
|
||||||
|
|
||||||
|
constructor(options: OAuthProviderOptions) {
|
||||||
|
const configDir = path.join(getConfigDir(), 'mcp', 'oauth')
|
||||||
|
this.config = {
|
||||||
|
serverUrlHash: options.serverUrlHash,
|
||||||
|
callbackPort: options.callbackPort || 12346,
|
||||||
|
callbackPath: options.callbackPath || '/oauth/callback',
|
||||||
|
configDir: options.configDir || configDir,
|
||||||
|
clientName: options.clientName || 'Cherry Studio',
|
||||||
|
clientUri: options.clientUri || 'https://github.com/CherryHQ/cherry-studio'
|
||||||
|
}
|
||||||
|
this.storage = new JsonFileStorage(this.config.serverUrlHash, this.config.configDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
get redirectUrl(): string {
|
||||||
|
return `http://localhost:${this.config.callbackPort}${this.config.callbackPath}`
|
||||||
|
}
|
||||||
|
|
||||||
|
get clientMetadata() {
|
||||||
|
return {
|
||||||
|
redirect_uris: [this.redirectUrl],
|
||||||
|
token_endpoint_auth_method: 'none',
|
||||||
|
grant_types: ['authorization_code', 'refresh_token'],
|
||||||
|
response_types: ['code'],
|
||||||
|
client_name: this.config.clientName,
|
||||||
|
client_uri: this.config.clientUri
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async clientInformation(): Promise<OAuthClientInformation | undefined> {
|
||||||
|
return this.storage.getClientInformation()
|
||||||
|
}
|
||||||
|
|
||||||
|
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
|
||||||
|
await this.storage.saveClientInformation(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
async tokens(): Promise<OAuthTokens | undefined> {
|
||||||
|
return this.storage.getTokens()
|
||||||
|
}
|
||||||
|
|
||||||
|
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
||||||
|
await this.storage.saveTokens(tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
async redirectToAuthorization(authorizationUrl: URL): Promise<void> {
|
||||||
|
try {
|
||||||
|
// Open the browser to the authorization URL
|
||||||
|
await open(authorizationUrl.toString())
|
||||||
|
Logger.info('Browser opened automatically.')
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error('Could not open browser automatically.')
|
||||||
|
throw error // Let caller handle the error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
||||||
|
await this.storage.saveCodeVerifier(codeVerifier)
|
||||||
|
}
|
||||||
|
|
||||||
|
async codeVerifier(): Promise<string> {
|
||||||
|
return this.storage.getCodeVerifier()
|
||||||
|
}
|
||||||
|
}
|
||||||
120
src/main/services/mcp/oauth/storage.ts
Normal file
120
src/main/services/mcp/oauth/storage.ts
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
import {
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthClientInformationFull,
|
||||||
|
OAuthTokens
|
||||||
|
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||||
|
import Logger from 'electron-log'
|
||||||
|
import fs from 'fs/promises'
|
||||||
|
import path from 'path'
|
||||||
|
|
||||||
|
import { IOAuthStorage, OAuthStorageData, OAuthStorageSchema } from './types'
|
||||||
|
|
||||||
|
export class JsonFileStorage implements IOAuthStorage {
|
||||||
|
private readonly filePath: string
|
||||||
|
private cache: OAuthStorageData | null = null
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
readonly serverUrlHash: string,
|
||||||
|
configDir: string
|
||||||
|
) {
|
||||||
|
this.filePath = path.join(configDir, `${serverUrlHash}_oauth.json`)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async readStorage(): Promise<OAuthStorageData> {
|
||||||
|
if (this.cache) {
|
||||||
|
return this.cache
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const data = await fs.readFile(this.filePath, 'utf-8')
|
||||||
|
const parsed = JSON.parse(data)
|
||||||
|
const validated = OAuthStorageSchema.parse(parsed)
|
||||||
|
this.cache = validated
|
||||||
|
return validated
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error && 'code' in error && error.code === 'ENOENT') {
|
||||||
|
// File doesn't exist, return initial state
|
||||||
|
const initial: OAuthStorageData = { lastUpdated: Date.now() }
|
||||||
|
await this.writeStorage(initial)
|
||||||
|
return initial
|
||||||
|
}
|
||||||
|
Logger.error('Error reading OAuth storage:', error)
|
||||||
|
throw new Error(`Failed to read OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async writeStorage(data: OAuthStorageData): Promise<void> {
|
||||||
|
try {
|
||||||
|
// Ensure directory exists
|
||||||
|
await fs.mkdir(path.dirname(this.filePath), { recursive: true })
|
||||||
|
|
||||||
|
// Update timestamp
|
||||||
|
data.lastUpdated = Date.now()
|
||||||
|
|
||||||
|
// Write file atomically
|
||||||
|
const tempPath = `${this.filePath}.tmp`
|
||||||
|
await fs.writeFile(tempPath, JSON.stringify(data, null, 2))
|
||||||
|
await fs.rename(tempPath, this.filePath)
|
||||||
|
|
||||||
|
// Update cache
|
||||||
|
this.cache = data
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error('Error writing OAuth storage:', error)
|
||||||
|
throw new Error(`Failed to write OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async getClientInformation(): Promise<OAuthClientInformation | undefined> {
|
||||||
|
const data = await this.readStorage()
|
||||||
|
return data.clientInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
async saveClientInformation(info: OAuthClientInformationFull): Promise<void> {
|
||||||
|
const data = await this.readStorage()
|
||||||
|
await this.writeStorage({
|
||||||
|
...data,
|
||||||
|
clientInfo: info
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async getTokens(): Promise<OAuthTokens | undefined> {
|
||||||
|
const data = await this.readStorage()
|
||||||
|
return data.tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
async saveTokens(tokens: OAuthTokens): Promise<void> {
|
||||||
|
const data = await this.readStorage()
|
||||||
|
await this.writeStorage({
|
||||||
|
...data,
|
||||||
|
tokens
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async getCodeVerifier(): Promise<string> {
|
||||||
|
const data = await this.readStorage()
|
||||||
|
if (!data.codeVerifier) {
|
||||||
|
throw new Error('No code verifier saved for session')
|
||||||
|
}
|
||||||
|
return data.codeVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
async saveCodeVerifier(codeVerifier: string): Promise<void> {
|
||||||
|
const data = await this.readStorage()
|
||||||
|
await this.writeStorage({
|
||||||
|
...data,
|
||||||
|
codeVerifier
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async clear(): Promise<void> {
|
||||||
|
try {
|
||||||
|
await fs.unlink(this.filePath)
|
||||||
|
this.cache = null
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error && 'code' in error && error.code !== 'ENOENT') {
|
||||||
|
Logger.error('Error clearing OAuth storage:', error)
|
||||||
|
throw new Error(`Failed to clear OAuth storage: ${error instanceof Error ? error.message : String(error)}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
61
src/main/services/mcp/oauth/types.ts
Normal file
61
src/main/services/mcp/oauth/types.ts
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
import {
|
||||||
|
OAuthClientInformation,
|
||||||
|
OAuthClientInformationFull,
|
||||||
|
OAuthTokens
|
||||||
|
} from '@modelcontextprotocol/sdk/shared/auth.js'
|
||||||
|
import EventEmitter from 'events'
|
||||||
|
import { z } from 'zod'
|
||||||
|
|
||||||
|
export interface OAuthStorageData {
|
||||||
|
clientInfo?: OAuthClientInformation
|
||||||
|
tokens?: OAuthTokens
|
||||||
|
codeVerifier?: string
|
||||||
|
lastUpdated: number
|
||||||
|
}
|
||||||
|
|
||||||
|
export const OAuthStorageSchema = z.object({
|
||||||
|
clientInfo: z.any().optional(),
|
||||||
|
tokens: z.any().optional(),
|
||||||
|
codeVerifier: z.string().optional(),
|
||||||
|
lastUpdated: z.number()
|
||||||
|
})
|
||||||
|
|
||||||
|
export interface IOAuthStorage {
|
||||||
|
getClientInformation(): Promise<OAuthClientInformation | undefined>
|
||||||
|
saveClientInformation(info: OAuthClientInformationFull): Promise<void>
|
||||||
|
getTokens(): Promise<OAuthTokens | undefined>
|
||||||
|
saveTokens(tokens: OAuthTokens): Promise<void>
|
||||||
|
getCodeVerifier(): Promise<string>
|
||||||
|
saveCodeVerifier(codeVerifier: string): Promise<void>
|
||||||
|
clear(): Promise<void>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OAuth callback server setup options
|
||||||
|
*/
|
||||||
|
export interface OAuthCallbackServerOptions {
|
||||||
|
/** Port for the callback server */
|
||||||
|
port: number
|
||||||
|
/** Path for the callback endpoint */
|
||||||
|
path: string
|
||||||
|
/** Event emitter to signal when auth code is received */
|
||||||
|
events: EventEmitter
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Options for creating an OAuth client provider
|
||||||
|
*/
|
||||||
|
export interface OAuthProviderOptions {
|
||||||
|
/** Server URL to connect to */
|
||||||
|
serverUrlHash: string
|
||||||
|
/** Port for the OAuth callback server */
|
||||||
|
callbackPort?: number
|
||||||
|
/** Path for the OAuth callback endpoint */
|
||||||
|
callbackPath?: string
|
||||||
|
/** Directory to store OAuth credentials */
|
||||||
|
configDir?: string
|
||||||
|
/** Client name to use for OAuth registration */
|
||||||
|
clientName?: string
|
||||||
|
/** Client URI to use for OAuth registration */
|
||||||
|
clientUri?: string
|
||||||
|
}
|
||||||
@ -79,3 +79,7 @@ export function getFilesDir() {
|
|||||||
export function getConfigDir() {
|
export function getConfigDir() {
|
||||||
return path.join(os.homedir(), '.cherrystudio', 'config')
|
return path.join(os.homedir(), '.cherrystudio', 'config')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getAppConfigDir(name: string) {
|
||||||
|
return path.join(getConfigDir(), name)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user