diff --git a/src/main/ipc.ts b/src/main/ipc.ts index d6edc2e5..ac4831c1 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -41,7 +41,13 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle('image:base64', async (_, filePath) => { try { const data = await fs.promises.readFile(filePath) - return `data:image/${path.extname(filePath).slice(1)};base64,${data.toString('base64')}` + const base64 = data.toString('base64') + const mime = `image/${path.extname(filePath).slice(1)}` + return { + mime, + base64, + data: `data:image/${mime};base64,${base64}` + } } catch (error) { Logger.error('Error reading file:', error) return '' diff --git a/src/preload/index.d.ts b/src/preload/index.d.ts index 8b665dc7..3b459fb1 100644 --- a/src/preload/index.d.ts +++ b/src/preload/index.d.ts @@ -31,7 +31,7 @@ declare global { all: () => Promise } image: { - base64: (filePath: string) => Promise + base64: (filePath: string) => Promise<{ mime: string; base64: string; data: string }> } } } diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index ff170921..d89b8a45 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -1,6 +1,7 @@ import { Model } from '@renderer/types' const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-turbo|dall|cogview/i +const VISION_REGEX = /llava|moondream|minicpm|gemini/i const EMBEDDING_REGEX = /embedding/i export const SYSTEM_MODELS: Record = { @@ -395,3 +396,7 @@ export function isTextToImageModel(model: Model): boolean { export function isEmbeddingModel(model: Model): boolean { return EMBEDDING_REGEX.test(model.id) } + +export function isVisionModel(model: Model): boolean { + return VISION_REGEX.test(model.id) +} diff --git a/src/renderer/src/pages/files/FilesPage.tsx b/src/renderer/src/pages/files/FilesPage.tsx index f80f818f..4d363152 100644 --- a/src/renderer/src/pages/files/FilesPage.tsx +++ b/src/renderer/src/pages/files/FilesPage.tsx @@ -46,16 +46,6 @@ const FilesPage: FC = () => { } ] - // const handleSelectFile = async () => { - // const files = await window.api.fileSelect({ - // properties: ['openFile', 'multiSelections'] - // }) - // for (const file of files || []) { - // const result = await window.api.fileUpload(file.path) - // console.log('Selected file:', file, result) - // } - // } - return ( @@ -63,7 +53,7 @@ const FilesPage: FC = () => { - +
diff --git a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx index 58c49aa0..88be5a94 100644 --- a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx @@ -1,16 +1,18 @@ import { PaperClipOutlined } from '@ant-design/icons' -import { FileMetadata } from '@renderer/types' +import { isVisionModel } from '@renderer/config/models' +import { FileMetadata, Model } from '@renderer/types' import { Tooltip } from 'antd' import { FC } from 'react' import { useTranslation } from 'react-i18next' interface Props { + model: Model files: FileMetadata[] setFiles: (files: FileMetadata[]) => void ToolbarButton: any } -const AttachmentButton: FC = ({ files, setFiles, ToolbarButton }) => { +const AttachmentButton: FC = ({ model, files, setFiles, ToolbarButton }) => { const { t } = useTranslation() const onSelectFile = async () => { @@ -20,6 +22,10 @@ const AttachmentButton: FC = ({ files, setFiles, ToolbarButton }) => { _files && setFiles(_files) } + if (!isVisionModel(model)) { + return null + } + return ( diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index d8b8aab2..76619934 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -40,7 +40,7 @@ let _text = '' const Inputbar: FC = ({ assistant, setActiveTopic }) => { const [text, setText] = useState(_text) const [inputFocus, setInputFocus] = useState(false) - const { addTopic } = useAssistant(assistant.id) + const { addTopic, model } = useAssistant(assistant.id) const { sendMessageShortcut, fontSize } = useSettings() const [expended, setExpend] = useState(false) const [estimateTokenCount, setEstimateTokenCount] = useState(0) @@ -261,7 +261,7 @@ const Inputbar: FC = ({ assistant, setActiveTopic }) => { - + {expended ? : } diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 73606745..4a92146a 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -1,10 +1,10 @@ -import { GoogleGenerativeAI } from '@google/generative-ai' +import { Content, GoogleGenerativeAI, InlineDataPart, Part } from '@google/generative-ai' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { EVENT_NAMES } from '@renderer/services/event' import { filterContextMessages, filterMessages } from '@renderer/services/messages' import { Assistant, Message, Provider, Suggestion } from '@renderer/types' import axios from 'axios' -import { isEmpty, takeRight } from 'lodash' +import { first, isEmpty, takeRight } from 'lodash' import OpenAI from 'openai' import BaseProvider from './BaseProvider' @@ -17,6 +17,27 @@ export default class GeminiProvider extends BaseProvider { this.sdk = new GoogleGenerativeAI(provider.apiKey) } + private async getMessageParts(message: Message): Promise { + const file = first(message.files) + + if (file && file.type === 'image') { + const base64Data = await window.api.image.base64(file.path) + return [ + { + text: message.content + }, + { + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } + } as InlineDataPart + ] + } + + return [{ text: message.content }] + } + public async completions( messages: Message[], assistant: Assistant, @@ -29,7 +50,7 @@ export default class GeminiProvider extends BaseProvider { const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))).map((message) => { return { role: message.role, - content: message.content + message } }) @@ -44,14 +65,19 @@ export default class GeminiProvider extends BaseProvider { const userLastMessage = userMessages.pop() - const chat = geminiModel.startChat({ - history: userMessages.map((message) => ({ - role: message.role === 'user' ? 'user' : 'model', - parts: [{ text: message.content }] - })) - }) + const history: Content[] = [] - const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!) + for (const message of userMessages) { + history.push({ + role: message.role === 'user' ? 'user' : 'model', + parts: await this.getMessageParts(message.message) + }) + } + + const chat = geminiModel.startChat({ history }) + const message = await this.getMessageParts(userLastMessage?.message!) + + const userMessagesStream = await chat.sendMessageStream(message) for await (const chunk of userMessagesStream.stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index bc570384..90766a90 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -34,12 +34,13 @@ export default class OpenAIProvider extends BaseProvider { } if (file.type === 'image') { + const base64Data = await window.api.image.base64(file.path) return [ { type: 'text', text: message.content }, { type: 'image_url', image_url: { - url: await window.api.image.base64(file.path) + url: base64Data.data } } ]