From 4296f49e667d72de035102c53121f13233e5bbbe Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Thu, 4 Jul 2024 18:04:21 +0800 Subject: [PATCH] feat: select model for assistant --- src/renderer/src/config/models.ts | 144 +++++++++++++----- src/renderer/src/hooks/useAssistant.ts | 13 +- src/renderer/src/hooks/useProvider.ts | 10 +- src/renderer/src/pages/home/HomePage.tsx | 5 +- .../home/components/Chat/Conversations.tsx | 6 +- .../home/components/Chat/NavigationCenter.tsx | 45 ++++++ .../components/ModalProviderSetting.tsx | 13 -- src/renderer/src/services/api.ts | 23 ++- src/renderer/src/services/provider.ts | 7 - src/renderer/src/store/assistants.ts | 24 ++- src/renderer/src/store/llm.ts | 17 +-- src/renderer/src/types/index.ts | 3 +- 12 files changed, 231 insertions(+), 79 deletions(-) create mode 100644 src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx delete mode 100644 src/renderer/src/services/provider.ts diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index b3c06e20..5a3d8715 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -1,222 +1,294 @@ import { Model } from '@renderer/types' -export const SYSTEM_MODELS: Record = { +type SystemModel = Model & { defaultEnabled: boolean } + +export const SYSTEM_MODELS: Record = { openai: [ { id: 'gpt-3.5-turbo', + provider: 'openai', name: 'gpt-3.5-turbo', group: 'GPT 3.5', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'gpt-3.5-turbo-0301', + provider: 'openai', name: 'gpt-3.5-turbo', group: 'GPT 3.5', - temperature: 0.3 + temperature: 0.3, + defaultEnabled: false }, { id: 'gpt-4', + provider: 'openai', name: 'gpt-4', group: 'GPT 4', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'gpt-4-0314', + provider: 'openai', name: 'gpt-4', group: 'GPT 4', - temperature: 0.3 + temperature: 0.3, + defaultEnabled: false }, { id: 'gpt-4-32k', + provider: 'openai', name: 'gpt-4-32k', group: 'GPT 4', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'gpt-4-32k-0314', + provider: 'openai', name: 'gpt-4-32k', group: 'GPT 4', - temperature: 0.3 + temperature: 0.3, + defaultEnabled: false } ], silicon: [ { id: 'deepseek-ai/DeepSeek-V2-Chat', + provider: 'silicon', name: 'DeepSeek-V2-Chat', group: 'DeepSeek', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'deepseek-ai/DeepSeek-Coder-V2-Instruct', + provider: 'silicon', name: 'DeepSeek-Coder-V2-Instruct', group: 'DeepSeek', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'deepseek-ai/deepseek-llm-67b-chat', + provider: 'silicon', name: 'deepseek-llm-67b-chat', group: 'DeepSeek', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'google/gemma-2-27b-it', + provider: 'silicon', name: 'gemma-2-27b-it', group: 'Gemma', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'google/gemma-2-9b-it', + provider: 'silicon', name: 'gemma-2-9b-it', group: 'Gemma', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'Qwen/Qwen2-7B-Instruct', + provider: 'silicon', name: 'Qwen2-7B-Instruct', group: 'Qwen2', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'Qwen/Qwen2-1.5B-Instruct', + provider: 'silicon', name: 'Qwen2-1.5B-Instruct', group: 'Qwen2', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'Qwen/Qwen1.5-7B-Chat', + provider: 'silicon', name: 'Qwen1.5-7B-Chat', group: 'Qwen1.5', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'Qwen/Qwen2-72B-Instruct', + provider: 'silicon', name: 'Qwen2-72B-Instruct', group: 'Qwen2', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'Qwen/Qwen2-57B-A14B-Instruct', + provider: 'silicon', name: 'Qwen2-57B-A14B-Instruct', group: 'Qwen2', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'Qwen/Qwen1.5-110B-Chat', + provider: 'silicon', name: 'Qwen1.5-110B-Chat', group: 'Qwen1.5', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'Qwen/Qwen1.5-32B-Chat', + provider: 'silicon', name: 'Qwen1.5-32B-Chat', group: 'Qwen1.5', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'Qwen/Qwen1.5-14B-Chat', + provider: 'silicon', name: 'Qwen1.5-14B-Chat', group: 'Qwen1.5', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'THUDM/glm-4-9b-chat', + provider: 'silicon', name: 'glm-4-9b-chat', group: 'GLM', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'THUDM/chatglm3-6b', + provider: 'silicon', name: 'chatglm3-6b', group: 'GLM', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: '01-ai/Yi-1.5-9B-Chat-16K', + provider: 'silicon', name: 'Yi-1.5-9B-Chat-16K', group: 'Yi', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: '01-ai/Yi-1.5-6B-Chat', + provider: 'silicon', name: 'Yi-1.5-6B-Chat', group: 'Yi', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: '01-ai/Yi-1.5-34B-Chat-16K', + provider: 'silicon', name: 'Yi-1.5-34B-Chat-16K', group: 'Yi', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'OpenAI/GPT-4o', + provider: 'silicon', name: 'GPT-4o', group: 'OpenAI', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'OpenAI/GPT-3.5 Turbo', + provider: 'silicon', name: 'GPT-3.5 Turbo', group: 'OpenAI', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'Anthropic/claude-3-5-sonnet', + provider: 'silicon', name: 'claude-3-5-sonnet', group: 'Claude', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'meta-llama/Meta-Llama-3-8B-Instruct', + provider: 'silicon', name: 'Meta-Llama-3-8B-Instruct', group: 'Meta Llama', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'meta-llama/Meta-Llama-3-70B-Instruct', + provider: 'silicon', name: 'Meta-Llama-3-70B-Instruct', group: 'Meta Llama', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false } ], deepseek: [ { id: 'deepseek-chat', + provider: 'deepseek', name: 'deepseek-chat', group: 'Deepseek Chat', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'deepseek-coder', + provider: 'deepseek', name: 'deepseek-coder', group: 'Deepseek Coder', - temperature: 1.0 + temperature: 1.0, + defaultEnabled: true } ], groq: [ { id: 'llama3-8b-8192', + provider: 'groq', name: 'LLaMA3 8b', group: 'Llama3', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'llama3-70b-8192', + provider: 'groq', name: 'LLaMA3 70b', group: 'Llama3', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: true }, { id: 'mixtral-8x7b-32768', + provider: 'groq', name: 'Mixtral 8x7b', group: 'Mixtral', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false }, { id: 'gemma-7b-it', + provider: 'groq', name: 'Gemma 7b', group: 'Gemma', - temperature: 0.7 + temperature: 0.7, + defaultEnabled: false } ] } diff --git a/src/renderer/src/hooks/useAssistant.ts b/src/renderer/src/hooks/useAssistant.ts index 42bfa9cc..7546953a 100644 --- a/src/renderer/src/hooks/useAssistant.ts +++ b/src/renderer/src/hooks/useAssistant.ts @@ -3,12 +3,13 @@ import { addTopic as _addTopic, removeAllTopics as _removeAllTopics, removeTopic as _removeTopic, + setModel as _setModel, updateTopic as _updateTopic, addAssistant, removeAssistant, updateAssistant } from '@renderer/store/assistants' -import { Assistant, Topic } from '@renderer/types' +import { Assistant, Model, Topic } from '@renderer/types' import localforage from 'localforage' export function useAssistants() { @@ -29,12 +30,19 @@ export function useAssistants() { } } +export function useDefaultModel() { + const defaultModel = useAppSelector((state) => state.llm.defaultModel) + return { defaultModel } +} + export function useAssistant(id: string) { const assistant = useAppSelector((state) => state.assistants.assistants.find((a) => a.id === id) as Assistant) const dispatch = useAppDispatch() + const { defaultModel } = useDefaultModel() return { assistant, + model: assistant?.model ?? defaultModel, addTopic: (topic: Topic) => { dispatch(_addTopic({ assistantId: assistant.id, topic })) }, @@ -46,6 +54,9 @@ export function useAssistant(id: string) { }, removeAllTopics: () => { dispatch(_removeAllTopics({ assistantId: assistant.id })) + }, + setModel: (model: Model) => { + dispatch(_setModel({ assistantId: assistant.id, model })) } } } diff --git a/src/renderer/src/hooks/useProvider.ts b/src/renderer/src/hooks/useProvider.ts index 0046163e..650396fc 100644 --- a/src/renderer/src/hooks/useProvider.ts +++ b/src/renderer/src/hooks/useProvider.ts @@ -4,7 +4,8 @@ import { removeModel as _removeModel, updateProvider as _updateProvider } from '@renderer/store/llm' -import { Model, Provider } from '@renderer/types' +import { Assistant, Model, Provider } from '@renderer/types' +import { useDefaultModel } from './useAssistant' export function useProviders() { return useAppSelector((state) => state.llm.providers) @@ -23,6 +24,13 @@ export function useProvider(id: string) { } } +export function useProviderByAssistant(assistant: Assistant) { + const { defaultModel } = useDefaultModel() + const model = assistant.model || defaultModel + const { provider } = useProvider(model.provider) + return provider +} + export function useDefaultProvider() { return useAppSelector((state) => state.llm.providers.find((p) => p.isDefault)) } diff --git a/src/renderer/src/pages/home/HomePage.tsx b/src/renderer/src/pages/home/HomePage.tsx index f2c8bec5..7ccf76d9 100644 --- a/src/renderer/src/pages/home/HomePage.tsx +++ b/src/renderer/src/pages/home/HomePage.tsx @@ -1,4 +1,4 @@ -import { Navbar, NavbarCenter, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar' +import { Navbar, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar' import { useAssistants } from '@renderer/hooks/useAssistant' import { FC, useState } from 'react' import styled from 'styled-components' @@ -8,6 +8,7 @@ import { uuid } from '@renderer/utils' import { getDefaultAssistant } from '@renderer/services/assistant' import { useShowRightSidebar } from '@renderer/hooks/useStore' import { Tooltip } from 'antd' +import NavigationCenter from './components/Chat/NavigationCenter' const HomePage: FC = () => { const { assistants, addAssistant } = useAssistants() @@ -29,7 +30,7 @@ const HomePage: FC = () => { - {activeAssistant?.name} + diff --git a/src/renderer/src/pages/home/components/Chat/Conversations.tsx b/src/renderer/src/pages/home/components/Chat/Conversations.tsx index 33c5e960..347d735b 100644 --- a/src/renderer/src/pages/home/components/Chat/Conversations.tsx +++ b/src/renderer/src/pages/home/components/Chat/Conversations.tsx @@ -11,6 +11,7 @@ import { useAssistant } from '@renderer/hooks/useAssistant' import { DEFAULT_TOPIC_NAME } from '@renderer/config/constant' import { runAsyncFunction } from '@renderer/utils' import LocalStorage from '@renderer/services/storage' +import { useProviderByAssistant } from '@renderer/hooks/useProvider' interface Props { assistant: Assistant @@ -21,6 +22,7 @@ const Conversations: FC = ({ assistant, topic }) => { const [messages, setMessages] = useState([]) const [lastMessage, setLastMessage] = useState(null) const { updateTopic } = useAssistant(assistant.id) + const provider = useProviderByAssistant(assistant) const onSendMessage = useCallback( (message: Message) => { @@ -46,8 +48,10 @@ const Conversations: FC = ({ assistant, topic }) => { useEffect(() => { const unsubscribes = [ EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, async (msg: Message) => { + console.debug({ assistant, provider, message: msg, topic }) + return onSendMessage(msg) - fetchChatCompletion({ assistant, message: msg, topic, onResponse: setLastMessage }) + fetchChatCompletion({ assistant, provider, message: msg, topic, onResponse: setLastMessage }) }), EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => { setLastMessage(null) diff --git a/src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx b/src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx new file mode 100644 index 00000000..7244ebb9 --- /dev/null +++ b/src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx @@ -0,0 +1,45 @@ +import { NavbarCenter } from '@renderer/components/app/Navbar' +import { useAssistant } from '@renderer/hooks/useAssistant' +import { useProviders } from '@renderer/hooks/useProvider' +import { Assistant } from '@renderer/types' +import { Button, Dropdown, MenuProps } from 'antd' +import { FC } from 'react' +import styled from 'styled-components' + +interface Props { + activeAssistant: Assistant +} + +const NavigationCenter: FC = ({ activeAssistant }) => { + const providers = useProviders() + const { model, setModel } = useAssistant(activeAssistant.id) + + const items: MenuProps['items'] = providers.map((p) => ({ + key: p.id, + label: p.name, + type: 'group', + children: p.models.map((m) => ({ + key: m.id, + label: m.name, + onClick: () => setModel(m) + })) + })) + + return ( + + {activeAssistant?.name} + + + + + ) +} + +const DropdownMenu = styled(Dropdown)` + -webkit-app-region: none; + margin-left: 10px; +` + +export default NavigationCenter diff --git a/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx b/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx index 7d8d2d68..f2602a10 100644 --- a/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx +++ b/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx @@ -13,7 +13,6 @@ interface Props { const ModalProviderSetting: FC = ({ provider }) => { const [apiKey, setApiKey] = useState(provider.apiKey) const [apiHost, setApiHost] = useState(provider.apiHost) - const [apiPath, setApiPath] = useState(provider.apiPath) const { updateProvider, models } = useProvider(provider.id) const modelGroups = groupBy(models, 'group') @@ -21,7 +20,6 @@ const ModalProviderSetting: FC = ({ provider }) => { useEffect(() => { setApiKey(provider.apiKey) setApiHost(provider.apiHost) - setApiPath(provider.apiPath) }, [provider]) const onUpdateApiKey = () => { @@ -32,10 +30,6 @@ const ModalProviderSetting: FC = ({ provider }) => { updateProvider({ ...provider, apiHost }) } - const onUpdateApiPath = () => { - updateProvider({ ...provider, apiHost }) - } - const onAddModal = () => { ModalListPopup.show({ provider }) } @@ -53,13 +47,6 @@ const ModalProviderSetting: FC = ({ provider }) => { onChange={(e) => setApiHost(e.target.value)} onBlur={onUpdateApiHost} /> - API Path - setApiPath(e.target.value)} - onBlur={onUpdateApiPath} - /> Models {Object.keys(modelGroups).map((group) => ( diff --git a/src/renderer/src/services/api.ts b/src/renderer/src/services/api.ts index 283256e0..c1ed83c3 100644 --- a/src/renderer/src/services/api.ts +++ b/src/renderer/src/services/api.ts @@ -1,19 +1,32 @@ -import { Assistant, Message, Topic } from '@renderer/types' -import { openaiProvider } from './provider' +import { Assistant, Message, Provider, Topic } from '@renderer/types' import { uuid } from '@renderer/utils' import { EVENT_NAMES, EventEmitter } from './event' import { ChatCompletionMessageParam, ChatCompletionSystemMessageParam } from 'openai/resources' +import OpenAI from 'openai' interface FetchChatCompletionParams { message: Message - assistant: Assistant topic: Topic + assistant: Assistant + provider: Provider onResponse: (message: Message) => void } -export async function fetchChatCompletion({ message, assistant, topic, onResponse }: FetchChatCompletionParams) { +export async function fetchChatCompletion({ + message, + topic, + assistant, + provider, + onResponse +}: FetchChatCompletionParams) { + const openaiProvider = new OpenAI({ + dangerouslyAllowBrowser: true, + apiKey: provider.apiKey, + baseURL: `${provider.apiHost}/v1/` + }) + const stream = await openaiProvider.chat.completions.create({ - model: 'Qwen/Qwen2-7B-Instruct', + model: assistant.model?.name || '', messages: [ { role: 'system', content: assistant.prompt }, { role: 'user', content: message.content } diff --git a/src/renderer/src/services/provider.ts b/src/renderer/src/services/provider.ts deleted file mode 100644 index 3126a12d..00000000 --- a/src/renderer/src/services/provider.ts +++ /dev/null @@ -1,7 +0,0 @@ -import OpenAI from 'openai' - -export const openaiProvider = new OpenAI({ - dangerouslyAllowBrowser: true, - apiKey: 'sk-cmxcwkapuoxpddlytqpuxxszyqymqgrcxremulcdlgcgabtq', - baseURL: 'https://api.siliconflow.cn/v1' -}) diff --git a/src/renderer/src/store/assistants.ts b/src/renderer/src/store/assistants.ts index d8371689..c139ec78 100644 --- a/src/renderer/src/store/assistants.ts +++ b/src/renderer/src/store/assistants.ts @@ -2,7 +2,7 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit' import { getDefaultAssistant } from '@renderer/services/assistant' import LocalStorage from '@renderer/services/storage' import { getDefaultTopic } from '@renderer/services/topic' -import { Assistant, Topic } from '@renderer/types' +import { Assistant, Model, Topic } from '@renderer/types' import { uniqBy } from 'lodash' export interface AssistantsState { @@ -69,11 +69,29 @@ const assistantsSlice = createSlice({ } return assistant }) + }, + setModel: (state, action: PayloadAction<{ assistantId: string; model: Model }>) => { + state.assistants = state.assistants.map((assistant) => + assistant.id === action.payload.assistantId + ? { + ...assistant, + model: action.payload.model + } + : assistant + ) } } }) -export const { addAssistant, removeAssistant, updateAssistant, addTopic, removeTopic, updateTopic, removeAllTopics } = - assistantsSlice.actions +export const { + addAssistant, + removeAssistant, + updateAssistant, + addTopic, + removeTopic, + updateTopic, + removeAllTopics, + setModel +} = assistantsSlice.actions export default assistantsSlice.reducer diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index e2d133bc..f8096051 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -1,9 +1,11 @@ import { createSlice, PayloadAction } from '@reduxjs/toolkit' +import { SYSTEM_MODELS } from '@renderer/config/models' import { Model, Provider } from '@renderer/types' import { uniqBy } from 'lodash' export interface LlmState { providers: Provider[] + defaultModel: Model } const initialState: LlmState = { @@ -13,38 +15,35 @@ const initialState: LlmState = { name: 'OpenAI', apiKey: '', apiHost: 'https://api.openai.com', - apiPath: '/v1/chat/completions', isSystem: true, - models: [] + models: SYSTEM_MODELS.openai.filter((m) => m.defaultEnabled) }, { id: 'silicon', name: 'Silicon', apiKey: '', apiHost: 'https://api.siliconflow.cn', - apiPath: '/v1/chat/completions', isSystem: true, - models: [] + models: SYSTEM_MODELS.silicon.filter((m) => m.defaultEnabled) }, { id: 'deepseek', name: 'deepseek', apiKey: '', apiHost: 'https://api.deepseek.com', - apiPath: '/v1/chat/completions', isSystem: true, - models: [] + models: SYSTEM_MODELS.deepseek.filter((m) => m.defaultEnabled) }, { id: 'groq', name: 'Groq', apiKey: '', apiHost: 'https://api.groq.com', - apiPath: '/v1/chat/completions', isSystem: true, - models: [] + models: SYSTEM_MODELS.groq.filter((m) => m.defaultEnabled) } - ] + ], + defaultModel: SYSTEM_MODELS.openai[0] } const settingsSlice = createSlice({ diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 38b33706..f60e01e5 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -4,6 +4,7 @@ export type Assistant = { description: string prompt: string topics: Topic[] + model?: Model } export type Message = { @@ -33,7 +34,6 @@ export type Provider = { name: string apiKey: string apiHost: string - apiPath: string models: Model[] isSystem?: boolean isDefault?: boolean @@ -41,6 +41,7 @@ export type Provider = { export type Model = { id: string + provider: string name: string group: string temperature: number