diff --git a/src/renderer/src/hooks/useAssistant.ts b/src/renderer/src/hooks/useAssistant.ts index 7546953a..a7c37e04 100644 --- a/src/renderer/src/hooks/useAssistant.ts +++ b/src/renderer/src/hooks/useAssistant.ts @@ -30,11 +30,6 @@ export function useAssistants() { } } -export function useDefaultModel() { - const defaultModel = useAppSelector((state) => state.llm.defaultModel) - return { defaultModel } -} - export function useAssistant(id: string) { const assistant = useAppSelector((state) => state.assistants.assistants.find((a) => a.id === id) as Assistant) const dispatch = useAppDispatch() @@ -60,3 +55,8 @@ export function useAssistant(id: string) { } } } + +export function useDefaultModel() { + const defaultModel = useAppSelector((state) => state.llm.defaultModel) + return { defaultModel } +} diff --git a/src/renderer/src/hooks/useProvider.ts b/src/renderer/src/hooks/useProvider.ts index 650396fc..6427ecbc 100644 --- a/src/renderer/src/hooks/useProvider.ts +++ b/src/renderer/src/hooks/useProvider.ts @@ -31,10 +31,6 @@ export function useProviderByAssistant(assistant: Assistant) { return provider } -export function useDefaultProvider() { - return useAppSelector((state) => state.llm.providers.find((p) => p.isDefault)) -} - export function useSystemProviders() { - return useAppSelector((state) => state.llm.providers.filter((p) => p.isSystem)) + return useAppSelector((state) => state.llm.providers.filter((p) => p.isSystem)) as unknown as Provider } diff --git a/src/renderer/src/pages/home/HomePage.tsx b/src/renderer/src/pages/home/HomePage.tsx index 7ccf76d9..e8e558eb 100644 --- a/src/renderer/src/pages/home/HomePage.tsx +++ b/src/renderer/src/pages/home/HomePage.tsx @@ -2,13 +2,13 @@ import { Navbar, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar import { useAssistants } from '@renderer/hooks/useAssistant' import { FC, useState } from 'react' import styled from 'styled-components' -import Chat from './components/Chat/Chat' +import Chat from './components/Chat' import Assistants from './components/Assistants' import { uuid } from '@renderer/utils' import { getDefaultAssistant } from '@renderer/services/assistant' import { useShowRightSidebar } from '@renderer/hooks/useStore' import { Tooltip } from 'antd' -import NavigationCenter from './components/Chat/NavigationCenter' +import Navigation from './components/Navigation' const HomePage: FC = () => { const { assistants, addAssistant } = useAssistants() @@ -30,7 +30,7 @@ const HomePage: FC = () => { - + diff --git a/src/renderer/src/pages/home/components/Chat/Chat.tsx b/src/renderer/src/pages/home/components/Chat.tsx similarity index 100% rename from src/renderer/src/pages/home/components/Chat/Chat.tsx rename to src/renderer/src/pages/home/components/Chat.tsx diff --git a/src/renderer/src/pages/home/components/Chat/Conversations.tsx b/src/renderer/src/pages/home/components/Conversations.tsx similarity index 89% rename from src/renderer/src/pages/home/components/Chat/Conversations.tsx rename to src/renderer/src/pages/home/components/Conversations.tsx index 347d735b..71d60b9d 100644 --- a/src/renderer/src/pages/home/components/Chat/Conversations.tsx +++ b/src/renderer/src/pages/home/components/Conversations.tsx @@ -38,20 +38,17 @@ const Conversations: FC = ({ assistant, topic }) => { const autoRenameTopic = useCallback(async () => { if (topic.name === DEFAULT_TOPIC_NAME && messages.length >= 2) { - const summaryText = await fetchConversationSummary({ messages }) - if (summaryText) { - updateTopic({ ...topic, name: summaryText }) - } + const summaryText = await fetchConversationSummary({ messages, assistant }) + summaryText && updateTopic({ ...topic, name: summaryText }) } - }, [messages, topic, updateTopic]) + }, [assistant, messages, topic, updateTopic]) useEffect(() => { const unsubscribes = [ EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, async (msg: Message) => { console.debug({ assistant, provider, message: msg, topic }) - return onSendMessage(msg) - fetchChatCompletion({ assistant, provider, message: msg, topic, onResponse: setLastMessage }) + fetchChatCompletion({ assistant, message: msg, topic, onResponse: setLastMessage }) }), EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => { setLastMessage(null) @@ -66,12 +63,12 @@ const Conversations: FC = ({ assistant, topic }) => { }) ] return () => unsubscribes.forEach((unsub) => unsub()) - }, [assistant, autoRenameTopic, onSendMessage, topic, updateTopic]) + }, [assistant, autoRenameTopic, onSendMessage, provider, topic, updateTopic]) useEffect(() => { runAsyncFunction(async () => { const messages = await LocalStorage.getTopicMessages(topic.id) - setMessages(messages) + setMessages(messages || []) }) }, [topic.id]) diff --git a/src/renderer/src/pages/home/components/Chat/Inputbar.tsx b/src/renderer/src/pages/home/components/Inputbar.tsx similarity index 100% rename from src/renderer/src/pages/home/components/Chat/Inputbar.tsx rename to src/renderer/src/pages/home/components/Inputbar.tsx diff --git a/src/renderer/src/pages/home/components/Chat/Message.tsx b/src/renderer/src/pages/home/components/Message.tsx similarity index 100% rename from src/renderer/src/pages/home/components/Chat/Message.tsx rename to src/renderer/src/pages/home/components/Message.tsx diff --git a/src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx b/src/renderer/src/pages/home/components/Navigation.tsx similarity index 100% rename from src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx rename to src/renderer/src/pages/home/components/Navigation.tsx diff --git a/src/renderer/src/pages/home/components/Chat/TopicList.tsx b/src/renderer/src/pages/home/components/TopicList.tsx similarity index 99% rename from src/renderer/src/pages/home/components/Chat/TopicList.tsx rename to src/renderer/src/pages/home/components/TopicList.tsx index 9198efac..a1bc4c24 100644 --- a/src/renderer/src/pages/home/components/Chat/TopicList.tsx +++ b/src/renderer/src/pages/home/components/TopicList.tsx @@ -29,7 +29,7 @@ const TopicList: FC = ({ assistant, activeTopic, setActiveTopic }) => { if (currentTopic.current) { const messages = await LocalStorage.getTopicMessages(currentTopic.current.id) if (messages.length >= 2) { - const summaryText = await fetchConversationSummary({ messages }) + const summaryText = await fetchConversationSummary({ messages, assistant }) if (summaryText) { updateTopic({ ...currentTopic.current, name: summaryText }) } diff --git a/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx b/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx index f2602a10..c979bdf1 100644 --- a/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx +++ b/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx @@ -39,7 +39,13 @@ const ModalProviderSetting: FC = ({ provider }) => { {provider.name} API Key - setApiKey(e.target.value)} onBlur={onUpdateApiKey} /> + setApiKey(e.target.value)} + onBlur={onUpdateApiKey} + spellCheck={false} + /> API Host void } -export async function fetchChatCompletion({ - message, - topic, - assistant, - provider, - onResponse -}: FetchChatCompletionParams) { - const openaiProvider = new OpenAI({ +const getOpenAiProvider = (provider: Provider) => { + return new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL: `${provider.apiHost}/v1/` }) +} + +export async function fetchChatCompletion({ message, topic, assistant, onResponse }: FetchChatCompletionParams) { + const provider = getAssistantProvider(assistant) + const openaiProvider = getOpenAiProvider(provider) const stream = await openaiProvider.chat.completions.create({ - model: assistant.model?.name || '', + model: assistant.model?.id || '', messages: [ { role: 'system', content: assistant.prompt }, { role: 'user', content: message.content } @@ -59,9 +58,13 @@ export async function fetchChatCompletion({ interface FetchConversationSummaryParams { messages: Message[] + assistant: Assistant } -export async function fetchConversationSummary({ messages }: FetchConversationSummaryParams) { +export async function fetchConversationSummary({ messages, assistant }: FetchConversationSummaryParams) { + const provider = getAssistantProvider(assistant) + const openaiProvider = getOpenAiProvider(provider) + const userMessages: ChatCompletionMessageParam[] = messages.map((message) => ({ role: 'user', content: message.content @@ -74,7 +77,7 @@ export async function fetchConversationSummary({ messages }: FetchConversationSu } const response = await openaiProvider.chat.completions.create({ - model: 'Qwen/Qwen2-7B-Instruct', + model: assistant.model?.id || '', messages: [systemMessage, ...userMessages], stream: false }) diff --git a/src/renderer/src/services/assistant.ts b/src/renderer/src/services/assistant.ts index 770c0a07..4c3a307a 100644 --- a/src/renderer/src/services/assistant.ts +++ b/src/renderer/src/services/assistant.ts @@ -1,5 +1,6 @@ -import { Assistant } from '@renderer/types' +import { Assistant, Provider } from '@renderer/types' import { getDefaultTopic } from './topic' +import store from '@renderer/store' export function getDefaultAssistant(): Assistant { return { @@ -10,3 +11,13 @@ export function getDefaultAssistant(): Assistant { topics: [getDefaultTopic()] } } + +export function getAssistantProvider(assistant: Assistant) { + const providers = store.getState().llm.providers + return providers.find((p) => p.id === assistant.id) || getDefaultProvider() +} + +export function getDefaultProvider() { + const provider = store.getState().llm.providers.find((p) => p.isSystem) + return provider as Provider +} diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 1395efdb..f75c5966 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -12,15 +12,18 @@ const rootReducer = combineReducers({ llm }) +const persistedReducer = persistReducer( + { + key: 'cherry-ai', + storage, + version: 1 + }, + rootReducer +) + const store = configureStore({ - reducer: persistReducer( - { - key: 'cherry-ai', - storage, - version: 1 - }, - rootReducer - ), + // @ts-ignore store type is unknown + reducer: persistedReducer as typeof rootReducer, middleware: (getDefaultMiddleware) => { return getDefaultMiddleware({ serializableCheck: { diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index f8096051..bf7ea8bd 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -38,7 +38,7 @@ const initialState: LlmState = { id: 'groq', name: 'Groq', apiKey: '', - apiHost: 'https://api.groq.com', + apiHost: 'https://api.groq.com/openai', isSystem: true, models: SYSTEM_MODELS.groq.filter((m) => m.defaultEnabled) } diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index f60e01e5..894f7306 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -36,7 +36,6 @@ export type Provider = { apiHost: string models: Model[] isSystem?: boolean - isDefault?: boolean } export type Model = {