feat: support gpt-4o image generation (#4054)

* feat: support gpt-4o image generation

* clean code
This commit is contained in:
Chen Tao 2025-03-29 07:18:42 +08:00 committed by GitHub
parent 53ae427f2f
commit 194ba1baa0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 59 additions and 2 deletions

View File

@ -731,6 +731,7 @@ export default class GeminiProvider extends BaseProvider {
onChunk({ onChunk({
text, text,
generateImage: { generateImage: {
type: 'base64',
images images
}, },
usage: { usage: {

View File

@ -27,6 +27,7 @@ import {
Suggestion Suggestion
} from '@renderer/types' } from '@renderer/types'
import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { addImageFileToContents } from '@renderer/utils/formats'
import { import {
callMCPTool, callMCPTool,
mcpToolsToOpenAITools, mcpToolsToOpenAITools,
@ -354,7 +355,7 @@ export default class OpenAIProvider extends BaseProvider {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
messages = addImageFileToContents(messages)
let systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined let systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
if (isOpenAIoSeries(model)) { if (isOpenAIoSeries(model)) {

View File

@ -5,6 +5,7 @@ import store from '@renderer/store'
import { setGenerating } from '@renderer/store/runtime' import { setGenerating } from '@renderer/store/runtime'
import { Assistant, MCPTool, Message, Model, Provider, Suggestion } from '@renderer/types' import { Assistant, MCPTool, Message, Model, Provider, Suggestion } from '@renderer/types'
import { formatMessageError, isAbortError } from '@renderer/utils/error' import { formatMessageError, isAbortError } from '@renderer/utils/error'
import { withGenerateImage } from '@renderer/utils/formats'
import { cloneDeep, findLast, isEmpty } from 'lodash' import { cloneDeep, findLast, isEmpty } from 'lodash'
import AiProvider from '../providers/AiProvider' import AiProvider from '../providers/AiProvider'
@ -156,6 +157,7 @@ export async function fetchChatCompletion({
}) })
message.status = 'success' message.status = 'success'
message = withGenerateImage(message)
if (!message.usage || !message?.usage?.completion_tokens) { if (!message.usage || !message?.usage?.completion_tokens) {
message.usage = await estimateMessagesUsage({ message.usage = await estimateMessagesUsage({
@ -191,7 +193,6 @@ export async function fetchChatCompletion({
// Reset generating state // Reset generating state
store.dispatch(setGenerating(false)) store.dispatch(setGenerating(false))
return message return message
} }

View File

@ -308,6 +308,7 @@ export type GenerateImageParams = {
} }
export type GenerateImageResponse = { export type GenerateImageResponse = {
type: 'url' | 'base64'
images: string[] images: string[]
} }

View File

@ -178,3 +178,56 @@ export function withMessageThought(message: Message) {
return message return message
} }
export function withGenerateImage(message: Message) {
const imagePattern = new RegExp(`!\\[[^\\]]*\\]\\((.*?)\\s*("(?:.*[^"])")?\\s*\\)`)
const imageMatches = message.content.match(imagePattern)
if (!imageMatches || imageMatches[1] === null) {
return message
}
const cleanImgContent = message.content
.replace(imagePattern, '')
.replace(/\n\s*\n/g, '\n')
.trim()
const downloadPattern = new RegExp(`\\[[^\\]]*\\]\\((.*?)\\s*("(?:.*[^"])")?\\s*\\)`)
const downloadMatches = cleanImgContent.match(downloadPattern)
let cleanContent = cleanImgContent
if (downloadMatches) {
cleanContent = cleanImgContent
.replace(downloadPattern, '')
.replace(/\n\s*\n/g, '\n')
.trim()
}
message = {
...message,
content: cleanContent,
metadata: {
...message.metadata,
generateImage: {
type: 'url',
images: [imageMatches[1]]
}
}
}
return message
}
export function addImageFileToContents(messages: Message[]) {
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
if (!lastAssistantMessage || !lastAssistantMessage.metadata || !lastAssistantMessage.metadata.generateImage) {
return messages
}
const imageFiles = lastAssistantMessage.metadata.generateImage.images
const updatedAssistantMessage = {
...lastAssistantMessage,
images: imageFiles
}
return messages.map((message) => (message.role === 'assistant' ? updatedAssistantMessage : message))
}