feat: add ollama provider

This commit is contained in:
kangfenmao 2024-07-10 15:26:44 +08:00
parent 8009e05c80
commit 17826fd2d1
10 changed files with 230 additions and 73 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.0 KiB

View File

@ -2,9 +2,9 @@ import { useSystemProviders } from '@renderer/hooks/useProvider'
import { Provider } from '@renderer/types'
import { FC, useState } from 'react'
import styled from 'styled-components'
import ProviderModals from './components/ProviderModals'
import { Avatar } from 'antd'
import { getProviderLogo } from '@renderer/services/provider'
import ProviderModels from './components/ProviderModels'
const ProviderSettings: FC = () => {
const providers = useSystemProviders()
@ -23,7 +23,7 @@ const ProviderSettings: FC = () => {
</ProviderListItem>
))}
</ProviderListContainer>
<ProviderModals provider={selectedProvider} />
<ProviderModels provider={selectedProvider} />
</Container>
)
}

View File

@ -0,0 +1,124 @@
import { TopView } from '@renderer/components/TopView'
import { useProvider } from '@renderer/hooks/useProvider'
import { Model, Provider } from '@renderer/types'
import { getDefaultGroupName } from '@renderer/utils'
import { Button, Form, FormProps, Input, Modal } from 'antd'
import { find } from 'lodash'
import { useState } from 'react'
interface ShowParams {
title: string
provider: Provider
}
interface Props extends ShowParams {
resolve: (data: any) => void
}
type FieldType = {
provider: string
id: string
name?: string
group?: string
}
const PopupContainer: React.FC<Props> = ({ title, provider, resolve }) => {
const [open, setOpen] = useState(true)
const [form] = Form.useForm()
const { addModel, models } = useProvider(provider.id)
const onOk = () => {
setOpen(false)
}
const onCancel = () => {
setOpen(false)
}
const onClose = () => {
resolve({})
}
const onFinish: FormProps<FieldType>['onFinish'] = (values) => {
if (find(models, { id: values.id })) {
Modal.error({ title: 'Error', content: 'Model ID already exists' })
return
}
const model: Model = {
id: values.id,
provider: provider.id,
name: values.name ? values.name : values.id.toUpperCase(),
group: getDefaultGroupName(values.group || values.id),
temperature: 0.7
}
addModel(model)
resolve(model)
}
return (
<Modal
title={title}
open={open}
onOk={onOk}
onCancel={onCancel}
maskClosable={false}
afterClose={onClose}
footer={null}>
<Form
form={form}
labelCol={{ flex: '110px' }}
labelAlign="left"
colon={false}
style={{ marginTop: 25 }}
onFinish={onFinish}>
<Form.Item label="Provider" name="provider" initialValue={provider.id} rules={[{ required: true }]}>
<Input placeholder="Provider Name" disabled />
</Form.Item>
<Form.Item label="Model ID" name="id" tooltip="Example: gpt-3.5-turbo" rules={[{ required: true }]}>
<Input
placeholder="Required e.g. gpt-3.5-turbo"
spellCheck={false}
onChange={(e) => {
form.setFieldValue('name', e.target.value.toUpperCase())
form.setFieldValue('group', getDefaultGroupName(e.target.value))
}}
/>
</Form.Item>
<Form.Item label="Model Name" tooltip="Example: GPT-3.5" name="name">
<Input placeholder="Optional e.g. GPT-4" spellCheck={false} />
</Form.Item>
<Form.Item label="Group Name" tooltip="Example: ChatGPT" name="group">
<Input placeholder="Optional e.g. OpenAI" spellCheck={false} />
</Form.Item>
<Form.Item label=" ">
<Button type="primary" htmlType="submit">
Add Model
</Button>
</Form.Item>
</Form>
</Modal>
)
}
export default class ModalAddPopup {
static topviewId = 0
static hide() {
TopView.hide(this.topviewId)
}
static show(props: ShowParams) {
return new Promise<any>((resolve) => {
this.topviewId = TopView.show(
<PopupContainer
{...props}
resolve={(v) => {
resolve(v)
this.hide()
}}
/>
)
})
}
}

