feat: add provider type

This commit is contained in:
kangfenmao 2024-11-18 13:04:46 +08:00
parent c33c0b20f2
commit 1b8a3885f7
12 changed files with 109 additions and 36 deletions

View File

@ -11,8 +11,12 @@ import FireworksProviderLogo from '@renderer/assets/images/providers/fireworks.p
import GithubProviderLogo from '@renderer/assets/images/providers/github.png'
import GoogleProviderLogo from '@renderer/assets/images/providers/google.png'
import GraphRagProviderLogo from '@renderer/assets/images/providers/graph-rag.png'
import GrokProviderLogo from '@renderer/assets/images/providers/grok.png'
import GroqProviderLogo from '@renderer/assets/images/providers/groq.png'
import HyperbolicProviderLogo from '@renderer/assets/images/providers/hyperbolic.png'
import JinaProviderLogo from '@renderer/assets/images/providers/jina.png'
import MinimaxProviderLogo from '@renderer/assets/images/providers/minimax.png'
import MistralProviderLogo from '@renderer/assets/images/providers/mistral.png'
import MoonshotProviderLogo from '@renderer/assets/images/providers/moonshot.png'
import NvidiaProviderLogo from '@renderer/assets/images/providers/nvidia.png'
import OcoolAiProviderLogo from '@renderer/assets/images/providers/ocoolai.png'
@ -24,10 +28,6 @@ import StepProviderLogo from '@renderer/assets/images/providers/step.png'
import TogetherProviderLogo from '@renderer/assets/images/providers/together.png'
import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png'
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
import GrokProviderLogo from '@renderer/assets/images/providers/grok.png'
import HyperbolicProviderLogo from '@renderer/assets/images/providers/hyperbolic.png'
import MistralProviderLogo from '@renderer/assets/images/providers/mistral.png'
import JinaProviderLogo from '@renderer/assets/images/providers/jina.png'
export function getProviderLogo(providerId: string) {
switch (providerId) {
@ -326,6 +326,7 @@ export const PROVIDER_CONFIG = {
},
websites: {
official: 'https://app.hyperbolic.xyz',
apiKey: 'https://app.hyperbolic.xyz/settings',
docs: 'https://docs.hyperbolic.xyz',
models: 'https://app.hyperbolic.xyz/models'
}
@ -336,6 +337,7 @@ export const PROVIDER_CONFIG = {
},
websites: {
official: 'https://mistral.ai',
apiKey: 'https://console.mistral.ai/api-keys/',
docs: 'https://docs.mistral.ai',
models: 'https://docs.mistral.ai/getting-started/models/models_overview'
}
@ -346,6 +348,7 @@ export const PROVIDER_CONFIG = {
},
websites: {
official: 'https://jina.ai',
apiKey: 'https://jina.ai/',
docs: 'https://jina.ai',
models: 'https://jina.ai'
}

View File

@ -377,8 +377,10 @@
"not_checked": "Not Checked",
"delete.title": "Delete Provider",
"delete.content": "Are you sure you want to delete this provider?",
"edit.name": "Provider Name",
"edit.name.placeholder": "Example: OpenAI",
"add.title": "Add Provider",
"add.name": "Provider Name",
"add.name.placeholder": "Example: OpenAI",
"add.type": "Provider Type",
"no_models": "Please add models first before checking the API connection"
}
},

View File

@ -377,8 +377,10 @@
"not_checked": "Не проверено",
"delete.title": "Удалить провайдер",
"delete.content": "Вы уверены, что хотите удалить этот провайдер?",
"edit.name": "Имя провайдера",
"edit.name.placeholder": "Пример: OpenAI",
"add.title": "Добавить провайдер",
"add.name": "Имя провайдера",
"add.name.placeholder": "Пример: OpenAI",
"add.type": "Тип провайдера",
"no_models": "Пожалуйста, добавьте модели перед проверкой соединения с API"
}
},

View File

@ -365,8 +365,10 @@
"not_checked": "未检查",
"delete.title": "删除提供商",
"delete.content": "确定要删除此模型提供商吗?",
"edit.name": "模型提供商名称",
"edit.name.placeholder": "例如 OpenAI",
"add.title": "添加提供商",
"add.name": "提供商名称",
"add.name.placeholder": "例如 OpenAI",
"add.type": "提供商类型",
"no_models": "请先添加模型再检查 API 连接"
}
},

