refactor: provider sdk
This commit is contained in:
parent
8d7b10d21e
commit
4d6cbf5073
@ -5,13 +5,13 @@ import { PersistGate } from 'redux-persist/integration/react'
|
||||
|
||||
import Sidebar from './components/app/Sidebar'
|
||||
import TopViewContainer from './components/TopView'
|
||||
import AntdProvider from './context/AntdProvider'
|
||||
import { ThemeProvider } from './context/ThemeProvider'
|
||||
import AgentsPage from './pages/agents/AgentsPage'
|
||||
import AppsPage from './pages/apps/AppsPage'
|
||||
import HomePage from './pages/home/HomePage'
|
||||
import SettingsPage from './pages/settings/SettingsPage'
|
||||
import TranslatePage from './pages/translate/TranslatePage'
|
||||
import AntdProvider from './providers/AntdProvider'
|
||||
import { ThemeProvider } from './providers/ThemeProvider'
|
||||
|
||||
function App(): JSX.Element {
|
||||
return (
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { useTheme } from '@renderer/providers/ThemeProvider'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { FC, useEffect, useRef } from 'react'
|
||||
|
||||
interface Props {
|
||||
|
||||
@ -95,7 +95,7 @@ const PopupContainer: React.FC<Props> = ({ app, resolve }) => {
|
||||
maskClosable={false}
|
||||
closeIcon={null}
|
||||
style={{ marginLeft: 'var(--sidebar-width)' }}>
|
||||
<webview src={app.url} ref={webviewRef} style={WebviewStyle} allowpopups={true} />
|
||||
<webview src={app.url} ref={webviewRef} style={WebviewStyle} allowpopups={'true' as any} />
|
||||
</Drawer>
|
||||
)
|
||||
}
|
||||
|
||||
@ -6,7 +6,7 @@ import { useRuntime, useShowAssistants } from '@renderer/hooks/useStore'
|
||||
import { Avatar } from 'antd'
|
||||
import { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Link, useLocation } from 'react-router-dom'
|
||||
import { useLocation, useNavigate } from 'react-router-dom'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import UserPopup from '../Popups/UserPopup'
|
||||
@ -20,6 +20,7 @@ const Sidebar: FC = () => {
|
||||
const { toggleShowAssistants } = useShowAssistants()
|
||||
const { generating } = useRuntime()
|
||||
const { t } = useTranslation()
|
||||
const navigate = useNavigate()
|
||||
|
||||
const isRoute = (path: string): string => (pathname === path ? 'active' : '')
|
||||
|
||||
@ -28,15 +29,13 @@ const Sidebar: FC = () => {
|
||||
const to = (path: string) => {
|
||||
if (generating) {
|
||||
window.message.warning({ content: t('message.switch.disabled'), key: 'switch-assistant' })
|
||||
return '/'
|
||||
return
|
||||
}
|
||||
return path
|
||||
navigate(path)
|
||||
}
|
||||
|
||||
const onToggleShowAssistants = () => {
|
||||
if (pathname === '/') {
|
||||
toggleShowAssistants()
|
||||
}
|
||||
pathname === '/' ? toggleShowAssistants() : navigate('/')
|
||||
}
|
||||
|
||||
return (
|
||||
@ -44,22 +43,22 @@ const Sidebar: FC = () => {
|
||||
<AvatarImg src={avatar || AppLogo} draggable={false} className="nodrag" onClick={onEditUser} />
|
||||
<MainMenus>
|
||||
<Menus>
|
||||
<StyledLink to={to('/')} onClick={onToggleShowAssistants}>
|
||||
<StyledLink onClick={onToggleShowAssistants}>
|
||||
<Icon className={isRoute('/')}>
|
||||
<i className="iconfont icon-chat"></i>
|
||||
</Icon>
|
||||
</StyledLink>
|
||||
<StyledLink to={to('/agents')}>
|
||||
<StyledLink onClick={() => to('/agents')}>
|
||||
<Icon className={isRoute('/agents')}>
|
||||
<i className="iconfont icon-business-smart-assistant"></i>
|
||||
</Icon>
|
||||
</StyledLink>
|
||||
<StyledLink to={to('/translate')}>
|
||||
<StyledLink onClick={() => to('/translate')}>
|
||||
<Icon className={isRoute('/translate')}>
|
||||
<TranslationOutlined />
|
||||
</Icon>
|
||||
</StyledLink>
|
||||
<StyledLink to={to('/apps')}>
|
||||
<StyledLink onClick={() => to('/apps')}>
|
||||
<Icon className={isRoute('/apps')}>
|
||||
<i className="iconfont icon-appstore"></i>
|
||||
</Icon>
|
||||
@ -67,7 +66,7 @@ const Sidebar: FC = () => {
|
||||
</Menus>
|
||||
</MainMenus>
|
||||
<Menus>
|
||||
<StyledLink to={to(isLocalAi ? '/settings/assistant' : '/settings/provider')}>
|
||||
<StyledLink onClick={() => to(isLocalAi ? '/settings/assistant' : '/settings/provider')}>
|
||||
<Icon className={pathname.startsWith('/settings') ? 'active' : ''}>
|
||||
<i className="iconfont icon-setting"></i>
|
||||
</Icon>
|
||||
@ -149,7 +148,7 @@ const Icon = styled.div`
|
||||
}
|
||||
`
|
||||
|
||||
const StyledLink = styled(Link)`
|
||||
const StyledLink = styled.div`
|
||||
text-decoration: none;
|
||||
-webkit-app-region: none;
|
||||
&* {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import MinApp from '@renderer/components/MinApp'
|
||||
import { useTheme } from '@renderer/providers/ThemeProvider'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { MinAppType } from '@renderer/types'
|
||||
import { FC } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import { ArrowRightOutlined, CopyOutlined, DeleteOutlined, EditOutlined } from '@ant-design/icons'
|
||||
import { ArrowLeftOutlined } from '@ant-design/icons'
|
||||
import DragableList from '@renderer/components/DragableList'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import AssistantSettingPopup from '@renderer/components/Popups/AssistantSettingPopup'
|
||||
@ -104,10 +103,6 @@ const Assistants: FC<Props> = ({
|
||||
if (showTopics) {
|
||||
return (
|
||||
<Container>
|
||||
<NavigtaionHeader onClick={() => setShowTopics(false)}>
|
||||
<ArrowLeftOutlined />
|
||||
{t('common.back')}
|
||||
</NavigtaionHeader>
|
||||
<Topics assistant={activeAssistant} activeTopic={activeTopic} setActiveTopic={setActiveTopic} />
|
||||
</Container>
|
||||
)
|
||||
@ -142,6 +137,7 @@ const Container = styled.div`
|
||||
height: calc(100vh - var(--navbar-height));
|
||||
overflow-y: auto;
|
||||
padding: 10px 0;
|
||||
padding-bottom: 0;
|
||||
`
|
||||
|
||||
const AssistantItem = styled.div`
|
||||
@ -155,40 +151,24 @@ const AssistantItem = styled.div`
|
||||
cursor: pointer;
|
||||
font-family: Ubuntu;
|
||||
.anticon {
|
||||
display: none;
|
||||
opacity: 0;
|
||||
color: var(--color-text-3);
|
||||
transition: opacity 0.2s ease-in-out;
|
||||
}
|
||||
&:hover {
|
||||
background-color: var(--color-background-soft);
|
||||
.count {
|
||||
display: none;
|
||||
}
|
||||
.anticon {
|
||||
display: block;
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
&.active {
|
||||
background-color: var(--color-background-mute);
|
||||
cursor: pointer;
|
||||
.name {
|
||||
font-weight: 500;
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
const NavigtaionHeader = styled.div`
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
justify-content: flex-start;
|
||||
gap: 10px;
|
||||
padding: 0 5px;
|
||||
cursor: pointer;
|
||||
color: var(--color-text-3);
|
||||
margin: 10px;
|
||||
margin-top: 0;
|
||||
`
|
||||
|
||||
const AssistantName = styled.div`
|
||||
color: var(--color-text);
|
||||
display: -webkit-box;
|
||||
|
||||
@ -29,7 +29,7 @@ const Chat: FC<Props> = (props) => {
|
||||
setShowSetting={setShowSetting}
|
||||
/>
|
||||
</Main>
|
||||
{showSetting && <Settings assistant={assistant} />}
|
||||
{showSetting && <Settings assistant={assistant} onClose={() => setShowSetting(false)} />}
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import { ArrowLeftOutlined } from '@ant-design/icons'
|
||||
import { Navbar, NavbarCenter, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar'
|
||||
import { isMac, isWindows } from '@renderer/config/constant'
|
||||
import { useAssistants, useDefaultAssistant } from '@renderer/hooks/useAssistant'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { useAssistant, useAssistants, useDefaultAssistant } from '@renderer/hooks/useAssistant'
|
||||
import { useShowAssistants } from '@renderer/hooks/useStore'
|
||||
import { useActiveTopic } from '@renderer/hooks/useTopic'
|
||||
import { useTheme } from '@renderer/providers/ThemeProvider'
|
||||
import { getDefaultTopic } from '@renderer/services/assistant'
|
||||
import { Assistant, Topic } from '@renderer/types'
|
||||
import { uuid } from '@renderer/utils'
|
||||
import { Switch } from 'antd'
|
||||
@ -29,6 +31,7 @@ const HomePage: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const { activeTopic, setActiveTopic } = useActiveTopic(activeAssistant)
|
||||
const { addTopic } = useAssistant(activeAssistant.id)
|
||||
|
||||
_activeAssistant = activeAssistant
|
||||
_showTopics = showTopics
|
||||
@ -39,10 +42,16 @@ const HomePage: FC = () => {
|
||||
setActiveAssistant(assistant)
|
||||
}
|
||||
|
||||
const onCreateAssistant = async () => {
|
||||
const onCreate = async () => {
|
||||
if (showTopics) {
|
||||
const topic = getDefaultTopic()
|
||||
addTopic(topic)
|
||||
setActiveTopic(topic)
|
||||
} else {
|
||||
const assistant = await AddAssistantPopup.show()
|
||||
assistant && setActiveAssistant(assistant)
|
||||
}
|
||||
}
|
||||
|
||||
const onSetActiveTopic = (topic: Topic) => {
|
||||
setActiveTopic(topic)
|
||||
@ -53,8 +62,13 @@ const HomePage: FC = () => {
|
||||
<Container>
|
||||
<Navbar>
|
||||
{showAssistants && (
|
||||
<NavbarLeft style={{ justifyContent: 'flex-end', borderRight: 'none', padding: '0 8px' }}>
|
||||
<NewButton onClick={onCreateAssistant}>
|
||||
<NavbarLeft
|
||||
style={{ justifyContent: 'space-between', alignItems: 'center', borderRight: 'none', padding: '0 8px' }}>
|
||||
<NavigtaionBack onClick={() => setShowTopics(false)} style={{ opacity: showTopics ? 1 : 0 }}>
|
||||
<ArrowLeftOutlined />
|
||||
{t('common.back')}
|
||||
</NavigtaionBack>
|
||||
<NewButton onClick={onCreate}>
|
||||
<i className="iconfont icon-a-addchat"></i>
|
||||
</NewButton>
|
||||
</NavbarLeft>
|
||||
@ -103,6 +117,23 @@ const ContentContainer = styled.div`
|
||||
background-color: var(--color-background);
|
||||
`
|
||||
|
||||
const NavigtaionBack = styled.div`
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
justify-content: flex-start;
|
||||
gap: 10px;
|
||||
cursor: pointer;
|
||||
margin-left: ${isMac ? '16px' : 0};
|
||||
-webkit-app-region: none;
|
||||
transition: all 0.2s ease-in-out;
|
||||
color: var(--color-icon);
|
||||
transition: opacity 0.2s ease-in-out;
|
||||
&:hover {
|
||||
color: var(--color-text);
|
||||
}
|
||||
`
|
||||
|
||||
const AssistantName = styled.span`
|
||||
margin-left: 5px;
|
||||
margin-right: 10px;
|
||||
|
||||
@ -4,12 +4,12 @@ import { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
interface Props {
|
||||
images: string[]
|
||||
setImages: (images: string[]) => void
|
||||
files: File[]
|
||||
setFiles: (files: File[]) => void
|
||||
ToolbarButton: any
|
||||
}
|
||||
|
||||
const AttachmentButton: FC<Props> = ({ images, setImages, ToolbarButton }) => {
|
||||
const AttachmentButton: FC<Props> = ({ files, setFiles, ToolbarButton }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
@ -19,22 +19,8 @@ const AttachmentButton: FC<Props> = ({ images, setImages, ToolbarButton }) => {
|
||||
accept="image/*"
|
||||
itemRender={() => null}
|
||||
maxCount={1}
|
||||
onChange={async ({ file }) => {
|
||||
try {
|
||||
const _file = file.originFileObj as File
|
||||
const reader = new FileReader()
|
||||
reader.onload = (e: ProgressEvent<FileReader>) => {
|
||||
const result = e.target?.result
|
||||
if (typeof result === 'string') {
|
||||
setImages([result])
|
||||
}
|
||||
}
|
||||
reader.readAsDataURL(_file)
|
||||
} catch (error: any) {
|
||||
window.message.error(error.message)
|
||||
}
|
||||
}}>
|
||||
<ToolbarButton type="text" className={images.length ? 'active' : ''}>
|
||||
onChange={async ({ file }) => file?.originFileObj && setFiles([file.originFileObj as File])}>
|
||||
<ToolbarButton type="text" className={files.length ? 'active' : ''}>
|
||||
<PaperClipOutlined style={{ rotate: '135deg' }} />
|
||||
</ToolbarButton>
|
||||
</Upload>
|
||||
|
||||
@ -25,7 +25,6 @@ import { CSSProperties, FC, useCallback, useEffect, useMemo, useRef, useState }
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
import AttachmentButton from './AttachmentButton'
|
||||
import SendMessageButton from './SendMessageButton'
|
||||
|
||||
interface Props {
|
||||
@ -46,7 +45,7 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic, showSetting, setShowSe
|
||||
const [estimateTokenCount, setEstimateTokenCount] = useState(0)
|
||||
const generating = useAppSelector((state) => state.runtime.generating)
|
||||
const textareaRef = useRef<TextAreaRef>(null)
|
||||
const [images, setImages] = useState<string[]>([])
|
||||
const [files, setFiles] = useState<File[]>([])
|
||||
const { t } = useTranslation()
|
||||
const containerRef = useRef(null)
|
||||
|
||||
@ -71,18 +70,19 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic, showSetting, setShowSe
|
||||
status: 'success'
|
||||
}
|
||||
|
||||
if (images.length > 0) {
|
||||
message.images = images
|
||||
if (files.length > 0) {
|
||||
message.files = files
|
||||
}
|
||||
|
||||
EventEmitter.emit(EVENT_NAMES.SEND_MESSAGE, message)
|
||||
|
||||
setText('')
|
||||
setImages([])
|
||||
setFiles([])
|
||||
setTimeout(() => setText(''), 500)
|
||||
setTimeout(() => resizeTextArea(), 0)
|
||||
|
||||
setExpend(false)
|
||||
}, [assistant.id, assistant.topics, generating, images, text])
|
||||
}, [assistant.id, assistant.topics, generating, files, text])
|
||||
|
||||
const inputTokenCount = useMemo(() => estimateInputTokenCount(text), [text])
|
||||
|
||||
@ -226,7 +226,7 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic, showSetting, setShowSe
|
||||
<ControlOutlined />
|
||||
</ToolbarButton>
|
||||
</Tooltip>
|
||||
<AttachmentButton images={images} setImages={setImages} ToolbarButton={ToolbarButton} />
|
||||
{/* <AttachmentButton files={files} setFiles={setFiles} ToolbarButton={ToolbarButton} /> */}
|
||||
<Tooltip placement="top" title={expended ? t('chat.input.collapse') : t('chat.input.expand')} arrow>
|
||||
<ToolbarButton type="text" onClick={onToggleExpended}>
|
||||
{expended ? <FullscreenExitOutlined /> : <FullscreenOutlined />}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { CheckOutlined } from '@ant-design/icons'
|
||||
import CopyIcon from '@renderer/components/Icons/CopyIcon'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { initMermaid } from '@renderer/init'
|
||||
import { useTheme } from '@renderer/providers/ThemeProvider'
|
||||
import { ThemeMode } from '@renderer/store/settings'
|
||||
import React, { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { CheckOutlined, QuestionCircleOutlined, ReloadOutlined } from '@ant-design/icons'
|
||||
import { CheckOutlined, CloseOutlined, QuestionCircleOutlined, ReloadOutlined } from '@ant-design/icons'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import { DEFAULT_CONEXTCOUNT, DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE } from '@renderer/config/constant'
|
||||
import { useAssistant } from '@renderer/hooks/useAssistant'
|
||||
@ -19,6 +19,7 @@ import styled from 'styled-components'
|
||||
|
||||
interface Props {
|
||||
assistant: Assistant
|
||||
onClose: () => void
|
||||
}
|
||||
|
||||
const SettingsTab: FC<Props> = (props) => {
|
||||
@ -87,6 +88,10 @@ const SettingsTab: FC<Props> = (props) => {
|
||||
|
||||
return (
|
||||
<Container>
|
||||
<SettingsHeader>
|
||||
{t('settings.title')}
|
||||
<CloseIcon onClick={props.onClose} />
|
||||
</SettingsHeader>
|
||||
<SettingSubtitle>
|
||||
{t('settings.messages.model.title')}{' '}
|
||||
<Tooltip title={t('chat.settings.reset')}>
|
||||
@ -259,4 +264,21 @@ const SettingRowTitleSmall = styled(SettingRowTitle)`
|
||||
font-size: 13px;
|
||||
`
|
||||
|
||||
const SettingsHeader = styled.div`
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
padding: 10px 15px;
|
||||
border-bottom: 0.5px solid var(--color-border);
|
||||
margin-left: -15px;
|
||||
margin-right: -15px;
|
||||
`
|
||||
|
||||
const CloseIcon = styled(CloseOutlined)`
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
color: var(--color-text-3);
|
||||
`
|
||||
|
||||
export default SettingsTab
|
||||
|
||||
@ -8,8 +8,8 @@ import {
|
||||
} from '@ant-design/icons'
|
||||
import { getModelLogo } from '@renderer/config/provider'
|
||||
import { PROVIDER_CONFIG } from '@renderer/config/provider'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { useProvider } from '@renderer/hooks/useProvider'
|
||||
import { useTheme } from '@renderer/providers/ThemeProvider'
|
||||
import { checkApi } from '@renderer/services/api'
|
||||
import { Provider } from '@renderer/types'
|
||||
import { Avatar, Button, Card, Divider, Flex, Input, Space, Switch } from 'antd'
|
||||
|
||||
40
src/renderer/src/providers/AiProvider.ts
Normal file
40
src/renderer/src/providers/AiProvider.ts
Normal file
@ -0,0 +1,40 @@
|
||||
import BaseProvider from '@renderer/providers/BaseProvider'
|
||||
import ProviderFactory from '@renderer/providers/ProviderFactory'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
export default class AiProvider {
|
||||
private sdk: BaseProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.sdk = ProviderFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
): Promise<void> {
|
||||
return this.sdk.completions(messages, assistant, onChunk)
|
||||
}
|
||||
|
||||
public async translate(message: Message, assistant: Assistant): Promise<string> {
|
||||
return this.sdk.translate(message, assistant)
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
return this.sdk.summaries(messages, assistant)
|
||||
}
|
||||
|
||||
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||
return this.sdk.suggestions(messages, assistant)
|
||||
}
|
||||
|
||||
public async check(): Promise<{ valid: boolean; error: Error | null }> {
|
||||
return this.sdk.check()
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return this.sdk.models()
|
||||
}
|
||||
}
|
||||
143
src/renderer/src/providers/AnthropicProvider.ts
Normal file
143
src/renderer/src/providers/AnthropicProvider.ts
Normal file
@ -0,0 +1,143 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES } from '@renderer/services/event'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { sum, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import BaseProvider from './BaseProvider'
|
||||
|
||||
export default class AnthropicProvider extends BaseProvider {
|
||||
private sdk: Anthropic
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() })
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessages = takeRight(messages, contextCount + 1).map((message) => {
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}
|
||||
})
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
const stream = this.sdk.messages
|
||||
.stream({
|
||||
model: model.id,
|
||||
messages: userMessages.filter(Boolean) as MessageParam[],
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
system: assistant.prompt,
|
||||
stream: true
|
||||
})
|
||||
.on('text', (text) => {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||
resolve()
|
||||
return stream.controller.abort()
|
||||
}
|
||||
onChunk({ text })
|
||||
})
|
||||
.on('finalMessage', (message) => {
|
||||
onChunk({
|
||||
text: '',
|
||||
usage: {
|
||||
prompt_tokens: message.usage.input_tokens,
|
||||
completion_tokens: message.usage.output_tokens,
|
||||
total_tokens: sum(Object.values(message.usage))
|
||||
}
|
||||
})
|
||||
resolve()
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
})
|
||||
}
|
||||
|
||||
public async translate(message: Message, assistant: Assistant) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const messages = [
|
||||
{ role: 'system', content: assistant.prompt },
|
||||
{ role: 'user', content: message.content }
|
||||
]
|
||||
|
||||
const response = await this.sdk.messages.create({
|
||||
model: model.id,
|
||||
messages: messages.filter((m) => m.role === 'user') as MessageParam[],
|
||||
max_tokens: 4096,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
system: assistant.prompt,
|
||||
stream: false
|
||||
})
|
||||
|
||||
return response.content[0].type === 'text' ? response.content[0].text : ''
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
|
||||
}
|
||||
|
||||
const message = await this.sdk.messages.create({
|
||||
messages: userMessages as Anthropic.Messages.MessageParam[],
|
||||
model: model.id,
|
||||
system: systemMessage.content,
|
||||
stream: false,
|
||||
max_tokens: 4096
|
||||
})
|
||||
|
||||
return message.content[0].type === 'text' ? message.content[0].text : null
|
||||
}
|
||||
|
||||
public async suggestions(): Promise<Suggestion[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
public async check(): Promise<{ valid: boolean; error: Error | null }> {
|
||||
const model = this.provider.models[0]
|
||||
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
max_tokens: 100,
|
||||
stream: false
|
||||
}
|
||||
|
||||
try {
|
||||
const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming)
|
||||
return {
|
||||
valid: message.content.length > 0,
|
||||
error: null
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return []
|
||||
}
|
||||
}
|
||||
33
src/renderer/src/providers/BaseProvider.ts
Normal file
33
src/renderer/src/providers/BaseProvider.ts
Normal file
@ -0,0 +1,33 @@
|
||||
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
export default abstract class BaseProvider {
|
||||
protected provider: Provider
|
||||
protected host: string
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
this.host = this.getBaseURL()
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
const host = this.provider.apiHost
|
||||
return host.endsWith('/') ? host : `${host}/v1/`
|
||||
}
|
||||
|
||||
public get keepAliveTime() {
|
||||
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
|
||||
}
|
||||
|
||||
abstract completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
): Promise<void>
|
||||
abstract translate(message: Message, assistant: Assistant): Promise<string>
|
||||
abstract summaries(messages: Message[], assistant: Assistant): Promise<string | null>
|
||||
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
|
||||
abstract check(): Promise<{ valid: boolean; error: Error | null }>
|
||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
||||
}
|
||||
170
src/renderer/src/providers/GeminiProvider.ts
Normal file
170
src/renderer/src/providers/GeminiProvider.ts
Normal file
@ -0,0 +1,170 @@
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES } from '@renderer/services/event'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import axios from 'axios'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import BaseProvider from './BaseProvider'
|
||||
|
||||
export default class GeminiProvider extends BaseProvider {
|
||||
private sdk: GoogleGenerativeAI
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.sdk = new GoogleGenerativeAI(provider.apiKey)
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessages = takeRight(messages, contextCount + 1).map((message) => {
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}
|
||||
})
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel({
|
||||
model: model.id,
|
||||
systemInstruction: assistant.prompt,
|
||||
generationConfig: {
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
})
|
||||
|
||||
const userLastMessage = userMessages.pop()
|
||||
|
||||
const chat = geminiModel.startChat({
|
||||
history: userMessages.map((message) => ({
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: [{ text: message.content }]
|
||||
}))
|
||||
})
|
||||
|
||||
const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!)
|
||||
|
||||
for await (const chunk of userMessagesStream.stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
onChunk({
|
||||
text: chunk.text(),
|
||||
usage: {
|
||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async translate(message: Message, assistant: Assistant) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const model = assistant.model || defaultModel
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel({
|
||||
model: model.id,
|
||||
systemInstruction: assistant.prompt,
|
||||
generationConfig: {
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
})
|
||||
|
||||
const { response } = await geminiModel.generateContent(message.content)
|
||||
|
||||
return response.text()
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
|
||||
}
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel({
|
||||
model: model.id,
|
||||
systemInstruction: systemMessage.content,
|
||||
generationConfig: {
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
})
|
||||
|
||||
const lastUserMessage = userMessages.pop()
|
||||
|
||||
const chat = await geminiModel.startChat({
|
||||
history: userMessages.map((message) => ({
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: [{ text: message.content }]
|
||||
}))
|
||||
})
|
||||
|
||||
const { response } = await chat.sendMessage(lastUserMessage?.content!)
|
||||
|
||||
return response.text()
|
||||
}
|
||||
|
||||
public async suggestions(): Promise<Suggestion[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
public async check(): Promise<{ valid: boolean; error: Error | null }> {
|
||||
const model = this.provider.models[0]
|
||||
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
max_tokens: 100,
|
||||
stream: false
|
||||
}
|
||||
|
||||
try {
|
||||
const geminiModel = this.sdk.getGenerativeModel({ model: body.model })
|
||||
const result = await geminiModel.generateContent(body.messages[0].content)
|
||||
return {
|
||||
valid: !isEmpty(result.response.text()),
|
||||
error: null
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const api = this.provider.apiHost + '/v1beta/models'
|
||||
const { data } = await axios.get(api, { params: { key: this.provider.apiKey } })
|
||||
return data.models.map(
|
||||
(m: any) =>
|
||||
({
|
||||
id: m.name.replace('models/', ''),
|
||||
name: m.displayName,
|
||||
description: m.description,
|
||||
object: 'model',
|
||||
created: Date.now(),
|
||||
owned_by: 'gemini'
|
||||
}) as OpenAI.Models.Model
|
||||
)
|
||||
} catch (error) {
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
185
src/renderer/src/providers/OpenAIProvider.ts
Normal file
185
src/renderer/src/providers/OpenAIProvider.ts
Normal file
@ -0,0 +1,185 @@
|
||||
import { isLocalAi } from '@renderer/config/env'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES } from '@renderer/services/event'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { fileToBase64, removeQuotes } from '@renderer/utils'
|
||||
import { first, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
import {
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionCreateParamsNonStreaming,
|
||||
ChatCompletionMessageParam
|
||||
} from 'openai/resources'
|
||||
|
||||
import BaseProvider from './BaseProvider'
|
||||
|
||||
export default class OpenAIProvider extends BaseProvider {
|
||||
private sdk: OpenAI
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.sdk = new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: provider.apiKey,
|
||||
baseURL: this.getBaseURL()
|
||||
})
|
||||
}
|
||||
|
||||
private async getMessageContent(message: Message): Promise<string | ChatCompletionContentPart[]> {
|
||||
const file = first(message.files)
|
||||
|
||||
if (!file) {
|
||||
return message.content
|
||||
}
|
||||
|
||||
if (file.type.includes('image')) {
|
||||
return [
|
||||
{ type: 'text', text: message.content },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: await fileToBase64(file)
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return message.content
|
||||
}
|
||||
|
||||
async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
|
||||
|
||||
const userMessages: ChatCompletionMessageParam[] = []
|
||||
|
||||
for (const message of takeRight(messages, contextCount + 1)) {
|
||||
userMessages.push({
|
||||
role: message.role,
|
||||
content: await this.getMessageContent(message)
|
||||
} as ChatCompletionMessageParam)
|
||||
}
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
const stream = await this.sdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
|
||||
stream: true,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
max_tokens: maxTokens,
|
||||
keep_alive: this.keepAliveTime
|
||||
})
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
|
||||
}
|
||||
}
|
||||
|
||||
async translate(message: Message, assistant: Assistant) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const messages = [
|
||||
{ role: 'system', content: assistant.prompt },
|
||||
{ role: 'user', content: message.content }
|
||||
]
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
const response = await this.sdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: messages as ChatCompletionMessageParam[],
|
||||
stream: false,
|
||||
keep_alive: this.keepAliveTime
|
||||
})
|
||||
|
||||
return response.choices[0].message?.content || ''
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
|
||||
}
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
const response = await this.sdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...(isLocalAi ? [first(userMessages)] : userMessages)] as ChatCompletionMessageParam[],
|
||||
stream: false,
|
||||
max_tokens: 50,
|
||||
keep_alive: this.keepAliveTime
|
||||
})
|
||||
|
||||
return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '')
|
||||
}
|
||||
|
||||
async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||
const model = assistant.model
|
||||
|
||||
if (!model) {
|
||||
return []
|
||||
}
|
||||
|
||||
const response: any = await this.sdk.request({
|
||||
method: 'post',
|
||||
path: '/advice_questions',
|
||||
body: {
|
||||
messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })),
|
||||
model: model.id,
|
||||
max_tokens: 0,
|
||||
temperature: 0,
|
||||
n: 0
|
||||
}
|
||||
})
|
||||
|
||||
return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || []
|
||||
}
|
||||
|
||||
public async check(): Promise<{ valid: boolean; error: Error | null }> {
|
||||
const model = this.provider.models[0]
|
||||
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
max_tokens: 100,
|
||||
stream: false
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
|
||||
|
||||
return {
|
||||
valid: Boolean(response?.choices[0].message),
|
||||
error: null
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const response = await this.sdk.models.list()
|
||||
return response.data
|
||||
} catch (error) {
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
19
src/renderer/src/providers/ProviderFactory.ts
Normal file
19
src/renderer/src/providers/ProviderFactory.ts
Normal file
@ -0,0 +1,19 @@
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import AnthropicProvider from './AnthropicProvider'
|
||||
import BaseProvider from './BaseProvider'
|
||||
import GeminiProvider from './GeminiProvider'
|
||||
import OpenAIProvider from './OpenAIProvider'
|
||||
|
||||
export default class ProviderFactory {
|
||||
static create(provider: Provider): BaseProvider {
|
||||
switch (provider.id) {
|
||||
case 'anthropic':
|
||||
return new AnthropicProvider(provider)
|
||||
case 'gemini':
|
||||
return new GeminiProvider(provider)
|
||||
default:
|
||||
return new OpenAIProvider(provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1,358 +0,0 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
|
||||
import { GoogleGenerativeAI } from '@google/generative-ai'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { isLocalAi } from '@renderer/config/env'
|
||||
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { removeQuotes } from '@renderer/utils'
|
||||
import axios from 'axios'
|
||||
import { first, isEmpty, sum, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources'
|
||||
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from './assistant'
|
||||
import { EVENT_NAMES } from './event'
|
||||
|
||||
export default class ProviderSDK {
|
||||
provider: Provider
|
||||
openaiSdk: OpenAI
|
||||
anthropicSdk: Anthropic
|
||||
geminiSdk: GoogleGenerativeAI
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
const host = provider.apiHost
|
||||
const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/`
|
||||
this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL })
|
||||
this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL })
|
||||
this.geminiSdk = new GoogleGenerativeAI(provider.apiKey)
|
||||
}
|
||||
|
||||
private get isAnthropic() {
|
||||
return this.provider.id === 'anthropic'
|
||||
}
|
||||
|
||||
private get isGemini() {
|
||||
return this.provider.id === 'gemini'
|
||||
}
|
||||
|
||||
private get keepAliveTime() {
|
||||
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
|
||||
const userMessages = takeRight(messages, contextCount + 1).map((message) => {
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}
|
||||
})
|
||||
|
||||
if (this.isAnthropic) {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
const stream = this.anthropicSdk.messages
|
||||
.stream({
|
||||
model: model.id,
|
||||
messages: userMessages.filter(Boolean) as MessageParam[],
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
system: assistant.prompt,
|
||||
stream: true
|
||||
})
|
||||
.on('text', (text) => {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||
resolve()
|
||||
return stream.controller.abort()
|
||||
}
|
||||
onChunk({ text })
|
||||
})
|
||||
.on('finalMessage', (message) => {
|
||||
onChunk({
|
||||
text: '',
|
||||
usage: {
|
||||
prompt_tokens: message.usage.input_tokens,
|
||||
completion_tokens: message.usage.output_tokens,
|
||||
total_tokens: sum(Object.values(message.usage))
|
||||
}
|
||||
})
|
||||
resolve()
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
})
|
||||
}
|
||||
|
||||
if (this.isGemini) {
|
||||
const geminiModel = this.geminiSdk.getGenerativeModel({
|
||||
model: model.id,
|
||||
systemInstruction: assistant.prompt,
|
||||
generationConfig: {
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
})
|
||||
|
||||
const userLastMessage = userMessages.pop()
|
||||
|
||||
const chat = geminiModel.startChat({
|
||||
history: userMessages.map((message) => ({
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: [{ text: message.content }]
|
||||
}))
|
||||
})
|
||||
|
||||
const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!)
|
||||
|
||||
for await (const chunk of userMessagesStream.stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
onChunk({
|
||||
text: chunk.text(),
|
||||
usage: {
|
||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
const _userMessages = takeRight(messages, contextCount + 1).map((message) => {
|
||||
return {
|
||||
role: message.role,
|
||||
content: message.images
|
||||
? [
|
||||
{ type: 'text', text: message.content },
|
||||
...message.images!.map((image) => ({ type: 'image_url', image_url: image }))
|
||||
]
|
||||
: message.content
|
||||
}
|
||||
})
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
const stream = await this.openaiSdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ..._userMessages].filter(Boolean) as ChatCompletionMessageParam[],
|
||||
stream: true,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
max_tokens: maxTokens,
|
||||
keep_alive: this.keepAliveTime
|
||||
})
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
|
||||
}
|
||||
}
|
||||
|
||||
public async translate(message: Message, assistant: Assistant) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const model = assistant.model || defaultModel
|
||||
const messages = [
|
||||
{ role: 'system', content: assistant.prompt },
|
||||
{ role: 'user', content: message.content }
|
||||
]
|
||||
|
||||
if (this.isAnthropic) {
|
||||
const response = await this.anthropicSdk.messages.create({
|
||||
model: model.id,
|
||||
messages: messages.filter((m) => m.role === 'user') as MessageParam[],
|
||||
max_tokens: 4096,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
system: assistant.prompt,
|
||||
stream: false
|
||||
})
|
||||
|
||||
return response.content[0].type === 'text' ? response.content[0].text : ''
|
||||
}
|
||||
|
||||
if (this.isGemini) {
|
||||
const geminiModel = this.geminiSdk.getGenerativeModel({
|
||||
model: model.id,
|
||||
systemInstruction: assistant.prompt,
|
||||
generationConfig: {
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
})
|
||||
|
||||
const { response } = await geminiModel.generateContent(message.content)
|
||||
|
||||
return response.text()
|
||||
}
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
const response = await this.openaiSdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: messages as ChatCompletionMessageParam[],
|
||||
stream: false,
|
||||
keep_alive: this.keepAliveTime
|
||||
})
|
||||
|
||||
return response.choices[0].message?.content || ''
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}))
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: '你是一名擅长会话的助理,你需要将用户的会话总结为 10 个字以内的标题,不要使用标点符号和其他特殊符号。'
|
||||
}
|
||||
|
||||
if (this.isAnthropic) {
|
||||
const message = await this.anthropicSdk.messages.create({
|
||||
messages: userMessages as Anthropic.Messages.MessageParam[],
|
||||
model: model.id,
|
||||
system: systemMessage.content,
|
||||
stream: false,
|
||||
max_tokens: 4096
|
||||
})
|
||||
|
||||
return message.content[0].type === 'text' ? message.content[0].text : null
|
||||
}
|
||||
|
||||
if (this.isGemini) {
|
||||
const geminiModel = this.geminiSdk.getGenerativeModel({
|
||||
model: model.id,
|
||||
systemInstruction: systemMessage.content,
|
||||
generationConfig: {
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
})
|
||||
|
||||
const lastUserMessage = userMessages.pop()
|
||||
|
||||
const chat = await geminiModel.startChat({
|
||||
history: userMessages.map((message) => ({
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: [{ text: message.content }]
|
||||
}))
|
||||
})
|
||||
|
||||
const { response } = await chat.sendMessage(lastUserMessage?.content!)
|
||||
|
||||
return response.text()
|
||||
}
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
const response = await this.openaiSdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...(isLocalAi ? [first(userMessages)] : userMessages)] as ChatCompletionMessageParam[],
|
||||
stream: false,
|
||||
max_tokens: 50,
|
||||
keep_alive: this.keepAliveTime
|
||||
})
|
||||
|
||||
return removeQuotes(response.choices[0].message?.content?.substring(0, 50) || '')
|
||||
}
|
||||
|
||||
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||
const model = assistant.model
|
||||
|
||||
if (!model) {
|
||||
return []
|
||||
}
|
||||
|
||||
const response: any = await this.openaiSdk.request({
|
||||
method: 'post',
|
||||
path: '/advice_questions',
|
||||
body: {
|
||||
messages: messages.filter((m) => m.role === 'user').map((m) => ({ role: m.role, content: m.content })),
|
||||
model: model.id,
|
||||
max_tokens: 0,
|
||||
temperature: 0,
|
||||
n: 0
|
||||
}
|
||||
})
|
||||
|
||||
return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || []
|
||||
}
|
||||
|
||||
public async check(): Promise<{ valid: boolean; error: Error | null }> {
|
||||
const model = this.provider.models[0]
|
||||
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
max_tokens: 100,
|
||||
stream: false
|
||||
}
|
||||
|
||||
try {
|
||||
if (this.isAnthropic) {
|
||||
const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming)
|
||||
return {
|
||||
valid: message.content.length > 0,
|
||||
error: null
|
||||
}
|
||||
}
|
||||
|
||||
if (this.isGemini) {
|
||||
const geminiModel = this.geminiSdk.getGenerativeModel({ model: body.model })
|
||||
const result = await geminiModel.generateContent(body.messages[0].content)
|
||||
return {
|
||||
valid: !isEmpty(result.response.text()),
|
||||
error: null
|
||||
}
|
||||
}
|
||||
|
||||
const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
|
||||
|
||||
return {
|
||||
valid: Boolean(response?.choices[0].message),
|
||||
error: null
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
if (this.isAnthropic) {
|
||||
return []
|
||||
}
|
||||
|
||||
if (this.isGemini) {
|
||||
const api = this.provider.apiHost + '/v1beta/models'
|
||||
const { data } = await axios.get(api, { params: { key: this.provider.apiKey } })
|
||||
return data.models.map(
|
||||
(m: any) =>
|
||||
({
|
||||
id: m.name.replace('models/', ''),
|
||||
name: m.displayName,
|
||||
description: m.description,
|
||||
object: 'model',
|
||||
created: Date.now(),
|
||||
owned_by: 'gemini'
|
||||
}) as OpenAI.Models.Model
|
||||
)
|
||||
}
|
||||
|
||||
const response = await this.openaiSdk.models.list()
|
||||
return response.data
|
||||
} catch (error) {
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -6,6 +6,7 @@ import { uuid } from '@renderer/utils'
|
||||
import dayjs from 'dayjs'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import AiProvider from '../providers/AiProvider'
|
||||
import {
|
||||
getAssistantProvider,
|
||||
getDefaultModel,
|
||||
@ -15,7 +16,6 @@ import {
|
||||
} from './assistant'
|
||||
import { EVENT_NAMES, EventEmitter } from './event'
|
||||
import { filterMessages } from './messages'
|
||||
import ProviderSDK from './ProviderSDK'
|
||||
|
||||
export async function fetchChatCompletion({
|
||||
messages,
|
||||
@ -33,7 +33,7 @@ export async function fetchChatCompletion({
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
store.dispatch(setGenerating(true))
|
||||
|
||||
@ -61,7 +61,7 @@ export async function fetchChatCompletion({
|
||||
}, 1000)
|
||||
|
||||
try {
|
||||
await providerSdk.completions(filterMessages(messages), assistant, ({ text, usage }) => {
|
||||
await AI.completions(filterMessages(messages), assistant, ({ text, usage }) => {
|
||||
message.content = message.content + text || ''
|
||||
message.usage = usage
|
||||
onResponse({ ...message, status: 'pending' })
|
||||
@ -103,10 +103,10 @@ export async function fetchTranslate({ message, assistant }: { message: Message;
|
||||
return ''
|
||||
}
|
||||
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
return await providerSdk.translate(message, assistant)
|
||||
return await AI.translate(message, assistant)
|
||||
} catch (error: any) {
|
||||
return ''
|
||||
}
|
||||
@ -120,10 +120,10 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
|
||||
return null
|
||||
}
|
||||
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
return await providerSdk.summaries(filterMessages(messages), assistant)
|
||||
return await AI.summaries(filterMessages(messages), assistant)
|
||||
} catch (error: any) {
|
||||
return null
|
||||
}
|
||||
@ -136,10 +136,8 @@ export async function fetchSuggestions({
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
}): Promise<Suggestion[]> {
|
||||
console.debug('fetchSuggestions', messages, assistant)
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
console.debug('fetchSuggestions', provider)
|
||||
const AI = new AiProvider(provider)
|
||||
const model = assistant.model
|
||||
|
||||
if (!model) {
|
||||
@ -155,7 +153,7 @@ export async function fetchSuggestions({
|
||||
}
|
||||
|
||||
try {
|
||||
return await providerSdk.suggestions(messages, assistant)
|
||||
return await AI.suggestions(messages, assistant)
|
||||
} catch (error: any) {
|
||||
return []
|
||||
}
|
||||
@ -183,9 +181,9 @@ export async function checkApi(provider: Provider) {
|
||||
return false
|
||||
}
|
||||
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
const { valid } = await providerSdk.check()
|
||||
const { valid } = await AI.check()
|
||||
|
||||
window.message[valid ? 'success' : 'error']({
|
||||
key: 'api-check',
|
||||
@ -204,10 +202,10 @@ function hasApiKey(provider: Provider) {
|
||||
}
|
||||
|
||||
export async function fetchModels(provider: Provider) {
|
||||
const providerSdk = new ProviderSDK(provider)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
return await providerSdk.models()
|
||||
return await AI.models()
|
||||
} catch (error) {
|
||||
return []
|
||||
}
|
||||
|
||||
@ -20,14 +20,15 @@ export type AssistantSettings = {
|
||||
|
||||
export type Message = {
|
||||
id: string
|
||||
assistantId: string
|
||||
role: 'user' | 'assistant'
|
||||
content: string
|
||||
images?: string[]
|
||||
assistantId: string
|
||||
topicId: string
|
||||
modelId?: string
|
||||
createdAt: string
|
||||
status: 'sending' | 'pending' | 'success' | 'paused' | 'error'
|
||||
modelId?: string
|
||||
files?: File[]
|
||||
images?: string[]
|
||||
usage?: OpenAI.Completions.CompletionUsage
|
||||
type?: 'text' | '@'
|
||||
}
|
||||
|
||||
@ -223,3 +223,18 @@ export function getBriefInfo(text: string, maxLength: number = 50): string {
|
||||
// 截取前面的内容,并在末尾添加 "..."
|
||||
return truncatedText + '...'
|
||||
}
|
||||
|
||||
export async function fileToBase64(file: File): Promise<string> {
|
||||
return new Promise((resolve, reject) => {
|
||||
try {
|
||||
const reader = new FileReader()
|
||||
reader.onload = (e: ProgressEvent<FileReader>) => {
|
||||
const result = e.target?.result
|
||||
resolve(typeof result === 'string' ? result : '')
|
||||
}
|
||||
reader.readAsDataURL(file)
|
||||
} catch (error: any) {
|
||||
reject(error)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user