From 302d7511dcf2d7f6d33147e3b34f1152a860c8aa Mon Sep 17 00:00:00 2001 From: kangfenmao Date: Tue, 8 Oct 2024 20:14:50 +0800 Subject: [PATCH] feat: add azure openai provider --- src/renderer/src/config/models.ts | 14 +++++++++++++ src/renderer/src/config/providers.ts | 15 ++++++++++++- src/renderer/src/i18n/en-us.json | 4 +++- src/renderer/src/i18n/zh-cn.json | 4 +++- src/renderer/src/i18n/zh-tw.json | 4 +++- .../ProviderSettings/ProviderSetting.tsx | 15 +++++++++++++ src/renderer/src/providers/OpenAIProvider.ts | 12 ++++++++++- src/renderer/src/store/index.ts | 2 +- src/renderer/src/store/llm.ts | 10 +++++++++ src/renderer/src/store/migrate.ts | 21 +++++++++++++++++++ src/renderer/src/types/index.ts | 1 + 11 files changed, 96 insertions(+), 6 deletions(-) diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index de6e5ea1..27625c6d 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -343,6 +343,20 @@ export const SYSTEM_MODELS: Record = { group: 'o1' } ], + 'azure-openai': [ + { + id: 'gpt-4o', + provider: 'openai', + name: ' GPT-4o', + group: 'GPT 4o' + }, + { + id: 'gpt-4o-mini', + provider: 'openai', + name: ' GPT-4o-mini', + group: 'GPT 4o' + } + ], gemini: [ { id: 'gemini-1.5-flash', diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index ee5f9053..deb98e07 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -1,4 +1,5 @@ import ZhinaoProviderLogo from '@renderer/assets/images/models/360.png' +import AzureProviderLogo from '@renderer/assets/images/models/microsoft.png' import AiHubMixProviderLogo from '@renderer/assets/images/providers/aihubmix.jpg' import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.png' import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png' @@ -73,7 +74,8 @@ export function getProviderLogo(providerId: string) { return ZhinaoProviderLogo case 'nvidia': return NvidiaProviderLogo - + case 'azure-openai': + return AzureProviderLogo default: return undefined } @@ -336,5 +338,16 @@ export const PROVIDER_CONFIG = { docs: 'https://docs.api.nvidia.com/nim/reference/llm-apis', models: 'https://build.nvidia.com/nim' } + }, + 'azure-openai': { + api: { + url: '' + }, + websites: { + official: 'https://azure.microsoft.com/en-us/products/ai-services/openai-service', + apiKey: 'https://portal.azure.com/#view/Microsoft_Azure_ProjectOxford/CognitiveServicesHub/~/OpenAI', + docs: 'https://learn.microsoft.com/en-us/azure/ai-services/openai/', + models: 'https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models' + } } } diff --git a/src/renderer/src/i18n/en-us.json b/src/renderer/src/i18n/en-us.json index 1884c4f1..1370a72c 100644 --- a/src/renderer/src/i18n/en-us.json +++ b/src/renderer/src/i18n/en-us.json @@ -175,7 +175,8 @@ "minimax": "MiniMax", "graphrag-kylin-mountain": "GraphRAG", "github": "GitHub Models", - "ocoolai": "ocoolAI" + "ocoolai": "ocoolAI", + "azure-openai": "Azure OpenAI" }, "settings": { "title": "Settings", @@ -220,6 +221,7 @@ "provider.check": "Check", "provider.get_api_key": "Get API Key", "provider.api_host": "API Host", + "provider.api_version": "API Version", "provider.docs_check": "Check", "provider.docs_more_details": "for more details", "provider.search_placeholder": "Search model id or name", diff --git a/src/renderer/src/i18n/zh-cn.json b/src/renderer/src/i18n/zh-cn.json index b9c2860d..cb4ab68c 100644 --- a/src/renderer/src/i18n/zh-cn.json +++ b/src/renderer/src/i18n/zh-cn.json @@ -175,7 +175,8 @@ "minimax": "MiniMax", "graphrag-kylin-mountain": "GraphRAG", "github": "GitHub Models", - "ocoolai": "ocoolAI" + "ocoolai": "ocoolAI", + "azure-openai": "Azure OpenAI" }, "settings": { "title": "设置", @@ -220,6 +221,7 @@ "provider.check": "检查", "provider.get_api_key": "点击这里获取密钥", "provider.api_host": "API 地址", + "provider.api_version": "API 版本", "provider.docs_check": "查看", "provider.docs_more_details": "获取更多详情", "provider.search_placeholder": "搜索模型 ID 或名称", diff --git a/src/renderer/src/i18n/zh-tw.json b/src/renderer/src/i18n/zh-tw.json index 580dc895..a162d437 100644 --- a/src/renderer/src/i18n/zh-tw.json +++ b/src/renderer/src/i18n/zh-tw.json @@ -175,7 +175,8 @@ "minimax": "MiniMax", "graphrag-kylin-mountain": "GraphRAG", "github": "GitHub Models", - "ocoolai": "ocoolAI" + "ocoolai": "ocoolAI", + "azure-openai": "Azure OpenAI" }, "settings": { "title": "設定", @@ -220,6 +221,7 @@ "provider.check": "檢查", "provider.get_api_key": "獲取 API 密鑰", "provider.api_host": "API 主機地址", + "provider.api_version": "API 版本", "provider.docs_check": "檢查", "provider.docs_more_details": "查看更多細節", "provider.search_placeholder": "搜尋模型 ID 或名稱", diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index 14d7123f..9b7ad9e4 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -41,6 +41,7 @@ const ProviderSetting: FC = ({ provider: _provider }) => { const { provider } = useProvider(_provider.id) const [apiKey, setApiKey] = useState(provider.apiKey) const [apiHost, setApiHost] = useState(provider.apiHost) + const [apiVersion, setApiVersion] = useState(provider.apiVersion) const [apiValid, setApiValid] = useState(false) const [apiChecking, setApiChecking] = useState(false) const { updateProvider, models, removeModel } = useProvider(provider.id) @@ -56,6 +57,7 @@ const ProviderSetting: FC = ({ provider: _provider }) => { const onUpdateApiKey = () => updateProvider({ ...provider, apiKey }) const onUpdateApiHost = () => updateProvider({ ...provider, apiHost }) + const onUpdateApiVersion = () => updateProvider({ ...provider, apiVersion }) const onManageModel = () => EditModelsPopup.show({ provider }) const onAddModel = () => AddModelPopup.show({ title: t('settings.models.add.add_model'), provider }) @@ -136,6 +138,19 @@ const ProviderSetting: FC = ({ provider: _provider }) => { )} + {provider.id === 'azure-openai' && ( + <> + {t('settings.provider.api_version')} + + setApiVersion(e.target.value)} + onBlur={onUpdateApiVersion} + /> + + + )} {provider.id === 'ollama' && } {provider.id === 'graphrag-kylin-mountain' && provider.models.length > 0 && ( diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index a17d44c0..0f4ba21a 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -6,7 +6,7 @@ import { filterContextMessages } from '@renderer/services/messages' import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeQuotes } from '@renderer/utils' import { first, takeRight } from 'lodash' -import OpenAI from 'openai' +import OpenAI, { AzureOpenAI } from 'openai' import { ChatCompletionContentPart, ChatCompletionCreateParamsNonStreaming, @@ -20,6 +20,16 @@ export default class OpenAIProvider extends BaseProvider { constructor(provider: Provider) { super(provider) + if (provider.id === 'azure-openai') { + this.sdk = new AzureOpenAI({ + dangerouslyAllowBrowser: true, + apiKey: provider.apiKey, + apiVersion: provider.apiVersion, + endpoint: provider.apiHost + }) + return + } + this.sdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index d3d57111..d098980d 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -22,7 +22,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 29, + version: 30, blacklist: ['runtime'], migrate }, diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index 9a0a4994..d7012b01 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -59,6 +59,16 @@ const initialState: LlmState = { isSystem: true, enabled: false }, + { + id: 'azure-openai', + name: 'Azure OpenAI', + apiKey: '', + apiHost: '', + apiVersion: '', + models: SYSTEM_MODELS['azure-openai'], + isSystem: true, + enabled: false + }, { id: 'gemini', name: 'Gemini', diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 6609ed91..73f025b0 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -504,6 +504,27 @@ const migrateConfig = { }) } } + }, + '30': (state: RootState) => { + return { + ...state, + llm: { + ...state.llm, + providers: [ + ...state.llm.providers, + { + id: 'azure-openai', + name: 'Azure OpenAI', + apiKey: '', + apiHost: '', + apiVersion: '', + models: SYSTEM_MODELS['azure-openai'], + isSystem: true, + enabled: false + } + ] + } + } } } diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index d6b6c54f..9b50f134 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -55,6 +55,7 @@ export type Provider = { name: string apiKey: string apiHost: string + apiVersion?: string models: Model[] enabled?: boolean isSystem?: boolean