From 6c6af2a12bffadc191fa21f4de784bbc90af432d Mon Sep 17 00:00:00 2001 From: Chen Tao <70054568+eeee0717@users.noreply.github.com> Date: Sat, 22 Mar 2025 21:50:45 +0800 Subject: [PATCH] feat(provider): gemini-2.0-flash-exp image (#3421) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: finish basic gemini-2.0-flash-exp generate image * feat: support edit image * chore * fix: package https-proxy-agent-v5 version * feat: throw finish message and add history messages * feat: update generate image models * chore --------- Co-authored-by: 亢奋猫 --- package.json | 1 + src/renderer/src/config/models.ts | 24 + src/renderer/src/i18n/locales/en-us.json | 13 +- src/renderer/src/i18n/locales/ja-jp.json | 13 +- src/renderer/src/i18n/locales/ru-ru.json | 13 +- src/renderer/src/i18n/locales/zh-cn.json | 4 +- src/renderer/src/i18n/locales/zh-tw.json | 13 +- .../src/pages/home/Inputbar/Inputbar.tsx | 25 +- .../pages/home/Messages/MessageContent.tsx | 2 + .../src/pages/home/Messages/MessageImage.tsx | 29 + src/renderer/src/providers/GeminiProvider.ts | 507 ++++++++++++------ src/renderer/src/providers/index.d.ts | 1 + src/renderer/src/services/ApiService.ts | 9 +- src/renderer/src/types/index.ts | 7 + yarn.lock | 143 ++++- 15 files changed, 623 insertions(+), 181 deletions(-) create mode 100644 src/renderer/src/pages/home/Messages/MessageImage.tsx diff --git a/package.json b/package.json index 067c2ed7..1e3bef94 100644 --- a/package.json +++ b/package.json @@ -56,6 +56,7 @@ "@electron-toolkit/utils": "^3.0.0", "@electron/notarize": "^2.5.0", "@emotion/is-prop-valid": "^1.3.1", + "@google/genai": "^0.4.0", "@google/generative-ai": "^0.21.0", "@llm-tools/embedjs": "patch:@llm-tools/embedjs@npm%3A0.1.28#~/.yarn/patches/@llm-tools-embedjs-npm-0.1.28-8e4393fa2d.patch", "@llm-tools/embedjs-libsql": "^0.1.28", diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index a5920d14..597fbcbf 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -1878,6 +1878,8 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [ 'stabilityai/stable-diffusion-xl-base-1.0' ] +export const GENERATE_IMAGE_MODELS = ['gemini-2.0-flash-exp-image-generation', 'gemini-2.0-flash-exp'] + export function isTextToImageModel(model: Model): boolean { return TEXT_TO_IMAGE_REGEX.test(model.id) } @@ -2009,6 +2011,28 @@ export function isWebSearchModel(model: Model): boolean { return false } +export function isGenerateImageModel(model: Model): boolean { + if (!model) { + return false + } + + const provider = getProviderByModel(model) + + if (!provider) { + return false + } + + const isEmbedding = isEmbeddingModel(model) + + if (isEmbedding) { + return false + } + if (GENERATE_IMAGE_MODELS.includes(model.id)) { + return true + } + return false +} + export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record { if (isWebSearchModel(model)) { if (assistant.enableWebSearch) { diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 73d653a0..8d7b5de2 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -198,7 +198,16 @@ "topics.prompt.tips": "Topic Prompts: Additional supplementary prompts provided for the current topic", "topics.title": "Topics", "topics.unpinned": "Unpinned Topics", - "translate": "Translate" + "topics.new": "New Topic", + "translate": "Translate", + "navigation": { + "prev": "Previous Message", + "next": "Next Message", + "first": "Already at the first message", + "last": "Already at the last message" + }, + "input.generate_image": "Generate image", + "input.generate_image_not_supported": "The model does not support generating images." }, "code_block": { "collapse": "Collapse", @@ -1168,4 +1177,4 @@ "visualization": "Visualization" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index 37162354..fe0c1d72 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -198,7 +198,16 @@ "topics.prompt.tips": "トピック提示語:現在のトピックに対して追加の補足提示語を提供", "topics.title": "トピック", "topics.unpinned": "固定解除", - "translate": "翻訳" + "topics.new": "新しいトピック", + "translate": "翻訳", + "navigation": { + "prev": "前のメッセージ", + "next": "次のメッセージ", + "first": "最初のメッセージです", + "last": "最後のメッセージです" + }, + "input.generate_image": "画像を生成する", + "input.generate_image_not_supported": "モデルは画像の生成をサポートしていません。" }, "code_block": { "collapse": "折りたたむ", @@ -1168,4 +1177,4 @@ "visualization": "可視化" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 3b4583f6..7e648567 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -198,7 +198,16 @@ "topics.prompt.tips": "Тематические подсказки: Дополнительные подсказки, предоставленные для текущей темы", "topics.title": "Топики", "topics.unpinned": "Открепленные темы", - "translate": "Перевести" + "topics.new": "Новый топик", + "translate": "Перевести", + "navigation": { + "prev": "Предыдущее сообщение", + "next": "Следующее сообщение", + "first": "Уже первое сообщение", + "last": "Уже последнее сообщение" + }, + "input.generate_image": "Сгенерировать изображение", + "input.generate_image_not_supported": "Модель не поддерживает генерацию изображений." }, "code_block": { "collapse": "Свернуть", @@ -1168,4 +1177,4 @@ "visualization": "Визуализация" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index e5ed1435..1eaab35b 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -116,6 +116,8 @@ "input.translate": "翻译成{{target_language}}", "input.upload": "上传图片或文档", "input.upload.document": "上传文档(模型不支持图片)", + "input.generate_image": "生成图片", + "input.generate_image_not_supported": "模型不支持生成图片", "input.web_search": "开启网络搜索", "input.web_search.button.ok": "去设置", "input.web_search.enable": "开启网络搜索", @@ -1168,4 +1170,4 @@ "visualization": "可视化" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index f5e3aeae..11ad3b84 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -200,7 +200,16 @@ "topics.prompt.tips": "話題提示詞:針對目前話題提供額外的補充提示詞", "topics.title": "話題", "topics.unpinned": "取消固定", - "translate": "翻譯" + "topics.new": "開始新對話", + "translate": "翻譯", + "navigation": { + "prev": "上一條訊息", + "next": "下一條訊息", + "first": "已經是第一條訊息", + "last": "已經是最後一條訊息" + }, + "input.generate_image": "生成圖片", + "input.generate_image_not_supported": "模型不支援生成圖片" }, "code_block": { "collapse": "折疊", @@ -1170,4 +1179,4 @@ "visualization": "視覺化" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index dce62038..d60ccb49 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -7,10 +7,11 @@ import { GlobalOutlined, HolderOutlined, PauseCircleOutlined, + PictureOutlined, QuestionCircleOutlined } from '@ant-design/icons' import TranslateButton from '@renderer/components/TranslateButton' -import { isFunctionCallingModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' +import { isFunctionCallingModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' import db from '@renderer/databases' import { useAssistant } from '@renderer/hooks/useAssistant' import { useMessageOperations } from '@renderer/hooks/useMessageOperations' @@ -626,7 +627,6 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = } const onEnableWebSearch = () => { - console.log(assistant) if (!isWebSearchModel(model)) { if (!WebSearchService.isWebSearchEnabled()) { window.modal.confirm({ @@ -645,10 +645,17 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = updateAssistant({ ...assistant, enableWebSearch: !assistant.enableWebSearch }) } + const onEnableGenerateImage = () => { + updateAssistant({ ...assistant, enableGenerateImage: !assistant.enableGenerateImage }) + } + useEffect(() => { if (!isWebSearchModel(model) && !WebSearchService.isWebSearchEnabled() && assistant.enableWebSearch) { updateAssistant({ ...assistant, enableWebSearch: false }) } + if (!isGenerateImageModel(model) && assistant.enableGenerateImage) { + updateAssistant({ ...assistant, enableGenerateImage: false }) + } }, [assistant, model, updateAssistant]) const resetHeight = () => { @@ -738,6 +745,20 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = ToolbarButton={ToolbarButton} /> )} + + + + + onMentionModel(model, mentionFromKeyboard)} diff --git a/src/renderer/src/pages/home/Messages/MessageContent.tsx b/src/renderer/src/pages/home/Messages/MessageContent.tsx index 6d5f7b69..eea78503 100644 --- a/src/renderer/src/pages/home/Messages/MessageContent.tsx +++ b/src/renderer/src/pages/home/Messages/MessageContent.tsx @@ -16,6 +16,7 @@ import styled from 'styled-components' import Markdown from '../Markdown/Markdown' import MessageAttachments from './MessageAttachments' import MessageError from './MessageError' +import MessageImage from './MessageImage' import MessageSearchResults from './MessageSearchResults' import MessageThought from './MessageThought' import MessageTools from './MessageTools' @@ -150,6 +151,7 @@ const MessageContent: React.FC = ({ message: _message, model }) => { + {message.metadata?.generateImage && } {message.translatedContent && ( diff --git a/src/renderer/src/pages/home/Messages/MessageImage.tsx b/src/renderer/src/pages/home/Messages/MessageImage.tsx new file mode 100644 index 00000000..0f24eded --- /dev/null +++ b/src/renderer/src/pages/home/Messages/MessageImage.tsx @@ -0,0 +1,29 @@ +import { Message } from '@renderer/types' +import { Image as AntdImage } from 'antd' +import { FC } from 'react' +import styled from 'styled-components' + +interface Props { + message: Message +} + +const MessageImage: FC = ({ message }) => { + return ( + + {message.metadata?.generateImage!.images.map((image, index) => ( + + ))} + + ) +} +const Container = styled.div` + display: flex; + flex-direction: row; + gap: 10px; + margin-top: 8px; +` +const Image = styled(AntdImage)` + border-radius: 10px; +` + +export default MessageImage diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 653877e3..a21385fb 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -1,3 +1,10 @@ +import { + ContentListUnion, + createPartFromBase64, + FinishReason, + GenerateContentResponse, + GoogleGenAI +} from '@google/genai' import { Content, FileDataPart, @@ -35,16 +42,19 @@ import axios from 'axios' import { isEmpty, takeRight } from 'lodash' import OpenAI from 'openai' -import { CompletionsParams } from '.' +import { ChunkCallbackData, CompletionsParams } from '.' import BaseProvider from './BaseProvider' export default class GeminiProvider extends BaseProvider { private sdk: GoogleGenerativeAI private requestOptions: RequestOptions + private imageSdk: GoogleGenAI constructor(provider: Provider) { super(provider) this.sdk = new GoogleGenerativeAI(this.apiKey) + /// this sdk is experimental + this.imageSdk = new GoogleGenAI({ apiKey: this.apiKey }) this.requestOptions = { baseUrl: this.getBaseURL() } @@ -105,6 +115,25 @@ export default class GeminiProvider extends BaseProvider { const role = message.role === 'user' ? 'user' : 'model' const parts: Part[] = [{ text: await this.getMessageContent(message) }] + // Add any generated images from previous responses + if (message.metadata?.generateImage?.images && message.metadata.generateImage.images.length > 0) { + for (const imageUrl of message.metadata.generateImage.images) { + if (imageUrl && imageUrl.startsWith('data:')) { + // Extract base64 data and mime type from the data URL + const matches = imageUrl.match(/^data:(.+);base64,(.*)$/) + if (matches && matches.length === 3) { + const mimeType = matches[1] + const base64Data = matches[2] + parts.push({ + inlineData: { + data: base64Data, + mimeType: mimeType + } + } as InlineDataPart) + } + } + } + } for (const file of message.files || []) { if (file.type === FileTypes.IMAGE) { @@ -179,180 +208,184 @@ export default class GeminiProvider extends BaseProvider { * @param onFilterMessages - The onFilterMessages callback */ public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) { - const defaultModel = getDefaultModel() - const model = assistant.model || defaultModel - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) + if (assistant.enableGenerateImage) { + await this.generateImageExp({ messages, assistant, onFilterMessages, onChunk }) + } else { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) - const userMessages = filterUserRoleStartMessages( - filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2))) - ) - onFilterMessages(userMessages) + const userMessages = filterUserRoleStartMessages( + filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2))) + ) + onFilterMessages(userMessages) - const userLastMessage = userMessages.pop() + const userLastMessage = userMessages.pop() - const history: Content[] = [] + const history: Content[] = [] - for (const message of userMessages) { - history.push(await this.getMessageContents(message)) - } - - const tools = mcpToolsToGeminiTools(mcpTools) - const toolResponses: MCPToolResponse[] = [] - - if (assistant.enableWebSearch && isWebSearchModel(model)) { - tools.push({ - // @ts-ignore googleSearch is not a valid tool for Gemini - googleSearch: {} - }) - } - - const geminiModel = this.sdk.getGenerativeModel( - { - model: model.id, - ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }), - safetySettings: this.getSafetySettings(model.id), - tools: tools, - generationConfig: { - maxOutputTokens: maxTokens, - temperature: assistant?.settings?.temperature, - topP: assistant?.settings?.topP, - ...this.getCustomParameters(assistant) - } - }, - this.requestOptions - ) - - const chat = geminiModel.startChat({ history }) - const messageContents = await this.getMessageContents(userLastMessage!) - - if (isGemmaModel(model) && assistant.prompt) { - const isFirstMessage = history.length === 0 - if (isFirstMessage) { - const systemMessage = { - role: 'user', - parts: [ - { - text: - 'user\n' + - assistant.prompt + - '\n' + - 'user\n' + - messageContents.parts[0].text + - '' - } - ] - } - messageContents.parts = systemMessage.parts + for (const message of userMessages) { + history.push(await this.getMessageContents(message)) } - } - const start_time_millsec = new Date().getTime() - const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) - const { signal } = abortController + const tools = mcpToolsToGeminiTools(mcpTools) + const toolResponses: MCPToolResponse[] = [] - if (!streamOutput) { - const { response } = await chat.sendMessage(messageContents.parts, { signal }) - const time_completion_millsec = new Date().getTime() - start_time_millsec - onChunk({ - text: response.candidates?.[0].content.parts[0].text, - usage: { - prompt_tokens: response.usageMetadata?.promptTokenCount || 0, - completion_tokens: response.usageMetadata?.candidatesTokenCount || 0, - total_tokens: response.usageMetadata?.totalTokenCount || 0 - }, - metrics: { - completion_tokens: response.usageMetadata?.candidatesTokenCount, - time_completion_millsec, - time_first_token_millsec: 0 - }, - search: response.candidates?.[0]?.groundingMetadata - }) - return - } - - const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) - let time_first_token_millsec = 0 - - const processStream = async (stream: GenerateContentStreamResult, idx: number) => { - for await (const chunk of stream.stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - - if (time_first_token_millsec == 0) { - time_first_token_millsec = new Date().getTime() - start_time_millsec - } - - const time_completion_millsec = new Date().getTime() - start_time_millsec - - const functionCalls = chunk.functionCalls() - - if (functionCalls) { - const fcallParts: FunctionCallPart[] = [] - const fcRespParts: FunctionResponsePart[] = [] - for (const call of functionCalls) { - console.log('Function call:', call) - fcallParts.push({ functionCall: call } as FunctionCallPart) - const mcpTool = geminiFunctionCallToMcpTool(mcpTools, call) - if (mcpTool) { - upsertMCPToolResponse( - toolResponses, - { - tool: mcpTool, - status: 'invoking', - id: `${call.name}-${idx}` - }, - onChunk - ) - const toolCallResponse = await callMCPTool(mcpTool) - fcRespParts.push({ - functionResponse: { - name: mcpTool.id, - response: toolCallResponse - } - }) - upsertMCPToolResponse( - toolResponses, - { - tool: mcpTool, - status: 'done', - response: toolCallResponse, - id: `${call.name}-${idx}` - }, - onChunk - ) - } - } - - if (fcRespParts) { - history.push(messageContents) - history.push({ - role: 'model', - parts: fcallParts - }) - const newChat = geminiModel.startChat({ history }) - const newStream = await newChat.sendMessageStream(fcRespParts, { signal }) - await processStream(newStream, idx + 1) - } - } - - onChunk({ - text: chunk.text(), - usage: { - prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, - completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, - total_tokens: chunk.usageMetadata?.totalTokenCount || 0 - }, - metrics: { - completion_tokens: chunk.usageMetadata?.candidatesTokenCount, - time_completion_millsec, - time_first_token_millsec - }, - search: chunk.candidates?.[0]?.groundingMetadata, - mcpToolResponse: toolResponses + if (assistant.enableWebSearch && isWebSearchModel(model)) { + tools.push({ + // @ts-ignore googleSearch is not a valid tool for Gemini + googleSearch: {} }) } - } - await processStream(userMessagesStream, 0).finally(cleanup) + const geminiModel = this.sdk.getGenerativeModel( + { + model: model.id, + ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }), + safetySettings: this.getSafetySettings(model.id), + tools: tools, + generationConfig: { + maxOutputTokens: maxTokens, + temperature: assistant?.settings?.temperature, + topP: assistant?.settings?.topP, + ...this.getCustomParameters(assistant) + } + }, + this.requestOptions + ) + + const chat = geminiModel.startChat({ history }) + const messageContents = await this.getMessageContents(userLastMessage!) + + if (isGemmaModel(model) && assistant.prompt) { + const isFirstMessage = history.length === 0 + if (isFirstMessage) { + const systemMessage = { + role: 'user', + parts: [ + { + text: + 'user\n' + + assistant.prompt + + '\n' + + 'user\n' + + messageContents.parts[0].text + + '' + } + ] + } + messageContents.parts = systemMessage.parts + } + } + + const start_time_millsec = new Date().getTime() + const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) + const { signal } = abortController + + if (!streamOutput) { + const { response } = await chat.sendMessage(messageContents.parts, { signal }) + const time_completion_millsec = new Date().getTime() - start_time_millsec + onChunk({ + text: response.candidates?.[0].content.parts[0].text, + usage: { + prompt_tokens: response.usageMetadata?.promptTokenCount || 0, + completion_tokens: response.usageMetadata?.candidatesTokenCount || 0, + total_tokens: response.usageMetadata?.totalTokenCount || 0 + }, + metrics: { + completion_tokens: response.usageMetadata?.candidatesTokenCount, + time_completion_millsec, + time_first_token_millsec: 0 + }, + search: response.candidates?.[0]?.groundingMetadata + }) + return + } + + const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) + let time_first_token_millsec = 0 + + const processStream = async (stream: GenerateContentStreamResult, idx: number) => { + for await (const chunk of stream.stream) { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break + + if (time_first_token_millsec == 0) { + time_first_token_millsec = new Date().getTime() - start_time_millsec + } + + const time_completion_millsec = new Date().getTime() - start_time_millsec + + const functionCalls = chunk.functionCalls() + + if (functionCalls) { + const fcallParts: FunctionCallPart[] = [] + const fcRespParts: FunctionResponsePart[] = [] + for (const call of functionCalls) { + console.log('Function call:', call) + fcallParts.push({ functionCall: call } as FunctionCallPart) + const mcpTool = geminiFunctionCallToMcpTool(mcpTools, call) + if (mcpTool) { + upsertMCPToolResponse( + toolResponses, + { + tool: mcpTool, + status: 'invoking', + id: `${call.name}-${idx}` + }, + onChunk + ) + const toolCallResponse = await callMCPTool(mcpTool) + fcRespParts.push({ + functionResponse: { + name: mcpTool.id, + response: toolCallResponse + } + }) + upsertMCPToolResponse( + toolResponses, + { + tool: mcpTool, + status: 'done', + response: toolCallResponse, + id: `${call.name}-${idx}` + }, + onChunk + ) + } + } + + if (fcRespParts) { + history.push(messageContents) + history.push({ + role: 'model', + parts: fcallParts + }) + const newChat = geminiModel.startChat({ history }) + const newStream = await newChat.sendMessageStream(fcRespParts, { signal }) + await processStream(newStream, idx + 1) + } + } + + onChunk({ + text: chunk.text(), + usage: { + prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, + completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, + total_tokens: chunk.usageMetadata?.totalTokenCount || 0 + }, + metrics: { + completion_tokens: chunk.usageMetadata?.candidatesTokenCount, + time_completion_millsec, + time_first_token_millsec + }, + search: chunk.candidates?.[0]?.groundingMetadata, + mcpToolResponse: toolResponses + }) + } + } + + await processStream(userMessagesStream, 0).finally(cleanup) + } } /** @@ -536,6 +569,150 @@ export default class GeminiProvider extends BaseProvider { return [] } + /** + * 生成图像 + * @param messages - 消息列表 + * @param assistant - 助手配置 + * @param onChunk - 处理生成块的回调 + * @param onFilterMessages - 过滤消息的回调 + * @returns Promise + */ + private async generateImageExp({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise { + const defaultModel = getDefaultModel() + const model = assistant.model || defaultModel + const { contextCount } = getAssistantSettings(assistant) + + const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2))) + onFilterMessages(userMessages) + + const userLastMessage = userMessages.pop() + if (!userLastMessage) { + throw new Error('No user message found') + } + + const history: Content[] = [] + + for (const message of userMessages) { + history.push(await this.getMessageContents(message)) + } + + const userLastMessageContent = await this.getMessageContents(userLastMessage) + const allContents = [...history, userLastMessageContent] + + let contents: ContentListUnion = allContents.length > 0 ? (allContents as ContentListUnion) : [] + + contents = await this.addImageFileToContents(userLastMessage, contents) + + const response = await this.callGeminiGenerateContent(model.id, contents) + + console.log('response', response) + + const { isValid, message } = this.isValidGeminiResponse(response) + if (!isValid) { + throw new Error(`Gemini API error: ${message}`) + } + + this.processGeminiImageResponse(response, onChunk) + } + + /** + * 添加图片文件到内容列表 + * @param message - 用户消息 + * @param contents - 内容列表 + * @returns 更新后的内容列表 + */ + private async addImageFileToContents(message: Message, contents: ContentListUnion): Promise { + if (message.files && message.files.length > 0) { + const file = message.files[0] + const fileContent = await window.api.file.base64Image(file.id + file.ext) + + if (fileContent && fileContent.base64) { + const contentsArray = Array.isArray(contents) ? contents : [contents] + return [...contentsArray, createPartFromBase64(fileContent.base64, fileContent.mime)] + } + } + return contents + } + + /** + * 调用Gemini API生成内容 + * @param modelId - 模型ID + * @param contents - 内容列表 + * @returns 生成结果 + */ + private async callGeminiGenerateContent( + modelId: string, + contents: ContentListUnion + ): Promise { + try { + return await this.imageSdk.models.generateContent({ + model: modelId, + contents: contents, + config: { + responseModalities: ['Text', 'Image'], + responseMimeType: 'text/plain' + } + }) + } catch (error) { + console.error('Gemini API error:', error) + throw error + } + } + + /** + * 检查Gemini响应是否有效 + * @param response - Gemini响应 + * @returns 是否有效 + */ + private isValidGeminiResponse(response: GenerateContentResponse): { isValid: boolean; message: string } { + return { + isValid: response?.candidates?.[0]?.finishReason === FinishReason.STOP ? true : false, + message: response?.candidates?.[0]?.finishReason || '' + } + } + + /** + * 处理Gemini图像响应 + * @param response - Gemini响应 + * @param onChunk - 处理生成块的回调 + */ + private processGeminiImageResponse(response: any, onChunk: (chunk: ChunkCallbackData) => void): void { + const parts = response.candidates[0].content.parts + + // 提取图像数据 + const images = parts + .filter((part: Part) => part.inlineData) + .map((part: Part) => { + if (!part.inlineData) { + return null + } + const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,` + return part.inlineData.data.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data + }) + + // 提取文本数据 + const text = parts + .filter((part: Part) => part.text !== undefined) + .map((part: Part) => part.text) + .join('') + + // 返回结果 + onChunk({ + text, + generateImage: { + images + }, + usage: { + prompt_tokens: response.usageMetadata?.promptTokenCount || 0, + completion_tokens: response.usageMetadata?.candidatesTokenCount || 0, + total_tokens: response.usageMetadata?.totalTokenCount || 0 + }, + metrics: { + completion_tokens: response.usageMetadata?.candidatesTokenCount + } + }) + } + /** * Check if the model is valid * @param model - The model diff --git a/src/renderer/src/providers/index.d.ts b/src/renderer/src/providers/index.d.ts index 40b9f680..f21880b4 100644 --- a/src/renderer/src/providers/index.d.ts +++ b/src/renderer/src/providers/index.d.ts @@ -9,6 +9,7 @@ interface ChunkCallbackData { search?: GroundingMetadata citations?: string[] mcpToolResponse?: MCPToolResponse[] + generateImage?: GenerateImageResponse } interface CompletionsParams { diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index a8e2e3c2..49290912 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -111,7 +111,7 @@ export async function fetchChatCompletion({ messages: filterUsefulMessages(messages), assistant, onFilterMessages: (messages) => (_messages = messages), - onChunk: ({ text, reasoning_content, usage, metrics, search, citations, mcpToolResponse }) => { + onChunk: ({ text, reasoning_content, usage, metrics, search, citations, mcpToolResponse, generateImage }) => { message.content = message.content + text || '' message.usage = usage message.metrics = metrics @@ -127,6 +127,12 @@ export async function fetchChatCompletion({ if (mcpToolResponse) { message.metadata = { ...message.metadata, mcpTools: cloneDeep(mcpToolResponse) } } + if (generateImage) { + message.metadata = { + ...message.metadata, + generateImage: generateImage + } + } // Handle citations from Perplexity API if (isFirstChunk && citations) { @@ -162,6 +168,7 @@ export async function fetchChatCompletion({ } } } + console.log('message', message) } catch (error: any) { if (isAbortError(error)) { message.status = 'paused' diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 1e3c3735..10e981c8 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -16,6 +16,7 @@ export type Assistant = { settings?: Partial messages?: AssistantMessage[] enableWebSearch?: boolean + enableGenerateImage?: boolean } export type AssistantMessage = { @@ -77,6 +78,8 @@ export type Message = { webSearch?: WebSearchResponse // MCP Tools mcpTools?: MCPToolResponse[] + // Generate Image + generateImage?: GenerateImageResponse } // 多模型消息样式 multiModelMessageStyle?: 'horizontal' | 'vertical' | 'fold' | 'grid' @@ -295,6 +298,10 @@ export type GenerateImageParams = { promptEnhancement?: boolean } +export type GenerateImageResponse = { + images: string[] +} + export interface TranslateHistory { id: string sourceText: string diff --git a/yarn.lock b/yarn.lock index 14207c21..4d47dc60 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1021,6 +1021,16 @@ __metadata: languageName: node linkType: hard +"@google/genai@npm:^0.4.0": + version: 0.4.0 + resolution: "@google/genai@npm:0.4.0" + dependencies: + google-auth-library: "npm:^9.14.2" + ws: "npm:^8.18.0" + checksum: 10c0/4feb837b373cdbe60a5388b880b2384b116ffa369ae17ec2562c4e9da0f90e315d5e30c413ee3a620b6d147c55e1e9165f0e143aba6d945f1dfbe61fa584fefc + languageName: node + linkType: hard + "@google/generative-ai@npm:^0.21.0": version: 0.21.0 resolution: "@google/generative-ai@npm:0.21.0" @@ -3336,6 +3346,7 @@ __metadata: "@emotion/is-prop-valid": "npm:^1.3.1" "@eslint-react/eslint-plugin": "npm:^1.36.1" "@eslint/js": "npm:^9.22.0" + "@google/genai": "npm:^0.4.0" "@google/generative-ai": "npm:^0.21.0" "@hello-pangea/dnd": "npm:^16.6.0" "@kangfenmao/keyv-storage": "npm:^0.1.0" @@ -4012,7 +4023,7 @@ __metadata: languageName: node linkType: hard -"base64-js@npm:^1.3.1, base64-js@npm:^1.5.1": +"base64-js@npm:^1.3.0, base64-js@npm:^1.3.1, base64-js@npm:^1.5.1": version: 1.5.1 resolution: "base64-js@npm:1.5.1" checksum: 10c0/f23823513b63173a001030fae4f2dabe283b99a9d324ade3ad3d148e218134676f1ee8568c877cd79ec1c53158dcf2d2ba527a97c606618928ba99dd930102bf @@ -4035,6 +4046,13 @@ __metadata: languageName: node linkType: hard +"bignumber.js@npm:^9.0.0": + version: 9.1.2 + resolution: "bignumber.js@npm:9.1.2" + checksum: 10c0/e17786545433f3110b868725c449fa9625366a6e675cd70eb39b60938d6adbd0158cb4b3ad4f306ce817165d37e63f4aa3098ba4110db1d9a3b9f66abfbaf10d + languageName: node + linkType: hard + "bindings@npm:^1.5.0": version: 1.5.0 resolution: "bindings@npm:1.5.0" @@ -4201,6 +4219,13 @@ __metadata: languageName: node linkType: hard +"buffer-equal-constant-time@npm:1.0.1": + version: 1.0.1 + resolution: "buffer-equal-constant-time@npm:1.0.1" + checksum: 10c0/fb2294e64d23c573d0dd1f1e7a466c3e978fe94a4e0f8183937912ca374619773bef8e2aceb854129d2efecbbc515bbd0cc78d2734a3e3031edb0888531bbc8e + languageName: node + linkType: hard + "buffer-equal@npm:0.0.1": version: 0.0.1 resolution: "buffer-equal@npm:0.0.1" @@ -5648,6 +5673,15 @@ __metadata: languageName: node linkType: hard +"ecdsa-sig-formatter@npm:1.0.11, ecdsa-sig-formatter@npm:^1.0.11": + version: 1.0.11 + resolution: "ecdsa-sig-formatter@npm:1.0.11" + dependencies: + safe-buffer: "npm:^5.0.1" + checksum: 10c0/ebfbf19d4b8be938f4dd4a83b8788385da353d63307ede301a9252f9f7f88672e76f2191618fd8edfc2f24679236064176fab0b78131b161ee73daa37125408c + languageName: node + linkType: hard + "ee-first@npm:1.1.1": version: 1.1.1 resolution: "ee-first@npm:1.1.1" @@ -6663,7 +6697,7 @@ __metadata: languageName: node linkType: hard -"extend@npm:^3.0.0, extend@npm:~3.0.2": +"extend@npm:^3.0.0, extend@npm:^3.0.2, extend@npm:~3.0.2": version: 3.0.2 resolution: "extend@npm:3.0.2" checksum: 10c0/73bf6e27406e80aa3e85b0d1c4fd987261e628064e170ca781125c0b635a3dabad5e05adbf07595ea0cf1e6c5396cacb214af933da7cbaf24fe75ff14818e8f9 @@ -7249,6 +7283,30 @@ __metadata: languageName: node linkType: hard +"gaxios@npm:^6.0.0, gaxios@npm:^6.1.1": + version: 6.7.1 + resolution: "gaxios@npm:6.7.1" + dependencies: + extend: "npm:^3.0.2" + https-proxy-agent: "npm:^7.0.1" + is-stream: "npm:^2.0.0" + node-fetch: "npm:^2.6.9" + uuid: "npm:^9.0.1" + checksum: 10c0/53e92088470661c5bc493a1de29d05aff58b1f0009ec5e7903f730f892c3642a93e264e61904383741ccbab1ce6e519f12a985bba91e13527678b32ee6d7d3fd + languageName: node + linkType: hard + +"gcp-metadata@npm:^6.1.0": + version: 6.1.1 + resolution: "gcp-metadata@npm:6.1.1" + dependencies: + gaxios: "npm:^6.1.1" + google-logging-utils: "npm:^0.0.2" + json-bigint: "npm:^1.0.0" + checksum: 10c0/71f6ad4800aa622c246ceec3955014c0c78cdcfe025971f9558b9379f4019f5e65772763428ee8c3244fa81b8631977316eaa71a823493f82e5c44d7259ffac8 + languageName: node + linkType: hard + "gensync@npm:^1.0.0-beta.2": version: 1.0.0-beta.2 resolution: "gensync@npm:1.0.0-beta.2" @@ -7492,6 +7550,27 @@ __metadata: languageName: node linkType: hard +"google-auth-library@npm:^9.14.2": + version: 9.15.1 + resolution: "google-auth-library@npm:9.15.1" + dependencies: + base64-js: "npm:^1.3.0" + ecdsa-sig-formatter: "npm:^1.0.11" + gaxios: "npm:^6.1.1" + gcp-metadata: "npm:^6.1.0" + gtoken: "npm:^7.0.0" + jws: "npm:^4.0.0" + checksum: 10c0/6eef36d9a9cb7decd11e920ee892579261c6390104b3b24d3e0f3889096673189fe2ed0ee43fd563710e2560de98e63ad5aa4967b91e7f4e69074a422d5f7b65 + languageName: node + linkType: hard + +"google-logging-utils@npm:^0.0.2": + version: 0.0.2 + resolution: "google-logging-utils@npm:0.0.2" + checksum: 10c0/9a4bbd470dd101c77405e450fffca8592d1d7114f245a121288d04a957aca08c9dea2dd1a871effe71e41540d1bb0494731a0b0f6fea4358e77f06645e4268c1 + languageName: node + linkType: hard + "gopd@npm:^1.0.1, gopd@npm:^1.2.0": version: 1.2.0 resolution: "gopd@npm:1.2.0" @@ -7551,6 +7630,16 @@ __metadata: languageName: node linkType: hard +"gtoken@npm:^7.0.0": + version: 7.1.0 + resolution: "gtoken@npm:7.1.0" + dependencies: + gaxios: "npm:^6.0.0" + jws: "npm:^4.0.0" + checksum: 10c0/0a3dcacb1a3c4578abe1ee01c7d0bf20bffe8ded3ee73fc58885d53c00f6eb43b4e1372ff179f0da3ed5cfebd5b7c6ab8ae2776f1787e90d943691b4fe57c716 + languageName: node + linkType: hard + "har-schema@npm:^2.0.0": version: 2.0.0 resolution: "har-schema@npm:2.0.0" @@ -8486,6 +8575,13 @@ __metadata: languageName: node linkType: hard +"is-stream@npm:^2.0.0": + version: 2.0.1 + resolution: "is-stream@npm:2.0.1" + checksum: 10c0/7c284241313fc6efc329b8d7f08e16c0efeb6baab1b4cd0ba579eb78e5af1aa5da11e68559896a2067cd6c526bd29241dda4eb1225e627d5aa1a89a76d4635a5 + languageName: node + linkType: hard + "is-stream@npm:^3.0.0": version: 3.0.0 resolution: "is-stream@npm:3.0.0" @@ -8662,6 +8758,15 @@ __metadata: languageName: node linkType: hard +"json-bigint@npm:^1.0.0": + version: 1.0.0 + resolution: "json-bigint@npm:1.0.0" + dependencies: + bignumber.js: "npm:^9.0.0" + checksum: 10c0/e3f34e43be3284b573ea150a3890c92f06d54d8ded72894556357946aeed9877fd795f62f37fe16509af189fd314ab1104d0fd0f163746ad231b9f378f5b33f4 + languageName: node + linkType: hard + "json-bignum@npm:^0.0.3": version: 0.0.3 resolution: "json-bignum@npm:0.0.3" @@ -8813,6 +8918,27 @@ __metadata: languageName: node linkType: hard +"jwa@npm:^2.0.0": + version: 2.0.0 + resolution: "jwa@npm:2.0.0" + dependencies: + buffer-equal-constant-time: "npm:1.0.1" + ecdsa-sig-formatter: "npm:1.0.11" + safe-buffer: "npm:^5.0.1" + checksum: 10c0/6baab823b93c038ba1d2a9e531984dcadbc04e9eb98d171f4901b7a40d2be15961a359335de1671d78cb6d987f07cbe5d350d8143255977a889160c4d90fcc3c + languageName: node + linkType: hard + +"jws@npm:^4.0.0": + version: 4.0.0 + resolution: "jws@npm:4.0.0" + dependencies: + jwa: "npm:^2.0.0" + safe-buffer: "npm:^5.0.1" + checksum: 10c0/f1ca77ea5451e8dc5ee219cb7053b8a4f1254a79cb22417a2e1043c1eb8a569ae118c68f24d72a589e8a3dd1824697f47d6bd4fb4bebb93a3bdf53545e721661 + languageName: node + linkType: hard + "katex@npm:^0.12.0": version: 0.12.0 resolution: "katex@npm:0.12.0" @@ -10803,7 +10929,7 @@ __metadata: languageName: node linkType: hard -"node-fetch@npm:^2.6.1, node-fetch@npm:^2.6.7": +"node-fetch@npm:^2.6.1, node-fetch@npm:^2.6.7, node-fetch@npm:^2.6.9": version: 2.7.0 resolution: "node-fetch@npm:2.7.0" dependencies: @@ -15259,6 +15385,15 @@ __metadata: languageName: node linkType: hard +"uuid@npm:^9.0.1": + version: 9.0.1 + resolution: "uuid@npm:9.0.1" + bin: + uuid: dist/bin/uuid + checksum: 10c0/1607dd32ac7fc22f2d8f77051e6a64845c9bce5cd3dd8aa0070c074ec73e666a1f63c7b4e0f4bf2bc8b9d59dc85a15e17807446d9d2b17c8485fbc2147b27f9b + languageName: node + linkType: hard + "uzip@npm:0.20201231.0": version: 0.20201231.0 resolution: "uzip@npm:0.20201231.0" @@ -15587,7 +15722,7 @@ __metadata: languageName: node linkType: hard -"ws@npm:^8.13.0": +"ws@npm:^8.13.0, ws@npm:^8.18.0": version: 8.18.1 resolution: "ws@npm:8.18.1" peerDependencies: