From a7a82be08397235346e8d74d099c442ec51df4e7 Mon Sep 17 00:00:00 2001 From: Asurada <43401755+ousugo@users.noreply.github.com> Date: Thu, 27 Feb 2025 17:00:01 +0800 Subject: [PATCH] feat: Add model editing functionality to provider settings (#2243) --- src/renderer/src/hooks/useProvider.ts | 4 +- src/renderer/src/i18n/locales/en-us.json | 3 +- src/renderer/src/i18n/locales/ja-jp.json | 3 +- src/renderer/src/i18n/locales/ru-ru.json | 3 +- src/renderer/src/i18n/locales/zh-cn.json | 3 +- src/renderer/src/i18n/locales/zh-tw.json | 3 +- .../ProviderSettings/ProviderSetting.tsx | 194 ++++++++++++++---- src/renderer/src/store/llm.ts | 18 +- 8 files changed, 180 insertions(+), 51 deletions(-) diff --git a/src/renderer/src/hooks/useProvider.ts b/src/renderer/src/hooks/useProvider.ts index dc314eae..95a8a8fa 100644 --- a/src/renderer/src/hooks/useProvider.ts +++ b/src/renderer/src/hooks/useProvider.ts @@ -5,6 +5,7 @@ import { addProvider, removeModel, removeProvider, + updateModel, updateProvider, updateProviders } from '@renderer/store/llm' @@ -51,7 +52,8 @@ export function useProvider(id: string) { models: provider?.models || [], updateProvider: (provider: Provider) => dispatch(updateProvider(provider)), addModel: (model: Model) => dispatch(addModel({ providerId: id, model })), - removeModel: (model: Model) => dispatch(removeModel({ providerId: id, model })) + removeModel: (model: Model) => dispatch(removeModel({ providerId: id, model })), + updateModel: (model: Model) => dispatch(updateModel({ providerId: id, model })) } } diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 0524006c..dd12ef58 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -472,7 +472,8 @@ "vision": "Vision" }, "vision": "Vision", - "websearch": "WebSearch" + "websearch": "WebSearch", + "edit": "Edit Model" }, "ollama": { "keep_alive_time.description": "The time in minutes to keep the connection alive, default is 5 minutes.", diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index 110cb92e..56fead70 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -472,7 +472,8 @@ "vision": "画像" }, "vision": "画像", - "websearch": "ウェブ検索" + "websearch": "ウェブ検索", + "edit": "モデルを編集" }, "ollama": { "keep_alive_time.description": "モデルがメモリに保持される時間(デフォルト:5分)", diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 3adf4dc8..71d276b9 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -472,7 +472,8 @@ "vision": "Изображение" }, "vision": "Визуальные", - "websearch": "Веб-поисковые" + "websearch": "Веб-поисковые", + "edit": "Редактировать модель" }, "ollama": { "keep_alive_time.description": "Время в минутах, в течение которого модель остается активной, по умолчанию 5 минут.", diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index 3991331f..682bb7c8 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -472,7 +472,8 @@ "vision": "图像" }, "vision": "视觉", - "websearch": "联网" + "websearch": "联网", + "edit": "编辑模型" }, "ollama": { "keep_alive_time.description": "对话后模型在内存中保持的时间(默认:5分钟)", diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 8ac8eba4..27a2c9e3 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -472,7 +472,8 @@ "vision": "圖像" }, "vision": "視覺", - "websearch": "網路搜索" + "websearch": "網路搜索", + "edit": "編輯模型" }, "ollama": { "keep_alive_time.description": "對話後模型在記憶體中保持的時間(預設為 5 分鐘)。", diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index d907f4c3..3d550131 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -22,9 +22,10 @@ import { isProviderSupportAuth, isProviderSupportCharge } from '@renderer/servic import { useAppDispatch } from '@renderer/store' import { setModel } from '@renderer/store/assistants' import { Model, ModelType, Provider } from '@renderer/types' +import { getDefaultGroupName } from '@renderer/utils' import { formatApiHost } from '@renderer/utils/api' import { providerCharge } from '@renderer/utils/oauth' -import { Avatar, Button, Card, Checkbox, Divider, Flex, Input, Popover, Space, Switch } from 'antd' +import { Avatar, Button, Card, Checkbox, Divider, Flex, Form, Input, Modal, Space, Switch } from 'antd' import Link from 'antd/es/typography/Link' import { groupBy, isEmpty } from 'lodash' import { FC, useEffect, useState } from 'react' @@ -51,6 +52,129 @@ interface Props { provider: Provider } +interface ModelEditContentProps { + model: Model + onUpdateModel: (model: Model) => void + open: boolean + onClose: () => void +} + +const ModelEditContent: FC = ({ model, onUpdateModel, open, onClose }) => { + const [form] = Form.useForm() + const { t } = useTranslation() + + const onFinish = (values: any) => { + const updatedModel = { + ...model, + id: values.id || model.id, + name: values.name || model.name, + group: values.group || model.group + } + onUpdateModel(updatedModel) + onClose() + } + + return ( + { + if (visible) { + form.getFieldInstance('id')?.focus() + } + }}> +
+ + { + const value = e.target.value + form.setFieldValue('name', value) + form.setFieldValue('group', getDefaultGroupName(value)) + }} + /> + + + + + + + + + + + +
+ {t('models.type.select')}: + {(() => { + const defaultTypes = [ + ...(isVisionModel(model) ? ['vision'] : []), + ...(isEmbeddingModel(model) ? ['embedding'] : []), + ...(isReasoningModel(model) ? ['reasoning'] : []) + ] as ModelType[] + + // 合并现有选择和默认类型 + const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])] + + return ( + onUpdateModel({ ...model, type: types as ModelType[] })} + options={[ + { + label: t('models.type.vision'), + value: 'vision', + disabled: isVisionModel(model) && !selectedTypes.includes('vision') + }, + { + label: t('models.type.embedding'), + value: 'embedding', + disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding') + }, + { + label: t('models.type.reasoning'), + value: 'reasoning', + disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning') + } + ]} + /> + ) + })()} +
+ +
+ ) +} + const ProviderSetting: FC = ({ provider: _provider }) => { const { provider } = useProvider(_provider.id) const [apiKey, setApiKey] = useState(provider.apiKey) @@ -76,6 +200,8 @@ const ProviderSetting: FC = ({ provider: _provider }) => { const modelsWebsite = providerConfig?.websites?.models const configedApiHost = providerConfig?.api?.url + const [editingModel, setEditingModel] = useState(null) + const onUpdateApiKey = () => { if (apiKey !== provider.apiKey) { updateProvider({ ...provider, apiKey }) @@ -164,67 +290,42 @@ const ProviderSetting: FC = ({ provider: _provider }) => { return formatApiHost(apiHost) + 'chat/completions' } - const onUpdateModelTypes = (model: Model, types: ModelType[]) => { + const onUpdateModel = (updatedModel: Model) => { const updatedModels = models.map((m) => { - if (m.id === model.id) { - return { ...m, type: types } + if (m.id === updatedModel.id) { + return updatedModel } return m }) updateProvider({ ...provider, models: updatedModels }) + // Update assistants using this model assistants.forEach((assistant) => { - if (assistant?.model?.id === model.id && assistant.model.provider === provider.id) { + if (assistant?.model?.id === updatedModel.id && assistant.model.provider === provider.id) { dispatch( setModel({ assistantId: assistant.id, - model: { ...model, type: types } + model: updatedModel }) ) } }) - if (defaultModel?.id === model.id && defaultModel?.provider === provider.id) { - setDefaultModel({ ...defaultModel, type: types }) + // Update default model if needed + if (defaultModel?.id === updatedModel.id && defaultModel?.provider === provider.id) { + setDefaultModel(updatedModel) } } const modelTypeContent = (model: Model) => { - // 获取默认选中的类型 - const defaultTypes = [ - ...(isVisionModel(model) ? ['vision'] : []), - ...(isEmbeddingModel(model) ? ['embedding'] : []), - ...(isReasoningModel(model) ? ['reasoning'] : []) - ] as ModelType[] - - // 合并现有选择和默认类型 - const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])] - return ( -
- onUpdateModelTypes(model, types as ModelType[])} - options={[ - { - label: t('models.type.vision'), - value: 'vision', - disabled: isVisionModel(model) && !selectedTypes.includes('vision') - }, - { - label: t('models.type.embedding'), - value: 'embedding', - disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding') - }, - { - label: t('models.type.reasoning'), - value: 'reasoning', - disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning') - } - ]} - /> -
+ setEditingModel(null)} + /> ) } @@ -355,9 +456,7 @@ const ProviderSetting: FC = ({ provider: _provider }) => { {model?.name} - - - + setEditingModel(model)} /> removeModel(model)} /> @@ -386,6 +485,7 @@ const ProviderSetting: FC = ({ provider: _provider }) => { {t('button.add')} + {models.map((model) => modelTypeContent(model))} ) } @@ -434,4 +534,10 @@ const ProviderName = styled.span` font-weight: 500; ` +const TypeTitle = styled.div` + margin-bottom: 12px; + font-size: 14px; + font-weight: 600; +` + export default ProviderSetting diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index 60f4477c..1a6fe425 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -502,6 +502,21 @@ const settingsSlice = createSlice({ }, setLMStudioKeepAliveTime: (state, action: PayloadAction) => { state.settings.lmstudio.keepAliveTime = action.payload + }, + updateModel: ( + state, + action: PayloadAction<{ + providerId: string + model: Model + }> + ) => { + const provider = state.providers.find((p) => p.id === action.payload.providerId) + if (provider) { + const modelIndex = provider.models.findIndex((m) => m.id === action.payload.model.id) + if (modelIndex !== -1) { + provider.models[modelIndex] = action.payload.model + } + } } } }) @@ -517,7 +532,8 @@ export const { setTopicNamingModel, setTranslateModel, setOllamaKeepAliveTime, - setLMStudioKeepAliveTime + setLMStudioKeepAliveTime, + updateModel } = settingsSlice.actions export default settingsSlice.reducer