diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 5ce2f28b..00129350 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -73,11 +73,47 @@ export default class OpenAIProvider extends BaseProvider { }) } + /** + * Check if the provider does not support files + * @returns True if the provider does not support files, false otherwise + */ private get isNotSupportFiles() { - const providers = ['deepseek', 'baichuan', 'minimax', 'doubao', 'xirang'] + const providers = ['deepseek', 'baichuan', 'minimax', 'xirang'] return providers.includes(this.provider.id) } + /** + * Extract the file content from the message + * @param message - The message + * @returns The file content + */ + private async extractFileContent(message: Message) { + if (message.files) { + const textFiles = message.files.filter((file) => [FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) + + if (textFiles.length > 0) { + let text = '' + const divider = '\n\n---\n\n' + + for (const file of textFiles) { + const fileContent = (await window.api.file.read(file.id + file.ext)).trim() + const fileNameRow = 'file: ' + file.origin_name + '\n\n' + text = text + fileNameRow + fileContent + divider + } + + return text + } + } + + return '' + } + + /** + * Get the message parameter + * @param message - The message + * @param model - The model + * @returns The message parameter + */ private async getMessageParam( message: Message, model: Model @@ -85,6 +121,7 @@ export default class OpenAIProvider extends BaseProvider { const isVision = isVisionModel(model) const content = await this.getMessageContent(message) + // If the message does not have files, return the message if (!message.files) { return { role: message.role, @@ -92,39 +129,22 @@ export default class OpenAIProvider extends BaseProvider { } } + // If the model does not support files, extract the file content if (this.isNotSupportFiles) { - if (message.files) { - const textFiles = message.files.filter((file) => [FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) - - if (textFiles.length > 0) { - let text = '' - const divider = '\n\n---\n\n' - - for (const file of textFiles) { - const fileContent = (await window.api.file.read(file.id + file.ext)).trim() - const fileNameRow = 'file: ' + file.origin_name + '\n\n' - text = text + fileNameRow + fileContent + divider - } - - return { - role: message.role, - content: content + divider + text - } - } - } + const fileContent = await this.extractFileContent(message) return { role: message.role, - content + content: content + '\n\n---\n\n' + fileContent } } - const parts: ChatCompletionContentPart[] = [ - { - type: 'text', - text: content - } - ] + // If the model supports files, add the file content to the message + const parts: ChatCompletionContentPart[] = [] + + if (content) { + parts.push({ type: 'text', text: content }) + } for (const file of message.files || []) { if (file.type === FileTypes.IMAGE && isVision) { @@ -149,12 +169,22 @@ export default class OpenAIProvider extends BaseProvider { } as ChatCompletionMessageParam } + /** + * Get the temperature for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The temperature + */ private getTemperature(assistant: Assistant, model: Model) { - if (isReasoningModel(model)) return undefined - - return assistant?.settings?.temperature + return isReasoningModel(model) ? undefined : assistant?.settings?.temperature } + /** + * Get the provider specific parameters for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The provider specific parameters + */ private getProviderSpecificParameters(assistant: Assistant, model: Model) { const { maxTokens } = getAssistantSettings(assistant) @@ -176,12 +206,24 @@ export default class OpenAIProvider extends BaseProvider { return {} } + /** + * Get the top P for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The top P + */ private getTopP(assistant: Assistant, model: Model) { if (isReasoningModel(model)) return undefined return assistant?.settings?.topP } + /** + * Get the reasoning effort for the assistant + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ private getReasoningEffort(assistant: Assistant, model: Model) { if (this.provider.id === 'groq') { return {} @@ -233,10 +275,24 @@ export default class OpenAIProvider extends BaseProvider { return {} } + /** + * Check if the model is an OpenAI reasoning model + * @param model - The model + * @returns True if the model is an OpenAI reasoning model, false otherwise + */ private isOpenAIReasoning(model: Model) { return model.id.startsWith('o1') || model.id.startsWith('o3') } + /** + * Generate completions for the assistant + * @param messages - The messages + * @param assistant - The assistant + * @param onChunk - The onChunk callback + * @param onFilterMessages - The onFilterMessages callback + * @param mcpTools - The MCP tools + * @returns The completions + */ async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel @@ -482,6 +538,13 @@ export default class OpenAIProvider extends BaseProvider { await processStream(stream, 0).finally(cleanup) } + /** + * Translate a message + * @param message - The message + * @param assistant - The assistant + * @param onResponse - The onResponse callback + * @returns The translated message + */ async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel @@ -552,6 +615,12 @@ export default class OpenAIProvider extends BaseProvider { return text } + /** + * Summarize a message + * @param messages - The messages + * @param assistant - The assistant + * @returns The summary + */ public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() @@ -593,6 +662,12 @@ export default class OpenAIProvider extends BaseProvider { return removeSpecialCharactersForTopicName(content.substring(0, 50)) } + /** + * Generate text + * @param prompt - The prompt + * @param content - The content + * @returns The generated text + */ public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { const model = getDefaultModel() @@ -608,6 +683,12 @@ export default class OpenAIProvider extends BaseProvider { return response.choices[0].message?.content || '' } + /** + * Generate suggestions + * @param messages - The messages + * @param assistant - The assistant + * @returns The suggestions + */ async suggestions(messages: Message[], assistant: Assistant): Promise { const model = assistant.model @@ -630,6 +711,11 @@ export default class OpenAIProvider extends BaseProvider { return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || [] } + /** + * Check if the model is valid + * @param model - The model + * @returns The validity of the model + */ public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { if (!model) { return { valid: false, error: new Error('No model found') } @@ -656,6 +742,10 @@ export default class OpenAIProvider extends BaseProvider { } } + /** + * Get the models + * @returns The models + */ public async models(): Promise { try { const response = await this.sdk.models.list() @@ -692,6 +782,11 @@ export default class OpenAIProvider extends BaseProvider { } } + /** + * Generate an image + * @param params - The parameters + * @returns The generated image + */ public async generateImage({ model, prompt, @@ -724,6 +819,11 @@ export default class OpenAIProvider extends BaseProvider { return response.data.map((item) => item.url) } + /** + * Get the embedding dimensions + * @param model - The model + * @returns The embedding dimensions + */ public async getEmbeddingDimensions(model: Model): Promise { const data = await this.sdk.embeddings.create({ model: model.id,