From 53f46218d32e95a6827bf3581b7dcf3c1fc58df0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BA=A2=E5=A5=8B=E7=8C=AB?= Date: Tue, 4 Feb 2025 15:41:40 +0800 Subject: [PATCH] feat: add oauth for siliconflow (#976) * wip: silicon oauth * feat: Add custom protocol handler for SiliconFlow OAuth login * feat: Improve SiliconFlow OAuth flow with dynamic key update * feat: Enhance OAuth and Provider Settings UI * feat: Refactor SiliconFlow OAuth and update localization strings * chore: Update provider localization and system provider configuration * feat: Add OAuth support for AIHubMix provider --- electron.vite.config.ts | 2 +- src/main/index.ts | 19 ++++++++ src/main/services/WindowService.ts | 13 +++++ .../src/components/OAuth/OAuthButton.tsx | 40 ++++++++++++++++ src/renderer/src/config/constant.ts | 2 + src/renderer/src/i18n/index.ts | 10 +++- src/renderer/src/i18n/locales/en-us.json | 7 ++- src/renderer/src/i18n/locales/ja-jp.json | 6 +++ src/renderer/src/i18n/locales/ru-ru.json | 7 ++- src/renderer/src/i18n/locales/zh-cn.json | 8 +++- src/renderer/src/i18n/locales/zh-tw.json | 8 +++- .../ProviderSettings/ProviderSetting.tsx | 43 ++++++++++------- src/renderer/src/services/ProviderService.ts | 6 +++ src/renderer/src/store/llm.ts | 20 ++++---- src/renderer/src/utils/oauth.ts | 48 ++++++++++++++++--- 15 files changed, 200 insertions(+), 39 deletions(-) create mode 100644 src/renderer/src/components/OAuth/OAuthButton.tsx diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 4e0c3726..8c625e85 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -50,7 +50,7 @@ export default defineConfig({ } }, optimizeDeps: { - exclude: ['chunk-RK3FTE5R.js'] + exclude: ['chunk-PZ64DZKH.js'] } } }) diff --git a/src/main/index.ts b/src/main/index.ts index 466f1761..a3e221df 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -19,6 +19,25 @@ if (!app.requestSingleInstanceLock()) { app.whenReady().then(async () => { await updateUserDataPath() + // Register custom protocol + if (!app.isDefaultProtocolClient('cherrystudio')) { + app.setAsDefaultProtocolClient('cherrystudio') + } + + // Handle protocol open + app.on('open-url', (event, url) => { + event.preventDefault() + const parsedUrl = new URL(url) + if (parsedUrl.pathname === 'siliconflow.oauth.login') { + const code = parsedUrl.searchParams.get('code') + if (code) { + // Handle the OAuth code here + console.log('OAuth code received:', code) + // You can send this code to your renderer process via IPC if needed + } + } + }) + // Set app user model id for windows electronApp.setAppUserModelId(import.meta.env.VITE_MAIN_BUNDLE_ID || 'com.kangfenmao.CherryStudio') diff --git a/src/main/services/WindowService.ts b/src/main/services/WindowService.ts index 87d09069..e26501aa 100644 --- a/src/main/services/WindowService.ts +++ b/src/main/services/WindowService.ts @@ -163,6 +163,19 @@ export class WindowService { mainWindow.webContents.setWindowOpenHandler((details) => { const { url } = details + const oauthProviderUrls = ['https://account.siliconflow.cn'] + + if (oauthProviderUrls.some((url) => url.startsWith(url))) { + return { + action: 'allow', + overrideBrowserWindowOptions: { + webPreferences: { + partition: 'persist:webview' + } + } + } + } + if (url.includes('http://file/')) { const fileName = url.replace('http://file/', '') const storageDir = path.join(app.getPath('userData'), 'Data', 'Files') diff --git a/src/renderer/src/components/OAuth/OAuthButton.tsx b/src/renderer/src/components/OAuth/OAuthButton.tsx new file mode 100644 index 00000000..9b73385b --- /dev/null +++ b/src/renderer/src/components/OAuth/OAuthButton.tsx @@ -0,0 +1,40 @@ +import { useProvider } from '@renderer/hooks/useProvider' +import { Provider } from '@renderer/types' +import { oauthWithAihubmix, oauthWithSiliconFlow } from '@renderer/utils/oauth' +import { Button, ButtonProps } from 'antd' +import { FC } from 'react' +import { useTranslation } from 'react-i18next' + +interface Props extends ButtonProps { + provider: Provider +} + +const OAuthButton: FC = (props) => { + const { t } = useTranslation() + const { provider, updateProvider } = useProvider(props.provider.id) + + const onAuth = () => { + const onSuccess = (key: string) => { + if (key.trim()) { + updateProvider({ ...provider, apiKey: key }) + window.message.success(t('auth.get_key_success')) + } + } + + if (provider.id === 'silicon') { + oauthWithSiliconFlow(onSuccess) + } + + if (provider.id === 'aihubmix') { + oauthWithAihubmix(onSuccess) + } + } + + return ( + + ) +} + +export default OAuthButton diff --git a/src/renderer/src/config/constant.ts b/src/renderer/src/config/constant.ts index 35603b59..d73055d4 100644 --- a/src/renderer/src/config/constant.ts +++ b/src/renderer/src/config/constant.ts @@ -8,3 +8,5 @@ export const platform = window.electron?.process?.platform export const isMac = platform === 'darwin' export const isWindows = platform === 'win32' || platform === 'win64' export const isLinux = platform === 'linux' + +export const SILICON_CLIENT_ID = 'SFaJLLq0y6CAMoyDm81aMu' diff --git a/src/renderer/src/i18n/index.ts b/src/renderer/src/i18n/index.ts index 7f2eb32c..ac38e1dc 100644 --- a/src/renderer/src/i18n/index.ts +++ b/src/renderer/src/i18n/index.ts @@ -15,9 +15,17 @@ const resources = { 'ru-RU': ruRU } +export const getLanguage = () => { + return localStorage.getItem('language') || navigator.language || 'en-US' +} + +export const getLanguageCode = () => { + return getLanguage().split('-')[0] +} + i18n.use(initReactI18next).init({ resources, - lng: localStorage.getItem('language') || navigator.language || 'en-US', + lng: getLanguage(), fallbackLng: 'en-US', interpolation: { escapeValue: false diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index c110c304..7e75aa03 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -509,7 +509,6 @@ "provider.check": "Check", "provider.docs_check": "Check", "provider.docs_more_details": "for more details", - "provider.get_api_key": "Get API Key", "provider.search_placeholder": "Search model id or name", "proxy": { "mode": { @@ -688,6 +687,12 @@ "esc_back": "back", "copy_last_message": "Press C to copy" } + }, + "auth": { + "oauth_button": "Auth with {{provider}}", + "get_key": "Get", + "get_key_success": "API key automatically obtained successfully", + "login": "Login" } } } diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index 325d1f67..96af4735 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -672,6 +672,12 @@ "esc_back": "戻る", "copy_last_message": "C キーを押してコピー" } + }, + "auth": { + "oauth_button": "{{provider}}で認証", + "get_key": "取得", + "get_key_success": "APIキーの自動取得に成功しました", + "login": "認証" } } } diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 43740687..08a3e945 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -506,7 +506,6 @@ "provider.check": "Проверить", "provider.docs_check": "Проверить", "provider.docs_more_details": "для получения дополнительной информации", - "provider.get_api_key": "Получить ключ API", "provider.search_placeholder": "Поиск по ID или имени модели", "proxy": { "mode": { @@ -685,6 +684,12 @@ "esc_back": "возвращения", "copy_last_message": "Нажмите C для копирования" } + }, + "auth": { + "oauth_button": "Авторизоваться с {{provider}}", + "get_key": "Получить", + "get_key_success": "Автоматический получение ключа API успешно", + "login": "Войти" } } } diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index f716d8bf..e6d89306 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -489,7 +489,7 @@ "delete.title": "删除提供商", "docs_check": "查看", "docs_more_details": "获取更多详情", - "get_api_key": "点击这里获取密钥", + "get_api_key": "获取密钥", "no_models": "请先添加模型再检查 API 连接", "not_checked": "未检查", "remove_duplicate_keys": "移除重复密钥", @@ -674,6 +674,12 @@ "esc_back": "返回", "copy_last_message": "按 C 键复制" } + }, + "auth": { + "oauth_button": "使用{{provider}}登录", + "get_key": "获取", + "get_key_success": "自动获取密钥成功", + "login": "登录" } } } diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 1455f823..73ce63f9 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -488,7 +488,7 @@ "delete.title": "刪除提供者", "docs_check": "檢查", "docs_more_details": "查看更多細節", - "get_api_key": "獲取 API 密鑰", + "get_api_key": "獲取密鑰", "no_models": "請先添加模型再檢查 API 連接", "not_checked": "未檢查", "remove_duplicate_keys": "移除重複密鑰", @@ -673,6 +673,12 @@ "esc_back": "返回", "copy_last_message": "按 C 鍵複製" } + }, + "auth": { + "oauth_button": "使用{{provider}}登入", + "get_key": "獲取", + "get_key_success": "自動獲取密鑰成功", + "login": "登入" } } } diff --git a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx index 3615b9f1..12c09be4 100644 --- a/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx +++ b/src/renderer/src/pages/settings/ProviderSettings/ProviderSetting.tsx @@ -8,6 +8,7 @@ import { SettingOutlined } from '@ant-design/icons' import ModelTags from '@renderer/components/ModelTags' +import OAuthButton from '@renderer/components/OAuth/OAuthButton' import { EMBEDDING_REGEX, getModelLogo, VISION_REGEX } from '@renderer/config/models' import { PROVIDER_CONFIG } from '@renderer/config/providers' import { useTheme } from '@renderer/context/ThemeProvider' @@ -16,6 +17,7 @@ import { useProvider } from '@renderer/hooks/useProvider' import i18n from '@renderer/i18n' import { isOpenAIProvider } from '@renderer/providers/ProviderFactory' import { checkApi } from '@renderer/services/ApiService' +import { isProviderSupportAuth } from '@renderer/services/ProviderService' import { useAppDispatch } from '@renderer/store' import { setModel } from '@renderer/store/assistants' import { Model, ModelType, Provider } from '@renderer/types' @@ -61,17 +63,18 @@ const ProviderSetting: FC = ({ provider: _provider }) => { const { defaultModel, setDefaultModel } = useDefaultModel() const modelGroups = groupBy(models, 'group') + const isAzureOpenAI = provider.id === 'azure-openai' || provider.type === 'azure-openai' - useEffect(() => { - setApiKey(provider.apiKey) - setApiHost(provider.apiHost) - }, [provider]) + const providerConfig = PROVIDER_CONFIG[provider.id] + const officialWebsite = providerConfig?.websites?.official + const apiKeyWebsite = providerConfig?.websites?.apiKey + const docsWebsite = providerConfig?.websites?.docs + const modelsWebsite = providerConfig?.websites?.models + const configedApiHost = providerConfig?.api?.url const onUpdateApiKey = () => { - if (apiKey.trim()) { + if (apiKey !== provider.apiKey) { updateProvider({ ...provider, apiKey }) - } else { - setApiKey(provider.apiKey) } } @@ -138,13 +141,6 @@ const ProviderSetting: FC = ({ provider: _provider }) => { } } - const providerConfig = PROVIDER_CONFIG[provider.id] - const officialWebsite = providerConfig?.websites?.official - const apiKeyWebsite = providerConfig?.websites?.apiKey - const docsWebsite = providerConfig?.websites?.docs - const modelsWebsite = providerConfig?.websites?.models - const configedApiHost = providerConfig?.api?.url - const onReset = () => { setApiHost(configedApiHost) updateProvider({ ...provider, apiHost: configedApiHost }) @@ -201,16 +197,28 @@ const ProviderSetting: FC = ({ provider: _provider }) => { return value.replaceAll(',', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',') } - const isAzureOpenAI = provider.id === 'azure-openai' || provider.type === 'azure-openai' + useEffect(() => { + setApiKey(provider.apiKey) + setApiHost(provider.apiHost) + }, [provider.apiKey, provider.apiHost]) + + // Save apiKey to provider when unmount + useEffect(() => { + return () => { + if (apiKey.trim() && apiKey !== provider.apiKey) { + updateProvider({ ...provider, apiKey }) + } + } + }, [apiKey, provider, updateProvider]) return ( - + {provider.isSystem ? t(`provider.${provider.id}`) : provider.name} {officialWebsite! && ( - + )} @@ -232,6 +240,7 @@ const ProviderSetting: FC = ({ provider: _provider }) => { type="password" autoFocus={provider.enabled && apiKey === ''} /> + {isProviderSupportAuth(provider) && }