refactor: provider sdk

This commit is contained in:
kangfenmao 2024-09-03 19:00:24 +08:00
parent 8d7b10d21e
commit 4d6cbf5073
25 changed files with 718 additions and 454 deletions

View File

@ -5,13 +5,13 @@ import { PersistGate } from 'redux-persist/integration/react'
import Sidebar from './components/app/Sidebar'
import TopViewContainer from './components/TopView'
import AntdProvider from './context/AntdProvider'
import { ThemeProvider } from './context/ThemeProvider'
import AgentsPage from './pages/agents/AgentsPage'
import AppsPage from './pages/apps/AppsPage'
import HomePage from './pages/home/HomePage'
import SettingsPage from './pages/settings/SettingsPage'
import TranslatePage from './pages/translate/TranslatePage'
import AntdProvider from './providers/AntdProvider'
import { ThemeProvider } from './providers/ThemeProvider'
function App(): JSX.Element {
return (

View File

@ -1,4 +1,4 @@
import { useTheme } from '@renderer/providers/ThemeProvider'
import { useTheme } from '@renderer/context/ThemeProvider'
import { FC, useEffect, useRef } from 'react'
interface Props {

View File

@ -95,7 +95,7 @@ const PopupContainer: React.FC<Props> = ({ app, resolve }) => {
maskClosable={false}
closeIcon={null}
style={{ marginLeft: 'var(--sidebar-width)' }}>
<webview src={app.url} ref={webviewRef} style={WebviewStyle} allowpopups={true} />
<webview src={app.url} ref={webviewRef} style={WebviewStyle} allowpopups={'true' as any} />
</Drawer>
)
}

View File

@ -6,7 +6,7 @@ import { useRuntime, useShowAssistants } from '@renderer/hooks/useStore'
import { Avatar } from 'antd'
import { FC } from 'react'
import { useTranslation } from 'react-i18next'
import { Link, useLocation } from 'react-router-dom'
import { useLocation, useNavigate } from 'react-router-dom'
import styled from 'styled-components'
import UserPopup from '../Popups/UserPopup'
@ -20,6 +20,7 @@ const Sidebar: FC = () => {
const { toggleShowAssistants } = useShowAssistants()
const { generating } = useRuntime()
const { t } = useTranslation()
const navigate = useNavigate()
const isRoute = (path: string): string => (pathname === path ? 'active' : '')
@ -28,15 +29,13 @@ const Sidebar: FC = () => {
const to = (path: string) => {
if (generating) {
window.message.warning({ content: t('message.switch.disabled'), key: 'switch-assistant' })
return '/'
return
}
return path
navigate(path)
}
const onToggleShowAssistants = () => {
if (pathname === '/') {
toggleShowAssistants()
}
pathname === '/' ? toggleShowAssistants() : navigate('/')
}
return (
@ -44,22 +43,22 @@ const Sidebar: FC = () => {
<AvatarImg src={avatar || AppLogo} draggable={false} className="nodrag" onClick={onEditUser} />
<MainMenus>
<Menus>
<StyledLink to={to('/')} onClick={onToggleShowAssistants}>
<StyledLink onClick={onToggleShowAssistants}>
<Icon className={isRoute('/')}>
<i className="iconfont icon-chat"></i>
</Icon>
</StyledLink>
<StyledLink to={to('/agents')}>
<StyledLink onClick={() => to('/agents')}>
<Icon className={isRoute('/agents')}>
<i className="iconfont icon-business-smart-assistant"></i>
</Icon>
</StyledLink>
<StyledLink to={to('/translate')}>
<StyledLink onClick={() => to('/translate')}>
<Icon className={isRoute('/translate')}>
<TranslationOutlined />
</Icon>
</StyledLink>
<StyledLink to={to('/apps')}>
<StyledLink onClick={() => to('/apps')}>
<Icon className={isRoute('/apps')}>
<i className="iconfont icon-appstore"></i>
</Icon>
@ -67,7 +66,7 @@ const Sidebar: FC = () => {
</Menus>
</MainMenus>
<Menus>
<StyledLink to={to(isLocalAi ? '/settings/assistant' : '/settings/provider')}>
<StyledLink onClick={() => to(isLocalAi ? '/settings/assistant' : '/settings/provider')}>
<Icon className={pathname.startsWith('/settings') ? 'active' : ''}>
<i className="iconfont icon-setting"></i>
</Icon>
@ -149,7 +148,7 @@ const Icon = styled.div`
}
`
const StyledLink = styled(Link)`
const StyledLink = styled.div`
text-decoration: none;
-webkit-app-region: none;
&* {

View File

@ -1,5 +1,5 @@
import MinApp from '@renderer/components/MinApp'
import { useTheme } from '@renderer/providers/ThemeProvider'
import { useTheme } from '@renderer/context/ThemeProvider'
import { MinAppType } from '@renderer/types'
import { FC } from 'react'
import styled from 'styled-components'

View File

@ -1,5 +1,4 @@
import { ArrowRightOutlined, CopyOutlined, DeleteOutlined, EditOutlined } from '@ant-design/icons'
import { ArrowLeftOutlined } from '@ant-design/icons'
import DragableList from '@renderer/components/DragableList'
import { HStack } from '@renderer/components/Layout'
import AssistantSettingPopup from '@renderer/components/Popups/AssistantSettingPopup'
@ -104,10 +103,6 @@ const Assistants: FC<Props> = ({
if (showTopics) {
return (
<Container>
<NavigtaionHeader onClick={() => setShowTopics(false)}>
<ArrowLeftOutlined />
{t('common.back')}
</NavigtaionHeader>
<Topics assistant={activeAssistant} activeTopic={activeTopic} setActiveTopic={setActiveTopic} />
</Container>
)
@ -142,6 +137,7 @@ const Container = styled.div`
height: calc(100vh - var(--navbar-height));
overflow-y: auto;
padding: 10px 0;
padding-bottom: 0;
`
const AssistantItem = styled.div`
@ -155,40 +151,24 @@ const AssistantItem = styled.div`
cursor: pointer;
font-family: Ubuntu;
.anticon {
display: none;
opacity: 0;
color: var(--color-text-3);
transition: opacity 0.2s ease-in-out;
}
&:hover {
background-color: var(--color-background-soft);
.count {
display: none;
}
.anticon {
display: block;
opacity: 1;
}
}
&.active {
background-color: var(--color-background-mute);
cursor: pointer;
.name {
font-weight: 500;
}
}
`
const NavigtaionHeader = styled.div`
display: flex;
flex-direction: row;
align-items: center;
justify-content: flex-start;
gap: 10px;
padding: 0 5px;
cursor: pointer;
color: var(--color-text-3);
margin: 10px;
margin-top: 0;
`
const AssistantName = styled.div`
color: var(--color-text);
display: -webkit-box;

View File

@ -29,7 +29,7 @@ const Chat: FC<Props> = (props) => {
setShowSetting={setShowSetting}
/>
</Main>
{showSetting && <Settings assistant={assistant} />}
{showSetting && <Settings assistant={assistant} onClose={() => setShowSetting(false)} />}
</Container>
)
}

View File

@ -1,9 +1,11 @@
import { ArrowLeftOutlined } from '@ant-design/icons'
import { Navbar, NavbarCenter, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar'
import { isMac, isWindows } from '@renderer/config/constant'
import { useAssistants, useDefaultAssistant } from '@renderer/hooks/useAssistant'
import { useTheme } from '@renderer/context/ThemeProvider'
import { useAssistant, useAssistants, useDefaultAssistant } from '@renderer/hooks/useAssistant'
import { useShowAssistants } from '@renderer/hooks/useStore'
import { useActiveTopic } from '@renderer/hooks/useTopic'
import { useTheme } from '@renderer/providers/ThemeProvider'
import { getDefaultTopic } from '@renderer/services/assistant'
import { Assistant, Topic } from '@renderer/types'
import { uuid } from '@renderer/utils'
import { Switch } from 'antd'
@ -29,6 +31,7 @@ const HomePage: FC = () => {
const { t } = useTranslation()
const { activeTopic, setActiveTopic } = useActiveTopic(activeAssistant)
const { addTopic } = useAssistant(activeAssistant.id)
_activeAssistant = activeAssistant
_showTopics = showTopics
@ -39,10 +42,16 @@ const HomePage: FC = () => {
setActiveAssistant(assistant)
}
const onCreateAssistant = async () => {
const onCreate = async () => {
if (showTopics) {
const topic = getDefaultTopic()
addTopic(topic)
setActiveTopic(topic)
} else {
const assistant = await AddAssistantPopup.show()
assistant && setActiveAssistant(assistant)
}
}
const onSetActiveTopic = (topic: Topic) => {
setActiveTopic(topic)
@ -53,8 +62,13 @@ const HomePage: FC = () => {
<Container>
<Navbar>
{showAssistants && (
<NavbarLeft style={{ justifyContent: 'flex-end', borderRight: 'none', padding: '0 8px' }}>
<NewButton onClick={onCreateAssistant}>
<NavbarLeft
style={{ justifyContent: 'space-between', alignItems: 'center', borderRight: 'none', padding: '0 8px' }}>
<NavigtaionBack onClick={() => setShowTopics(false)} style={{ opacity: showTopics ? 1 : 0 }}>
<ArrowLeftOutlined />
{t('common.back')}
</NavigtaionBack>
<NewButton onClick={onCreate}>
<i className="iconfont icon-a-addchat"></i>
</NewButton>
</NavbarLeft>
@ -103,6 +117,23 @@ const ContentContainer = styled.div`
background-color: var(--color-background);
`
const NavigtaionBack = styled.div`
display: flex;
flex-direction: row;
align-items: center;
justify-content: flex-start;
gap: 10px;
cursor: pointer;
margin-left: ${isMac ? '16px' : 0};
-webkit-app-region: none;
transition: all 0.2s ease-in-out;
color: var(--color-icon);
transition: opacity 0.2s ease-in-out;
&:hover {
color: var(--color-text);
}
`
const AssistantName = styled.span`
margin-left: 5px;
margin-right: 10px;

View File

@ -4,12 +4,12 @@ import { FC } from 'react'
import { useTranslation } from 'react-i18next'
interface Props {
images: string[]
setImages: (images: string[]) => void
files: File[]
setFiles: (files: File[]) => void
ToolbarButton: any
}
const AttachmentButton: FC<Props> = ({ images, setImages, ToolbarButton }) => {
const AttachmentButton: FC<Props> = ({ files, setFiles, ToolbarButton }) => {
const { t } = useTranslation()
return (
@ -19,22 +19,8 @@ const AttachmentButton: FC<Props> = ({ images, setImages, ToolbarButton }) => {
accept="image/*"
itemRender={() => null}
maxCount={1}
onChange={async ({ file }) => {
try {
const _file = file.originFileObj as File
const reader = new FileReader()
reader.onload = (e: ProgressEvent<FileReader>) => {
const result = e.target?.result
if (typeof result === 'string') {
setImages([result])
}
}
reader.readAsDataURL(_file)
} catch (error: any) {
window.message.error(error.message)
}
}}>
<ToolbarButton type="text" className={images.length ? 'active' : ''}>
onChange={async ({ file }) => file?.originFileObj && setFiles([file.originFileObj as File])}>
<ToolbarButton type="text" className={files.length ? 'active' : ''}>
<PaperClipOutlined style={{ rotate: '135deg' }} />
</ToolbarButton>
</Upload>

View File

@ -25,7 +25,6 @@ import { CSSProperties, FC, useCallback, useEffect, useMemo, useRef, useState }
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import AttachmentButton from './AttachmentButton'
import SendMessageButton from './SendMessageButton'
interface Props {
@ -46,7 +45,7 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic, showSetting, setShowSe
const [estimateTokenCount, setEstimateTokenCount] = useState(0)
const generating = useAppSelector((state) => state.runtime.generating)
const textareaRef = useRef<TextAreaRef>(null)
const [images, setImages] = useState<string[]>([])
const [files, setFiles] = useState<File[]>([])
const { t } = useTranslation()
const containerRef = useRef(null)
@ -71,18 +70,19 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic, showSetting, setShowSe
status: 'success'
}
if (images.length > 0) {
message.images = images
if (files.length > 0) {
message.files = files
}
EventEmitter.emit(EVENT_NAMES.SEND_MESSAGE, message)
setText('')
setImages([])
setFiles([])
setTimeout(() => setText(''), 500)
setTimeout(() => resizeTextArea(), 0)
setExpend(false)
}, [assistant.id, assistant.topics, generating, images, text])
}, [assistant.id, assistant.topics, generating, files, text])
const inputTokenCount = useMemo(() => estimateInputTokenCount(text), [text])
@ -226,7 +226,7 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic, showSetting, setShowSe
<ControlOutlined />
</ToolbarButton>
</Tooltip>
<AttachmentButton images={images} setImages={setImages} ToolbarButton={ToolbarButton} />
{/* <AttachmentButton files={files} setFiles={setFiles} ToolbarButton={ToolbarButton} /> */}
<Tooltip placement="top" title={expended ? t('chat.input.collapse') : t('chat.input.expand')} arrow>
<ToolbarButton type="text" onClick={onToggleExpended}>
{expended ? <FullscreenExitOutlined /> : <FullscreenOutlined />}

View File

@ -1,7 +1,7 @@
import { CheckOutlined } from '@ant-design/icons'
import CopyIcon from '@renderer/components/Icons/CopyIcon'
import { useTheme } from '@renderer/context/ThemeProvider'
import { initMermaid } from '@renderer/init'
import { useTheme } from '@renderer/providers/ThemeProvider'
import { ThemeMode } from '@renderer/store/settings'
import React, { useState } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -1,4 +1,4 @@
import { CheckOutlined, QuestionCircleOutlined, ReloadOutlined } from '@ant-design/icons'
import { CheckOutlined, CloseOutlined, QuestionCircleOutlined, ReloadOutlined } from '@ant-design/icons'
import { HStack } from '@renderer/components/Layout'
import { DEFAULT_CONEXTCOUNT, DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE } from '@renderer/config/constant'
import { useAssistant } from '@renderer/hooks/useAssistant'
@ -19,6 +19,7 @@ import styled from 'styled-components'
interface Props {
assistant: Assistant
onClose: () => void
}
const SettingsTab: FC<Props> = (props) => {
@ -87,6 +88,10 @@ const SettingsTab: FC<Props> = (props) => {
return (
<Container>
<SettingsHeader>
{t('settings.title')}
<CloseIcon onClick={props.onClose} />
</SettingsHeader>
<SettingSubtitle>
{t('settings.messages.model.title')}{' '}
<Tooltip title={t('chat.settings.reset')}>
@ -259,4 +264,21 @@ const SettingRowTitleSmall = styled(SettingRowTitle)`
font-size: 13px;
`
const SettingsHeader = styled.div`
display: flex;
flex-direction: row;
align-items: center;
justify-content: space-between;
padding: 10px 15px;
border-bottom: 0.5px solid var(--color-border);
margin-left: -15px;
margin-right: -15px;
`
const CloseIcon = styled(CloseOutlined)`
font-size: 14px;
cursor: pointer;
color: var(--color-text-3);
`
export default SettingsTab

View File

@ -8,8 +8,8 @@ import {
} from '@ant-design/icons'
import { getModelLogo } from '@renderer/config/provider'
import { PROVIDER_CONFIG } from '@renderer/config/provider'
import { useTheme } from '@renderer/context/ThemeProvider'
import { useProvider } from '@renderer/hooks/useProvider'
import { useTheme } from '@renderer/providers/ThemeProvider'
import { checkApi } from '@renderer/services/api'
import { Provider } from '@renderer/types'
import { Avatar, Button, Card, Divider, Flex, Input, Space, Switch } from 'antd'

View File

@ -0,0 +1,40 @@
import BaseProvider from '@renderer/providers/BaseProvider'
import ProviderFactory from '@renderer/providers/ProviderFactory'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import OpenAI from 'openai'
export default class AiProvider {
private sdk: BaseProvider
constructor(provider: Provider) {
this.sdk = ProviderFactory.create(provider)
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
): Promise<void> {
return this.sdk.completions(messages, assistant, onChunk)
}
public async translate(message: Message, assistant: Assistant): Promise<string> {
return this.sdk.translate(message, assistant)
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
return this.sdk.summaries(messages, assistant)
}
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
return this.sdk.suggestions(messages, assistant)
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
return this.sdk.check()
}
public async models(): Promise<OpenAI.Models.Model[]> {
return this.sdk.models()
}
}

View File

@ -0,0 +1,143 @@
import Anthropic from '@anthropic-ai/sdk'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { sum, takeRight } from 'lodash'
import OpenAI from 'openai'
import BaseProvider from './BaseProvider'
export default class AnthropicProvider extends BaseProvider {
private sdk: Anthropic
constructor(provider: Provider) {
super(provider)
this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() })
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const userMessages = takeRight(messages, contextCount + 1).map((message) => {
return {
role: message.role,
content: message.content
}
})
return new Promise<void>((resolve, reject) => {
const stream = this.sdk.messages
.stream({
model: model.id,
messages: userMessages.filter(Boolean) as MessageParam[],
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: assistant?.settings?.temperature,
system: assistant.prompt,
stream: true
})
.on('text', (text) => {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
resolve()
return stream.controller.abort()
}
onChunk({ text })
})
.on('finalMessage', (message) => {
onChunk({
text: '',
usage: {
prompt_tokens: message.usage.input_tokens,
completion_tokens: message.usage.output_tokens,
total_tokens: sum(Object.values(message.usage))
}
})
resolve()
})
.on('error', (error) => reject(error))
})
}
public async translate(message: Message, assistant: Assistant) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const messages = [
{ role: 'system', content: assistant.prompt },
{ role: 'user', content: message.content }
]
const response = await this.sdk.messages.create({
model: model.id,
messages: messages.filter((m) => m.role === 'user') as MessageParam[],
max_tokens: 4096,
temperature: assistant?.settings?.temperature,
system: assistant.prompt,
stream: false
})
return response.content[0].type === 'text' ? response.content[0].text : ''
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: message.content
}))
const systemMessage = {
role: 'system',
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
}
const message = await this.sdk.messages.create({
messages: userMessages as Anthropic.Messages.MessageParam[],
model: model.id,
system: systemMessage.content,
stream: false,
max_tokens: 4096
})
return message.content[0].type === 'text' ? message.content[0].text : null
}
public async suggestions(): Promise<Suggestion[]> {
return []
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0]
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try {
const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming)
return {
valid: message.content.length > 0,
error: null
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
public async models(): Promise<OpenAI.Models.Model[]> {
return []
}
}

View File

@ -0,0 +1,33 @@
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import OpenAI from 'openai'
export default abstract class BaseProvider {
protected provider: Provider
protected host: string
constructor(provider: Provider) {
this.provider = provider
this.host = this.getBaseURL()
}
public getBaseURL(): string {
const host = this.provider.apiHost
return host.endsWith('/') ? host : `${host}/v1/`
}
public get keepAliveTime() {
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
}
abstract completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
): Promise<void>
abstract translate(message: Message, assistant: Assistant): Promise<string>
abstract summaries(messages: Message[], assistant: Assistant): Promise<string | null>
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
abstract check(): Promise<{ valid: boolean; error: Error | null }>
abstract models(): Promise<OpenAI.Models.Model[]>
}

View File

@ -0,0 +1,170 @@
import { GoogleGenerativeAI } from '@google/generative-ai'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import axios from 'axios'
import { isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai'
import BaseProvider from './BaseProvider'
export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenerativeAI
constructor(provider: Provider) {
super(provider)
this.sdk = new GoogleGenerativeAI(provider.apiKey)
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const userMessages = takeRight(messages, contextCount + 1).map((message) => {
return {
role: message.role,
content: message.content
}
})
const geminiModel = this.sdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
})
const userLastMessage = userMessages.pop()
const chat = geminiModel.startChat({
history: userMessages.map((message) => ({
role: message.role === 'user' ? 'user' : 'model',
parts: [{ text: message.content }]
}))
})
const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!)
for await (const chunk of userMessagesStream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
onChunk({
text: chunk.text(),
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
}
})
}
}
async translate(message: Message, assistant: Assistant) {
const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel
const geminiModel = this.sdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
})
const { response } = await geminiModel.generateContent(message.content)
return response.text()
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: message.content
}))
const systemMessage = {
role: 'system',
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
}
const geminiModel = this.sdk.getGenerativeModel({
model: model.id,
systemInstruction: systemMessage.content,
generationConfig: {
temperature: assistant?.settings?.temperature
}
})
const lastUserMessage = userMessages.pop()
const chat = await geminiModel.startChat({
history: userMessages.map((message) => ({
role: message.role === 'user' ? 'user' : 'model',
parts: [{ text: message.content }]
}))
})
const { response } = await chat.sendMessage(lastUserMessage?.content!)
return response.text()
}
public async suggestions(): Promise<Suggestion[]> {
return []
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0]
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try {
const geminiModel = this.sdk.getGenerativeModel({ model: body.model })
const result = await geminiModel.generateContent(body.messages[0].content)
return {
valid: !isEmpty(result.response.text()),
error: null
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
public async models(): Promise<OpenAI.Models.Model[]> {
try {
const api = this.provider.apiHost + '/v1beta/models'
const { data } = await axios.get(api, { params: { key: this.provider.apiKey } })
return data.models.map(
(m: any) =>
({
id: m.name.replace('models/', ''),
name: m.displayName,
description: m.description,
object: 'model',
created: Date.now(),
owned_by: 'gemini'
}) as OpenAI.Models.Model
)
} catch (error) {
return []
}
}
}

View File

@ -0,0 +1,185 @@
import { isLocalAi } from '@renderer/config/env'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { fileToBase64, removeQuotes } from '@renderer/utils'
import { first, takeRight } from 'lodash'
import OpenAI from 'openai'
import {
ChatCompletionContentPart,
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam
} from 'openai/resources'
import BaseProvider from './BaseProvider'
export default class OpenAIProvider extends BaseProvider {
private sdk: OpenAI
constructor(provider: Provider) {
super(provider)
this.sdk = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: provider.apiKey,
baseURL: this.getBaseURL()
})
}
private async getMessageContent(message: Message): Promise<string | ChatCompletionContentPart[]> {
const file = first(message.files)
if (!file) {
return message.content
}
if (file.type.includes('image')) {
return [
{ type: 'text', text: message.content },
{
type: 'image_url',
image_url: {
url: await fileToBase64(file)
}
}
]
}
return message.content
}
async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
const userMessages: ChatCompletionMessageParam[] = []
for (const message of takeRight(messages, contextCount + 1)) {
userMessages.push({
role: message.role,
content: await this.getMessageContent(message)
} as ChatCompletionMessageParam)
}
// @ts-ignore key is not typed
const stream = await this.sdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
stream: true,
temperature: assistant?.settings?.temperature,
max_tokens: maxTokens,
keep_alive: this.keepAliveTime
})
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
}
}
async translate(message: Message, assistant: Assistant) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const messages = [
{ role: 'system', content: assistant.prompt },
{ role: 'user', content: message.content }
]
// @ts-ignore key is not typed
const response = await this.sdk.chat.completions.create({
model: model.id,
messages: messages as ChatCompletionMessageParam[],
stream: false,
keep_alive: this.keepAliveTime
})
return response.choices[0].message?.content || ''
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: message.content
}))
const systemMessage = {
role: 'system',
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
}
// @ts-ignore key is not typed
const response = await this.sdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...(isLocalAi ? [first(userMessages)] : userMessages)] as ChatCompletionMessageParam[],
stream: false,
max_tokens: 50,
keep_alive: this.keepAliveTime
})
return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '')
}
async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
const model = assistant.model
if (!model) {
return []
}
const response: any = await this.sdk.request({
method: 'post',
path: '/advice_questions',
body: {
messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })),
model: model.id,
max_tokens: 0,
temperature: 0,
n: 0
}
})
return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || []
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0]
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try {
const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
return {
valid: Boolean(response?.choices[0].message),
error: null
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
public async models(): Promise<OpenAI.Models.Model[]> {
try {
const response = await this.sdk.models.list()
return response.data
} catch (error) {
return []
}
}
}

View File

@ -0,0 +1,19 @@
import { Provider } from '@renderer/types'
import AnthropicProvider from './AnthropicProvider'
import BaseProvider from './BaseProvider'
import GeminiProvider from './GeminiProvider'
import OpenAIProvider from './OpenAIProvider'
export default class ProviderFactory {
static create(provider: Provider): BaseProvider {
switch (provider.id) {
case 'anthropic':
return new AnthropicProvider(provider)
case 'gemini':
return new GeminiProvider(provider)
default:
return new OpenAIProvider(provider)
}
}
}

View File

@ -1,358 +0,0 @@
import Anthropic from '@anthropic-ai/sdk'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
import { GoogleGenerativeAI } from '@google/generative-ai'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { isLocalAi } from '@renderer/config/env'
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { removeQuotes } from '@renderer/utils'
import axios from 'axios'
import { first, isEmpty, sum, takeRight } from 'lodash'
import OpenAI from 'openai'
import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from './assistant'
import { EVENT_NAMES } from './event'
export default class ProviderSDK {
provider: Provider
openaiSdk: OpenAI
anthropicSdk: Anthropic
geminiSdk: GoogleGenerativeAI
constructor(provider: Provider) {
this.provider = provider
const host = provider.apiHost
const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/`
this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL })
this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL })
this.geminiSdk = new GoogleGenerativeAI(provider.apiKey)
}
private get isAnthropic() {
return this.provider.id === 'anthropic'
}
private get isGemini() {
return this.provider.id === 'gemini'
}
private get keepAliveTime() {
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
const userMessages = takeRight(messages, contextCount + 1).map((message) => {
return {
role: message.role,
content: message.content
}
})
if (this.isAnthropic) {
return new Promise<void>((resolve, reject) => {
const stream = this.anthropicSdk.messages
.stream({
model: model.id,
messages: userMessages.filter(Boolean) as MessageParam[],
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: assistant?.settings?.temperature,
system: assistant.prompt,
stream: true
})
.on('text', (text) => {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
resolve()
return stream.controller.abort()
}
onChunk({ text })
})
.on('finalMessage', (message) => {
onChunk({
text: '',
usage: {
prompt_tokens: message.usage.input_tokens,
completion_tokens: message.usage.output_tokens,
total_tokens: sum(Object.values(message.usage))
}
})
resolve()
})
.on('error', (error) => reject(error))
})
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
})
const userLastMessage = userMessages.pop()
const chat = geminiModel.startChat({
history: userMessages.map((message) => ({
role: message.role === 'user' ? 'user' : 'model',
parts: [{ text: message.content }]
}))
})
const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!)
for await (const chunk of userMessagesStream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
onChunk({
text: chunk.text(),
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
}
})
}
return
}
const _userMessages = takeRight(messages, contextCount + 1).map((message) => {
return {
role: message.role,
content: message.images
? [
{ type: 'text', text: message.content },
...message.images!.map((image) => ({ type: 'image_url', image_url: image }))
]
: message.content
}
})
// @ts-ignore key is not typed
const stream = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ..._userMessages].filter(Boolean) as ChatCompletionMessageParam[],
stream: true,
temperature: assistant?.settings?.temperature,
max_tokens: maxTokens,
keep_alive: this.keepAliveTime
})
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
}
}
public async translate(message: Message, assistant: Assistant) {
const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel
const messages = [
{ role: 'system', content: assistant.prompt },
{ role: 'user', content: message.content }
]
if (this.isAnthropic) {
const response = await this.anthropicSdk.messages.create({
model: model.id,
messages: messages.filter((m) => m.role === 'user') as MessageParam[],
max_tokens: 4096,
temperature: assistant?.settings?.temperature,
system: assistant.prompt,
stream: false
})
return response.content[0].type === 'text' ? response.content[0].text : ''
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
})
const { response } = await geminiModel.generateContent(message.content)
return response.text()
}
// @ts-ignore key is not typed
const response = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: messages as ChatCompletionMessageParam[],
stream: false,
keep_alive: this.keepAliveTime
})
return response.choices[0].message?.content || ''
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: message.content
}))
const systemMessage = {
role: 'system',
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
}
if (this.isAnthropic) {
const message = await this.anthropicSdk.messages.create({
messages: userMessages as Anthropic.Messages.MessageParam[],
model: model.id,
system: systemMessage.content,
stream: false,
max_tokens: 4096
})
return message.content[0].type === 'text' ? message.content[0].text : null
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
systemInstruction: systemMessage.content,
generationConfig: {
temperature: assistant?.settings?.temperature
}
})
const lastUserMessage = userMessages.pop()
const chat = await geminiModel.startChat({
history: userMessages.map((message) => ({
role: message.role === 'user' ? 'user' : 'model',
parts: [{ text: message.content }]
}))
})
const { response } = await chat.sendMessage(lastUserMessage?.content!)
return response.text()
}
// @ts-ignore key is not typed
const response = await this.openaiSdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...(isLocalAi ? [first(userMessages)] : userMessages)] as ChatCompletionMessageParam[],
stream: false,
max_tokens: 50,
keep_alive: this.keepAliveTime
})
return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '')
}
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
const model = assistant.model
if (!model) {
return []
}
const response: any = await this.openaiSdk.request({
method: 'post',
path: '/advice_questions',
body: {
messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })),
model: model.id,
max_tokens: 0,
temperature: 0,
n: 0
}
})
return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || []
}
public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0]
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try {
if (this.isAnthropic) {
const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming)
return {
valid: message.content.length > 0,
error: null
}
}
if (this.isGemini) {
const geminiModel = this.geminiSdk.getGenerativeModel({ model: body.model })
const result = await geminiModel.generateContent(body.messages[0].content)
return {
valid: !isEmpty(result.response.text()),
error: null
}
}
const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
return {
valid: Boolean(response?.choices[0].message),
error: null
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
public async models(): Promise<OpenAI.Models.Model[]> {
try {
if (this.isAnthropic) {
return []
}
if (this.isGemini) {
const api = this.provider.apiHost + '/v1beta/models'
const { data } = await axios.get(api, { params: { key: this.provider.apiKey } })
return data.models.map(
(m: any) =>
({
id: m.name.replace('models/', ''),
name: m.displayName,
description: m.description,
object: 'model',
created: Date.now(),
owned_by: 'gemini'
}) as OpenAI.Models.Model
)
}
const response = await this.openaiSdk.models.list()
return response.data
} catch (error) {
return []
}
}
}

