feat: copy and paste files or images

This commit is contained in:
kangfenmao 2024-09-18 21:00:15 +08:00
parent 6e7e5cb1f1
commit 29605fbcdb
10 changed files with 168 additions and 108 deletions

View File

@ -41,6 +41,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle('file:clear', async () => await fileManager.clear())
ipcMain.handle('file:read', async (_, id: string) => await fileManager.readFile(id))
ipcMain.handle('file:delete', async (_, id: string) => await fileManager.deleteFile(id))
ipcMain.handle('file:get', async (_, filePath: string) => await fileManager.getFile(filePath))
ipcMain.handle('file:create', async (_, fileName: string) => await fileManager.createTempFile(fileName))
ipcMain.handle(
'file:write',
async (_, filePath: string, data: Uint8Array | string) => await fileManager.writeFile(filePath, data)
)
ipcMain.handle('minapp', (_, args) => {
createMinappWindow({

View File

@ -131,6 +131,30 @@ class File {
return fileMetadata
}
async getFile(filePath: string): Promise<FileType | null> {
if (!fs.existsSync(filePath)) {
return null
}
const stats = fs.statSync(filePath)
const ext = path.extname(filePath)
const fileType = getFileType(ext)
const fileInfo: FileType = {
id: uuidv4(),
origin_name: path.basename(filePath),
name: path.basename(filePath),
path: filePath,
created_at: stats.birthtime,
size: stats.size,
ext: ext,
type: fileType,
count: 1
}
return fileInfo
}
async deleteFile(id: string): Promise<void> {
await fs.promises.unlink(path.join(this.storageDir, id))
}
@ -140,6 +164,19 @@ class File {
return fs.readFileSync(filePath, 'utf8')
}
async createTempFile(fileName: string): Promise<string> {
const tempDir = path.join(app.getPath('temp'), 'CherryStudio')
if (!fs.existsSync(tempDir)) {
fs.mkdirSync(tempDir, { recursive: true })
}
const tempFilePath = path.join(tempDir, `${uuidv4()}_${fileName}`)
return tempFilePath
}
async writeFile(filePath: string, data: Uint8Array | string): Promise<void> {
await fs.promises.writeFile(filePath, data)
}
async base64Image(id: string): Promise<{ mime: string; base64: string; data: string }> {
const filePath = path.join(this.storageDir, id)
const data = await fs.promises.readFile(filePath)

View File

@ -28,6 +28,9 @@ declare global {
read: (fileId: string) => Promise<string>
base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }>
clear: () => Promise<void>
get: (filePath: string) => Promise<FileType | null>
create: (fileName: string) => Promise<string>
write: (filePath: string, data: Uint8Array | string) => Promise<void>
}
}
}

View File

@ -22,7 +22,10 @@ const api = {
delete: (fileId: string) => ipcRenderer.invoke('file:delete', fileId),
read: (fileId: string) => ipcRenderer.invoke('file:read', fileId),
base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId),
clear: () => ipcRenderer.invoke('file:clear')
clear: () => ipcRenderer.invoke('file:clear'),
get: (filePath: string) => ipcRenderer.invoke('file:get', filePath),
create: (fileName: string) => ipcRenderer.invoke('file:create', fileName),
write: (filePath: string, data: Uint8Array | string) => ipcRenderer.invoke('file:write', filePath, data)
}
}

View File

@ -18,6 +18,9 @@ const AttachmentButton: FC<Props> = ({ model, files, setFiles, ToolbarButton })
const extensions = isVisionModel(model) ? [...imageExts, ...textExts] : [...textExts]
const onSelectFile = async () => {
if (files.length > 0) {
return setFiles([])
}
const _files = await window.api.file.select({ filters: [{ name: 'Files', extensions }] })
_files && setFiles(_files)
}

View File

