diff --git a/src/renderer/src/providers/BaseProvider.ts b/src/renderer/src/providers/BaseProvider.ts index 77337b77..bdb15c93 100644 --- a/src/renderer/src/providers/BaseProvider.ts +++ b/src/renderer/src/providers/BaseProvider.ts @@ -17,6 +17,12 @@ export default abstract class BaseProvider { return host.endsWith('/') ? host : `${host}/v1/` } + public getHeaders() { + return { + 'X-Api-Key': this.provider.apiKey + } + } + public get keepAliveTime() { return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined } diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 4c3f5c66..d2a77b2c 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -138,16 +138,21 @@ export default class OpenAIProvider extends BaseProvider { const isSupportStreamOutput = streamOutput && this.isSupportStreamOutput(model.id) // @ts-ignore key is not typed - const stream = await this.sdk.chat.completions.create({ - model: model.id, - messages: [isOpenAIo1 ? undefined : systemMessage, ...userMessages].filter( - Boolean - ) as ChatCompletionMessageParam[], - temperature: isOpenAIo1 ? 1 : assistant?.settings?.temperature, - max_tokens: maxTokens, - keep_alive: this.keepAliveTime, - stream: isSupportStreamOutput - }) + const stream = await this.sdk.chat.completions.create( + { + model: model.id, + messages: [isOpenAIo1 ? undefined : systemMessage, ...userMessages].filter( + Boolean + ) as ChatCompletionMessageParam[], + temperature: isOpenAIo1 ? 1 : assistant?.settings?.temperature, + max_tokens: maxTokens, + keep_alive: this.keepAliveTime, + stream: isSupportStreamOutput + }, + { + headers: this.getHeaders() + } + ) if (!isSupportStreamOutput) { return onChunk({ @@ -177,12 +182,17 @@ export default class OpenAIProvider extends BaseProvider { ] // @ts-ignore key is not typed - const response = await this.sdk.chat.completions.create({ - model: model.id, - messages: messages as ChatCompletionMessageParam[], - stream: false, - keep_alive: this.keepAliveTime - }) + const response = await this.sdk.chat.completions.create( + { + model: model.id, + messages: messages as ChatCompletionMessageParam[], + stream: false, + keep_alive: this.keepAliveTime + }, + { + headers: this.getHeaders() + } + ) return response.choices[0].message?.content || '' } @@ -213,13 +223,18 @@ export default class OpenAIProvider extends BaseProvider { } // @ts-ignore key is not typed - const response = await this.sdk.chat.completions.create({ - model: model.id, - messages: [systemMessage, userMessage] as ChatCompletionMessageParam[], - stream: false, - keep_alive: this.keepAliveTime, - max_tokens: 1000 - }) + const response = await this.sdk.chat.completions.create( + { + model: model.id, + messages: [systemMessage, userMessage] as ChatCompletionMessageParam[], + stream: false, + keep_alive: this.keepAliveTime, + max_tokens: 1000 + }, + { + headers: this.getHeaders() + } + ) return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '') } @@ -227,14 +242,19 @@ export default class OpenAIProvider extends BaseProvider { public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { const model = getDefaultModel() - const response = await this.sdk.chat.completions.create({ - model: model.id, - stream: false, - messages: [ - { role: 'system', content: prompt }, - { role: 'user', content } - ] - }) + const response = await this.sdk.chat.completions.create( + { + model: model.id, + stream: false, + messages: [ + { role: 'system', content: prompt }, + { role: 'user', content } + ] + }, + { + headers: this.getHeaders() + } + ) return response.choices[0].message?.content || '' } @@ -249,6 +269,7 @@ export default class OpenAIProvider extends BaseProvider { const response: any = await this.sdk.request({ method: 'post', path: '/advice_questions', + headers: this.getHeaders(), body: { messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })), model: model.id, @@ -272,7 +293,9 @@ export default class OpenAIProvider extends BaseProvider { } try { - const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) + const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming, { + headers: this.getHeaders() + }) return { valid: Boolean(response?.choices[0].message), @@ -294,7 +317,7 @@ export default class OpenAIProvider extends BaseProvider { query.type = 'text' } - const response = await this.sdk.models.list({ query }) + const response = await this.sdk.models.list({ query, headers: this.getHeaders() }) if (this.provider.id === 'github') { // @ts-ignore key is not typed