refactor: ProviderSDK
This commit is contained in:
parent
13c73a3de1
commit
5a636e7614
@ -2,11 +2,7 @@ import { Assistant, Message, Provider } from '@renderer/types'
|
|||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
import Anthropic from '@anthropic-ai/sdk'
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import { getDefaultModel, getTopNamingModel } from './assistant'
|
import { getDefaultModel, getTopNamingModel } from './assistant'
|
||||||
import {
|
import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources'
|
||||||
ChatCompletionCreateParamsNonStreaming,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionSystemMessageParam
|
|
||||||
} from 'openai/resources'
|
|
||||||
import { sum, takeRight } from 'lodash'
|
import { sum, takeRight } from 'lodash'
|
||||||
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
|
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
|
||||||
import { EVENT_NAMES } from './event'
|
import { EVENT_NAMES } from './event'
|
||||||
@ -47,14 +43,18 @@ export default class ProviderSDK {
|
|||||||
if (this.isAnthropic) {
|
if (this.isAnthropic) {
|
||||||
await this.anthropicSdk.messages
|
await this.anthropicSdk.messages
|
||||||
.stream({
|
.stream({
|
||||||
max_tokens: 1024,
|
max_tokens: 2048,
|
||||||
messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[],
|
messages: [systemMessage, ...userMessages].filter(Boolean) as MessageParam[],
|
||||||
model: model.id
|
model: model.id
|
||||||
})
|
})
|
||||||
.on('text', (text) => onChunk({ text: text || '' }))
|
.on('text', (text) => onChunk({ text: text || '' }))
|
||||||
.on('finalMessage', (message) =>
|
.on('finalMessage', (message) =>
|
||||||
onChunk({
|
onChunk({
|
||||||
usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: sum(Object.values(message.usage)) }
|
usage: {
|
||||||
|
prompt_tokens: message.usage.input_tokens,
|
||||||
|
completion_tokens: message.usage.output_tokens,
|
||||||
|
total_tokens: sum(Object.values(message.usage))
|
||||||
|
}
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@ -64,9 +64,7 @@ export default class ProviderSDK {
|
|||||||
stream: true
|
stream: true
|
||||||
})
|
})
|
||||||
for await (const chunk of stream) {
|
for await (const chunk of stream) {
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||||
break
|
|
||||||
}
|
|
||||||
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
|
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -75,12 +73,12 @@ export default class ProviderSDK {
|
|||||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||||
|
|
||||||
const userMessages: ChatCompletionMessageParam[] = takeRight(messages, 5).map((message) => ({
|
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||||
role: 'user',
|
role: 'user',
|
||||||
content: message.content
|
content: message.content
|
||||||
}))
|
}))
|
||||||
|
|
||||||
const systemMessage: ChatCompletionSystemMessageParam = {
|
const systemMessage = {
|
||||||
role: 'system',
|
role: 'system',
|
||||||
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号'
|
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要加标点符号'
|
||||||
}
|
}
|
||||||
@ -97,8 +95,9 @@ export default class ProviderSDK {
|
|||||||
} else {
|
} else {
|
||||||
const response = await this.openaiSdk.chat.completions.create({
|
const response = await this.openaiSdk.chat.completions.create({
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages: [systemMessage, ...userMessages],
|
messages: [systemMessage, ...userMessages] as ChatCompletionMessageParam[],
|
||||||
stream: false
|
stream: false,
|
||||||
|
max_tokens: 50
|
||||||
})
|
})
|
||||||
|
|
||||||
return removeQuotes(response.choices[0].message?.content || '')
|
return removeQuotes(response.choices[0].message?.content || '')
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user