View File

@ -1,8 +1,8 @@
import { Avatar, Button, Modal } from 'antd'
import { Avatar, Button, Empty, Modal } from 'antd'
import { useState } from 'react'
import { TopView } from '../TopView'
import { TopView } from '../../../components/TopView'
import { Model, Provider } from '@renderer/types'
import { groupBy } from 'lodash'
import { groupBy, isEmpty, uniqBy } from 'lodash'
import styled from 'styled-components'
import { MinusOutlined, PlusOutlined } from '@ant-design/icons'
import { useProvider } from '@renderer/hooks/useProvider'
@ -19,10 +19,11 @@ interface Props extends ShowParams {
const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
const [open, setOpen] = useState(true)
const { provider, addModel, removeModel } = useProvider(_provider.id)
const { provider, models, addModel, removeModel } = useProvider(_provider.id)
const systemModels = SYSTEM_MODELS[_provider.id]
const systemModelGroups = groupBy(systemModels, 'group')
const systemModels = SYSTEM_MODELS[_provider.id] || []
const allModels = uniqBy([...systemModels, ...models], 'id')
const systemModelGroups = groupBy(allModels, 'group')
const onOk = () => {
setOpen(false)
@ -79,6 +80,7 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
})}
</div>
))}
{isEmpty(allModels) && <Empty image={Empty.PRESENTED_IMAGE_SIMPLE} description="No models" />}
</ListContainer>
</Modal>
)
@ -124,7 +126,7 @@ const ListItemName = styled.div`
margin-left: 6px;
`
export default class ModalListPopup {
export default class ModelListPopup {
static topviewId = 0
static hide() {
TopView.hide(this.topviewId)

View File

@ -1,18 +1,20 @@
import { Provider } from '@renderer/types'
import { FC, useEffect, useState } from 'react'
import styled from 'styled-components'
import { Avatar, Button, Card, Divider, Input } from 'antd'
import { Avatar, Button, Card, Divider, Flex, Input } from 'antd'
import { useProvider } from '@renderer/hooks/useProvider'
import ModalListPopup from '@renderer/components/Popups/ModalListPopup'
import { groupBy } from 'lodash'
import { SettingContainer, SettingSubtitle, SettingTitle } from './SettingComponent'
import { getModelLogo } from '@renderer/services/provider'
import { EditOutlined, PlusOutlined } from '@ant-design/icons'
import ModalAddPopup from './ModelAddPopup'
import ModelListPopup from './ModelListPopup'
interface Props {
provider: Provider
}
const ProviderModals: FC<Props> = ({ provider }) => {
const ProviderModels: FC<Props> = ({ provider }) => {
const [apiKey, setApiKey] = useState(provider.apiKey)
const [apiHost, setApiHost] = useState(provider.apiHost)
const { updateProvider, models } = useProvider(provider.id)
@ -32,8 +34,12 @@ const ProviderModals: FC<Props> = ({ provider }) => {
updateProvider({ ...provider, apiHost })
}
const onAddModal = () => {
ModalListPopup.show({ provider })
const onManageModel = () => {
ModelListPopup.show({ provider })
}
const onAddModel = () => {
ModalAddPopup.show({ title: 'Add Model', provider })
}
return (
@ -66,9 +72,14 @@ const ProviderModals: FC<Props> = ({ provider }) => {
))}
</Card>
))}
<Button type="primary" style={{ width: '100px', marginTop: '10px' }} onClick={onAddModal}>
Edit Models
<Flex gap={10} style={{ marginTop: '10px' }}>
<Button type="primary" onClick={onManageModel} icon={<EditOutlined />}>
Manage
</Button>
<Button type="default" onClick={onAddModel} icon={<PlusOutlined />}>
Add
</Button>
</Flex>
</SettingContainer>
)
}
@ -81,4 +92,4 @@ const ModelListItem = styled.div`
padding: 5px 0;
`
export default ProviderModals
export default ProviderModels

