feat: support gpt-4o image generation (#4054)
* feat: support gpt-4o image generation * clean code
This commit is contained in:
parent
53ae427f2f
commit
194ba1baa0
@ -731,6 +731,7 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
onChunk({
|
onChunk({
|
||||||
text,
|
text,
|
||||||
generateImage: {
|
generateImage: {
|
||||||
|
type: 'base64',
|
||||||
images
|
images
|
||||||
},
|
},
|
||||||
usage: {
|
usage: {
|
||||||
|
|||||||
@ -27,6 +27,7 @@ import {
|
|||||||
Suggestion
|
Suggestion
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||||
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import {
|
import {
|
||||||
callMCPTool,
|
callMCPTool,
|
||||||
mcpToolsToOpenAITools,
|
mcpToolsToOpenAITools,
|
||||||
@ -354,7 +355,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
|
messages = addImageFileToContents(messages)
|
||||||
let systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
|
let systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
|
||||||
|
|
||||||
if (isOpenAIoSeries(model)) {
|
if (isOpenAIoSeries(model)) {
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import store from '@renderer/store'
|
|||||||
import { setGenerating } from '@renderer/store/runtime'
|
import { setGenerating } from '@renderer/store/runtime'
|
||||||
import { Assistant, MCPTool, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, MCPTool, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { formatMessageError, isAbortError } from '@renderer/utils/error'
|
import { formatMessageError, isAbortError } from '@renderer/utils/error'
|
||||||
|
import { withGenerateImage } from '@renderer/utils/formats'
|
||||||
import { cloneDeep, findLast, isEmpty } from 'lodash'
|
import { cloneDeep, findLast, isEmpty } from 'lodash'
|
||||||
|
|
||||||
import AiProvider from '../providers/AiProvider'
|
import AiProvider from '../providers/AiProvider'
|
||||||
@ -156,6 +157,7 @@ export async function fetchChatCompletion({
|
|||||||
})
|
})
|
||||||
|
|
||||||
message.status = 'success'
|
message.status = 'success'
|
||||||
|
message = withGenerateImage(message)
|
||||||
|
|
||||||
if (!message.usage || !message?.usage?.completion_tokens) {
|
if (!message.usage || !message?.usage?.completion_tokens) {
|
||||||
message.usage = await estimateMessagesUsage({
|
message.usage = await estimateMessagesUsage({
|
||||||
@ -191,7 +193,6 @@ export async function fetchChatCompletion({
|
|||||||
|
|
||||||
// Reset generating state
|
// Reset generating state
|
||||||
store.dispatch(setGenerating(false))
|
store.dispatch(setGenerating(false))
|
||||||
|
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -308,6 +308,7 @@ export type GenerateImageParams = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export type GenerateImageResponse = {
|
export type GenerateImageResponse = {
|
||||||
|
type: 'url' | 'base64'
|
||||||
images: string[]
|
images: string[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -178,3 +178,56 @@ export function withMessageThought(message: Message) {
|
|||||||
|
|
||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function withGenerateImage(message: Message) {
|
||||||
|
const imagePattern = new RegExp(`!\\[[^\\]]*\\]\\((.*?)\\s*("(?:.*[^"])")?\\s*\\)`)
|
||||||
|
const imageMatches = message.content.match(imagePattern)
|
||||||
|
|
||||||
|
if (!imageMatches || imageMatches[1] === null) {
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
const cleanImgContent = message.content
|
||||||
|
.replace(imagePattern, '')
|
||||||
|
.replace(/\n\s*\n/g, '\n')
|
||||||
|
.trim()
|
||||||
|
|
||||||
|
const downloadPattern = new RegExp(`\\[[^\\]]*\\]\\((.*?)\\s*("(?:.*[^"])")?\\s*\\)`)
|
||||||
|
const downloadMatches = cleanImgContent.match(downloadPattern)
|
||||||
|
|
||||||
|
let cleanContent = cleanImgContent
|
||||||
|
if (downloadMatches) {
|
||||||
|
cleanContent = cleanImgContent
|
||||||
|
.replace(downloadPattern, '')
|
||||||
|
.replace(/\n\s*\n/g, '\n')
|
||||||
|
.trim()
|
||||||
|
}
|
||||||
|
|
||||||
|
message = {
|
||||||
|
...message,
|
||||||
|
content: cleanContent,
|
||||||
|
metadata: {
|
||||||
|
...message.metadata,
|
||||||
|
generateImage: {
|
||||||
|
type: 'url',
|
||||||
|
images: [imageMatches[1]]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
}
|
||||||
|
|
||||||
|
export function addImageFileToContents(messages: Message[]) {
|
||||||
|
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
|
||||||
|
if (!lastAssistantMessage || !lastAssistantMessage.metadata || !lastAssistantMessage.metadata.generateImage) {
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
const imageFiles = lastAssistantMessage.metadata.generateImage.images
|
||||||
|
const updatedAssistantMessage = {
|
||||||
|
...lastAssistantMessage,
|
||||||
|
images: imageFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages.map((message) => (message.role === 'assistant' ? updatedAssistantMessage : message))
|
||||||
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user