feat: add generate to ai provider api

This commit is contained in:
kangfenmao 2024-09-13 09:57:27 +08:00
parent 6a1a861ecc
commit 56af85cc3e
7 changed files with 76 additions and 7 deletions

View File

@ -15,6 +15,7 @@ module.exports = {
'@typescript-eslint/no-non-null-asserted-optional-chain': 'off',
'react/prop-types': 'off',
'simple-import-sort/imports': 'error',
'simple-import-sort/exports': 'error'
'simple-import-sort/exports': 'error',
'react/no-is-mounted': 'off'
}
}

View File

@ -22,7 +22,7 @@ export default class AiProvider {
return this.sdk.translate(message, assistant)
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
return this.sdk.summaries(messages, assistant)
}
@ -30,6 +30,10 @@ export default class AiProvider {
return this.sdk.suggestions(messages, assistant)
}
public async generate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
return this.sdk.generate({ prompt, content })
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
return this.sdk.check()
}

View File

@ -90,7 +90,7 @@ export default class AnthropicProvider extends BaseProvider {
return response.content[0].type === 'text' ? response.content[0].text : ''
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
@ -115,7 +115,26 @@ export default class AnthropicProvider extends BaseProvider {
max_tokens: 4096
})
return message.content[0].type === 'text' ? message.content[0].text : null
return message.content[0].type === 'text' ? message.content[0].text : ''
}
public async generate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const message = await this.sdk.messages.create({
messages: [
{
role: 'user',
content
}
],
model: model.id,
system: prompt,
stream: false,
max_tokens: 4096
})
return message.content[0].type === 'text' ? message.content[0].text : ''
}
public async suggestions(): Promise<Suggestion[]> {

View File

@ -26,8 +26,9 @@ export default abstract class BaseProvider {
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
): Promise<void>
abstract translate(message: Message, assistant: Assistant): Promise<string>
abstract summaries(messages: Message[], assistant: Assistant): Promise<string | null>
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
abstract generate({ prompt, content }: { prompt: string; content: string }): Promise<string>
abstract check(): Promise<{ valid: boolean; error: Error | null }>
abstract models(): Promise<OpenAI.Models.Model[]>
}

View File

@ -85,7 +85,7 @@ export default class GeminiProvider extends BaseProvider {
return response.text()
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
@ -120,6 +120,18 @@ export default class GeminiProvider extends BaseProvider {
return response.text()
}
public async generate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const systemMessage = { role: 'system', content: prompt }
const geminiModel = this.sdk.getGenerativeModel({ model: model.id })
const chat = await geminiModel.startChat({ systemInstruction: systemMessage.content })
const { response } = await chat.sendMessage(content)
return response.text()
}
public async suggestions(): Promise<Suggestion[]> {
return []
}

View File

@ -103,7 +103,7 @@ export default class OpenAIProvider extends BaseProvider {
return response.choices[0].message?.content || ''
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
@ -128,6 +128,21 @@ export default class OpenAIProvider extends BaseProvider {
return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '')
}
public async generate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const response = await this.sdk.chat.completions.create({
model: model.id,
stream: false,
messages: [
{ role: 'user', content },
{ role: 'system', content: prompt }
]
})
return response.choices[0].message?.content || ''
}
async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
const model = assistant.model

View File

@ -129,6 +129,23 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
}
}
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return ''
}
const AI = new AiProvider(provider)
try {
return await AI.generate({ prompt, content })
} catch (error: any) {
return ''
}
}
export async function fetchSuggestions({
messages,
assistant