refactor: ProviderSDK

This commit is contained in:
kangfenmao 2024-07-19 15:49:08 +08:00
parent 13c73a3de1
commit 5a636e7614

View File

@ -2,11 +2,7 @@ import { Assistant, Message, Provider } from '@renderer/types'
import OpenAI from 'openai' import OpenAI from 'openai'
import Anthropic from '@anthropic-ai/sdk' import Anthropic from '@anthropic-ai/sdk'
import { getDefaultModel, getTopNamingModel } from './assistant' import { getDefaultModel, getTopNamingModel } from './assistant'
import { import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources'
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam
} from 'openai/resources'
import { sum, takeRight } from 'lodash' import { sum, takeRight } from 'lodash'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
import { EVENT_NAMES } from './event' import { EVENT_NAMES } from './event'
@ -47,14 +43,18 @@ export default class ProviderSDK {
if (this.isAnthropic) { if (this.isAnthropic) {
await this.anthropicSdk.messages await this.anthropicSdk.messages
.stream({ .stream({
max_tokens: 1024, max_tokens: 2048,
messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[], messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[],
model: model.id model: model.id
}) })
.on('text', (text) => onChunk({ text: text || '' })) .on('text', (text) => onChunk({ text: text || '' }))
.on('finalMessage', (message) => .on('finalMessage', (message) =>
onChunk({ onChunk({
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: sum(Object.values(message.usage)) } usage: {
prompt_tokens: message.usage.input_tokens,
completion_tokens: message.usage.output_tokens,
total_tokens: sum(Object.values(message.usage))
}
}) })
) )
} else { } else {
@ -64,9 +64,7 @@ export default class ProviderSDK {
stream: true stream: true
}) })
for await (const chunk of stream) { for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
break
}
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
} }
} }
@ -75,12 +73,12 @@ export default class ProviderSDK {
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> { public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = getTopNamingModel() || assistant.model || getDefaultModel() const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({ const userMessages = takeRight(messages, 5).map((message) => ({
role: 'user', role: 'user',
content: message.content content: message.content
})) }))
const systemMessage: ChatCompletionSystemMessageParam = { const systemMessage = {
role: 'system', role: 'system',
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号' content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号'
} }
@ -97,8 +95,9 @@ export default class ProviderSDK {
} else { } else {
const response = await this.openaiSdk.chat.completions.create({ const response = await this.openaiSdk.chat.completions.create({
model: model.id, model: model.id,
messages: [systemMessage, ...userMessages], messages: [systemMessage, ...userMessages] as ChatCompletionMessageParam[],
stream: false stream: false,
max_tokens: 50
}) })
return removeQuotes(response.choices[0].message?.content || '') return removeQuotes(response.choices[0].message?.content || '')