View File

@ -6,6 +6,7 @@ import { uuid } from '@renderer/utils'
import dayjs from 'dayjs'
import { isEmpty } from 'lodash'
import AiProvider from '../providers/AiProvider'
import {
getAssistantProvider,
getDefaultModel,
@ -15,7 +16,6 @@ import {
} from './assistant'
import { EVENT_NAMES, EventEmitter } from './event'
import { filterMessages } from './messages'
import ProviderSDK from './ProviderSDK'
export async function fetchChatCompletion({
messages,
@ -33,7 +33,7 @@ export async function fetchChatCompletion({
const provider = getAssistantProvider(assistant)
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const providerSdk = new ProviderSDK(provider)
const AI = new AiProvider(provider)
store.dispatch(setGenerating(true))
@ -61,7 +61,7 @@ export async function fetchChatCompletion({
}, 1000)
try {
await providerSdk.completions(filterMessages(messages), assistant, ({ text, usage }) => {
await AI.completions(filterMessages(messages), assistant, ({ text, usage }) => {
message.content = message.content + text || ''
message.usage = usage
onResponse({ ...message, status: 'pending' })
@ -103,10 +103,10 @@ export async function fetchTranslate({ message, assistant }: { message: Message;
return ''
}
const providerSdk = new ProviderSDK(provider)
const AI = new AiProvider(provider)
try {
return await providerSdk.translate(message, assistant)
return await AI.translate(message, assistant)
} catch (error: any) {
return ''
}
@ -120,10 +120,10 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
return null
}
const providerSdk = new ProviderSDK(provider)
const AI = new AiProvider(provider)
try {
return await providerSdk.summaries(filterMessages(messages), assistant)
return await AI.summaries(filterMessages(messages), assistant)
} catch (error: any) {
return null
}
@ -136,10 +136,8 @@ export async function fetchSuggestions({
messages: Message[]
assistant: Assistant
}): Promise<Suggestion[]> {
console.debug('fetchSuggestions', messages, assistant)
const provider = getAssistantProvider(assistant)
const providerSdk = new ProviderSDK(provider)
console.debug('fetchSuggestions', provider)
const AI = new AiProvider(provider)
const model = assistant.model
if (!model) {
@ -155,7 +153,7 @@ export async function fetchSuggestions({
}
try {
return await providerSdk.suggestions(messages, assistant)
return await AI.suggestions(messages, assistant)
} catch (error: any) {
return []
}
@ -183,9 +181,9 @@ export async function checkApi(provider: Provider) {
return false
}
const providerSdk = new ProviderSDK(provider)
const AI = new AiProvider(provider)
const { valid } = await providerSdk.check()
const { valid } = await AI.check()
window.message[valid ? 'success' : 'error']({
key: 'api-check',
@ -204,10 +202,10 @@ function hasApiKey(provider: Provider) {
}
export async function fetchModels(provider: Provider) {
const providerSdk = new ProviderSDK(provider)
const AI = new AiProvider(provider)
try {
return await providerSdk.models()
return await AI.models()
} catch (error) {
return []
}

View File

@ -20,14 +20,15 @@ export type AssistantSettings = {
export type Message = {
id: string
assistantId: string
role: 'user' | 'assistant'
content: string
images?: string[]
assistantId: string
topicId: string
modelId?: string
createdAt: string
status: 'sending' | 'pending' | 'success' | 'paused' | 'error'
modelId?: string
files?: File[]
images?: string[]
usage?: OpenAI.Completions.CompletionUsage
type?: 'text' | '@'
}

View File

@ -223,3 +223,18 @@ export function getBriefInfo(text: string, maxLength: number = 50): string {
// 截取前面的内容,并在末尾添加 "..."
return truncatedText + '...'
}
export async function fileToBase64(file: File): Promise<string> {
return new Promise((resolve, reject) => {
try {
const reader = new FileReader()
reader.onload = (e: ProgressEvent<FileReader>) => {
const result = e.target?.result
resolve(typeof result === 'string' ? result : '')
}
reader.readAsDataURL(file)
} catch (error: any) {
reject(error)
}
})
}