cherry-studio/src/renderer/src/providers/OpenAIProvider.ts

735 lines
21 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
getOpenAIWebSearchParams,
isOpenAIoSeries,
isReasoningModel,
isSupportedModel,
isVisionModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import {
filterContextMessages,
filterEmptyMessages,
filterUserRoleStartMessages
} from '@renderer/services/MessagesService'
import {
Assistant,
FileTypes,
GenerateImageParams,
MCPToolResponse,
Message,
Model,
Provider,
Suggestion
} from '@renderer/types'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import {
callMCPTool,
filterMCPTools,
mcpToolsToOpenAITools,
openAIToolsToMcpTool,
upsertMCPToolResponse
} from '@renderer/utils/mcp-tools'
import { isString, takeRight } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai'
import {
ChatCompletionAssistantMessageParam,
ChatCompletionContentPart,
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionToolMessageParam
} from 'openai/resources'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
type ReasoningEffort = 'high' | 'medium' | 'low'
export default class OpenAIProvider extends BaseProvider {
private sdk: OpenAI
constructor(provider: Provider) {
super(provider)
if (provider.id === 'azure-openai' || provider.type === 'azure-openai') {
this.sdk = new AzureOpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.apiKey,
apiVersion: provider.apiVersion,
endpoint: provider.apiHost
})
return
}
this.sdk = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.apiKey,
baseURL: this.getBaseURL(),
defaultHeaders: this.defaultHeaders()
})
}
private get isNotSupportFiles() {
const providers = ['deepseek', 'baichuan', 'minimax', 'doubao', 'xirang']
return providers.includes(this.provider.id)
}
private async getMessageParam(
message: Message,
model: Model
): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam> {
const isVision = isVisionModel(model)
const content = await this.getMessageContent(message)
if (!message.files) {
return {
role: message.role,
content
}
}
if (this.isNotSupportFiles) {
if (message.files) {
const textFiles = message.files.filter((file) => [FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type))
if (textFiles.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const file of textFiles) {
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return {
role: message.role,
content: content + divider + text
}
}
}
return {
role: message.role,
content
}
}
const parts: ChatCompletionContentPart[] = [
{
type: 'text',
text: content
}
]
for (const file of message.files || []) {
if (file.type === FileTypes.IMAGE && isVision) {
const image = await window.api.file.base64Image(file.id + file.ext)
parts.push({
type: 'image_url',
image_url: { url: image.data }
})
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role: message.role,
content: parts
} as ChatCompletionMessageParam
}
private getTemperature(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined
return assistant?.settings?.temperature
}
private getProviderSpecificParameters(assistant: Assistant, model: Model) {
const { maxTokens } = getAssistantSettings(assistant)
if (this.provider.id === 'openrouter') {
if (model.id.includes('deepseek-r1')) {
return {
include_reasoning: true
}
}
}
if (this.isOpenAIReasoning(model)) {
return {
max_tokens: undefined,
max_completion_tokens: maxTokens
}
}
return {}
}
private getTopP(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined
return assistant?.settings?.topP
}
private getReasoningEffort(assistant: Assistant, model: Model) {
if (this.provider.id === 'groq') {
return {}
}
if (isReasoningModel(model)) {
if (model.provider === 'openrouter') {
return {
reasoning: {
effort: assistant?.settings?.reasoning_effort
}
}
}
if (isOpenAIoSeries(model)) {
return {
reasoning_effort: assistant?.settings?.reasoning_effort
}
}
if (model.id.includes('claude-3.7-sonnet') || model.id.includes('claude-3-7-sonnet')) {
const effortRatios: Record<ReasoningEffort, number> = {
high: 0.8,
medium: 0.5,
low: 0.2
}
const effort = assistant?.settings?.reasoning_effort as ReasoningEffort
const effortRatio = effortRatios[effort]
if (!effortRatio) {
return {}
}
const maxTokens = assistant?.settings?.maxTokens || DEFAULT_MAX_TOKENS
const budgetTokens = Math.trunc(Math.max(Math.min(maxTokens * effortRatio, 32000), 1024))
return {
thinking: {
type: 'enabled',
budget_tokens: budgetTokens
}
}
}
return {}
}
return {}
}
private isOpenAIReasoning(model: Model) {
return model.id.startsWith('o1') || model.id.startsWith('o3')
}
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
let systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
if (isOpenAIoSeries(model)) {
systemMessage = {
role: 'developer',
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
}
}
const userMessages: ChatCompletionMessageParam[] = []
const _messages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
)
onFilterMessages(_messages)
for (const message of _messages) {
userMessages.push(await this.getMessageParam(message, model))
}
const isOpenAIReasoning = this.isOpenAIReasoning(model)
const isSupportStreamOutput = () => {
if (isOpenAIReasoning) {
return false
}
return streamOutput
}
let hasReasoningContent = false
let lastChunk = ''
const isReasoningJustDone = (
delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta & {
reasoning_content?: string
reasoning?: string
thinking?: string
}
) => {
if (!delta?.content) return false
// 检查当前chunk和上一个chunk的组合是否形成###Response标记
const combinedChunks = lastChunk + delta.content
lastChunk = delta.content
// 检测思考结束
if (combinedChunks.includes('###Response') || delta.content === '</think>') {
return true
}
// 如果有reasoning_content或reasoning说明是在思考中
if (delta?.reasoning_content || delta?.reasoning || delta?.thinking) {
hasReasoningContent = true
}
// 如果之前有reasoning_content或reasoning现在有普通content说明思考结束
if (hasReasoningContent && delta.content) {
return true
}
return false
}
let time_first_token_millsec = 0
let time_first_content_millsec = 0
const start_time_millsec = new Date().getTime()
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
const tools = mcpTools && mcpTools.length > 0 ? mcpToolsToOpenAITools(mcpTools) : undefined
const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter(
Boolean
) as ChatCompletionMessageParam[]
const toolResponses: MCPToolResponse[] = []
const processStream = async (stream: any, idx: number) => {
if (!isSupportStreamOutput()) {
const time_completion_millsec = new Date().getTime() - start_time_millsec
return onChunk({
text: stream.choices[0].message?.content || '',
usage: stream.usage,
metrics: {
completion_tokens: stream.usage?.completion_tokens,
time_completion_millsec,
time_first_token_millsec: 0
}
})
}
const final_tool_calls = {} as Record<number, ChatCompletionMessageToolCall>
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
break
}
const delta = chunk.choices[0]?.delta
if (delta?.reasoning_content || delta?.reasoning) {
hasReasoningContent = true
}
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
if (time_first_content_millsec == 0 && isReasoningJustDone(delta)) {
time_first_content_millsec = new Date().getTime()
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0
// Extract citations from the raw response if available
const citations = (chunk as OpenAI.Chat.Completions.ChatCompletionChunk & { citations?: string[] })?.citations
const finishReason = chunk.choices[0]?.finish_reason
if (delta?.tool_calls) {
const chunkToolCalls = delta.tool_calls
for (const t of chunkToolCalls) {
const { index, id, function: fn, type } = t
const args = fn && typeof fn.arguments === 'string' ? fn.arguments : ''
if (!(index in final_tool_calls)) {
final_tool_calls[index] = {
id,
function: {
name: fn?.name,
arguments: args
},
type
} as ChatCompletionMessageToolCall
} else {
final_tool_calls[index].function.arguments += args
}
}
if (finishReason !== 'tool_calls') {
continue
}
}
if (finishReason === 'tool_calls') {
const toolCalls = Object.values(final_tool_calls)
reqMessages.push({
role: 'assistant',
tool_calls: toolCalls
} as ChatCompletionAssistantMessageParam)
for (const toolCall of toolCalls) {
const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall)
if (!mcpTool) {
continue
}
upsertMCPToolResponse(toolResponses, { tool: mcpTool, status: 'invoking', id: toolCall.id }, onChunk)
const toolCallResponse = await callMCPTool(mcpTool)
reqMessages.push({
role: 'tool',
content: isString(toolCallResponse.content)
? toolCallResponse.content
: JSON.stringify(toolCallResponse.content),
tool_call_id: toolCall.id
} as ChatCompletionToolMessageParam)
upsertMCPToolResponse(
toolResponses,
{ tool: mcpTool, status: 'done', response: toolCallResponse, id: toolCall.id },
onChunk
)
}
const newStream = await this.sdk.chat.completions
// @ts-ignore key is not typed
.create(
{
model: model.id,
messages: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_tokens: maxTokens,
keep_alive: this.keepAliveTime,
stream: isSupportStreamOutput(),
tools: tools,
...getOpenAIWebSearchParams(assistant, model),
...this.getReasoningEffort(assistant, model),
...this.getProviderSpecificParameters(assistant, model),
...this.getCustomParameters(assistant)
},
{
signal
}
)
await processStream(newStream, idx + 1)
}
onChunk({
text: delta?.content || '',
reasoning_content: delta?.reasoning_content || delta?.reasoning || '',
usage: chunk.usage,
metrics: {
completion_tokens: chunk.usage?.completion_tokens,
time_completion_millsec,
time_first_token_millsec,
time_thinking_millsec
},
citations,
mcpToolResponse: toolResponses
})
}
}
const stream = await this.sdk.chat.completions
// @ts-ignore key is not typed
.create(
{
model: model.id,
messages: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_tokens: maxTokens,
keep_alive: this.keepAliveTime,
stream: isSupportStreamOutput(),
tools: tools,
...getOpenAIWebSearchParams(assistant, model),
...this.getReasoningEffort(assistant, model),
...this.getProviderSpecificParameters(assistant, model),
...this.getCustomParameters(assistant)
},
{
signal
}
)
await processStream(stream, 0).finally(cleanup)
}
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const messages = message.content
? [
{ role: 'system', content: assistant.prompt },
{ role: 'user', content: message.content }
]
: [{ role: 'user', content: assistant.prompt }]
const isOpenAIReasoning = this.isOpenAIReasoning(model)
const isSupportedStreamOutput = () => {
if (!onResponse) {
return false
}
if (isOpenAIReasoning) {
return false
}
return true
}
const stream = isSupportedStreamOutput()
// @ts-ignore key is not typed
const response = await this.sdk.chat.completions.create({
model: model.id,
messages: messages as ChatCompletionMessageParam[],
stream,
keep_alive: this.keepAliveTime,
temperature: assistant?.settings?.temperature
})
if (!stream) {
return response.choices[0].message?.content || ''
}
let text = ''
let isThinking = false
const isReasoning = isReasoningModel(model)
for await (const chunk of response) {
const deltaContent = chunk.choices[0]?.delta?.content || ''
if (!deltaContent.trim()) {
continue
}
if (isReasoning) {
if (deltaContent.includes('<think>')) {
isThinking = true
}
if (!isThinking) {
text += deltaContent
onResponse?.(text)
}
if (deltaContent.includes('</think>')) {
isThinking = false
}
} else {
text += deltaContent
onResponse?.(text)
}
}
return text
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5)
.filter((message) => !message.isPreset)
.map((message) => ({
role: message.role,
content: message.content
}))
const userMessageContent = userMessages.reduce((prev, curr) => {
const content = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}`
return prev + (prev ? '\n' : '') + content
}, '')
const systemMessage = {
role: 'system',
content: getStoreSetting('topicNamingPrompt') || i18n.t('prompts.title')
}
const userMessage = {
role: 'user',
content: userMessageContent
}
// @ts-ignore key is not typed
const response = await this.sdk.chat.completions.create({
model: model.id,
messages: [systemMessage, userMessage] as ChatCompletionMessageParam[],
stream: false,
keep_alive: this.keepAliveTime,
max_tokens: 1000
})
// 针对思考类模型的返回,总结仅截取</think>之后的内容
let content = response.choices[0].message?.content || ''
content = content.replace(/^<think>(.*?)<\/think>/s, '')
return removeSpecialCharactersForTopicName(content.substring(0, 50))
}
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const response = await this.sdk.chat.completions.create({
model: model.id,
stream: false,
messages: [
{ role: 'system', content: prompt },
{ role: 'user', content }
]
})
return response.choices[0].message?.content || ''
}
async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
const model = assistant.model
if (!model) {
return []
}
const response: any = await this.sdk.request({
method: 'post',
path: '/advice_questions',
body: {
messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })),
model: model.id,
max_tokens: 0,
temperature: 0,
n: 0
}
})
return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || []
}
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
if (!model) {
return { valid: false, error: new Error('No model found') }
}
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
stream: false
}
try {
const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
return {
valid: Boolean(response?.choices[0].message),
error: null
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
public async models(): Promise<OpenAI.Models.Model[]> {
try {
const response = await this.sdk.models.list()
if (this.provider.id === 'github') {
// @ts-ignore key is not typed
return response.body
.map((model) => ({
id: model.name,
description: model.summary,
object: 'model',
owned_by: model.publisher
}))
.filter(isSupportedModel)
}
if (this.provider.id === 'together') {
// @ts-ignore key is not typed
return response?.body
.map((model: any) => ({
id: model.id,
description: model.display_name,
object: 'model',
owned_by: model.organization
}))
.filter(isSupportedModel)
}
const models = response?.data || []
return models.filter(isSupportedModel)
} catch (error) {
return []
}
}
public async generateImage({
model,
prompt,
negativePrompt,
imageSize,
batchSize,
seed,
numInferenceSteps,
guidanceScale,
signal,
promptEnhancement
}: GenerateImageParams): Promise<string[]> {
const response = (await this.sdk.request({
method: 'post',
path: '/images/generations',
signal,
body: {
model,
prompt,
negative_prompt: negativePrompt,
image_size: imageSize,
batch_size: batchSize,
seed: seed ? parseInt(seed) : undefined,
num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale,
prompt_enhancement: promptEnhancement
}
})) as { data: Array<{ url: string }> }
return response.data.map((item) => item.url)
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
const data = await this.sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi'
})
return data.data[0].embedding.length
}
}