feat: select model for assistant

This commit is contained in:
kangfenmao 2024-07-04 18:04:21 +08:00
parent da3e10cf04
commit 4296f49e66
12 changed files with 231 additions and 79 deletions

View File

@ -1,222 +1,294 @@
import { Model } from '@renderer/types'
export const SYSTEM_MODELS: Record<string, Model[]> = {
type SystemModel = Model & { defaultEnabled: boolean }
export const SYSTEM_MODELS: Record<string, SystemModel[]> = {
openai: [
{
id: 'gpt-3.5-turbo',
provider: 'openai',
name: 'gpt-3.5-turbo',
group: 'GPT 3.5',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'gpt-3.5-turbo-0301',
provider: 'openai',
name: 'gpt-3.5-turbo',
group: 'GPT 3.5',
temperature: 0.3
temperature: 0.3,
defaultEnabled: false
},
{
id: 'gpt-4',
provider: 'openai',
name: 'gpt-4',
group: 'GPT 4',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'gpt-4-0314',
provider: 'openai',
name: 'gpt-4',
group: 'GPT 4',
temperature: 0.3
temperature: 0.3,
defaultEnabled: false
},
{
id: 'gpt-4-32k',
provider: 'openai',
name: 'gpt-4-32k',
group: 'GPT 4',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'gpt-4-32k-0314',
provider: 'openai',
name: 'gpt-4-32k',
group: 'GPT 4',
temperature: 0.3
temperature: 0.3,
defaultEnabled: false
}
],
silicon: [
{
id: 'deepseek-ai/DeepSeek-V2-Chat',
provider: 'silicon',
name: 'DeepSeek-V2-Chat',
group: 'DeepSeek',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'deepseek-ai/DeepSeek-Coder-V2-Instruct',
provider: 'silicon',
name: 'DeepSeek-Coder-V2-Instruct',
group: 'DeepSeek',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'deepseek-ai/deepseek-llm-67b-chat',
provider: 'silicon',
name: 'deepseek-llm-67b-chat',
group: 'DeepSeek',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'google/gemma-2-27b-it',
provider: 'silicon',
name: 'gemma-2-27b-it',
group: 'Gemma',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'google/gemma-2-9b-it',
provider: 'silicon',
name: 'gemma-2-9b-it',
group: 'Gemma',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'Qwen/Qwen2-7B-Instruct',
provider: 'silicon',
name: 'Qwen2-7B-Instruct',
group: 'Qwen2',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'Qwen/Qwen2-1.5B-Instruct',
provider: 'silicon',
name: 'Qwen2-1.5B-Instruct',
group: 'Qwen2',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'Qwen/Qwen1.5-7B-Chat',
provider: 'silicon',
name: 'Qwen1.5-7B-Chat',
group: 'Qwen1.5',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'Qwen/Qwen2-72B-Instruct',
provider: 'silicon',
name: 'Qwen2-72B-Instruct',
group: 'Qwen2',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'Qwen/Qwen2-57B-A14B-Instruct',
provider: 'silicon',
name: 'Qwen2-57B-A14B-Instruct',
group: 'Qwen2',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'Qwen/Qwen1.5-110B-Chat',
provider: 'silicon',
name: 'Qwen1.5-110B-Chat',
group: 'Qwen1.5',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'Qwen/Qwen1.5-32B-Chat',
provider: 'silicon',
name: 'Qwen1.5-32B-Chat',
group: 'Qwen1.5',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'Qwen/Qwen1.5-14B-Chat',
provider: 'silicon',
name: 'Qwen1.5-14B-Chat',
group: 'Qwen1.5',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'THUDM/glm-4-9b-chat',
provider: 'silicon',
name: 'glm-4-9b-chat',
group: 'GLM',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'THUDM/chatglm3-6b',
provider: 'silicon',
name: 'chatglm3-6b',
group: 'GLM',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: '01-ai/Yi-1.5-9B-Chat-16K',
provider: 'silicon',
name: 'Yi-1.5-9B-Chat-16K',
group: 'Yi',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: '01-ai/Yi-1.5-6B-Chat',
provider: 'silicon',
name: 'Yi-1.5-6B-Chat',
group: 'Yi',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: '01-ai/Yi-1.5-34B-Chat-16K',
provider: 'silicon',
name: 'Yi-1.5-34B-Chat-16K',
group: 'Yi',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'OpenAI/GPT-4o',
provider: 'silicon',
name: 'GPT-4o',
group: 'OpenAI',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'OpenAI/GPT-3.5 Turbo',
provider: 'silicon',
name: 'GPT-3.5 Turbo',
group: 'OpenAI',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'Anthropic/claude-3-5-sonnet',
provider: 'silicon',
name: 'claude-3-5-sonnet',
group: 'Claude',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'meta-llama/Meta-Llama-3-8B-Instruct',
provider: 'silicon',
name: 'Meta-Llama-3-8B-Instruct',
group: 'Meta Llama',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'meta-llama/Meta-Llama-3-70B-Instruct',
provider: 'silicon',
name: 'Meta-Llama-3-70B-Instruct',
group: 'Meta Llama',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
}
],
deepseek: [
{
id: 'deepseek-chat',
provider: 'deepseek',
name: 'deepseek-chat',
group: 'Deepseek Chat',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'deepseek-coder',
provider: 'deepseek',
name: 'deepseek-coder',
group: 'Deepseek Coder',
temperature: 1.0
temperature: 1.0,
defaultEnabled: true
}
],
groq: [
{
id: 'llama3-8b-8192',
provider: 'groq',
name: 'LLaMA3 8b',
group: 'Llama3',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'llama3-70b-8192',
provider: 'groq',
name: 'LLaMA3 70b',
group: 'Llama3',
temperature: 0.7
temperature: 0.7,
defaultEnabled: true
},
{
id: 'mixtral-8x7b-32768',
provider: 'groq',
name: 'Mixtral 8x7b',
group: 'Mixtral',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
},
{
id: 'gemma-7b-it',
provider: 'groq',
name: 'Gemma 7b',
group: 'Gemma',
temperature: 0.7
temperature: 0.7,
defaultEnabled: false
}
]
}

View File

@ -3,12 +3,13 @@ import {
addTopic as _addTopic,
removeAllTopics as _removeAllTopics,
removeTopic as _removeTopic,
setModel as _setModel,
updateTopic as _updateTopic,
addAssistant,
removeAssistant,
updateAssistant
} from '@renderer/store/assistants'
import { Assistant, Topic } from '@renderer/types'
import { Assistant, Model, Topic } from '@renderer/types'
import localforage from 'localforage'
export function useAssistants() {
@ -29,12 +30,19 @@ export function useAssistants() {
}
}
export function useDefaultModel() {
const defaultModel = useAppSelector((state) => state.llm.defaultModel)
return { defaultModel }
}
export function useAssistant(id: string) {
const assistant = useAppSelector((state) => state.assistants.assistants.find((a) => a.id === id) as Assistant)
const dispatch = useAppDispatch()
const { defaultModel } = useDefaultModel()
return {
assistant,
model: assistant?.model ?? defaultModel,
addTopic: (topic: Topic) => {
dispatch(_addTopic({ assistantId: assistant.id, topic }))
},
@ -46,6 +54,9 @@ export function useAssistant(id: string) {
},
removeAllTopics: () => {
dispatch(_removeAllTopics({ assistantId: assistant.id }))
},
setModel: (model: Model) => {
dispatch(_setModel({ assistantId: assistant.id, model }))
}
}
}

View File

@ -4,7 +4,8 @@ import {
removeModel as _removeModel,
updateProvider as _updateProvider
} from '@renderer/store/llm'
import { Model, Provider } from '@renderer/types'
import { Assistant, Model, Provider } from '@renderer/types'
import { useDefaultModel } from './useAssistant'
export function useProviders() {
return useAppSelector((state) => state.llm.providers)
@ -23,6 +24,13 @@ export function useProvider(id: string) {
}
}
export function useProviderByAssistant(assistant: Assistant) {
const { defaultModel } = useDefaultModel()
const model = assistant.model || defaultModel
const { provider } = useProvider(model.provider)
return provider
}
export function useDefaultProvider() {
return useAppSelector((state) => state.llm.providers.find((p) => p.isDefault))
}

View File

@ -1,4 +1,4 @@
import { Navbar, NavbarCenter, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar'
import { Navbar, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar'
import { useAssistants } from '@renderer/hooks/useAssistant'
import { FC, useState } from 'react'
import styled from 'styled-components'
@ -8,6 +8,7 @@ import { uuid } from '@renderer/utils'
import { getDefaultAssistant } from '@renderer/services/assistant'
import { useShowRightSidebar } from '@renderer/hooks/useStore'
import { Tooltip } from 'antd'
import NavigationCenter from './components/Chat/NavigationCenter'
const HomePage: FC = () => {
const { assistants, addAssistant } = useAssistants()
@ -29,7 +30,7 @@ const HomePage: FC = () => {
<i className="iconfont icon-a-addchat"></i>
</NewButton>
</NavbarLeft>
<NavbarCenter style={{ border: 'none' }}>{activeAssistant?.name}</NavbarCenter>
<NavigationCenter activeAssistant={activeAssistant} />
<NavbarRight style={{ justifyContent: 'flex-end', padding: 5 }}>
<Tooltip placement="left" title={showRightSidebar ? 'Hide Topics' : 'Show Topics'} arrow>
<NewButton onClick={setShowRightSidebar}>

View File

@ -11,6 +11,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
import { DEFAULT_TOPIC_NAME } from '@renderer/config/constant'
import { runAsyncFunction } from '@renderer/utils'
import LocalStorage from '@renderer/services/storage'
import { useProviderByAssistant } from '@renderer/hooks/useProvider'
interface Props {
assistant: Assistant
@ -21,6 +22,7 @@ const Conversations: FC<Props> = ({ assistant, topic }) => {
const [messages, setMessages] = useState<Message[]>([])
const [lastMessage, setLastMessage] = useState<Message | null>(null)
const { updateTopic } = useAssistant(assistant.id)
const provider = useProviderByAssistant(assistant)
const onSendMessage = useCallback(
(message: Message) => {
@ -46,8 +48,10 @@ const Conversations: FC<Props> = ({ assistant, topic }) => {
useEffect(() => {
const unsubscribes = [
EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, async (msg: Message) => {
console.debug({ assistant, provider, message: msg, topic })
return
onSendMessage(msg)
fetchChatCompletion({ assistant, message: msg, topic, onResponse: setLastMessage })
fetchChatCompletion({ assistant, provider, message: msg, topic, onResponse: setLastMessage })
}),
EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => {
setLastMessage(null)

View File

@ -0,0 +1,45 @@
import { NavbarCenter } from '@renderer/components/app/Navbar'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useProviders } from '@renderer/hooks/useProvider'
import { Assistant } from '@renderer/types'
import { Button, Dropdown, MenuProps } from 'antd'
import { FC } from 'react'
import styled from 'styled-components'
interface Props {
activeAssistant: Assistant
}
const NavigationCenter: FC<Props> = ({ activeAssistant }) => {
const providers = useProviders()
const { model, setModel } = useAssistant(activeAssistant.id)
const items: MenuProps['items'] = providers.map((p) => ({
key: p.id,
label: p.name,
type: 'group',
children: p.models.map((m) => ({
key: m.id,
label: m.name,
onClick: () => setModel(m)
}))
}))
return (
<NavbarCenter style={{ border: 'none' }}>
{activeAssistant?.name}
<DropdownMenu menu={{ items }} trigger={['click']}>
<Button size="small" type="primary" ghost style={{ fontSize: '11px' }}>
{model ? model.name : 'Select Model'}
</Button>
</DropdownMenu>
</NavbarCenter>
)
}
const DropdownMenu = styled(Dropdown)`
-webkit-app-region: none;
margin-left: 10px;
`
export default NavigationCenter

View File

@ -13,7 +13,6 @@ interface Props {
const ModalProviderSetting: FC<Props> = ({ provider }) => {
const [apiKey, setApiKey] = useState(provider.apiKey)
const [apiHost, setApiHost] = useState(provider.apiHost)
const [apiPath, setApiPath] = useState(provider.apiPath)
const { updateProvider, models } = useProvider(provider.id)
const modelGroups = groupBy(models, 'group')
@ -21,7 +20,6 @@ const ModalProviderSetting: FC<Props> = ({ provider }) => {
useEffect(() => {
setApiKey(provider.apiKey)
setApiHost(provider.apiHost)
setApiPath(provider.apiPath)
}, [provider])
const onUpdateApiKey = () => {
@ -32,10 +30,6 @@ const ModalProviderSetting: FC<Props> = ({ provider }) => {
updateProvider({ ...provider, apiHost })
}
const onUpdateApiPath = () => {
updateProvider({ ...provider, apiHost })
}
const onAddModal = () => {
ModalListPopup.show({ provider })
}
@ -53,13 +47,6 @@ const ModalProviderSetting: FC<Props> = ({ provider }) => {
onChange={(e) => setApiHost(e.target.value)}
onBlur={onUpdateApiHost}
/>
<SubTitle>API Path</SubTitle>
<Input
value={apiPath}
placeholder="API Path"
onChange={(e) => setApiPath(e.target.value)}
onBlur={onUpdateApiPath}
/>
<SubTitle>Models</SubTitle>
{Object.keys(modelGroups).map((group) => (
<Card key={group} type="inner" title={group} style={{ marginBottom: '10px' }} size="small">

View File

@ -1,19 +1,32 @@
import { Assistant, Message, Topic } from '@renderer/types'
import { openaiProvider } from './provider'
import { Assistant, Message, Provider, Topic } from '@renderer/types'
import { uuid } from '@renderer/utils'
import { EVENT_NAMES, EventEmitter } from './event'
import { ChatCompletionMessageParam, ChatCompletionSystemMessageParam } from 'openai/resources'
import OpenAI from 'openai'
interface FetchChatCompletionParams {
message: Message
assistant: Assistant
topic: Topic
assistant: Assistant
provider: Provider
onResponse: (message: Message) => void
}
export async function fetchChatCompletion({ message, assistant, topic, onResponse }: FetchChatCompletionParams) {
export async function fetchChatCompletion({
message,
topic,
assistant,
provider,
onResponse
}: FetchChatCompletionParams) {
const openaiProvider = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: provider.apiKey,
baseURL: `${provider.apiHost}/v1/`
})
const stream = await openaiProvider.chat.completions.create({
model: 'Qwen/Qwen2-7B-Instruct',
model: assistant.model?.name || '',
messages: [
{ role: 'system', content: assistant.prompt },
{ role: 'user', content: message.content }

View File

@ -1,7 +0,0 @@
import OpenAI from 'openai'
export const openaiProvider = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: 'sk-cmxcwkapuoxpddlytqpuxxszyqymqgrcxremulcdlgcgabtq',
baseURL: 'https://api.siliconflow.cn/v1'
})

View File

@ -2,7 +2,7 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit'
import { getDefaultAssistant } from '@renderer/services/assistant'
import LocalStorage from '@renderer/services/storage'
import { getDefaultTopic } from '@renderer/services/topic'
import { Assistant, Topic } from '@renderer/types'
import { Assistant, Model, Topic } from '@renderer/types'
import { uniqBy } from 'lodash'
export interface AssistantsState {
@ -69,11 +69,29 @@ const assistantsSlice = createSlice({
}
return assistant
})
},
setModel: (state, action: PayloadAction<{ assistantId: string; model: Model }>) => {
state.assistants = state.assistants.map((assistant) =>
assistant.id === action.payload.assistantId
? {
...assistant,
model: action.payload.model
}
: assistant
)
}
}
})
export const { addAssistant, removeAssistant, updateAssistant, addTopic, removeTopic, updateTopic, removeAllTopics } =
assistantsSlice.actions
export const {
addAssistant,
removeAssistant,
updateAssistant,
addTopic,
removeTopic,
updateTopic,
removeAllTopics,
setModel
} = assistantsSlice.actions
export default assistantsSlice.reducer

View File

@ -1,9 +1,11 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'
import { SYSTEM_MODELS } from '@renderer/config/models'
import { Model, Provider } from '@renderer/types'
import { uniqBy } from 'lodash'
export interface LlmState {
providers: Provider[]
defaultModel: Model
}
const initialState: LlmState = {
@ -13,38 +15,35 @@ const initialState: LlmState = {
name: 'OpenAI',
apiKey: '',
apiHost: 'https://api.openai.com',
apiPath: '/v1/chat/completions',
isSystem: true,
models: []
models: SYSTEM_MODELS.openai.filter((m) => m.defaultEnabled)
},
{
id: 'silicon',
name: 'Silicon',
apiKey: '',
apiHost: 'https://api.siliconflow.cn',
apiPath: '/v1/chat/completions',
isSystem: true,
models: []
models: SYSTEM_MODELS.silicon.filter((m) => m.defaultEnabled)
},
{
id: 'deepseek',
name: 'deepseek',
apiKey: '',
apiHost: 'https://api.deepseek.com',
apiPath: '/v1/chat/completions',
isSystem: true,
models: []
models: SYSTEM_MODELS.deepseek.filter((m) => m.defaultEnabled)
},
{
id: 'groq',
name: 'Groq',
apiKey: '',
apiHost: 'https://api.groq.com',
apiPath: '/v1/chat/completions',
isSystem: true,
models: []
models: SYSTEM_MODELS.groq.filter((m) => m.defaultEnabled)
}
]
],
defaultModel: SYSTEM_MODELS.openai[0]
}
const settingsSlice = createSlice({

View File

@ -4,6 +4,7 @@ export type Assistant = {
description: string
prompt: string
topics: Topic[]
model?: Model
}
export type Message = {
@ -33,7 +34,6 @@ export type Provider = {
name: string
apiKey: string
apiHost: string
apiPath: string
models: Model[]
isSystem?: boolean
isDefault?: boolean
@ -41,6 +41,7 @@ export type Provider = {
export type Model = {
id: string
provider: string
name: string
group: string
temperature: number