feat: Improved IPC image handling and added vision model support.

- Improved IPC image handling to return mime type and base64 encoded data alongside the image data.
- Updated type definition for `base64` method in image object to return an object with mime, base64, and data properties.
- Added support for vision models using new function and regex.
- Table cell size has been reduced on the FilesPage component.
- Added support for vision model attachments.
- Added model dependency to AttachmentButton component.
- Implemented new functionality to handle image messages in the GeminiProvider class.
- Update image base64 encoding to directly use API response data.
This commit is contained in:
kangfenmao 2024-09-11 13:53:54 +08:00
parent 2016ba7062
commit 8781388760
8 changed files with 62 additions and 28 deletions

View File

@ -41,7 +41,13 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle('image:base64', async (_, filePath) => { ipcMain.handle('image:base64', async (_, filePath) => {
try { try {
const data = await fs.promises.readFile(filePath) const data = await fs.promises.readFile(filePath)
return `data:image/${path.extname(filePath).slice(1)};base64,${data.toString('base64')}` const base64 = data.toString('base64')
const mime = `image/${path.extname(filePath).slice(1)}`
return {
mime,
base64,
data: `data:image/${mime};base64,${base64}`
}
} catch (error) { } catch (error) {
Logger.error('Error reading file:', error) Logger.error('Error reading file:', error)
return '' return ''

View File

@ -31,7 +31,7 @@ declare global {
all: () => Promise<FileMetadata[]> all: () => Promise<FileMetadata[]>
} }
image: { image: {
base64: (filePath: string) => Promise<string> base64: (filePath: string) => Promise<{ mime: string; base64: string; data: string }>
} }
} }
} }

View File

@ -1,6 +1,7 @@
import { Model } from '@renderer/types' import { Model } from '@renderer/types'
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-turbo|dall|cogview/i const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-turbo|dall|cogview/i
const VISION_REGEX = /llava|moondream|minicpm|gemini/i
const EMBEDDING_REGEX = /embedding/i const EMBEDDING_REGEX = /embedding/i
export const SYSTEM_MODELS: Record<string, Model[]> = { export const SYSTEM_MODELS: Record<string, Model[]> = {
@ -395,3 +396,7 @@ export function isTextToImageModel(model: Model): boolean {
export function isEmbeddingModel(model: Model): boolean { export function isEmbeddingModel(model: Model): boolean {
return EMBEDDING_REGEX.test(model.id) return EMBEDDING_REGEX.test(model.id)
} }
export function isVisionModel(model: Model): boolean {
return VISION_REGEX.test(model.id)
}

View File

@ -46,16 +46,6 @@ const FilesPage: FC = () => {
} }
] ]
// const handleSelectFile = async () => {
// const files = await window.api.fileSelect({
// properties: ['openFile', 'multiSelections']
// })
// for (const file of files || []) {
// const result = await window.api.fileUpload(file.path)
// console.log('Selected file:', file, result)
// }
// }
return ( return (
<Container> <Container>
<Navbar> <Navbar>
@ -63,7 +53,7 @@ const FilesPage: FC = () => {
</Navbar> </Navbar>
<ContentContainer> <ContentContainer>
<VStack style={{ flex: 1 }}> <VStack style={{ flex: 1 }}>
<Table dataSource={dataSource} columns={columns} style={{ width: '100%', height: '100%' }} /> <Table dataSource={dataSource} columns={columns} style={{ width: '100%', height: '100%' }} size="small" />
</VStack> </VStack>
</ContentContainer> </ContentContainer>
</Container> </Container>

View File

@ -1,16 +1,18 @@
import { PaperClipOutlined } from '@ant-design/icons' import { PaperClipOutlined } from '@ant-design/icons'
import { FileMetadata } from '@renderer/types' import { isVisionModel } from '@renderer/config/models'
import { FileMetadata, Model } from '@renderer/types'
import { Tooltip } from 'antd' import { Tooltip } from 'antd'
import { FC } from 'react' import { FC } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
interface Props { interface Props {
model: Model
files: FileMetadata[] files: FileMetadata[]
setFiles: (files: FileMetadata[]) => void setFiles: (files: FileMetadata[]) => void
ToolbarButton: any ToolbarButton: any
} }
const AttachmentButton: FC<Props> = ({ files, setFiles, ToolbarButton }) => { const AttachmentButton: FC<Props> = ({ model, files, setFiles, ToolbarButton }) => {
const { t } = useTranslation() const { t } = useTranslation()
const onSelectFile = async () => { const onSelectFile = async () => {
@ -20,6 +22,10 @@ const AttachmentButton: FC<Props> = ({ files, setFiles, ToolbarButton }) => {
_files && setFiles(_files) _files && setFiles(_files)
} }
if (!isVisionModel(model)) {
return null
}
return ( return (
<Tooltip placement="top" title={t('chat.input.upload')} arrow> <Tooltip placement="top" title={t('chat.input.upload')} arrow>
<ToolbarButton type="text" className={files.length ? 'active' : ''} onClick={onSelectFile}> <ToolbarButton type="text" className={files.length ? 'active' : ''} onClick={onSelectFile}>

View File

@ -40,7 +40,7 @@ let _text = ''
const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => { const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
const [text, setText] = useState(_text) const [text, setText] = useState(_text)
const [inputFocus, setInputFocus] = useState(false) const [inputFocus, setInputFocus] = useState(false)
const { addTopic } = useAssistant(assistant.id) const { addTopic, model } = useAssistant(assistant.id)
const { sendMessageShortcut, fontSize } = useSettings() const { sendMessageShortcut, fontSize } = useSettings()
const [expended, setExpend] = useState(false) const [expended, setExpend] = useState(false)
const [estimateTokenCount, setEstimateTokenCount] = useState(0) const [estimateTokenCount, setEstimateTokenCount] = useState(0)
@ -261,7 +261,7 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
<ControlOutlined /> <ControlOutlined />
</ToolbarButton> </ToolbarButton>
</Tooltip> </Tooltip>
<AttachmentButton files={files} setFiles={setFiles} ToolbarButton={ToolbarButton} /> <AttachmentButton model={model} files={files} setFiles={setFiles} ToolbarButton={ToolbarButton} />
<Tooltip placement="top" title={expended ? t('chat.input.collapse') : t('chat.input.expand')} arrow> <Tooltip placement="top" title={expended ? t('chat.input.collapse') : t('chat.input.expand')} arrow>
<ToolbarButton type="text" onClick={onToggleExpended}> <ToolbarButton type="text" onClick={onToggleExpended}>
{expended ? <FullscreenExitOutlined /> : <FullscreenOutlined />} {expended ? <FullscreenExitOutlined /> : <FullscreenOutlined />}

View File

@ -1,10 +1,10 @@
import { GoogleGenerativeAI } from '@google/generative-ai' import { Content, GoogleGenerativeAI, InlineDataPart, Part } from '@google/generative-ai'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event' import { EVENT_NAMES } from '@renderer/services/event'
import { filterContextMessages, filterMessages } from '@renderer/services/messages' import { filterContextMessages, filterMessages } from '@renderer/services/messages'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types' import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import axios from 'axios' import axios from 'axios'
import { isEmpty, takeRight } from 'lodash' import { first, isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai' import OpenAI from 'openai'
import BaseProvider from './BaseProvider' import BaseProvider from './BaseProvider'
@ -17,6 +17,27 @@ export default class GeminiProvider extends BaseProvider {
this.sdk = new GoogleGenerativeAI(provider.apiKey) this.sdk = new GoogleGenerativeAI(provider.apiKey)
} }
private async getMessageParts(message: Message): Promise<Part[]> {
const file = first(message.files)
if (file && file.type === 'image') {
const base64Data = await window.api.image.base64(file.path)
return [
{
text: message.content
},
{
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
}
} as InlineDataPart
]
}
return [{ text: message.content }]
}
public async completions( public async completions(
messages: Message[], messages: Message[],
assistant: Assistant, assistant: Assistant,
@ -29,7 +50,7 @@ export default class GeminiProvider extends BaseProvider {
const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))).map((message) => { const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))).map((message) => {
return { return {
role: message.role, role: message.role,
content: message.content message
} }
}) })
@ -44,14 +65,19 @@ export default class GeminiProvider extends BaseProvider {
const userLastMessage = userMessages.pop() const userLastMessage = userMessages.pop()
const chat = geminiModel.startChat({ const history: Content[] = []
history: userMessages.map((message) => ({
role: message.role === 'user' ? 'user' : 'model',
parts: [{ text: message.content }]
}))
})
const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!) for (const message of userMessages) {
history.push({
role: message.role === 'user' ? 'user' : 'model',
parts: await this.getMessageParts(message.message)
})
}
const chat = geminiModel.startChat({ history })
const message = await this.getMessageParts(userLastMessage?.message!)
const userMessagesStream = await chat.sendMessageStream(message)
for await (const chunk of userMessagesStream.stream) { for await (const chunk of userMessagesStream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break

View File

@ -34,12 +34,13 @@ export default class OpenAIProvider extends BaseProvider {
} }
if (file.type === 'image') { if (file.type === 'image') {
const base64Data = await window.api.image.base64(file.path)
return [ return [
{ type: 'text', text: message.content }, { type: 'text', text: message.content },
{ {
type: 'image_url', type: 'image_url',
image_url: { image_url: {
url: await window.api.image.base64(file.path) url: base64Data.data
} }
} }
] ]