From 5a636e76147c82924a42dfed5c69fd6c2dfce4e7 Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Fri, 19 Jul 2024 15:49:08 +0800 Subject: [PATCH] refactor: ProviderSDK --- src/renderer/src/services/ProviderSDK.ts | 27 ++++++++++++------------ 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/renderer/src/services/ProviderSDK.ts b/src/renderer/src/services/ProviderSDK.ts index 2f15cdc2..a1247dac 100644 --- a/src/renderer/src/services/ProviderSDK.ts +++ b/src/renderer/src/services/ProviderSDK.ts @@ -2,11 +2,7 @@ import { Assistant, Message, Provider } from '@renderer/types' import OpenAI from 'openai' import Anthropic from '@anthropic-ai/sdk' import { getDefaultModel, getTopNamingModel } from './assistant' -import { - ChatCompletionCreateParamsNonStreaming, - ChatCompletionMessageParam, - ChatCompletionSystemMessageParam -} from 'openai/resources' +import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources' import { sum, takeRight } from 'lodash' import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' import { EVENT_NAMES } from './event' @@ -47,14 +43,18 @@ export default class ProviderSDK { if (this.isAnthropic) { await this.anthropicSdk.messages .stream({ - max_tokens: 1024, + max_tokens: 2048, messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[], model: model.id }) .on('text', (text) => onChunk({ text: text || '' })) .on('finalMessage', (message) => onChunk({ - usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: sum(Object.values(message.usage)) } + usage: { + prompt_tokens: message.usage.input_tokens, + completion_tokens: message.usage.output_tokens, + total_tokens: sum(Object.values(message.usage)) + } }) ) } else { @@ -64,9 +64,7 @@ export default class ProviderSDK { stream: true }) for await (const chunk of stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - break - } + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) } } @@ -75,12 +73,12 @@ export default class ProviderSDK { public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() - const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({ + const userMessages = takeRight(messages, 5).map((message) => ({ role: 'user', content: message.content })) - const systemMessage: ChatCompletionSystemMessageParam = { + const systemMessage = { role: 'system', content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号' } @@ -97,8 +95,9 @@ export default class ProviderSDK { } else { const response = await this.openaiSdk.chat.completions.create({ model: model.id, - messages: [systemMessage, ...userMessages], - stream: false + messages: [systemMessage, ...userMessages] as ChatCompletionMessageParam[], + stream: false, + max_tokens: 50 }) return removeQuotes(response.choices[0].message?.content || '')