diff --git a/src/renderer/src/hooks/useAssistant.ts b/src/renderer/src/hooks/useAssistant.ts
index 7546953a..a7c37e04 100644
--- a/src/renderer/src/hooks/useAssistant.ts
+++ b/src/renderer/src/hooks/useAssistant.ts
@@ -30,11 +30,6 @@ export function useAssistants() {
}
}
-export function useDefaultModel() {
- const defaultModel = useAppSelector((state) => state.llm.defaultModel)
- return { defaultModel }
-}
-
export function useAssistant(id: string) {
const assistant = useAppSelector((state) => state.assistants.assistants.find((a) => a.id === id) as Assistant)
const dispatch = useAppDispatch()
@@ -60,3 +55,8 @@ export function useAssistant(id: string) {
}
}
}
+
+export function useDefaultModel() {
+ const defaultModel = useAppSelector((state) => state.llm.defaultModel)
+ return { defaultModel }
+}
diff --git a/src/renderer/src/hooks/useProvider.ts b/src/renderer/src/hooks/useProvider.ts
index 650396fc..6427ecbc 100644
--- a/src/renderer/src/hooks/useProvider.ts
+++ b/src/renderer/src/hooks/useProvider.ts
@@ -31,10 +31,6 @@ export function useProviderByAssistant(assistant: Assistant) {
return provider
}
-export function useDefaultProvider() {
- return useAppSelector((state) => state.llm.providers.find((p) => p.isDefault))
-}
-
export function useSystemProviders() {
- return useAppSelector((state) => state.llm.providers.filter((p) => p.isSystem))
+ return useAppSelector((state) => state.llm.providers.filter((p) => p.isSystem)) as unknown as Provider
}
diff --git a/src/renderer/src/pages/home/HomePage.tsx b/src/renderer/src/pages/home/HomePage.tsx
index 7ccf76d9..e8e558eb 100644
--- a/src/renderer/src/pages/home/HomePage.tsx
+++ b/src/renderer/src/pages/home/HomePage.tsx
@@ -2,13 +2,13 @@ import { Navbar, NavbarLeft, NavbarRight } from '@renderer/components/app/Navbar
import { useAssistants } from '@renderer/hooks/useAssistant'
import { FC, useState } from 'react'
import styled from 'styled-components'
-import Chat from './components/Chat/Chat'
+import Chat from './components/Chat'
import Assistants from './components/Assistants'
import { uuid } from '@renderer/utils'
import { getDefaultAssistant } from '@renderer/services/assistant'
import { useShowRightSidebar } from '@renderer/hooks/useStore'
import { Tooltip } from 'antd'
-import NavigationCenter from './components/Chat/NavigationCenter'
+import Navigation from './components/Navigation'
const HomePage: FC = () => {
const { assistants, addAssistant } = useAssistants()
@@ -30,7 +30,7 @@ const HomePage: FC = () => {
-
+
diff --git a/src/renderer/src/pages/home/components/Chat/Chat.tsx b/src/renderer/src/pages/home/components/Chat.tsx
similarity index 100%
rename from src/renderer/src/pages/home/components/Chat/Chat.tsx
rename to src/renderer/src/pages/home/components/Chat.tsx
diff --git a/src/renderer/src/pages/home/components/Chat/Conversations.tsx b/src/renderer/src/pages/home/components/Conversations.tsx
similarity index 89%
rename from src/renderer/src/pages/home/components/Chat/Conversations.tsx
rename to src/renderer/src/pages/home/components/Conversations.tsx
index 347d735b..71d60b9d 100644
--- a/src/renderer/src/pages/home/components/Chat/Conversations.tsx
+++ b/src/renderer/src/pages/home/components/Conversations.tsx
@@ -38,20 +38,17 @@ const Conversations: FC = ({ assistant, topic }) => {
const autoRenameTopic = useCallback(async () => {
if (topic.name === DEFAULT_TOPIC_NAME && messages.length >= 2) {
- const summaryText = await fetchConversationSummary({ messages })
- if (summaryText) {
- updateTopic({ ...topic, name: summaryText })
- }
+ const summaryText = await fetchConversationSummary({ messages, assistant })
+ summaryText && updateTopic({ ...topic, name: summaryText })
}
- }, [messages, topic, updateTopic])
+ }, [assistant, messages, topic, updateTopic])
useEffect(() => {
const unsubscribes = [
EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, async (msg: Message) => {
console.debug({ assistant, provider, message: msg, topic })
- return
onSendMessage(msg)
- fetchChatCompletion({ assistant, provider, message: msg, topic, onResponse: setLastMessage })
+ fetchChatCompletion({ assistant, message: msg, topic, onResponse: setLastMessage })
}),
EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => {
setLastMessage(null)
@@ -66,12 +63,12 @@ const Conversations: FC = ({ assistant, topic }) => {
})
]
return () => unsubscribes.forEach((unsub) => unsub())
- }, [assistant, autoRenameTopic, onSendMessage, topic, updateTopic])
+ }, [assistant, autoRenameTopic, onSendMessage, provider, topic, updateTopic])
useEffect(() => {
runAsyncFunction(async () => {
const messages = await LocalStorage.getTopicMessages(topic.id)
- setMessages(messages)
+ setMessages(messages || [])
})
}, [topic.id])
diff --git a/src/renderer/src/pages/home/components/Chat/Inputbar.tsx b/src/renderer/src/pages/home/components/Inputbar.tsx
similarity index 100%
rename from src/renderer/src/pages/home/components/Chat/Inputbar.tsx
rename to src/renderer/src/pages/home/components/Inputbar.tsx
diff --git a/src/renderer/src/pages/home/components/Chat/Message.tsx b/src/renderer/src/pages/home/components/Message.tsx
similarity index 100%
rename from src/renderer/src/pages/home/components/Chat/Message.tsx
rename to src/renderer/src/pages/home/components/Message.tsx
diff --git a/src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx b/src/renderer/src/pages/home/components/Navigation.tsx
similarity index 100%
rename from src/renderer/src/pages/home/components/Chat/NavigationCenter.tsx
rename to src/renderer/src/pages/home/components/Navigation.tsx
diff --git a/src/renderer/src/pages/home/components/Chat/TopicList.tsx b/src/renderer/src/pages/home/components/TopicList.tsx
similarity index 99%
rename from src/renderer/src/pages/home/components/Chat/TopicList.tsx
rename to src/renderer/src/pages/home/components/TopicList.tsx
index 9198efac..a1bc4c24 100644
--- a/src/renderer/src/pages/home/components/Chat/TopicList.tsx
+++ b/src/renderer/src/pages/home/components/TopicList.tsx
@@ -29,7 +29,7 @@ const TopicList: FC = ({ assistant, activeTopic, setActiveTopic }) => {
if (currentTopic.current) {
const messages = await LocalStorage.getTopicMessages(currentTopic.current.id)
if (messages.length >= 2) {
- const summaryText = await fetchConversationSummary({ messages })
+ const summaryText = await fetchConversationSummary({ messages, assistant })
if (summaryText) {
updateTopic({ ...currentTopic.current, name: summaryText })
}
diff --git a/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx b/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx
index f2602a10..c979bdf1 100644
--- a/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx
+++ b/src/renderer/src/pages/settings/components/ModalProviderSetting.tsx
@@ -39,7 +39,13 @@ const ModalProviderSetting: FC = ({ provider }) => {
{provider.name}
API Key
- setApiKey(e.target.value)} onBlur={onUpdateApiKey} />
+ setApiKey(e.target.value)}
+ onBlur={onUpdateApiKey}
+ spellCheck={false}
+ />
API Host
void
}
-export async function fetchChatCompletion({
- message,
- topic,
- assistant,
- provider,
- onResponse
-}: FetchChatCompletionParams) {
- const openaiProvider = new OpenAI({
+const getOpenAiProvider = (provider: Provider) => {
+ return new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: provider.apiKey,
baseURL: `${provider.apiHost}/v1/`
})
+}
+
+export async function fetchChatCompletion({ message, topic, assistant, onResponse }: FetchChatCompletionParams) {
+ const provider = getAssistantProvider(assistant)
+ const openaiProvider = getOpenAiProvider(provider)
const stream = await openaiProvider.chat.completions.create({
- model: assistant.model?.name || '',
+ model: assistant.model?.id || '',
messages: [
{ role: 'system', content: assistant.prompt },
{ role: 'user', content: message.content }
@@ -59,9 +58,13 @@ export async function fetchChatCompletion({
interface FetchConversationSummaryParams {
messages: Message[]
+ assistant: Assistant
}
-export async function fetchConversationSummary({ messages }: FetchConversationSummaryParams) {
+export async function fetchConversationSummary({ messages, assistant }: FetchConversationSummaryParams) {
+ const provider = getAssistantProvider(assistant)
+ const openaiProvider = getOpenAiProvider(provider)
+
const userMessages: ChatCompletionMessageParam[] = messages.map((message) => ({
role: 'user',
content: message.content
@@ -74,7 +77,7 @@ export async function fetchConversationSummary({ messages }: FetchConversationSu
}
const response = await openaiProvider.chat.completions.create({
- model: 'Qwen/Qwen2-7B-Instruct',
+ model: assistant.model?.id || '',
messages: [systemMessage, ...userMessages],
stream: false
})
diff --git a/src/renderer/src/services/assistant.ts b/src/renderer/src/services/assistant.ts
index 770c0a07..4c3a307a 100644
--- a/src/renderer/src/services/assistant.ts
+++ b/src/renderer/src/services/assistant.ts
@@ -1,5 +1,6 @@
-import { Assistant } from '@renderer/types'
+import { Assistant, Provider } from '@renderer/types'
import { getDefaultTopic } from './topic'
+import store from '@renderer/store'
export function getDefaultAssistant(): Assistant {
return {
@@ -10,3 +11,13 @@ export function getDefaultAssistant(): Assistant {
topics: [getDefaultTopic()]
}
}
+
+export function getAssistantProvider(assistant: Assistant) {
+ const providers = store.getState().llm.providers
+ return providers.find((p) => p.id === assistant.id) || getDefaultProvider()
+}
+
+export function getDefaultProvider() {
+ const provider = store.getState().llm.providers.find((p) => p.isSystem)
+ return provider as Provider
+}
diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts
index 1395efdb..f75c5966 100644
--- a/src/renderer/src/store/index.ts
+++ b/src/renderer/src/store/index.ts
@@ -12,15 +12,18 @@ const rootReducer = combineReducers({
llm
})
+const persistedReducer = persistReducer(
+ {
+ key: 'cherry-ai',
+ storage,
+ version: 1
+ },
+ rootReducer
+)
+
const store = configureStore({
- reducer: persistReducer(
- {
- key: 'cherry-ai',
- storage,
- version: 1
- },
- rootReducer
- ),
+ // @ts-ignore store type is unknown
+ reducer: persistedReducer as typeof rootReducer,
middleware: (getDefaultMiddleware) => {
return getDefaultMiddleware({
serializableCheck: {
diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts
index f8096051..bf7ea8bd 100644
--- a/src/renderer/src/store/llm.ts
+++ b/src/renderer/src/store/llm.ts
@@ -38,7 +38,7 @@ const initialState: LlmState = {
id: 'groq',
name: 'Groq',
apiKey: '',
- apiHost: 'https://api.groq.com',
+ apiHost: 'https://api.groq.com/openai',
isSystem: true,
models: SYSTEM_MODELS.groq.filter((m) => m.defaultEnabled)
}
diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts
index f60e01e5..894f7306 100644
--- a/src/renderer/src/types/index.ts
+++ b/src/renderer/src/types/index.ts
@@ -36,7 +36,6 @@ export type Provider = {
apiHost: string
models: Model[]
isSystem?: boolean
- isDefault?: boolean
}
export type Model = {