@ -8,6 +8,7 @@ import {
PauseCircleOutlined,
QuestionCircleOutlined
} from '@ant-design/icons'
import { textExts } from '@renderer/config/constant'
import db from '@renderer/databases'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useSettings } from '@renderer/hooks/useSettings'
@ -19,7 +20,7 @@ import { estimateTextTokens } from '@renderer/services/tokens'
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { setGenerating, setSearching } from '@renderer/store/runtime'
import { Assistant, FileType, Message, Topic } from '@renderer/types'
import { delay, uuid } from '@renderer/utils'
import { delay, getFileExtension, uuid } from '@renderer/utils'
import { Button, Popconfirm, Tooltip } from 'antd'
import TextArea, { TextAreaRef } from 'antd/es/input/TextArea'
import dayjs from 'dayjs'
@ -171,6 +172,44 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
const onInput = () => !expended && resizeTextArea()
const onPaste = useCallback(async (event: ClipboardEvent) => {
for (const file of event.clipboardData?.files || []) {
event.preventDefault()
const ext = getFileExtension(file.path)
if (textExts.includes(ext)) {
const selectedFile = await window.api.file.get(file.path)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
}
}
if (event.clipboardData?.items) {
const item = event.clipboardData.items[0]
const file = item.getAsFile()
if (file && file.type.startsWith('image/')) {
const tempFilePath = await window.api.file.create(file.name)
const arrayBuffer = await file.arrayBuffer()
const uint8Array = new Uint8Array(arrayBuffer)
await window.api.file.write(tempFilePath, uint8Array)
const selectedFile = await window.api.file.get(tempFilePath)
selectedFile && setFiles((prevFiles) => [...prevFiles, selectedFile])
}
// if (item.kind === 'string' && item.type === 'text/plain') {
// // 处理文本内容
// await new Promise<void>((resolve) => {
// item.getAsString(async (text) => {
// const tempFilePath = await window.api.file.create('pasted_text.txt')
// await window.api.file.write(tempFilePath, text)
// const selectedFile = await window.api.file.get(tempFilePath)
// if (selectedFile) {
// newFiles.push(selectedFile)
// }
// resolve()
// })
// })
// }
}
}, [])
// Command or Ctrl + N create new topic
useEffect(() => {
const onKeydown = (e) => {
@ -206,6 +245,11 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
textareaRef.current?.focus()
}, [assistant])
useEffect(() => {
document.addEventListener('paste', onPaste)
return () => document.removeEventListener('paste', onPaste)
}, [onPaste])
return (
<Container>
<AttachmentPreview files={files} setFiles={setFiles} />

View File

@ -18,49 +18,33 @@ export default class AnthropicProvider extends BaseProvider {
this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() })
}
private async getMessageParam(message: Message): Promise<MessageParam[]> {
const file = first(message.files)
private async getMessageParam(message: Message): Promise<MessageParam> {
const parts: MessageParam['content'] = [{ type: 'text', text: message.content }]
if (file) {
for (const file of message.files || []) {
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
return [
{
role: message.role,
content: [
{ type: 'text', text: message.content },
{
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
}
]
} as MessageParam
]
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
})
}
if (file.type === FileTypes.TEXT) {
return [
{
role: message.role,
content: message.content
} as MessageParam,
{
role: 'assistant',
content: (await window.api.file.read(file.id + file.ext)).trimEnd()
} as MessageParam
]
parts.push({
type: 'text',
text: (await window.api.file.read(file.id + file.ext)).trimEnd()
})
}
}
return [
{
role: message.role,
content: message.content
} as MessageParam
]
return {
role: message.role,
content: parts
}
}
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {

View File

@ -1,10 +1,10 @@
import { Content, GoogleGenerativeAI, InlineDataPart, TextPart } from '@google/generative-ai'
import { Content, GoogleGenerativeAI, InlineDataPart, Part, TextPart } from '@google/generative-ai'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event'
import { filterContextMessages, filterMessages } from '@renderer/services/messages'
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
import axios from 'axios'
import { first, flatten, isEmpty, takeRight } from 'lodash'
import { flatten, isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai'
import BaseProvider from './BaseProvider'
@ -17,48 +17,37 @@ export default class GeminiProvider extends BaseProvider {
this.sdk = new GoogleGenerativeAI(provider.apiKey)
}
private async getMessageContents(message: Message): Promise<Content[]> {
const file = first(message.files)
private async getMessageContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
if (file) {
const parts: Part[] = [
{
type: 'text',
text: message.content
} as TextPart
]
for (const file of message.files || []) {
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
return [
{
role: message.role,
parts: [
{ text: message.content } as TextPart,
{
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
}
} as InlineDataPart
]
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
}
]
} as InlineDataPart)
}
if (file.type === FileTypes.TEXT) {
return [
{
role: 'model',
parts: [{ text: await window.api.file.read(file.id + file.ext) } as TextPart]
},
{
role,
parts: [{ text: message.content } as TextPart]
}
]
parts.push({
text: await window.api.file.read(file.id + file.ext)
} as TextPart)
}
}
return [
{
role,
parts: [{ text: message.content } as TextPart]
}
]
return {
role,
parts: parts
}
}
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {

View File

@ -33,49 +33,34 @@ export default class OpenAIProvider extends BaseProvider {
return true
}
private async getMessageParam(message: Message): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> {
const file = first(message.files)
private async getMessageParam(message: Message): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam> {
const parts: ChatCompletionContentPart[] = [
{
type: 'text',
text: message.content
}
]
const content: string | ChatCompletionContentPart[] = message.content
if (file) {
for (const file of message.files || []) {
if (file.type === FileTypes.IMAGE) {
const image = await window.api.file.base64Image(file.id + file.ext)
return [
{
role: message.role,
content: [
{ type: 'text', text: message.content },
{
type: 'image_url',
image_url: {
url: image.data
}
}
]
} as ChatCompletionMessageParam
]
parts.push({
type: 'image_url',
image_url: { url: image.data }
})
}
if (file.type === FileTypes.TEXT) {
return [
{
role: 'assistant',
content: await window.api.file.read(file.id + file.ext)
} as ChatCompletionMessageParam,
{
role: message.role,
content
} as ChatCompletionMessageParam
]
parts.push({
type: 'text',
text: await window.api.file.read(file.id + file.ext)
})
}
}
return [
{
role: message.role,
content
} as ChatCompletionMessageParam
]
return {
role: message.role,
content: parts
} as ChatCompletionMessageParam
}
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
@ -84,13 +69,13 @@ export default class OpenAIProvider extends BaseProvider {
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
let userMessages: ChatCompletionMessageParam[] = []
const userMessages: ChatCompletionMessageParam[] = []
const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
onFilterMessages(_messages)
for (const message of _messages) {
userMessages = userMessages.concat(await this.getMessageParam(message))
userMessages.push(await this.getMessageParam(message))
}
// @ts-ignore key is not typed

View File

@ -235,3 +235,9 @@ export function getFileDirectory(filePath: string) {
const directory = parts.slice(0, -1).join('/')
return directory
}
export function getFileExtension(filePath: string) {
const parts = filePath.split('.')
const extension = parts.slice(-1)[0]
return '.' + extension
}