feat: Add model editing functionality to provider settings (#2243)

This commit is contained in:
Asurada 2025-02-27 17:00:01 +08:00 committed by GitHub
parent c0117c25ac
commit a7a82be083
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 180 additions and 51 deletions

View File

@ -5,6 +5,7 @@ import {
addProvider,
removeModel,
removeProvider,
updateModel,
updateProvider,
updateProviders
} from '@renderer/store/llm'
@ -51,7 +52,8 @@ export function useProvider(id: string) {
models: provider?.models || [],
updateProvider: (provider: Provider) => dispatch(updateProvider(provider)),
addModel: (model: Model) => dispatch(addModel({ providerId: id, model })),
removeModel: (model: Model) => dispatch(removeModel({ providerId: id, model }))
removeModel: (model: Model) => dispatch(removeModel({ providerId: id, model })),
updateModel: (model: Model) => dispatch(updateModel({ providerId: id, model }))
}
}

View File

@ -472,7 +472,8 @@
"vision": "Vision"
},
"vision": "Vision",
"websearch": "WebSearch"
"websearch": "WebSearch",
"edit": "Edit Model"
},
"ollama": {
"keep_alive_time.description": "The time in minutes to keep the connection alive, default is 5 minutes.",

View File

@ -472,7 +472,8 @@
"vision": "画像"
},
"vision": "画像",
"websearch": "ウェブ検索"
"websearch": "ウェブ検索",
"edit": "モデルを編集"
},
"ollama": {
"keep_alive_time.description": "モデルがメモリに保持される時間デフォルト5分",

View File

@ -472,7 +472,8 @@
"vision": "Изображение"
},
"vision": "Визуальные",
"websearch": "Веб-поисковые"
"websearch": "Веб-поисковые",
"edit": "Редактировать модель"
},
"ollama": {
"keep_alive_time.description": "Время в минутах, в течение которого модель остается активной, по умолчанию 5 минут.",

View File

@ -472,7 +472,8 @@
"vision": "图像"
},
"vision": "视觉",
"websearch": "联网"
"websearch": "联网",
"edit": "编辑模型"
},
"ollama": {
"keep_alive_time.description": "对话后模型在内存中保持的时间默认5分钟",

View File

@ -472,7 +472,8 @@
"vision": "圖像"
},
"vision": "視覺",
"websearch": "網路搜索"
"websearch": "網路搜索",
"edit": "編輯模型"
},
"ollama": {
"keep_alive_time.description": "對話後模型在記憶體中保持的時間(預設為 5 分鐘)。",

View File

@ -22,9 +22,10 @@ import { isProviderSupportAuth, isProviderSupportCharge } from '@renderer/servic
import { useAppDispatch } from '@renderer/store'
import { setModel } from '@renderer/store/assistants'
import { Model, ModelType, Provider } from '@renderer/types'
import { getDefaultGroupName } from '@renderer/utils'
import { formatApiHost } from '@renderer/utils/api'
import { providerCharge } from '@renderer/utils/oauth'
import { Avatar, Button, Card, Checkbox, Divider, Flex, Input, Popover, Space, Switch } from 'antd'
import { Avatar, Button, Card, Checkbox, Divider, Flex, Form, Input, Modal, Space, Switch } from 'antd'
import Link from 'antd/es/typography/Link'
import { groupBy, isEmpty } from 'lodash'
import { FC, useEffect, useState } from 'react'
@ -51,6 +52,129 @@ interface Props {
provider: Provider
}
interface ModelEditContentProps {
model: Model
onUpdateModel: (model: Model) => void
open: boolean
onClose: () => void
}
const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, open, onClose }) => {
const [form] = Form.useForm()
const { t } = useTranslation()
const onFinish = (values: any) => {
const updatedModel = {
...model,
id: values.id || model.id,
name: values.name || model.name,
group: values.group || model.group
}
onUpdateModel(updatedModel)
onClose()
}
return (
<Modal
title={t('models.edit')}
open={open}
onCancel={onClose}
footer={null}
maskClosable={false}
centered
afterOpenChange={(visible) => {
if (visible) {
form.getFieldInstance('id')?.focus()
}
}}>
<Form
form={form}
labelCol={{ flex: '110px' }}
labelAlign="left"
colon={false}
style={{ marginTop: 15 }}
initialValues={{
id: model.id,
name: model.name,
group: model.group
}}
onFinish={onFinish}>
<Form.Item
name="id"
label={t('settings.models.add.model_id')}
tooltip={t('settings.models.add.model_id.tooltip')}
rules={[{ required: true }]}>
<Input
placeholder={t('settings.models.add.model_id.placeholder')}
spellCheck={false}
maxLength={200}
onChange={(e) => {
const value = e.target.value
form.setFieldValue('name', value)
form.setFieldValue('group', getDefaultGroupName(value))
}}
/>
</Form.Item>
<Form.Item
name="name"
label={t('settings.models.add.model_name')}
tooltip={t('settings.models.add.model_name.tooltip')}>
<Input placeholder={t('settings.models.add.model_name.placeholder')} spellCheck={false} />
</Form.Item>
<Form.Item
name="group"
label={t('settings.models.add.group_name')}
tooltip={t('settings.models.add.group_name.tooltip')}>
<Input placeholder={t('settings.models.add.group_name.placeholder')} spellCheck={false} />
</Form.Item>
<Form.Item style={{ marginBottom: 15, textAlign: 'center' }}>
<Button type="primary" htmlType="submit" size="middle">
{t('common.save')}
</Button>
</Form.Item>
<Divider style={{ margin: '0 0 15px 0' }} />
<div>
<TypeTitle>{t('models.type.select')}:</TypeTitle>
{(() => {
const defaultTypes = [
...(isVisionModel(model) ? ['vision'] : []),
...(isEmbeddingModel(model) ? ['embedding'] : []),
...(isReasoningModel(model) ? ['reasoning'] : [])
] as ModelType[]
// 合并现有选择和默认类型
const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])]
return (
<Checkbox.Group
value={selectedTypes}
onChange={(types) => onUpdateModel({ ...model, type: types as ModelType[] })}
options={[
{
label: t('models.type.vision'),
value: 'vision',
disabled: isVisionModel(model) && !selectedTypes.includes('vision')
},
{
label: t('models.type.embedding'),
value: 'embedding',
disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding')
},
{
label: t('models.type.reasoning'),
value: 'reasoning',
disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning')
}
]}
/>
)
})()}
</div>
</Form>
</Modal>
)
}
const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
const { provider } = useProvider(_provider.id)
const [apiKey, setApiKey] = useState(provider.apiKey)
@ -76,6 +200,8 @@ const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
const modelsWebsite = providerConfig?.websites?.models
const configedApiHost = providerConfig?.api?.url
const [editingModel, setEditingModel] = useState<Model | null>(null)
const onUpdateApiKey = () => {
if (apiKey !== provider.apiKey) {
updateProvider({ ...provider, apiKey })
@ -164,67 +290,42 @@ const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
return formatApiHost(apiHost) + 'chat/completions'
}
const onUpdateModelTypes = (model: Model, types: ModelType[]) => {
const onUpdateModel = (updatedModel: Model) => {
const updatedModels = models.map((m) => {
if (m.id === model.id) {
return { ...m, type: types }
if (m.id === updatedModel.id) {
return updatedModel
}
return m
})
updateProvider({ ...provider, models: updatedModels })
// Update assistants using this model
assistants.forEach((assistant) => {
if (assistant?.model?.id === model.id && assistant.model.provider === provider.id) {
if (assistant?.model?.id === updatedModel.id && assistant.model.provider === provider.id) {
dispatch(
setModel({
assistantId: assistant.id,
model: { ...model, type: types }
model: updatedModel
})
)
}
})
if (defaultModel?.id === model.id && defaultModel?.provider === provider.id) {
setDefaultModel({ ...defaultModel, type: types })
// Update default model if needed
if (defaultModel?.id === updatedModel.id && defaultModel?.provider === provider.id) {
setDefaultModel(updatedModel)
}
}
const modelTypeContent = (model: Model) => {
// 获取默认选中的类型
const defaultTypes = [
...(isVisionModel(model) ? ['vision'] : []),
...(isEmbeddingModel(model) ? ['embedding'] : []),
...(isReasoningModel(model) ? ['reasoning'] : [])
] as ModelType[]
// 合并现有选择和默认类型
const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])]
return (
<div>
<Checkbox.Group
value={selectedTypes}
onChange={(types) => onUpdateModelTypes(model, types as ModelType[])}
options={[
{
label: t('models.type.vision'),
value: 'vision',
disabled: isVisionModel(model) && !selectedTypes.includes('vision')
},
{
label: t('models.type.embedding'),
value: 'embedding',
disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding')
},
{
label: t('models.type.reasoning'),
value: 'reasoning',
disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning')
}
]}
<ModelEditContent
model={model}
onUpdateModel={onUpdateModel}
open={editingModel?.id === model.id}
onClose={() => setEditingModel(null)}
/>
</div>
)
}
@ -355,9 +456,7 @@ const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
<span>{model?.name}</span>
<ModelTags model={model} />
</ModelNameRow>
<Popover content={modelTypeContent(model)} title={t('models.type.select')} trigger="click">
<SettingIcon />
</Popover>
<SettingIcon onClick={() => setEditingModel(model)} />
</ModelListHeader>
<RemoveIcon onClick={() => removeModel(model)} />
</ModelListItem>
@ -386,6 +485,7 @@ const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
{t('button.add')}
</Button>
</Flex>
{models.map((model) => modelTypeContent(model))}
</SettingContainer>
)
}
@ -434,4 +534,10 @@ const ProviderName = styled.span`
font-weight: 500;
`
const TypeTitle = styled.div`
margin-bottom: 12px;
font-size: 14px;
font-weight: 600;
`
export default ProviderSetting

View File

@ -502,6 +502,21 @@ const settingsSlice = createSlice({
},
setLMStudioKeepAliveTime: (state, action: PayloadAction<number>) => {
state.settings.lmstudio.keepAliveTime = action.payload
},
updateModel: (
state,
action: PayloadAction<{
providerId: string
model: Model
}>
) => {
const provider = state.providers.find((p) => p.id === action.payload.providerId)
if (provider) {
const modelIndex = provider.models.findIndex((m) => m.id === action.payload.model.id)
if (modelIndex !== -1) {
provider.models[modelIndex] = action.payload.model
}
}
}
}
})
@ -517,7 +532,8 @@ export const {
setTopicNamingModel,
setTranslateModel,
setOllamaKeepAliveTime,
setLMStudioKeepAliveTime
setLMStudioKeepAliveTime,
updateModel
} = settingsSlice.actions
export default settingsSlice.reducer