View File

@ -364,9 +364,11 @@
"remove_duplicate_keys": "移除重複密鑰",
"not_checked": "未檢查",
"delete.title": "刪除提供者",
".delete.content": "確定要刪除此提供者嗎?",
"edit.name": "提供者名稱",
"edit.name.placeholder": "例如OpenAI",
"delete.content": "確定要刪除此提供者嗎?",
"add.title": "添加提供者",
"add.name": "提供者名稱",
"add.name.placeholder": "例如OpenAI",
"add.type": "提供商類型",
"no_models": "請先添加模型再檢查 API 連接"
}
},

View File

@ -1,31 +1,32 @@
import { TopView } from '@renderer/components/TopView'
import { Provider } from '@renderer/types'
import { Input, Modal } from 'antd'
import { Provider, ProviderType } from '@renderer/types'
import { Divider, Form, Input, Modal, Select } from 'antd'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
interface Props {
provider?: Provider
resolve: (name: string) => void
resolve: (result: { name: string; type: ProviderType }) => void
}
const PopupContainer: React.FC<Props> = ({ provider, resolve }) => {
const [open, setOpen] = useState(true)
const [name, setName] = useState(provider?.name || '')
const [type, setType] = useState<ProviderType>(provider?.type || 'openai')
const { t } = useTranslation()
const onOk = () => {
setOpen(false)
resolve(name)
resolve({ name, type })
}
const onCancel = () => {
setOpen(false)
resolve('')
resolve({ name: '', type: 'openai' })
}
const onClose = () => {
resolve(name)
resolve({ name, type })
}
const buttonDisabled = name.length === 0
@ -39,15 +40,31 @@ const PopupContainer: React.FC<Props> = ({ provider, resolve }) => {
width={360}
closable={false}
centered
title={t('settings.provider.edit.name')}
title={t('settings.provider.add.title')}
okButtonProps={{ disabled: buttonDisabled }}>
<Divider style={{ margin: '8px 0' }} />
<Form layout="vertical" style={{ gap: 8 }}>
<Form.Item label={t('settings.provider.add.name')} style={{ marginBottom: 8 }}>
<Input
value={name}
onChange={(e) => setName(e.target.value.trim())}
placeholder={t('settings.provider.edit.name.placeholder')}
placeholder={t('settings.provider.add.name.placeholder')}
onKeyDown={(e) => e.key === 'Enter' && onOk()}
maxLength={32}
/>
</Form.Item>
<Form.Item label={t('settings.provider.add.type')} style={{ marginBottom: 0 }}>
<Select
value={type}
onChange={setType}
options={[
{ label: 'OpenAI', value: 'openai' },
{ label: 'Gemini', value: 'gemini' },
{ label: 'Anthropic', value: 'anthropic' }
]}
/>
</Form.Item>
</Form>
</Modal>
)
}
@ -58,7 +75,7 @@ export default class AddProviderPopup {
TopView.hide('AddProviderPopup')
}
static show(provider?: Provider) {
return new Promise<string>((resolve) => {
return new Promise<{ name: string; type: ProviderType }>((resolve) => {
TopView.show(
<PopupContainer
provider={provider}

View File

@ -31,15 +31,16 @@ const ProvidersList: FC = () => {
}
const onAddProvider = async () => {
const prividerName = await AddProviderPopup.show()
const { name: prividerName, type } = await AddProviderPopup.show()
if (!prividerName) {
if (!prividerName.trim()) {
return
}
const provider = {
id: uuid(),
name: prividerName,
name: prividerName.trim(),
type,
apiKey: '',
apiHost: '',
models: [],
@ -58,8 +59,8 @@ const ProvidersList: FC = () => {
key: 'edit',
icon: <EditOutlined />,
async onClick() {
const name = await AddProviderPopup.show(provider)
name && updateProvider({ ...provider, name })
const { name, type } = await AddProviderPopup.show(provider)
name && updateProvider({ ...provider, name, type })
}
},
{

View File

@ -7,7 +7,7 @@ import OpenAIProvider from './OpenAIProvider'
export default class ProviderFactory {
static create(provider: Provider): BaseProvider {
switch (provider.id) {
switch (provider.type) {
case 'anthropic':
return new AnthropicProvider(provider)
case 'gemini':
@ -19,5 +19,5 @@ export default class ProviderFactory {
}
export function isOpenAIProvider(provider: Provider) {
return !['anthropic', 'gemini'].includes(provider.id)
return !['anthropic', 'gemini'].includes(provider.type)
}

View File

@ -24,7 +24,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 40,
version: 41,
blacklist: ['runtime'],
migrate
},

View File

@ -26,6 +26,7 @@ const initialState: LlmState = {
{
id: 'silicon',
name: 'Silicon',
type: 'openai',
apiKey: '',
apiHost: 'https://api.siliconflow.cn',
models: SYSTEM_MODELS.silicon,
@ -35,6 +36,7 @@ const initialState: LlmState = {
{
id: 'ollama',
name: 'Ollama',
type: 'openai',
apiKey: '',
apiHost: 'http://localhost:11434/v1/',
models: SYSTEM_MODELS.ollama,
@ -44,6 +46,7 @@ const initialState: LlmState = {
{
id: 'anthropic',
name: 'Anthropic',
type: 'anthropic',
apiKey: '',
apiHost: 'https://api.anthropic.com/',
models: SYSTEM_MODELS.anthropic,
@ -53,6 +56,7 @@ const initialState: LlmState = {
{
id: 'openai',
name: 'OpenAI',
type: 'openai',
apiKey: '',
apiHost: 'https://api.openai.com',
models: SYSTEM_MODELS.openai,
@ -62,6 +66,7 @@ const initialState: LlmState = {
{
id: 'azure-openai',
name: 'Azure OpenAI',
type: 'openai',
apiKey: '',
apiHost: '',
apiVersion: '',
@ -72,6 +77,7 @@ const initialState: LlmState = {
{
id: 'gemini',
name: 'Gemini',
type: 'gemini',
apiKey: '',
apiHost: 'https://generativelanguage.googleapis.com',
models: SYSTEM_MODELS.gemini,
@ -81,6 +87,7 @@ const initialState: LlmState = {
{
id: 'deepseek',
name: 'deepseek',
type: 'openai',
apiKey: '',
apiHost: 'https://api.deepseek.com',
models: SYSTEM_MODELS.deepseek,
@ -90,6 +97,7 @@ const initialState: LlmState = {
{
id: 'ocoolai',
name: 'ocoolAI',
type: 'openai',
apiKey: '',
apiHost: 'https://one.ooo.cool',
models: SYSTEM_MODELS.ocoolai,
@ -99,6 +107,7 @@ const initialState: LlmState = {
{
id: 'github',
name: 'Github Models',
type: 'openai',
apiKey: '',
apiHost: 'https://models.inference.ai.azure.com/',
models: SYSTEM_MODELS.github,
@ -108,6 +117,7 @@ const initialState: LlmState = {
{
id: 'yi',
name: 'Yi',
type: 'openai',
apiKey: '',
apiHost: 'https://api.lingyiwanwu.com',
models: SYSTEM_MODELS.yi,
@ -117,6 +127,7 @@ const initialState: LlmState = {
{
id: 'zhipu',
name: 'ZhiPu',
type: 'openai',
apiKey: '',
apiHost: 'https://open.bigmodel.cn/api/paas/v4/',
models: SYSTEM_MODELS.zhipu,
@ -126,6 +137,7 @@ const initialState: LlmState = {
{
id: 'moonshot',
name: 'Moonshot AI',
type: 'openai',
apiKey: '',
apiHost: 'https://api.moonshot.cn',
models: SYSTEM_MODELS.moonshot,
@ -135,6 +147,7 @@ const initialState: LlmState = {
{
id: 'baichuan',
name: 'BAICHUAN AI',
type: 'openai',
apiKey: '',
apiHost: 'https://api.baichuan-ai.com',
models: SYSTEM_MODELS.baichuan,
@ -144,6 +157,7 @@ const initialState: LlmState = {
{
id: 'dashscope',
name: 'Bailian',
type: 'openai',
apiKey: '',
apiHost: 'https://dashscope.aliyuncs.com/compatible-mode/v1/',
models: SYSTEM_MODELS.bailian,
@ -153,6 +167,7 @@ const initialState: LlmState = {
{
id: 'stepfun',
name: 'StepFun',
type: 'openai',
apiKey: '',
apiHost: 'https://api.stepfun.com',
models: SYSTEM_MODELS.stepfun,
@ -162,6 +177,7 @@ const initialState: LlmState = {
{
id: 'doubao',
name: 'doubao',
type: 'openai',
apiKey: '',
apiHost: 'https://ark.cn-beijing.volces.com/api/v3/',
models: SYSTEM_MODELS.doubao,
@ -171,6 +187,7 @@ const initialState: LlmState = {
{
id: 'minimax',
name: 'MiniMax',
type: 'openai',
apiKey: '',
apiHost: 'https://api.minimax.chat/v1/',
models: SYSTEM_MODELS.minimax,
@ -180,6 +197,7 @@ const initialState: LlmState = {
{
id: 'graphrag-kylin-mountain',
name: 'GraphRAG',
type: 'openai',
apiKey: '',
apiHost: '',
models: [],
@ -189,6 +207,7 @@ const initialState: LlmState = {
{
id: 'openrouter',
name: 'OpenRouter',
type: 'openai',
apiKey: '',
apiHost: 'https://openrouter.ai/api/v1/',
models: SYSTEM_MODELS.openrouter,
@ -198,6 +217,7 @@ const initialState: LlmState = {
{
id: 'groq',
name: 'Groq',
type: 'openai',
apiKey: '',
apiHost: 'https://api.groq.com/openai',
models: SYSTEM_MODELS.groq,
@ -207,6 +227,7 @@ const initialState: LlmState = {
{
id: 'together',
name: 'Together',
type: 'openai',
apiKey: '',
apiHost: 'https://api.together.xyz',
models: SYSTEM_MODELS.together,
@ -216,6 +237,7 @@ const initialState: LlmState = {
{
id: 'fireworks',
name: 'Fireworks',
type: 'openai',
apiKey: '',
apiHost: 'https://api.fireworks.ai/inference',
models: SYSTEM_MODELS.fireworks,
@ -225,6 +247,7 @@ const initialState: LlmState = {
{
id: 'zhinao',
name: 'zhinao',
type: 'openai',
apiKey: '',
apiHost: 'https://api.360.cn',
models: SYSTEM_MODELS.zhinao,
@ -234,6 +257,7 @@ const initialState: LlmState = {
{
id: 'hunyuan',
name: 'hunyuan',
type: 'openai',
apiKey: '',
apiHost: 'https://api.hunyuan.cloud.tencent.com',
models: SYSTEM_MODELS.hunyuan,
@ -243,6 +267,7 @@ const initialState: LlmState = {
{
id: 'nvidia',
name: 'nvidia',
type: 'openai',
apiKey: '',
apiHost: 'https://integrate.api.nvidia.com',
models: SYSTEM_MODELS.nvidia,
@ -252,6 +277,7 @@ const initialState: LlmState = {
{
id: 'grok',
name: 'Grok',
type: 'openai',
apiKey: '',
apiHost: 'https://api.x.ai',
models: SYSTEM_MODELS.grok,
@ -261,6 +287,7 @@ const initialState: LlmState = {
{
id: 'hyperbolic',
name: 'Hyperbolic',
type: 'openai',
apiKey: '',
apiHost: 'https://api.hyperbolic.xyz',
models: SYSTEM_MODELS.hyperbolic,
@ -270,6 +297,7 @@ const initialState: LlmState = {
{
id: 'mistral',
name: 'Mistral',
type: 'openai',
apiKey: '',
apiHost: 'https://api.mistral.ai',
models: SYSTEM_MODELS.mistral,
@ -288,6 +316,7 @@ const initialState: LlmState = {
{
id: 'aihubmix',
name: 'AiHubMix',
type: 'openai',
apiKey: '',
apiHost: 'https://aihubmix.com',
models: SYSTEM_MODELS.aihubmix,

View File

@ -677,6 +677,18 @@ const migrateConfig = {
'40': (state: RootState) => {
state.settings.tray = true
return state
},
'41': (state: RootState) => {
state.llm.providers.forEach((provider) => {
if (provider.id === 'gemini') {
provider.type = 'gemini'
} else if (provider.id === 'anthropic') {
provider.type = 'anthropic'
} else {
provider.type = 'openai'
}
})
return state
}
}

View File

@ -66,6 +66,7 @@ export type User = {
export type Provider = {
id: string
type: ProviderType
name: string
apiKey: string
apiHost: string
@ -75,6 +76,8 @@ export type Provider = {
isSystem?: boolean
}
export type ProviderType = 'openai' | 'anthropic' | 'gemini'
export type Model = {
id: string
provider: string