diff --git a/package.json b/package.json index a008d836..fb5b75b4 100644 --- a/package.json +++ b/package.json @@ -59,6 +59,7 @@ "dotenv-cli": "^7.4.2", "electron": "^28.3.3", "electron-builder": "^24.9.1", + "electron-devtools-installer": "^3.2.0", "electron-vite": "^2.0.0", "emittery": "^1.0.3", "emoji-picker-element": "^1.22.1", @@ -67,7 +68,7 @@ "eslint-plugin-react-hooks": "^4.6.2", "eslint-plugin-simple-import-sort": "^12.1.1", "eslint-plugin-unused-imports": "^4.0.0", - "gpt-tokens": "^1.3.6", + "gpt-tokens": "^1.3.10", "i18next": "^23.11.5", "localforage": "^1.10.0", "lodash": "^4.17.21", diff --git a/src/main/index.ts b/src/main/index.ts index 2c96ff14..b1d17b4b 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -1,5 +1,6 @@ import { electronApp, optimizer } from '@electron-toolkit/utils' import { app, BrowserWindow } from 'electron' +import installExtension, { REDUX_DEVTOOLS } from 'electron-devtools-installer' import { registerIpc } from './ipc' import { updateUserDataPath } from './utils/upgrade' @@ -30,6 +31,12 @@ app.whenReady().then(async () => { const mainWindow = createMainWindow() registerIpc(mainWindow, app) + + if (process.env.NODE_ENV === 'development') { + installExtension(REDUX_DEVTOOLS) + .then((name) => console.log(`Added Extension: ${name}`)) + .catch((err) => console.log('An error occurred: ', err)) + } }) // Quit when all windows are closed, except on macOS. There, it's common diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 6f39d333..8558a96c 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -1,8 +1,5 @@ import { FileType } from '@types' import { BrowserWindow, ipcMain, OpenDialogOptions, session, shell } from 'electron' -import Logger from 'electron-log' -import fs from 'fs' -import path from 'path' import { appConfig, titleBarOverlayDark, titleBarOverlayLight } from './config' import AppUpdater from './services/AppUpdater' @@ -38,29 +35,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle('zip:compress', (_, text: string) => compress(text)) ipcMain.handle('zip:decompress', (_, text: Buffer) => decompress(text)) - ipcMain.handle('image:base64', async (_, filePath) => { - try { - const data = await fs.promises.readFile(filePath) - const base64 = data.toString('base64') - const mime = `image/${path.extname(filePath).slice(1)}` - return { - mime, - base64, - data: `data:${mime};base64,${base64}` - } - } catch (error) { - Logger.error('Error reading file:', error) - return '' - } - }) - + ipcMain.handle('file:base64Image', async (_, id) => await fileManager.base64Image(id)) ipcMain.handle('file:select', async (_, options?: OpenDialogOptions) => await fileManager.selectFile(options)) ipcMain.handle('file:upload', async (_, file: FileType) => await fileManager.uploadFile(file)) ipcMain.handle('file:clear', async () => await fileManager.clear()) - ipcMain.handle('file:delete', async (_, fileId: string) => { - await fileManager.deleteFile(fileId) - return { success: true } - }) + ipcMain.handle('file:read', async (_, id: string) => await fileManager.readFile(id)) + ipcMain.handle('file:delete', async (_, id: string) => await fileManager.deleteFile(id)) ipcMain.handle('minapp', (_, args) => { createMinappWindow({ diff --git a/src/main/services/File.ts b/src/main/services/File.ts index 0f551d53..8d511353 100644 --- a/src/main/services/File.ts +++ b/src/main/services/File.ts @@ -135,6 +135,23 @@ class File { await fs.promises.unlink(path.join(this.storageDir, id)) } + async readFile(id: string): Promise { + const filePath = path.join(this.storageDir, id) + return fs.readFileSync(filePath, 'utf8') + } + + async base64Image(id: string): Promise<{ mime: string; base64: string; data: string }> { + const filePath = path.join(this.storageDir, id) + const data = await fs.promises.readFile(filePath) + const base64 = data.toString('base64') + const mime = `image/${path.extname(filePath).slice(1)}` + return { + mime, + base64, + data: `data:${mime};base64,${base64}` + } + } + async clear(): Promise { await fs.promises.rmdir(this.storageDir, { recursive: true }) await this.initStorageDir() diff --git a/src/main/utils/file.ts b/src/main/utils/file.ts index ab979a90..9bfdf58c 100644 --- a/src/main/utils/file.ts +++ b/src/main/utils/file.ts @@ -56,12 +56,103 @@ export function getFileType(ext: string): FileTypes { const imageExts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'] const videoExts = ['.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv'] const audioExts = ['.mp3', '.wav', '.ogg', '.flac', '.aac'] - const documentExts = ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt'] + const documentExts = ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx'] + const textExts = [ + '.txt', // 普通文本文件 + '.md', // Markdown 文件 + '.mdx', // Markdown 文件 + '.html', // HTML 文件 + '.htm', // HTML 文件的另一种扩展名 + '.xml', // XML 文件 + '.json', // JSON 文件 + '.yaml', // YAML 文件 + '.yml', // YAML 文件的另一种扩展名 + '.csv', // 逗号分隔值文件 + '.tsv', // 制表符分隔值文件 + '.ini', // 配置文件 + '.log', // 日志文件 + '.rtf', // 富文本格式文件 + '.tex', // LaTeX 文件 + '.srt', // 字幕文件 + '.xhtml', // XHTML 文件 + '.nfo', // 信息文件(主要用于场景发布) + '.conf', // 配置文件 + '.config', // 配置文件 + '.env', // 环境变量文件 + '.properties', // 配置属性文件 + '.latex', // LaTeX 文档文件 + '.rst', // reStructuredText 文件 + '.php', // PHP 脚本文件,包含嵌入的 HTML + '.js', // JavaScript 文件(部分是文本,部分可能包含代码) + '.ts', // TypeScript 文件 + '.jsp', // JavaServer Pages 文件 + '.aspx', // ASP.NET 文件 + '.bat', // Windows 批处理文件 + '.sh', // Unix/Linux Shell 脚本文件 + '.py', // Python 脚本文件 + '.rb', // Ruby 脚本文件 + '.pl', // Perl 脚本文件 + '.sql', // SQL 脚本文件 + '.css', // Cascading Style Sheets 文件 + '.less', // Less CSS 预处理器文件 + '.scss', // Sass CSS 预处理器文件 + '.sass', // Sass 文件 + '.styl', // Stylus CSS 预处理器文件 + '.coffee', // CoffeeScript 文件 + '.ino', // Arduino 代码文件 + '.ino', // Arduino 代码文件 + '.asm', // Assembly 语言文件 + '.go', // Go 语言文件 + '.scala', // Scala 语言文件 + '.swift', // Swift 语言文件 + '.kt', // Kotlin 语言文件 + '.rs', // Rust 语言文件 + '.lua', // Lua 语言文件 + '.groovy', // Groovy 语言文件 + '.dart', // Dart 语言文件 + '.hs', // Haskell 语言文件 + '.clj', // Clojure 语言文件 + '.cljs', // ClojureScript 语言文件 + '.elm', // Elm 语言文件 + '.erl', // Erlang 语言文件 + '.ex', // Elixir 语言文件 + '.exs', // Elixir 脚本文件 + '.pug', // Pug (formerly Jade) 模板文件 + '.haml', // Haml 模板文件 + '.slim', // Slim 模板文件 + '.tpl', // 模板文件(通用) + '.ejs', // Embedded JavaScript 模板文件 + '.hbs', // Handlebars 模板文件 + '.mustache', // Mustache 模板文件 + '.jade', // Jade 模板文件 (已重命名为 Pug) + '.twig', // Twig 模板文件 + '.blade', // Blade 模板文件 (Laravel) + '.vue', // Vue.js 单文件组件 + '.jsx', // React JSX 文件 + '.tsx', // React TSX 文件 + '.graphql', // GraphQL 查询语言文件 + '.gql', // GraphQL 查询语言文件 + '.proto', // Protocol Buffers 文件 + '.thrift', // Thrift 文件 + '.toml', // TOML 配置文件 + '.edn', // Clojure 数据表示文件 + '.cake', // CakePHP 配置文件 + '.ctp', // CakePHP 视图文件 + '.cfm', // ColdFusion 标记语言文件 + '.cfc', // ColdFusion 组件文件 + '.m', // Objective-C 源文件 + '.mm', // Objective-C++ 源文件 + '.gradle', // Gradle 构建文件 + '.groovy', // Gradle 构建文件 + '.gradle', // Gradle 构建文件 + '.kts' // Kotlin Script 文件 + ] ext = ext.toLowerCase() if (imageExts.includes(ext)) return FileTypes.IMAGE if (videoExts.includes(ext)) return FileTypes.VIDEO if (audioExts.includes(ext)) return FileTypes.AUDIO + if (textExts.includes(ext)) return FileTypes.TEXT if (documentExts.includes(ext)) return FileTypes.DOCUMENT return FileTypes.OTHER } diff --git a/src/preload/index.d.ts b/src/preload/index.d.ts index 3e5b1123..ac379e1d 100644 --- a/src/preload/index.d.ts +++ b/src/preload/index.d.ts @@ -24,12 +24,11 @@ declare global { file: { select: (options?: OpenDialogOptions) => Promise upload: (file: FileType) => Promise - delete: (fileId: string) => Promise<{ success: boolean }> + delete: (fileId: string) => Promise + read: (fileId: string) => Promise + base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }> clear: () => Promise } - image: { - base64: (filePath: string) => Promise<{ mime: string; base64: string; data: string }> - } } } } diff --git a/src/preload/index.ts b/src/preload/index.ts index f7ed8ee8..d1e85c7c 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -20,10 +20,9 @@ const api = { select: (options?: OpenDialogOptions) => ipcRenderer.invoke('file:select', options), upload: (filePath: string) => ipcRenderer.invoke('file:upload', filePath), delete: (fileId: string) => ipcRenderer.invoke('file:delete', fileId), + read: (fileId: string) => ipcRenderer.invoke('file:read', fileId), + base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId), clear: () => ipcRenderer.invoke('file:clear') - }, - image: { - base64: (filePath: string) => ipcRenderer.invoke('image:base64', filePath) } } diff --git a/src/renderer/src/assets/images/models/minicpm.webp b/src/renderer/src/assets/images/models/minicpm.webp new file mode 100644 index 00000000..15f66155 Binary files /dev/null and b/src/renderer/src/assets/images/models/minicpm.webp differ diff --git a/src/renderer/src/config/constant.ts b/src/renderer/src/config/constant.ts index b33c51a9..e2cd9a94 100644 --- a/src/renderer/src/config/constant.ts +++ b/src/renderer/src/config/constant.ts @@ -7,3 +7,95 @@ export const platform = window.electron?.process?.platform export const isMac = platform === 'darwin' export const isWindows = platform === 'win32' || platform === 'win64' export const isLinux = platform === 'linux' + +export const imageExts = ['jpg', 'png', 'jpeg'] +export const textExts = [ + '.txt', // 普通文本文件 + '.md', // Markdown 文件 + '.mdx', // Markdown 文件 + '.html', // HTML 文件 + '.htm', // HTML 文件的另一种扩展名 + '.xml', // XML 文件 + '.json', // JSON 文件 + '.yaml', // YAML 文件 + '.yml', // YAML 文件的另一种扩展名 + '.csv', // 逗号分隔值文件 + '.tsv', // 制表符分隔值文件 + '.ini', // 配置文件 + '.log', // 日志文件 + '.rtf', // 富文本格式文件 + '.tex', // LaTeX 文件 + '.srt', // 字幕文件 + '.xhtml', // XHTML 文件 + '.nfo', // 信息文件(主要用于场景发布) + '.conf', // 配置文件 + '.config', // 配置文件 + '.env', // 环境变量文件 + '.properties', // 配置属性文件 + '.latex', // LaTeX 文档文件 + '.rst', // reStructuredText 文件 + '.php', // PHP 脚本文件,包含嵌入的 HTML + '.js', // JavaScript 文件(部分是文本,部分可能包含代码) + '.ts', // TypeScript 文件 + '.jsp', // JavaServer Pages 文件 + '.aspx', // ASP.NET 文件 + '.bat', // Windows 批处理文件 + '.sh', // Unix/Linux Shell 脚本文件 + '.py', // Python 脚本文件 + '.rb', // Ruby 脚本文件 + '.pl', // Perl 脚本文件 + '.sql', // SQL 脚本文件 + '.css', // Cascading Style Sheets 文件 + '.less', // Less CSS 预处理器文件 + '.scss', // Sass CSS 预处理器文件 + '.sass', // Sass 文件 + '.styl', // Stylus CSS 预处理器文件 + '.coffee', // CoffeeScript 文件 + '.ino', // Arduino 代码文件 + '.ino', // Arduino 代码文件 + '.asm', // Assembly 语言文件 + '.go', // Go 语言文件 + '.scala', // Scala 语言文件 + '.swift', // Swift 语言文件 + '.kt', // Kotlin 语言文件 + '.rs', // Rust 语言文件 + '.lua', // Lua 语言文件 + '.groovy', // Groovy 语言文件 + '.dart', // Dart 语言文件 + '.hs', // Haskell 语言文件 + '.clj', // Clojure 语言文件 + '.cljs', // ClojureScript 语言文件 + '.elm', // Elm 语言文件 + '.erl', // Erlang 语言文件 + '.ex', // Elixir 语言文件 + '.exs', // Elixir 脚本文件 + '.pug', // Pug (formerly Jade) 模板文件 + '.haml', // Haml 模板文件 + '.slim', // Slim 模板文件 + '.tpl', // 模板文件(通用) + '.ejs', // Embedded JavaScript 模板文件 + '.hbs', // Handlebars 模板文件 + '.mustache', // Mustache 模板文件 + '.jade', // Jade 模板文件 (已重命名为 Pug) + '.twig', // Twig 模板文件 + '.blade', // Blade 模板文件 (Laravel) + '.vue', // Vue.js 单文件组件 + '.jsx', // React JSX 文件 + '.tsx', // React TSX 文件 + '.graphql', // GraphQL 查询语言文件 + '.gql', // GraphQL 查询语言文件 + '.proto', // Protocol Buffers 文件 + '.thrift', // Thrift 文件 + '.toml', // TOML 配置文件 + '.edn', // Clojure 数据表示文件 + '.cake', // CakePHP 配置文件 + '.ctp', // CakePHP 视图文件 + '.cfm', // ColdFusion 标记语言文件 + '.cfc', // ColdFusion 组件文件 + '.m', // Objective-C 源文件 + '.mm', // Objective-C++ 源文件 + '.gradle', // Gradle 构建文件 + '.groovy', // Gradle 构建文件 + '.gradle', // Gradle 构建文件 + '.kts' // Kotlin Script 文件 +] diff --git a/src/renderer/src/config/provider.ts b/src/renderer/src/config/provider.ts index c9b7a17a..5593696d 100644 --- a/src/renderer/src/config/provider.ts +++ b/src/renderer/src/config/provider.ts @@ -14,6 +14,7 @@ import GemmaModelLogo from '@renderer/assets/images/models/gemma.jpeg' import HailuoModelLogo from '@renderer/assets/images/models/hailuo.png' import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg' import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png' +import MinicpmModelLogo from '@renderer/assets/images/models/minicpm.webp' import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg' import PalmModelLogo from '@renderer/assets/images/models/palm.svg' import QwenModelLogo from '@renderer/assets/images/models/qwen.png' @@ -91,6 +92,7 @@ export function getModelLogo(modelId: string) { } const logoMap = { + o1: OpenAiProviderLogo, gpt: ChatGPTModelLogo, glm: ChatGLMModelLogo, deepseek: DeepSeekModelLogo, @@ -112,7 +114,8 @@ export function getModelLogo(modelId: string) { abab: HailuoModelLogo, 'ep-202': DoubaoModelLogo, cohere: CohereModelLogo, - command: CohereModelLogo + command: CohereModelLogo, + minicpm: MinicpmModelLogo } for (const key in logoMap) { diff --git a/src/renderer/src/i18n/index.ts b/src/renderer/src/i18n/index.ts index 51c6bfe3..320dec3b 100644 --- a/src/renderer/src/i18n/index.ts +++ b/src/renderer/src/i18n/index.ts @@ -88,7 +88,7 @@ const resources = { 'input.send': 'Send', 'input.pause': 'Pause', 'input.settings': 'Settings', - 'input.upload': 'Upload image png、jpg、jpeg', + 'input.upload': 'Upload image or text file', 'input.context_count.tip': 'Context Count', 'input.estimated_tokens.tip': 'Estimated tokens', 'settings.temperature': 'Temperature', @@ -356,7 +356,7 @@ const resources = { 'input.send': '发送', 'input.pause': '暂停', 'input.settings': '设置', - 'input.upload': '上传图片 png、jpg、jpeg', + 'input.upload': '上传图片或纯文本文件', 'input.context_count.tip': '上下文数', 'input.estimated_tokens.tip': '预估 token 数', 'settings.temperature': '模型温度', diff --git a/src/renderer/src/pages/files/FilesPage.tsx b/src/renderer/src/pages/files/FilesPage.tsx index 62c10cd7..2ba4b82d 100644 --- a/src/renderer/src/pages/files/FilesPage.tsx +++ b/src/renderer/src/pages/files/FilesPage.tsx @@ -1,7 +1,8 @@ import { Navbar, NavbarCenter } from '@renderer/components/app/Navbar' import { VStack } from '@renderer/components/Layout' import db from '@renderer/databases' -import { FileType } from '@renderer/types' +import { FileType, FileTypes } from '@renderer/types' +import { getFileDirectory } from '@renderer/utils' import { Image, Table } from 'antd' import dayjs from 'dayjs' import { useLiveQuery } from 'dexie-react-hooks' @@ -13,13 +14,17 @@ const FilesPage: FC = () => { const { t } = useTranslation() const files = useLiveQuery(() => db.files.toArray()) - const dataSource = files?.map((file) => ({ - key: file.id, - file: , - name: {file.origin_name}, - size: `${(file.size / 1024 / 1024).toFixed(2)} MB`, - created_at: dayjs(file.created_at).format('MM-DD HH:mm') - })) + const dataSource = files?.map((file) => { + const isImage = file.type === FileTypes.IMAGE + const ImageView = + return { + key: file.id, + file: isImage ? ImageView : file.origin_name, + name: {file.origin_name}, + size: `${(file.size / 1024 / 1024).toFixed(2)} MB`, + created_at: dayjs(file.created_at).format('MM-DD HH:mm') + } + }) const columns = [ { diff --git a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx index 9fe7dff3..474f7af6 100644 --- a/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx +++ b/src/renderer/src/pages/home/Inputbar/AttachmentButton.tsx @@ -1,4 +1,5 @@ import { PaperClipOutlined } from '@ant-design/icons' +import { imageExts, textExts } from '@renderer/config/constant' import { isVisionModel } from '@renderer/config/models' import { FileType, Model } from '@renderer/types' import { Tooltip } from 'antd' @@ -14,18 +15,13 @@ interface Props { const AttachmentButton: FC = ({ model, files, setFiles, ToolbarButton }) => { const { t } = useTranslation() + const extensions = isVisionModel(model) ? [...imageExts, ...textExts] : [...textExts] const onSelectFile = async () => { - const _files = await window.api.file.select({ - filters: [{ name: 'Files', extensions: ['jpg', 'png', 'jpeg'] }] - }) + const _files = await window.api.file.select({ filters: [{ name: 'Files', extensions }] }) _files && setFiles(_files) } - if (!isVisionModel(model)) { - return null - } - return ( diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index c5ea036c..5f8d5bbc 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -15,7 +15,7 @@ import { useRuntime, useShowTopics } from '@renderer/hooks/useStore' import { getDefaultTopic } from '@renderer/services/assistant' import { EVENT_NAMES, EventEmitter } from '@renderer/services/event' import FileManager from '@renderer/services/file' -import { estimateInputTokenCount } from '@renderer/services/messages' +import { estimateTextTokens } from '@renderer/services/tokens' import store, { useAppDispatch, useAppSelector } from '@renderer/store' import { setGenerating, setSearching } from '@renderer/store/runtime' import { Assistant, FileType, Message, Topic } from '@renderer/types' @@ -92,7 +92,7 @@ const Inputbar: FC = ({ assistant, setActiveTopic }) => { setExpend(false) }, [assistant.id, assistant.topics, generating, files, text]) - const inputTokenCount = useMemo(() => estimateInputTokenCount(text), [text]) + const inputTokenCount = useMemo(() => estimateTextTokens(text), [text]) const handleKeyDown = (event: React.KeyboardEvent) => { const isEnterPressed = event.keyCode == 13 diff --git a/src/renderer/src/pages/home/Inputbar/TokenCount.tsx b/src/renderer/src/pages/home/Inputbar/TokenCount.tsx index 7d53f335..8d142185 100644 --- a/src/renderer/src/pages/home/Inputbar/TokenCount.tsx +++ b/src/renderer/src/pages/home/Inputbar/TokenCount.tsx @@ -44,8 +44,8 @@ const TokenCount: FC = ({ estimateTokenCount, inputTokenCount, contextCou - - + + {contextCount} diff --git a/src/renderer/src/pages/home/Messages/Message.tsx b/src/renderer/src/pages/home/Messages/Message.tsx index dfbcf576..51e57195 100644 --- a/src/renderer/src/pages/home/Messages/Message.tsx +++ b/src/renderer/src/pages/home/Messages/Message.tsx @@ -46,14 +46,12 @@ const MessageItem: FC = ({ message, index, showMenu, onDeleteMessage }) = const { assistant, setModel } = useAssistant(message.assistantId) const model = useModel(message.modelId) const { userName, showMessageDivider, messageFont, fontSize } = useSettings() - const { generating } = useRuntime() const [copied, setCopied] = useState(false) const isLastMessage = index === 0 const isUserMessage = message.role === 'user' const isAssistantMessage = message.role === 'assistant' const canRegenerate = isLastMessage && isAssistantMessage - const showMetadata = Boolean(message.usage) && !generating const onCopy = useCallback(() => { navigator.clipboard.writeText(removeTrailingDoubleSpaces(message.content)) @@ -133,7 +131,7 @@ const MessageItem: FC = ({ message, index, showMenu, onDeleteMessage }) = style={{ borderRadius: '20%', cursor: 'pointer', - border: isLocalAi ? '1px solid var(--color-border)' : '' + border: '1px solid var(--color-border)' }} onClick={showMiniApp}> {avatarName} @@ -206,18 +204,39 @@ const MessageItem: FC = ({ message, index, showMenu, onDeleteMessage }) = )} )} - {showMetadata && ( - - Tokens: {message?.usage?.total_tokens} | ↑{message?.usage?.prompt_tokens} | ↓ - {message?.usage?.completion_tokens} - - )} + ) } +const MessgeTokens: React.FC<{ message: Message }> = ({ message }) => { + const { generating } = useRuntime() + + if (!message.usage) { + return null + } + + if (message.role === 'user') { + return Tokens: {message?.usage?.total_tokens} + } + + if (generating) { + return null + } + + if (message.role === 'assistant') { + return ( + + Tokens: {message?.usage?.total_tokens} | ↑{message?.usage?.prompt_tokens} | ↓{message?.usage?.completion_tokens} + + ) + } + + return null +} + const MessageContent: React.FC<{ message: Message }> = ({ message }) => { const { t } = useTranslation() diff --git a/src/renderer/src/pages/home/Messages/MessageAttachments.tsx b/src/renderer/src/pages/home/Messages/MessageAttachments.tsx index 8c7044d9..e82e99ab 100644 --- a/src/renderer/src/pages/home/Messages/MessageAttachments.tsx +++ b/src/renderer/src/pages/home/Messages/MessageAttachments.tsx @@ -1,5 +1,6 @@ -import { Message } from '@renderer/types' -import { Image as AntdImage } from 'antd' +import { FileTypes, Message } from '@renderer/types' +import { getFileDirectory } from '@renderer/utils' +import { Image as AntdImage, Upload } from 'antd' import { FC } from 'react' import styled from 'styled-components' @@ -8,9 +9,27 @@ interface Props { } const MessageAttachments: FC = ({ message }) => { + if (message?.files && message.files[0]?.type === FileTypes.IMAGE) { + return ( + + {message.files?.map((image) => )} + + ) + } + return ( - - {message.files?.map((image) => )} + + item.url && window.open(getFileDirectory(item.url))} + fileList={message.files?.map((file) => ({ + uid: file.id, + url: 'file://' + file.path, + status: 'done', + name: file.origin_name + }))} + /> ) } diff --git a/src/renderer/src/pages/home/Messages/Messages.tsx b/src/renderer/src/pages/home/Messages/Messages.tsx index 6e6c8046..ef885e76 100644 --- a/src/renderer/src/pages/home/Messages/Messages.tsx +++ b/src/renderer/src/pages/home/Messages/Messages.tsx @@ -4,16 +4,12 @@ import { getTopic, TopicManager } from '@renderer/hooks/useTopic' import { fetchChatCompletion, fetchMessagesSummary } from '@renderer/services/api' import { getDefaultTopic } from '@renderer/services/assistant' import { EVENT_NAMES, EventEmitter } from '@renderer/services/event' -import { - deleteMessageFiles, - estimateHistoryTokenCount, - filterMessages, - getContextCount -} from '@renderer/services/messages' +import { deleteMessageFiles, filterMessages, getContextCount } from '@renderer/services/messages' +import { estimateHistoryTokens, estimateMessageUsage } from '@renderer/services/tokens' import { Assistant, Message, Model, Topic } from '@renderer/types' import { getBriefInfo, runAsyncFunction, uuid } from '@renderer/utils' import { t } from 'i18next' -import { last, reverse, take } from 'lodash' +import { flatten, last, reverse, take } from 'lodash' import { FC, useCallback, useEffect, useRef, useState } from 'react' import styled from 'styled-components' @@ -34,12 +30,15 @@ const Messages: FC = ({ assistant, topic, setActiveTopic }) => { const { updateTopic, addTopic } = useAssistant(assistant.id) const onSendMessage = useCallback( - (message: Message) => { + async (message: Message) => { + if (message.role === 'user') { + message.usage = await estimateMessageUsage(message) + } const _messages = [...messages, message] setMessages(_messages) db.topics.put({ id: topic.id, messages: _messages }) }, - [messages, topic] + [messages, topic.id] ) const autoRenameTopic = useCallback(async () => { @@ -68,9 +67,14 @@ const Messages: FC = ({ assistant, topic, setActiveTopic }) => { const unsubscribes = [ EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, async (msg: Message) => { onSendMessage(msg) - fetchChatCompletion({ assistant, messages: [...messages, msg], topic, onResponse: setLastMessage }) + fetchChatCompletion({ + assistant, + messages: [...messages, msg], + topic, + onResponse: setLastMessage + }) }), - EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => { + EventEmitter.on(EVENT_NAMES.RECEIVE_MESSAGE, async (msg: Message) => { setLastMessage(null) onSendMessage(msg) setTimeout(() => EventEmitter.emit(EVENT_NAMES.AI_AUTO_RENAME), 100) @@ -98,6 +102,7 @@ const Messages: FC = ({ assistant, topic, setActiveTopic }) => { const lastMessage = last(messages) if (lastMessage && lastMessage.type === 'clear') { + onDeleteMessage(lastMessage) return } @@ -117,16 +122,37 @@ const Messages: FC = ({ assistant, topic, setActiveTopic }) => { } as Message) }), EventEmitter.on(EVENT_NAMES.NEW_BRANCH, async (index: number) => { - const _topic = getDefaultTopic() - _topic.name = topic.name - await db.topics.add({ id: _topic.id, messages: take(messages, messages.length - index) }) - addTopic(_topic) - setActiveTopic(_topic) + const newTopic = getDefaultTopic() + newTopic.name = topic.name + const branchMessages = take(messages, messages.length - index) + + // 将分支的消息放入数据库 + await db.topics.add({ id: newTopic.id, messages: branchMessages }) + addTopic(newTopic) + setActiveTopic(newTopic) autoRenameTopic() + + // 由于复制了消息,消息中附带的文件的总数变了,需要更新 + const filesArr = branchMessages.map((m) => m.files) + const files = flatten(filesArr).filter(Boolean) + files.map(async (f) => { + const file = await db.files.get({ id: f?.id }) + file && db.files.update(file.id, { count: file.count + 1 }) + }) }) ] return () => unsubscribes.forEach((unsub) => unsub()) - }, [addTopic, assistant, autoRenameTopic, messages, onSendMessage, setActiveTopic, topic, updateTopic]) + }, [ + addTopic, + assistant, + autoRenameTopic, + messages, + onDeleteMessage, + onSendMessage, + setActiveTopic, + topic, + updateTopic + ]) useEffect(() => { runAsyncFunction(async () => { @@ -140,9 +166,11 @@ const Messages: FC = ({ assistant, topic, setActiveTopic }) => { }, [messages]) useEffect(() => { - EventEmitter.emit(EVENT_NAMES.ESTIMATED_TOKEN_COUNT, { - tokensCount: estimateHistoryTokenCount(assistant, messages), - contextCount: getContextCount(assistant, messages) + runAsyncFunction(async () => { + EventEmitter.emit(EVENT_NAMES.ESTIMATED_TOKEN_COUNT, { + tokensCount: await estimateHistoryTokens(assistant, messages), + contextCount: getContextCount(assistant, messages) + }) }) }, [assistant, messages]) diff --git a/src/renderer/src/pages/home/components/Suggestions.tsx b/src/renderer/src/pages/home/components/Suggestions.tsx index 83b34062..9cdb849a 100644 --- a/src/renderer/src/pages/home/components/Suggestions.tsx +++ b/src/renderer/src/pages/home/components/Suggestions.tsx @@ -37,7 +37,7 @@ const Suggestions: FC = ({ assistant, messages, lastMessage }) => { useEffect(() => { const unsubscribes = [ - EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => { + EventEmitter.on(EVENT_NAMES.RECEIVE_MESSAGE, async (msg: Message) => { setLoadingSuggestions(true) const _suggestions = await fetchSuggestions({ assistant, messages: [...messages, msg] }) if (_suggestions.length) { diff --git a/src/renderer/src/providers/AiProvider.ts b/src/renderer/src/providers/AiProvider.ts index 2ee7a295..82a5c0c1 100644 --- a/src/renderer/src/providers/AiProvider.ts +++ b/src/renderer/src/providers/AiProvider.ts @@ -10,12 +10,8 @@ export default class AiProvider { this.sdk = ProviderFactory.create(provider) } - public async completions( - messages: Message[], - assistant: Assistant, - onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void - ): Promise { - return this.sdk.completions(messages, assistant, onChunk) + public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise { + return this.sdk.completions({ messages, assistant, onChunk, onFilterMessages }) } public async translate(message: Message, assistant: Assistant): Promise { diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index abd4b5be..d4e33e4c 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -4,8 +4,8 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { EVENT_NAMES } from '@renderer/services/event' import { filterContextMessages, filterMessages } from '@renderer/services/messages' -import { Assistant, Message, Provider, Suggestion } from '@renderer/types' -import { first, sum, takeRight } from 'lodash' +import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' +import { first, flatten, sum, takeRight } from 'lodash' import OpenAI from 'openai' import BaseProvider from './BaseProvider' @@ -18,49 +18,67 @@ export default class AnthropicProvider extends BaseProvider { this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() }) } - private async getMessageContent(message: Message): Promise { + private async getMessageParam(message: Message): Promise { const file = first(message.files) - if (!file) { - return message.content + if (file) { + if (file.type === FileTypes.IMAGE) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + return [ + { + role: message.role, + content: [ + { type: 'text', text: message.content }, + { + type: 'image', + source: { + data: base64Data.base64, + media_type: base64Data.mime.replace('jpg', 'jpeg') as any, + type: 'base64' + } + } + ] + } as MessageParam + ] + } + if (file.type === FileTypes.TEXT) { + return [ + { + role: message.role, + content: message.content + } as MessageParam, + { + role: 'assistant', + content: (await window.api.file.read(file.id + file.ext)).trimEnd() + } as MessageParam + ] + } } - if (file.type === 'image') { - const base64Data = await window.api.image.base64(file.path) - return [ - { type: 'text', text: message.content }, - { - type: 'image', - source: { - data: base64Data.base64, - media_type: base64Data.mime.replace('jpg', 'jpeg') as any, - type: 'base64' - } - } - ] - } - - return message.content + return [ + { + role: message.role, + content: message.content + } as MessageParam + ] } - public async completions( - messages: Message[], - assistant: Assistant, - onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void - ) { + public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens } = getAssistantSettings(assistant) - const userMessages: MessageParam[] = [] + let userMessagesParams: MessageParam[][] = [] + const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 2))) - for (const message of filterMessages(filterContextMessages(takeRight(messages, contextCount + 2)))) { - userMessages.push({ - role: message.role, - content: await this.getMessageContent(message) - }) + onFilterMessages(_messages) + + for (const message of _messages) { + userMessagesParams = userMessagesParams.concat(await this.getMessageParam(message)) } + const userMessages = flatten(userMessagesParams) + if (first(userMessages)?.role === 'assistant') { userMessages.shift() } @@ -69,7 +87,7 @@ export default class AnthropicProvider extends BaseProvider { const stream = this.sdk.messages .stream({ model: model.id, - messages: userMessages.filter(Boolean) as MessageParam[], + messages: userMessages, max_tokens: maxTokens || DEFAULT_MAX_TOKENS, temperature: assistant?.settings?.temperature, system: assistant.prompt, diff --git a/src/renderer/src/providers/BaseProvider.ts b/src/renderer/src/providers/BaseProvider.ts index 7d36bd68..67efbf2d 100644 --- a/src/renderer/src/providers/BaseProvider.ts +++ b/src/renderer/src/providers/BaseProvider.ts @@ -20,11 +20,7 @@ export default abstract class BaseProvider { return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined } - abstract completions( - messages: Message[], - assistant: Assistant, - onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void - ): Promise + abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise abstract translate(message: Message, assistant: Assistant): Promise abstract summaries(messages: Message[], assistant: Assistant): Promise abstract suggestions(messages: Message[], assistant: Assistant): Promise diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 4a92146a..c450820f 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -1,10 +1,10 @@ -import { Content, GoogleGenerativeAI, InlineDataPart, Part } from '@google/generative-ai' +import { Content, GoogleGenerativeAI, InlineDataPart, TextPart } from '@google/generative-ai' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { EVENT_NAMES } from '@renderer/services/event' import { filterContextMessages, filterMessages } from '@renderer/services/messages' -import { Assistant, Message, Provider, Suggestion } from '@renderer/types' +import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' import axios from 'axios' -import { first, isEmpty, takeRight } from 'lodash' +import { first, flatten, isEmpty, takeRight } from 'lodash' import OpenAI from 'openai' import BaseProvider from './BaseProvider' @@ -17,42 +17,67 @@ export default class GeminiProvider extends BaseProvider { this.sdk = new GoogleGenerativeAI(provider.apiKey) } - private async getMessageParts(message: Message): Promise { + private async getMessageContents(message: Message): Promise { const file = first(message.files) + const role = message.role === 'user' ? 'user' : 'model' - if (file && file.type === 'image') { - const base64Data = await window.api.image.base64(file.path) - return [ - { - text: message.content - }, - { - inlineData: { - data: base64Data.base64, - mimeType: base64Data.mime + if (file) { + if (file.type === FileTypes.IMAGE) { + const base64Data = await window.api.file.base64Image(file.id + file.ext) + return [ + { + role: message.role, + parts: [ + { text: message.content } as TextPart, + { + inlineData: { + data: base64Data.base64, + mimeType: base64Data.mime + } + } as InlineDataPart + ] } - } as InlineDataPart - ] + ] + } + if (file.type === FileTypes.TEXT) { + return [ + { + role: 'model', + parts: [{ text: await window.api.file.read(file.id + file.ext) } as TextPart] + }, + { + role, + parts: [{ text: message.content } as TextPart] + } + ] + } } - return [{ text: message.content }] + return [ + { + role, + parts: [{ text: message.content } as TextPart] + } + ] } - public async completions( - messages: Message[], - assistant: Assistant, - onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void - ) { + public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens } = getAssistantSettings(assistant) - const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))).map((message) => { - return { - role: message.role, - message - } - }) + const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))) + onFilterMessages(userMessages) + + const userLastMessage = userMessages.pop() + + let historyContents: Content[][] = [] + + for (const message of userMessages) { + historyContents = historyContents.concat(await this.getMessageContents(message)) + } + + const history = flatten(historyContents) const geminiModel = this.sdk.getGenerativeModel({ model: model.id, @@ -63,21 +88,9 @@ export default class GeminiProvider extends BaseProvider { } }) - const userLastMessage = userMessages.pop() - - const history: Content[] = [] - - for (const message of userMessages) { - history.push({ - role: message.role === 'user' ? 'user' : 'model', - parts: await this.getMessageParts(message.message) - }) - } - const chat = geminiModel.startChat({ history }) - const message = await this.getMessageParts(userLastMessage?.message!) - - const userMessagesStream = await chat.sendMessageStream(message) + const messageContents = await this.getMessageContents(userLastMessage!) + const userMessagesStream = await chat.sendMessageStream(messageContents[0].parts) for await (const chunk of userMessagesStream.stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 90766a90..58ea55e9 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -2,7 +2,7 @@ import { isLocalAi } from '@renderer/config/env' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { EVENT_NAMES } from '@renderer/services/event' import { filterContextMessages, filterMessages } from '@renderer/services/messages' -import { Assistant, Message, Provider, Suggestion } from '@renderer/types' +import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' import { removeQuotes } from '@renderer/utils' import { first, takeRight } from 'lodash' import OpenAI from 'openai' @@ -26,61 +26,92 @@ export default class OpenAIProvider extends BaseProvider { }) } - private async getMessageContent(message: Message): Promise { - const file = first(message.files) - - if (!file) { - return message.content + private isSupportStreamOutput(modelId: string): boolean { + if (this.provider.id === 'openai' && modelId.includes('o1-')) { + return false } - - if (file.type === 'image') { - const base64Data = await window.api.image.base64(file.path) - return [ - { type: 'text', text: message.content }, - { - type: 'image_url', - image_url: { - url: base64Data.data - } - } - ] - } - - return message.content + return true } - async completions( - messages: Message[], - assistant: Assistant, - onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void - ): Promise { + private async getMessageParam(message: Message): Promise { + const file = first(message.files) + + const content: string | ChatCompletionContentPart[] = message.content + + if (file) { + if (file.type === FileTypes.IMAGE) { + const image = await window.api.file.base64Image(file.id + file.ext) + return [ + { + role: message.role, + content: [ + { type: 'text', text: message.content }, + { + type: 'image_url', + image_url: { + url: image.data + } + } + ] + } as ChatCompletionMessageParam + ] + } + if (file.type === FileTypes.TEXT) { + return [ + { + role: 'assistant', + content: await window.api.file.read(file.id + file.ext) + } as ChatCompletionMessageParam, + { + role: message.role, + content + } as ChatCompletionMessageParam + ] + } + } + + return [ + { + role: message.role, + content + } as ChatCompletionMessageParam + ] + } + + async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens } = getAssistantSettings(assistant) const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined - const userMessages: ChatCompletionMessageParam[] = [] + let userMessages: ChatCompletionMessageParam[] = [] - for (const message of filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))) { - userMessages.push({ - role: message.role, - content: await this.getMessageContent(message) - } as ChatCompletionMessageParam) + const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))) + onFilterMessages(_messages) + + for (const message of _messages) { + userMessages = userMessages.concat(await this.getMessageParam(message)) } // @ts-ignore key is not typed const stream = await this.sdk.chat.completions.create({ model: model.id, messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], - stream: true, + stream: this.isSupportStreamOutput(model.id), temperature: assistant?.settings?.temperature, max_tokens: maxTokens, keep_alive: this.keepAliveTime }) for await (const chunk of stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { + break + } + + onChunk({ + text: chunk.choices[0]?.delta?.content || '', + usage: chunk.usage + }) } } diff --git a/src/renderer/src/providers/index.d.ts b/src/renderer/src/providers/index.d.ts new file mode 100644 index 00000000..e894982a --- /dev/null +++ b/src/renderer/src/providers/index.d.ts @@ -0,0 +1,11 @@ +interface ChunkCallbackData { + text?: string + usage?: OpenAI.Completions.CompletionUsage +} + +interface CompletionsParams { + messages: Message[] + assistant: Assistant + onChunk: ({ text, usage }: ChunkCallbackData) => void + onFilterMessages: (messages: Message[]) => void +} diff --git a/src/renderer/src/services/api.ts b/src/renderer/src/services/api.ts index 39e00aec..a7ffb0ce 100644 --- a/src/renderer/src/services/api.ts +++ b/src/renderer/src/services/api.ts @@ -15,7 +15,8 @@ import { getTranslateModel } from './assistant' import { EVENT_NAMES, EventEmitter } from './event' -import { estimateMessagesToken, filterMessages } from './messages' +import { filterMessages } from './messages' +import { estimateMessagesUsage } from './tokens' export async function fetchChatCompletion({ messages, @@ -61,13 +62,27 @@ export async function fetchChatCompletion({ }, 1000) try { - await AI.completions(messages, assistant, ({ text, usage }) => { - message.content = message.content + text || '' - message.usage = usage - onResponse({ ...message, status: 'pending' }) + let _messages: Message[] = [] + + await AI.completions({ + messages, + assistant, + onFilterMessages: (messages) => (_messages = messages), + onChunk: ({ text, usage }) => { + message.content = message.content + text || '' + message.usage = usage + onResponse({ ...message, status: 'pending' }) + } }) + message.status = 'success' - message.usage = message.usage || (await estimateMessagesToken({ assistant, messages: [...messages, message] })) + + if (!message.usage) { + message.usage = await estimateMessagesUsage({ + assistant, + messages: [..._messages, message] + }) + } } catch (error: any) { message.content = `Error: ${error.message}` message.status = 'error' @@ -83,7 +98,7 @@ export async function fetchChatCompletion({ message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : message.status // Emit chat completion event - EventEmitter.emit(EVENT_NAMES.AI_CHAT_COMPLETION, message) + EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message) // Reset generating state store.dispatch(setGenerating(false)) diff --git a/src/renderer/src/services/event.ts b/src/renderer/src/services/event.ts index cdf402e3..d0b70b37 100644 --- a/src/renderer/src/services/event.ts +++ b/src/renderer/src/services/event.ts @@ -4,7 +4,7 @@ export const EventEmitter = new Emittery() export const EVENT_NAMES = { SEND_MESSAGE: 'SEND_MESSAGE', - AI_CHAT_COMPLETION: 'AI_CHAT_COMPLETION', + RECEIVE_MESSAGE: 'RECEIVE_MESSAGE', AI_AUTO_RENAME: 'AI_AUTO_RENAME', CLEAR_MESSAGES: 'CLEAR_MESSAGES', ADD_ASSISTANT: 'ADD_ASSISTANT', diff --git a/src/renderer/src/services/messages.ts b/src/renderer/src/services/messages.ts index 86418687..29a3a8ff 100644 --- a/src/renderer/src/services/messages.ts +++ b/src/renderer/src/services/messages.ts @@ -1,10 +1,7 @@ import { DEFAULT_CONEXTCOUNT } from '@renderer/config/constant' import { Assistant, Message } from '@renderer/types' -import { GPTTokens } from 'gpt-tokens' -import { isEmpty, last, takeRight } from 'lodash' -import { CompletionUsage } from 'openai/resources' +import { isEmpty, takeRight } from 'lodash' -import { getAssistantSettings } from './assistant' import FileManager from './file' export const filterMessages = (messages: Message[]) => { @@ -36,50 +33,6 @@ export function getContextCount(assistant: Assistant, messages: Message[]) { return messagesCount - (clearIndex + 1) } -export function estimateInputTokenCount(text: string) { - const input = new GPTTokens({ - model: 'gpt-4o', - messages: [{ role: 'user', content: text }] - }) - - return input.usedTokens - 7 -} - -export async function estimateMessagesToken({ - assistant, - messages -}: { - assistant: Assistant - messages: Message[] -}): Promise { - const responseMessageContent = last(messages)?.content - const inputMessageContent = messages[messages.length - 2]?.content - const completion_tokens = await estimateInputTokenCount(responseMessageContent ?? '') - const prompt_tokens = await estimateInputTokenCount(assistant.prompt + inputMessageContent ?? '') - return { - completion_tokens, - prompt_tokens: prompt_tokens, - total_tokens: prompt_tokens + completion_tokens - } as CompletionUsage -} - -export function estimateHistoryTokenCount(assistant: Assistant, msgs: Message[]) { - const { contextCount } = getAssistantSettings(assistant) - - const all = new GPTTokens({ - model: 'gpt-4o', - messages: [ - { role: 'system', content: assistant.prompt }, - ...filterMessages(filterContextMessages(takeRight(msgs, contextCount))).map((message) => ({ - role: message.role, - content: message.content - })) - ] - }) - - return all.usedTokens - 7 -} - export function deleteMessageFiles(message: Message) { message.files && FileManager.deleteFiles(message.files.map((f) => f.id)) } diff --git a/src/renderer/src/services/tokens.ts b/src/renderer/src/services/tokens.ts new file mode 100644 index 00000000..8d5ace9a --- /dev/null +++ b/src/renderer/src/services/tokens.ts @@ -0,0 +1,130 @@ +import { Assistant, FileType, FileTypes, Message } from '@renderer/types' +import { GPTTokens } from 'gpt-tokens' +import { flatten, takeRight } from 'lodash' +import { CompletionUsage } from 'openai/resources' + +import { getAssistantSettings } from './assistant' +import { filterContextMessages, filterMessages } from './messages' + +interface MessageItem { + name?: string + role: 'system' | 'user' | 'assistant' + content: string +} + +async function getFileContent(file: FileType) { + if (!file) { + return '' + } + + const fileId = file.id + file.ext + + if (file.type === FileTypes.IMAGE) { + const data = await window.api.file.base64Image(fileId) + return data.data + } + + if (file.type === FileTypes.TEXT) { + return await window.api.file.read(fileId) + } + + return '' +} + +async function getMessageParam(message: Message): Promise { + const param: MessageItem[] = [] + + param.push({ + role: message.role, + content: message.content + }) + + if (message.files) { + for (const file of message.files) { + param.push({ + role: 'assistant', + content: await getFileContent(file) + }) + } + } + + return param +} + +export function estimateTextTokens(text: string) { + const { usedTokens } = new GPTTokens({ + model: 'gpt-4o', + messages: [{ role: 'user', content: text }] + }) + + return usedTokens - 7 +} + +export async function estimateMessageUsage(message: Message): Promise { + const { usedTokens, promptUsedTokens, completionUsedTokens } = new GPTTokens({ + model: 'gpt-4o', + messages: await getMessageParam(message) + }) + + const hasImage = message.files?.some((f) => f.type === FileTypes.IMAGE) + + return { + prompt_tokens: promptUsedTokens, + completion_tokens: completionUsedTokens, + total_tokens: hasImage ? Math.floor(usedTokens / 80) : usedTokens - 7 + } +} + +export async function estimateMessagesUsage({ + assistant, + messages +}: { + assistant: Assistant + messages: Message[] +}): Promise { + const outputMessage = messages.pop()! + + const prompt_tokens = await estimateHistoryTokens(assistant, messages) + const { completion_tokens } = await estimateMessageUsage(outputMessage) + + return { + prompt_tokens: await estimateHistoryTokens(assistant, messages), + completion_tokens, + total_tokens: prompt_tokens + completion_tokens + } as CompletionUsage +} + +export async function estimateHistoryTokens(assistant: Assistant, msgs: Message[]) { + const { contextCount } = getAssistantSettings(assistant) + const messages = filterMessages(filterContextMessages(takeRight(msgs, contextCount))) + + // 有 usage 数据的消息,快速计算总数 + const uasageTokens = messages + .filter((m) => m.usage) + .reduce((acc, message) => { + const inputTokens = message.usage?.total_tokens ?? 0 + const outputTokens = message.usage!.completion_tokens ?? 0 + return acc + (message.role === 'user' ? inputTokens : outputTokens) + }, 0) + + // 没有 usage 数据的消息,需要计算每条消息的 token + let allMessages: MessageItem[][] = [] + + for (const message of messages.filter((m) => !m.usage)) { + const items = await getMessageParam(message) + allMessages = allMessages.concat(items) + } + + const { usedTokens } = new GPTTokens({ + model: 'gpt-4o', + messages: [ + { + role: 'system', + content: assistant.prompt + }, + ...flatten(allMessages) + ] + }) + + return usedTokens - 7 + uasageTokens +} diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 671a4ee9..d8bafc49 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -97,12 +97,14 @@ export interface FileType { type: FileTypes created_at: Date count: number + tokens?: number } export enum FileTypes { IMAGE = 'image', VIDEO = 'video', AUDIO = 'audio', + TEXT = 'text', DOCUMENT = 'document', OTHER = 'other' } diff --git a/src/renderer/src/utils/index.ts b/src/renderer/src/utils/index.ts index 31bf5a4f..9dab5102 100644 --- a/src/renderer/src/utils/index.ts +++ b/src/renderer/src/utils/index.ts @@ -229,3 +229,9 @@ export function removeTrailingDoubleSpaces(markdown: string): string { // 使用正则表达式匹配末尾的两个空格,并替换为空字符串 return markdown.replace(/ {2}$/gm, '') } + +export function getFileDirectory(filePath: string) { + const parts = filePath.split('/') + const directory = parts.slice(0, -1).join('/') + return directory +} diff --git a/yarn.lock b/yarn.lock index ae61e5ec..228950db 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1771,6 +1771,7 @@ __metadata: dotenv-cli: "npm:^7.4.2" electron: "npm:^28.3.3" electron-builder: "npm:^24.9.1" + electron-devtools-installer: "npm:^3.2.0" electron-log: "npm:^5.1.5" electron-store: "npm:^8.2.0" electron-updater: "npm:^6.1.7" @@ -1783,7 +1784,7 @@ __metadata: eslint-plugin-react-hooks: "npm:^4.6.2" eslint-plugin-simple-import-sort: "npm:^12.1.1" eslint-plugin-unused-imports: "npm:^4.0.0" - gpt-tokens: "npm:^1.3.6" + gpt-tokens: "npm:^1.3.10" i18next: "npm:^23.11.5" localforage: "npm:^1.10.0" lodash: "npm:^4.17.21" @@ -2837,6 +2838,13 @@ __metadata: languageName: node linkType: hard +"core-util-is@npm:~1.0.0": + version: 1.0.3 + resolution: "core-util-is@npm:1.0.3" + checksum: 10c0/90a0e40abbddfd7618f8ccd63a74d88deea94e77d0e8dbbea059fa7ebebb8fbb4e2909667fe26f3a467073de1a542ebe6ae4c73a73745ac5833786759cd906c9 + languageName: node + linkType: hard + "crc@npm:^3.8.0": version: 3.8.0 resolution: "crc@npm:3.8.0" @@ -3225,6 +3233,18 @@ __metadata: languageName: node linkType: hard +"electron-devtools-installer@npm:^3.2.0": + version: 3.2.0 + resolution: "electron-devtools-installer@npm:3.2.0" + dependencies: + rimraf: "npm:^3.0.2" + semver: "npm:^7.2.1" + tslib: "npm:^2.1.0" + unzip-crx-3: "npm:^0.2.0" + checksum: 10c0/50d56e174e3bbe568d3d4a56a56e8c87faf44aa54a49ecc93ab672905f30ca1bf4e6a1b5a0b297c6ffeec1e89848086a6ff47f0db8197edb16d1bda16d6440c2 + languageName: node + linkType: hard + "electron-log@npm:^5.1.5": version: 5.2.0 resolution: "electron-log@npm:5.2.0" @@ -4373,14 +4393,14 @@ __metadata: languageName: node linkType: hard -"gpt-tokens@npm:^1.3.6": - version: 1.3.9 - resolution: "gpt-tokens@npm:1.3.9" +"gpt-tokens@npm:^1.3.10": + version: 1.3.10 + resolution: "gpt-tokens@npm:1.3.10" dependencies: decimal.js: "npm:^10.4.3" js-tiktoken: "npm:^1.0.14" openai-chat-tokens: "npm:^0.2.8" - checksum: 10c0/14ea94c0df4b83fdbfc8ee9c337aca5584441f8b4440a619eea9defc65dc3782fdb81c138d7153e39bae34cff60ce778e1d38f62e775d7cc378c2eac78d3299c + checksum: 10c0/9aa83bf1aecc3a11b8557769fa20d0f9daae61e3e79e1e308e2435069cee9c49f06563f68d9ce90b53c4981e84c92bf9bb43168f042e4606b51ccfce9210f95c languageName: node linkType: hard @@ -4813,7 +4833,7 @@ __metadata: languageName: node linkType: hard -"inherits@npm:2": +"inherits@npm:2, inherits@npm:~2.0.3": version: 2.0.4 resolution: "inherits@npm:2.0.4" checksum: 10c0/4e531f648b29039fb7426fb94075e6545faa1eb9fe83c29f0b6d9e7263aceb4289d2d4557db0d428188eeb449cc7c5e77b0a0b2c4e248ff2a65933a0dee49ef2 @@ -5187,6 +5207,13 @@ __metadata: languageName: node linkType: hard +"isarray@npm:~1.0.0": + version: 1.0.0 + resolution: "isarray@npm:1.0.0" + checksum: 10c0/18b5be6669be53425f0b84098732670ed4e727e3af33bc7f948aac01782110eb9a18b3b329c5323bcdd3acdaae547ee077d3951317e7f133bff7105264b3003d + languageName: node + linkType: hard + "isbinaryfile@npm:^4.0.8": version: 4.0.10 resolution: "isbinaryfile@npm:4.0.10" @@ -5395,6 +5422,18 @@ __metadata: languageName: node linkType: hard +"jszip@npm:^3.1.0": + version: 3.10.1 + resolution: "jszip@npm:3.10.1" + dependencies: + lie: "npm:~3.3.0" + pako: "npm:~1.0.2" + readable-stream: "npm:~2.3.6" + setimmediate: "npm:^1.0.5" + checksum: 10c0/58e01ec9c4960383fb8b38dd5f67b83ccc1ec215bf74c8a5b32f42b6e5fb79fada5176842a11409c4051b5b94275044851814a31076bf49e1be218d3ef57c863 + languageName: node + linkType: hard + "katex@npm:^0.16.0": version: 0.16.11 resolution: "katex@npm:0.16.11" @@ -5441,6 +5480,15 @@ __metadata: languageName: node linkType: hard +"lie@npm:~3.3.0": + version: 3.3.0 + resolution: "lie@npm:3.3.0" + dependencies: + immediate: "npm:~3.0.5" + checksum: 10c0/56dd113091978f82f9dc5081769c6f3b947852ecf9feccaf83e14a123bc630c2301439ce6182521e5fbafbde88e88ac38314327a4e0493a1bea7e0699a7af808 + languageName: node + linkType: hard + "localforage@npm:^1.10.0": version: 1.10.0 resolution: "localforage@npm:1.10.0" @@ -6713,6 +6761,13 @@ __metadata: languageName: node linkType: hard +"pako@npm:~1.0.2": + version: 1.0.11 + resolution: "pako@npm:1.0.11" + checksum: 10c0/86dd99d8b34c3930345b8bbeb5e1cd8a05f608eeb40967b293f72fe469d0e9c88b783a8777e4cc7dc7c91ce54c5e93d88ff4b4f060e6ff18408fd21030d9ffbe + languageName: node + linkType: hard + "parent-module@npm:^1.0.0": version: 1.0.1 resolution: "parent-module@npm:1.0.1" @@ -6936,6 +6991,13 @@ __metadata: languageName: node linkType: hard +"process-nextick-args@npm:~2.0.0": + version: 2.0.1 + resolution: "process-nextick-args@npm:2.0.1" + checksum: 10c0/bec089239487833d46b59d80327a1605e1c5287eaad770a291add7f45fda1bb5e28b38e0e061add0a1d0ee0984788ce74fa394d345eed1c420cacf392c554367 + languageName: node + linkType: hard + "progress@npm:^2.0.3": version: 2.0.3 resolution: "progress@npm:2.0.3" @@ -7763,6 +7825,21 @@ __metadata: languageName: node linkType: hard +"readable-stream@npm:~2.3.6": + version: 2.3.8 + resolution: "readable-stream@npm:2.3.8" + dependencies: + core-util-is: "npm:~1.0.0" + inherits: "npm:~2.0.3" + isarray: "npm:~1.0.0" + process-nextick-args: "npm:~2.0.0" + safe-buffer: "npm:~5.1.1" + string_decoder: "npm:~1.1.1" + util-deprecate: "npm:~1.0.1" + checksum: 10c0/7efdb01f3853bc35ac62ea25493567bf588773213f5f4a79f9c365e1ad13bab845ac0dae7bc946270dc40c3929483228415e92a3fc600cc7e4548992f41ee3fa + languageName: node + linkType: hard + "readdirp@npm:~3.6.0": version: 3.6.0 resolution: "readdirp@npm:3.6.0" @@ -8128,6 +8205,13 @@ __metadata: languageName: node linkType: hard +"safe-buffer@npm:~5.1.0, safe-buffer@npm:~5.1.1": + version: 5.1.2 + resolution: "safe-buffer@npm:5.1.2" + checksum: 10c0/780ba6b5d99cc9a40f7b951d47152297d0e260f0df01472a1b99d4889679a4b94a13d644f7dbc4f022572f09ae9005fa2fbb93bbbd83643316f365a3e9a45b21 + languageName: node + linkType: hard + "safe-regex-test@npm:^1.0.3": version: 1.0.3 resolution: "safe-regex-test@npm:1.0.3" @@ -8209,7 +8293,7 @@ __metadata: languageName: node linkType: hard -"semver@npm:^7.3.2, semver@npm:^7.3.5, semver@npm:^7.3.8, semver@npm:^7.5.3, semver@npm:^7.5.4, semver@npm:^7.6.3": +"semver@npm:^7.2.1, semver@npm:^7.3.2, semver@npm:^7.3.5, semver@npm:^7.3.8, semver@npm:^7.5.3, semver@npm:^7.5.4, semver@npm:^7.6.3": version: 7.6.3 resolution: "semver@npm:7.6.3" bin: @@ -8253,6 +8337,13 @@ __metadata: languageName: node linkType: hard +"setimmediate@npm:^1.0.5": + version: 1.0.5 + resolution: "setimmediate@npm:1.0.5" + checksum: 10c0/5bae81bfdbfbd0ce992893286d49c9693c82b1bcc00dcaaf3a09c8f428fdeacf4190c013598b81875dfac2b08a572422db7df779a99332d0fce186d15a3e4d49 + languageName: node + linkType: hard + "shallowequal@npm:1.1.0": version: 1.1.0 resolution: "shallowequal@npm:1.1.0" @@ -8504,6 +8595,15 @@ __metadata: languageName: node linkType: hard +"string_decoder@npm:~1.1.1": + version: 1.1.1 + resolution: "string_decoder@npm:1.1.1" + dependencies: + safe-buffer: "npm:~5.1.0" + checksum: 10c0/b4f89f3a92fd101b5653ca3c99550e07bdf9e13b35037e9e2a1c7b47cec4e55e06ff3fc468e314a0b5e80bfbaf65c1ca5a84978764884ae9413bec1fc6ca924e + languageName: node + linkType: hard + "stringify-entities@npm:^4.0.0": version: 4.0.4 resolution: "stringify-entities@npm:4.0.4" @@ -8763,7 +8863,7 @@ __metadata: languageName: node linkType: hard -"tslib@npm:^2.6.2": +"tslib@npm:^2.1.0, tslib@npm:^2.6.2": version: 2.7.0 resolution: "tslib@npm:2.7.0" checksum: 10c0/469e1d5bf1af585742128827000711efa61010b699cb040ab1800bcd3ccdd37f63ec30642c9e07c4439c1db6e46345582614275daca3e0f4abae29b0083f04a6 @@ -9013,6 +9113,17 @@ __metadata: languageName: node linkType: hard +"unzip-crx-3@npm:^0.2.0": + version: 0.2.0 + resolution: "unzip-crx-3@npm:0.2.0" + dependencies: + jszip: "npm:^3.1.0" + mkdirp: "npm:^0.5.1" + yaku: "npm:^0.16.6" + checksum: 10c0/e551cb3d57d0271da41825e9bd9a7f4ef9ec5c3f15edc37bf909928c8327f21a6938ddc922787ee2b1b31f95ac83232dac79fd5a44e2727f9e800df9017a3b91 + languageName: node + linkType: hard + "update-browserslist-db@npm:^1.1.0": version: 1.1.0 resolution: "update-browserslist-db@npm:1.1.0" @@ -9061,6 +9172,13 @@ __metadata: languageName: node linkType: hard +"util-deprecate@npm:~1.0.1": + version: 1.0.2 + resolution: "util-deprecate@npm:1.0.2" + checksum: 10c0/41a5bdd214df2f6c3ecf8622745e4a366c4adced864bc3c833739791aeeeb1838119af7daed4ba36428114b5c67dcda034a79c882e97e43c03e66a4dd7389942 + languageName: node + linkType: hard + "uuid@npm:^10.0.0": version: 10.0.0 resolution: "uuid@npm:10.0.0" @@ -9343,6 +9461,13 @@ __metadata: languageName: node linkType: hard +"yaku@npm:^0.16.6": + version: 0.16.7 + resolution: "yaku@npm:0.16.7" + checksum: 10c0/de45a2f6c31ab905174c4e5a3aad93972cb3cf2946d206ab9718ad5fb9adf0781c011d8b1d576daf20ebfa02fa120a8e7552a742a88bb3d288eea5a3a693187c + languageName: node + linkType: hard + "yallist@npm:^3.0.2": version: 3.1.1 resolution: "yallist@npm:3.1.1"