From 56af85cc3e73f092cbb8f1657e0920f9e3f16a82 Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Fri, 13 Sep 2024 09:57:27 +0800 Subject: [PATCH] feat: add generate to ai provider api --- .eslintrc.cjs | 3 ++- src/renderer/src/providers/AiProvider.ts | 6 ++++- .../src/providers/AnthropicProvider.ts | 23 +++++++++++++++++-- src/renderer/src/providers/BaseProvider.ts | 3 ++- src/renderer/src/providers/GeminiProvider.ts | 14 ++++++++++- src/renderer/src/providers/OpenAIProvider.ts | 17 +++++++++++++- src/renderer/src/services/api.ts | 17 ++++++++++++++ 7 files changed, 76 insertions(+), 7 deletions(-) diff --git a/.eslintrc.cjs b/.eslintrc.cjs index 3b92776f..79a247e6 100644 --- a/.eslintrc.cjs +++ b/.eslintrc.cjs @@ -15,6 +15,7 @@ module.exports = { '@typescript-eslint/no-non-null-asserted-optional-chain': 'off', 'react/prop-types': 'off', 'simple-import-sort/imports': 'error', - 'simple-import-sort/exports': 'error' + 'simple-import-sort/exports': 'error', + 'react/no-is-mounted': 'off' } } diff --git a/src/renderer/src/providers/AiProvider.ts b/src/renderer/src/providers/AiProvider.ts index b986e215..f2ea3248 100644 --- a/src/renderer/src/providers/AiProvider.ts +++ b/src/renderer/src/providers/AiProvider.ts @@ -22,7 +22,7 @@ export default class AiProvider { return this.sdk.translate(message, assistant) } - public async summaries(messages: Message[], assistant: Assistant): Promise { + public async summaries(messages: Message[], assistant: Assistant): Promise { return this.sdk.summaries(messages, assistant) } @@ -30,6 +30,10 @@ export default class AiProvider { return this.sdk.suggestions(messages, assistant) } + public async generate({ prompt, content }: { prompt: string; content: string }): Promise { + return this.sdk.generate({ prompt, content }) + } + public async check(): Promise<{ valid: boolean; error: Error | null }> { return this.sdk.check() } diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index dabcec1e..4bc4ef63 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -90,7 +90,7 @@ export default class AnthropicProvider extends BaseProvider { return response.content[0].type === 'text' ? response.content[0].text : '' } - public async summaries(messages: Message[], assistant: Assistant): Promise { + public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() const userMessages = takeRight(messages, 5).map((message) => ({ @@ -115,7 +115,26 @@ export default class AnthropicProvider extends BaseProvider { max_tokens: 4096 }) - return message.content[0].type === 'text' ? message.content[0].text : null + return message.content[0].type === 'text' ? message.content[0].text : '' + } + + public async generate({ prompt, content }: { prompt: string; content: string }): Promise { + const model = getDefaultModel() + + const message = await this.sdk.messages.create({ + messages: [ + { + role: 'user', + content + } + ], + model: model.id, + system: prompt, + stream: false, + max_tokens: 4096 + }) + + return message.content[0].type === 'text' ? message.content[0].text : '' } public async suggestions(): Promise { diff --git a/src/renderer/src/providers/BaseProvider.ts b/src/renderer/src/providers/BaseProvider.ts index c1d83de2..ec1e4db1 100644 --- a/src/renderer/src/providers/BaseProvider.ts +++ b/src/renderer/src/providers/BaseProvider.ts @@ -26,8 +26,9 @@ export default abstract class BaseProvider { onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void ): Promise abstract translate(message: Message, assistant: Assistant): Promise - abstract summaries(messages: Message[], assistant: Assistant): Promise + abstract summaries(messages: Message[], assistant: Assistant): Promise abstract suggestions(messages: Message[], assistant: Assistant): Promise + abstract generate({ prompt, content }: { prompt: string; content: string }): Promise abstract check(): Promise<{ valid: boolean; error: Error | null }> abstract models(): Promise } diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index ffbe9762..4b09d45c 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -85,7 +85,7 @@ export default class GeminiProvider extends BaseProvider { return response.text() } - public async summaries(messages: Message[], assistant: Assistant): Promise { + public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() const userMessages = takeRight(messages, 5).map((message) => ({ @@ -120,6 +120,18 @@ export default class GeminiProvider extends BaseProvider { return response.text() } + public async generate({ prompt, content }: { prompt: string; content: string }): Promise { + const model = getDefaultModel() + const systemMessage = { role: 'system', content: prompt } + + const geminiModel = this.sdk.getGenerativeModel({ model: model.id }) + + const chat = await geminiModel.startChat({ systemInstruction: systemMessage.content }) + const { response } = await chat.sendMessage(content) + + return response.text() + } + public async suggestions(): Promise { return [] } diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index a97b32fd..1f229d48 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -103,7 +103,7 @@ export default class OpenAIProvider extends BaseProvider { return response.choices[0].message?.content || '' } - public async summaries(messages: Message[], assistant: Assistant): Promise { + public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() const userMessages = takeRight(messages, 5).map((message) => ({ @@ -128,6 +128,21 @@ export default class OpenAIProvider extends BaseProvider { return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '') } + public async generate({ prompt, content }: { prompt: string; content: string }): Promise { + const model = getDefaultModel() + + const response = await this.sdk.chat.completions.create({ + model: model.id, + stream: false, + messages: [ + { role: 'user', content }, + { role: 'system', content: prompt } + ] + }) + + return response.choices[0].message?.content || '' + } + async suggestions(messages: Message[], assistant: Assistant): Promise { const model = assistant.model diff --git a/src/renderer/src/services/api.ts b/src/renderer/src/services/api.ts index ceeeeddd..dacffe97 100644 --- a/src/renderer/src/services/api.ts +++ b/src/renderer/src/services/api.ts @@ -129,6 +129,23 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages: } } +export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise { + const model = getDefaultModel() + const provider = getProviderByModel(model) + + if (!hasApiKey(provider)) { + return '' + } + + const AI = new AiProvider(provider) + + try { + return await AI.generate({ prompt, content }) + } catch (error: any) { + return '' + } +} + export async function fetchSuggestions({ messages, assistant