feat(OpenAIProvider): Add file content extraction and enhance message handling

- Implemented a method to extract file content from messages, supporting text and document types.
- Updated message parameter handling to include file content when the model does not support files.
- Added detailed JSDoc comments for new methods and existing functionalities for better documentation.
This commit is contained in:
kangfenmao 2025-03-14 10:29:59 +08:00
parent 008bb33013
commit 18b7618a8d

View File

@ -73,11 +73,47 @@ export default class OpenAIProvider extends BaseProvider {
}) })
} }
/**
* Check if the provider does not support files
* @returns True if the provider does not support files, false otherwise
*/
private get isNotSupportFiles() { private get isNotSupportFiles() {
const providers = ['deepseek', 'baichuan', 'minimax', 'doubao', 'xirang'] const providers = ['deepseek', 'baichuan', 'minimax', 'xirang']
return providers.includes(this.provider.id) return providers.includes(this.provider.id)
} }
/**
* Extract the file content from the message
* @param message - The message
* @returns The file content
*/
private async extractFileContent(message: Message) {
if (message.files) {
const textFiles = message.files.filter((file) => [FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type))
if (textFiles.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const file of textFiles) {
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
/**
* Get the message parameter
* @param message - The message
* @param model - The model
* @returns The message parameter
*/
private async getMessageParam( private async getMessageParam(
message: Message, message: Message,
model: Model model: Model
@ -85,6 +121,7 @@ export default class OpenAIProvider extends BaseProvider {
const isVision = isVisionModel(model) const isVision = isVisionModel(model)
const content = await this.getMessageContent(message) const content = await this.getMessageContent(message)
// If the message does not have files, return the message
if (!message.files) { if (!message.files) {
return { return {
role: message.role, role: message.role,
@ -92,39 +129,22 @@ export default class OpenAIProvider extends BaseProvider {
} }
} }
// If the model does not support files, extract the file content
if (this.isNotSupportFiles) { if (this.isNotSupportFiles) {
if (message.files) { const fileContent = await this.extractFileContent(message)
const textFiles = message.files.filter((file) => [FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type))
if (textFiles.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const file of textFiles) {
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return {
role: message.role,
content: content + divider + text
}
}
}
return { return {
role: message.role, role: message.role,
content content: content + '\n\n---\n\n' + fileContent
} }
} }
const parts: ChatCompletionContentPart[] = [ // If the model supports files, add the file content to the message
{ const parts: ChatCompletionContentPart[] = []
type: 'text',
text: content if (content) {
} parts.push({ type: 'text', text: content })
] }
for (const file of message.files || []) { for (const file of message.files || []) {
if (file.type === FileTypes.IMAGE && isVision) { if (file.type === FileTypes.IMAGE && isVision) {
@ -149,12 +169,22 @@ export default class OpenAIProvider extends BaseProvider {
} as ChatCompletionMessageParam } as ChatCompletionMessageParam
} }
/**
* Get the temperature for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The temperature
*/
private getTemperature(assistant: Assistant, model: Model) { private getTemperature(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined return isReasoningModel(model) ? undefined : assistant?.settings?.temperature
return assistant?.settings?.temperature
} }
/**
* Get the provider specific parameters for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The provider specific parameters
*/
private getProviderSpecificParameters(assistant: Assistant, model: Model) { private getProviderSpecificParameters(assistant: Assistant, model: Model) {
const { maxTokens } = getAssistantSettings(assistant) const { maxTokens } = getAssistantSettings(assistant)
@ -176,12 +206,24 @@ export default class OpenAIProvider extends BaseProvider {
return {} return {}
} }
/**
* Get the top P for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The top P
*/
private getTopP(assistant: Assistant, model: Model) { private getTopP(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined if (isReasoningModel(model)) return undefined
return assistant?.settings?.topP return assistant?.settings?.topP
} }
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getReasoningEffort(assistant: Assistant, model: Model) { private getReasoningEffort(assistant: Assistant, model: Model) {
if (this.provider.id === 'groq') { if (this.provider.id === 'groq') {
return {} return {}
@ -233,10 +275,24 @@ export default class OpenAIProvider extends BaseProvider {
return {} return {}
} }
/**
* Check if the model is an OpenAI reasoning model
* @param model - The model
* @returns True if the model is an OpenAI reasoning model, false otherwise
*/
private isOpenAIReasoning(model: Model) { private isOpenAIReasoning(model: Model) {
return model.id.startsWith('o1') || model.id.startsWith('o3') return model.id.startsWith('o1') || model.id.startsWith('o3')
} }
/**
* Generate completions for the assistant
* @param messages - The messages
* @param assistant - The assistant
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
* @param mcpTools - The MCP tools
* @returns The completions
*/
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> { async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
@ -482,6 +538,13 @@ export default class OpenAIProvider extends BaseProvider {
await processStream(stream, 0).finally(cleanup) await processStream(stream, 0).finally(cleanup)
} }
/**
* Translate a message
* @param message - The message
* @param assistant - The assistant
* @param onResponse - The onResponse callback
* @returns The translated message
*/
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
@ -552,6 +615,12 @@ export default class OpenAIProvider extends BaseProvider {
return text return text
} }
/**
* Summarize a message
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaries(messages: Message[], assistant: Assistant): Promise<string> { public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel() const model = getTopNamingModel() || assistant.model || getDefaultModel()
@ -593,6 +662,12 @@ export default class OpenAIProvider extends BaseProvider {
return removeSpecialCharactersForTopicName(content.substring(0, 50)) return removeSpecialCharactersForTopicName(content.substring(0, 50))
} }
/**
* Generate text
* @param prompt - The prompt
* @param content - The content
* @returns The generated text
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> { public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel() const model = getDefaultModel()
@ -608,6 +683,12 @@ export default class OpenAIProvider extends BaseProvider {
return response.choices[0].message?.content || '' return response.choices[0].message?.content || ''
} }
/**
* Generate suggestions
* @param messages - The messages
* @param assistant - The assistant
* @returns The suggestions
*/
async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> { async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
const model = assistant.model const model = assistant.model
@ -630,6 +711,11 @@ export default class OpenAIProvider extends BaseProvider {
return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || [] return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || []
} }
/**
* Check if the model is valid
* @param model - The model
* @returns The validity of the model
*/
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
if (!model) { if (!model) {
return { valid: false, error: new Error('No model found') } return { valid: false, error: new Error('No model found') }
@ -656,6 +742,10 @@ export default class OpenAIProvider extends BaseProvider {
} }
} }
/**
* Get the models
* @returns The models
*/
public async models(): Promise<OpenAI.Models.Model[]> { public async models(): Promise<OpenAI.Models.Model[]> {
try { try {
const response = await this.sdk.models.list() const response = await this.sdk.models.list()
@ -692,6 +782,11 @@ export default class OpenAIProvider extends BaseProvider {
} }
} }
/**
* Generate an image
* @param params - The parameters
* @returns The generated image
*/
public async generateImage({ public async generateImage({
model, model,
prompt, prompt,
@ -724,6 +819,11 @@ export default class OpenAIProvider extends BaseProvider {
return response.data.map((item) => item.url) return response.data.map((item) => item.url)
} }
/**
* Get the embedding dimensions
* @param model - The model
* @returns The embedding dimensions
*/
public async getEmbeddingDimensions(model: Model): Promise<number> { public async getEmbeddingDimensions(model: Model): Promise<number> {
const data = await this.sdk.embeddings.create({ const data = await this.sdk.embeddings.create({
model: model.id, model: model.id,