View File

@ -4,6 +4,7 @@ import DeepSeekProviderLogo from '@renderer/assets/images/providers/deepseek.png
import YiProviderLogo from '@renderer/assets/images/providers/yi.svg'
import GroqProviderLogo from '@renderer/assets/images/providers/groq.png'
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
import OllamaProviderLogo from '@renderer/assets/images/providers/ollama.png'
import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg'
import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg'
import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png'
@ -14,66 +15,42 @@ import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg'
import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg'
export function getProviderLogo(providerId: string) {
if (providerId === 'openai') {
switch (providerId) {
case 'openai':
return OpenAiProviderLogo
}
if (providerId === 'silicon') {
case 'silicon':
return SiliconFlowProviderLogo
}
if (providerId === 'deepseek') {
case 'deepseek':
return DeepSeekProviderLogo
}
if (providerId === 'yi') {
case 'yi':
return YiProviderLogo
}
if (providerId === 'groq') {
case 'groq':
return GroqProviderLogo
}
if (providerId === 'zhipu') {
case 'zhipu':
return ZhipuProviderLogo
}
case 'ollama':
return OllamaProviderLogo
default:
return ''
}
}
export function getModelLogo(modelId: string) {
const _modelId = modelId.toLowerCase()
if (_modelId.includes('gpt')) {
return ChatGPTModelLogo
const logoMap = {
gpt: ChatGPTModelLogo,
glm: ChatGLMModelLogo,
deepseek: DeepSeekModelLogo,
qwen: QwenModelLogo,
gemma: GemmaModelLogo,
'yi-': YiModelLogo,
llama: LlamaModelLogo,
mixtral: MixtralModelLogo
}
if (_modelId.includes('glm')) {
return ChatGLMModelLogo
for (const key in logoMap) {
if (modelId.toLowerCase().includes(key)) {
return logoMap[key]
}
if (_modelId.includes('deepseek')) {
return DeepSeekModelLogo
}
if (_modelId.includes('qwen')) {
return QwenModelLogo
}
if (_modelId.includes('gemma')) {
return GemmaModelLogo
}
if (_modelId.includes('yi-')) {
return YiModelLogo
}
if (_modelId.includes('llama')) {
return LlamaModelLogo
}
if (_modelId.includes('mixtral')) {
return MixtralModelLogo
}
return ''

View File

@ -19,7 +19,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 3,
version: 4,
blacklist: ['runtime'],
migrate
},

View File

@ -60,6 +60,14 @@ const initialState: LlmState = {
apiHost: 'https://api.groq.com/openai',
isSystem: true,
models: SYSTEM_MODELS.groq.filter((m) => m.defaultEnabled)
},
{
id: 'ollama',
name: 'Ollama',
apiKey: '',
apiHost: 'http://localhost:11434/v1/',
isSystem: true,
models: []
}
]
}

View File

@ -42,6 +42,26 @@ const migrate = createMigrate({
]
}
}
},
// @ts-ignore store type is unknown
'4': (state: RootState) => {
return {
...state,
llm: {
...state.llm,
providers: [
...state.llm.providers,
{
id: 'ollama',
name: 'Ollama',
apiKey: '',
apiHost: 'http://localhost:11434/v1/',
isSystem: true,
models: []
}
]
}
}
}
})

View File

@ -66,3 +66,18 @@ export const compressImage = async (file: File) => {
useWebWorker: false
})
}
// Converts 'gpt-3.5-turbo-16k-0613' to 'GPT-3.5-Turbo'
// Converts 'qwen2:1.5b' to 'QWEN2'
export const getDefaultGroupName = (id: string) => {
if (id.includes(':')) {
return id.split(':')[0].toUpperCase()
}
if (id.includes('-')) {
const parts = id.split('-')
return parts[0].toUpperCase() + '-' + parts[1].toUpperCase()
}
return id.toUpperCase()
}