diff --git a/package.json b/package.json index c3430d2d..5d444bb5 100644 --- a/package.json +++ b/package.json @@ -37,6 +37,7 @@ "@electron-toolkit/eslint-config-prettier": "^2.0.0", "@electron-toolkit/eslint-config-ts": "^1.0.1", "@electron-toolkit/tsconfig": "^1.0.1", + "@google/generative-ai": "^0.16.0", "@hello-pangea/dnd": "^16.6.0", "@kangfenmao/keyv-storage": "^0.1.0", "@reduxjs/toolkit": "^2.2.5", @@ -47,6 +48,7 @@ "@vitejs/plugin-react": "^4.2.1", "ahooks": "^3.8.0", "antd": "^5.18.3", + "axios": "^1.7.3", "browser-image-compression": "^2.0.2", "dayjs": "^1.11.11", "dotenv-cli": "^7.4.2", diff --git a/src/renderer/src/assets/images/models/embedding.png b/src/renderer/src/assets/images/models/embedding.png new file mode 100644 index 00000000..00e49036 Binary files /dev/null and b/src/renderer/src/assets/images/models/embedding.png differ diff --git a/src/renderer/src/assets/images/models/gemini.png b/src/renderer/src/assets/images/models/gemini.png new file mode 100644 index 00000000..05bce507 Binary files /dev/null and b/src/renderer/src/assets/images/models/gemini.png differ diff --git a/src/renderer/src/assets/images/models/palm.svg b/src/renderer/src/assets/images/models/palm.svg new file mode 100644 index 00000000..5c345fe1 --- /dev/null +++ b/src/renderer/src/assets/images/models/palm.svg @@ -0,0 +1,67 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/renderer/src/assets/images/providers/gemini.png b/src/renderer/src/assets/images/providers/gemini.png new file mode 100644 index 00000000..05bce507 Binary files /dev/null and b/src/renderer/src/assets/images/providers/gemini.png differ diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index ac46e15a..c8569bf8 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -33,6 +33,22 @@ export const SYSTEM_MODELS: Record = { enabled: true } ], + gemini: [ + { + id: 'gemini-1.5-flash', + provider: 'gemini', + name: 'Gemini 1.5 Flash', + group: 'Gemini 1.5', + enabled: true + }, + { + id: 'gemini-1.5-pro-exp-0801', + provider: 'gemini', + name: 'Gemini 1.5 Pro Experimental 0801', + group: 'Gemini 1.5', + enabled: true + } + ], silicon: [ { id: 'Qwen/Qwen2-7B-Instruct', diff --git a/src/renderer/src/config/provider.ts b/src/renderer/src/config/provider.ts index c01e276f..6d912179 100644 --- a/src/renderer/src/config/provider.ts +++ b/src/renderer/src/config/provider.ts @@ -3,10 +3,13 @@ import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg' import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg' import ClaudeModelLogo from '@renderer/assets/images/models/claude.png' import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png' +import EmbeddingModelLogo from '@renderer/assets/images/models/embedding.png' +import GeminiModelLogo from '@renderer/assets/images/models/gemini.png' import GemmaModelLogo from '@renderer/assets/images/models/gemma.jpeg' import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg' import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png' import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg' +import PalmModelLogo from '@renderer/assets/images/models/palm.svg' import QwenModelLogo from '@renderer/assets/images/models/qwen.png' import YiModelLogo from '@renderer/assets/images/models/yi.svg' import AiHubMixProviderLogo from '@renderer/assets/images/providers/aihubmix.jpg' @@ -14,6 +17,7 @@ import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.j import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png' import DashScopeProviderLogo from '@renderer/assets/images/providers/dashscope.png' import DeepSeekProviderLogo from '@renderer/assets/images/providers/deepseek.png' +import GeminiProviderLogo from '@renderer/assets/images/providers/gemini.png' import GroqProviderLogo from '@renderer/assets/images/providers/groq.png' import MoonshotProviderLogo from '@renderer/assets/images/providers/moonshot.jpeg' import MoonshotModelLogo from '@renderer/assets/images/providers/moonshot.jpeg' @@ -52,6 +56,8 @@ export function getProviderLogo(providerId: string) { return AnthropicProviderLogo case 'aihubmix': return AiHubMixProviderLogo + case 'gemini': + return GeminiProviderLogo default: return undefined } @@ -75,7 +81,11 @@ export function getModelLogo(modelId: string) { moonshot: MoonshotModelLogo, phi: MicrosoftModelLogo, baichuan: BaichuanModelLogo, - claude: ClaudeModelLogo + claude: ClaudeModelLogo, + gemini: GeminiModelLogo, + embedding: EmbeddingModelLogo, + bison: PalmModelLogo, + palm: PalmModelLogo } for (const key in logoMap) { @@ -242,5 +252,17 @@ export const PROVIDER_CONFIG = { docs: 'https://doc.aihubmix.com/', models: 'https://aihubmix.com/models' } + }, + gemini: { + api: { + url: 'https://generativelanguage.googleapis.com', + editable: false + }, + websites: { + official: 'https://gemini.google.com/', + apiKey: 'https://aistudio.google.com/app/apikey', + docs: 'https://ai.google.dev/gemini-api/docs', + models: 'https://ai.google.dev/gemini-api/docs/models/gemini' + } } } diff --git a/src/renderer/src/i18n/index.ts b/src/renderer/src/i18n/index.ts index 0d0a7641..fe22c8c0 100644 --- a/src/renderer/src/i18n/index.ts +++ b/src/renderer/src/i18n/index.ts @@ -106,6 +106,7 @@ const resources = { }, provider: { openai: 'OpenAI', + gemini: 'Gemini', deepseek: 'DeepSeek', moonshot: 'Moonshot', silicon: 'SiliconFlow', @@ -323,6 +324,7 @@ const resources = { }, provider: { openai: 'OpenAI', + gemini: 'Gemini', deepseek: '深度求索', moonshot: '月之暗面', silicon: '硅基流动', diff --git a/src/renderer/src/services/ProviderSDK.ts b/src/renderer/src/services/ProviderSDK.ts index fe17ae41..76df0640 100644 --- a/src/renderer/src/services/ProviderSDK.ts +++ b/src/renderer/src/services/ProviderSDK.ts @@ -1,10 +1,12 @@ import Anthropic from '@anthropic-ai/sdk' import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' +import { GoogleGenerativeAI } from '@google/generative-ai' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama' import { Assistant, Message, Provider, Suggestion } from '@renderer/types' import { removeQuotes } from '@renderer/utils' -import { sum, takeRight } from 'lodash' +import axios from 'axios' +import { isEmpty, sum, takeRight } from 'lodash' import OpenAI from 'openai' import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources' @@ -15,6 +17,7 @@ export default class ProviderSDK { provider: Provider openaiSdk: OpenAI anthropicSdk: Anthropic + geminiSdk: GoogleGenerativeAI constructor(provider: Provider) { this.provider = provider @@ -22,12 +25,17 @@ export default class ProviderSDK { const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/` this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL }) this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL }) + this.geminiSdk = new GoogleGenerativeAI(provider.apiKey) } private get isAnthropic() { return this.provider.id === 'anthropic' } + private get isGemini() { + return this.provider.id === 'gemini' + } + private get keepAliveTime() { return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined } @@ -42,7 +50,6 @@ export default class ProviderSDK { const { contextCount, maxTokens } = getAssistantSettings(assistant) const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined - const userMessages = takeRight(messages, contextCount + 1).map((message) => ({ role: message.role, content: message.content @@ -66,25 +73,64 @@ export default class ProviderSDK { } }) ) - } else { - // @ts-ignore key is not typed - const stream = await this.openaiSdk.chat.completions.create({ + return + } + + if (this.isGemini) { + const geminiModel = this.geminiSdk.getGenerativeModel({ model: model.id, - messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], - stream: true, - temperature: assistant?.settings?.temperature, - max_tokens: maxTokens, - keep_alive: this.keepAliveTime + systemInstruction: assistant.prompt, + generationConfig: { + maxOutputTokens: maxTokens, + temperature: assistant?.settings?.temperature + } }) - for await (const chunk of stream) { + + const userLastMessage = userMessages.pop() + + const chat = geminiModel.startChat({ + history: userMessages.map((message) => ({ + role: message.role === 'user' ? 'user' : 'model', + parts: [{ text: message.content }] + })) + }) + + const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!) + + for await (const chunk of userMessagesStream.stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) + onChunk({ + text: chunk.text(), + usage: { + prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, + completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, + total_tokens: chunk.usageMetadata?.totalTokenCount || 0 + } + }) } + + return + } + + // @ts-ignore key is not typed + const stream = await this.openaiSdk.chat.completions.create({ + model: model.id, + messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], + stream: true, + temperature: assistant?.settings?.temperature, + max_tokens: maxTokens, + keep_alive: this.keepAliveTime + }) + + for await (const chunk of stream) { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break + onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage }) } } public async translate(message: Message, assistant: Assistant) { const defaultModel = getDefaultModel() + const { maxTokens } = getAssistantSettings(assistant) const model = assistant.model || defaultModel const messages = [ { role: 'system', content: assistant.prompt }, @@ -99,17 +145,34 @@ export default class ProviderSDK { temperature: assistant?.settings?.temperature, stream: false }) + return response.content[0].type === 'text' ? response.content[0].text : '' - } else { - // @ts-ignore key is not typed - const response = await this.openaiSdk.chat.completions.create({ - model: model.id, - messages: messages as ChatCompletionMessageParam[], - stream: false, - keep_alive: this.keepAliveTime - }) - return response.choices[0].message?.content || '' } + + if (this.isGemini) { + const geminiModel = this.geminiSdk.getGenerativeModel({ + model: model.id, + systemInstruction: assistant.prompt, + generationConfig: { + maxOutputTokens: maxTokens, + temperature: assistant?.settings?.temperature + } + }) + + const { response } = await geminiModel.generateContent(message.content) + + return response.text() + } + + // @ts-ignore key is not typed + const response = await this.openaiSdk.chat.completions.create({ + model: model.id, + messages: messages as ChatCompletionMessageParam[], + stream: false, + keep_alive: this.keepAliveTime + }) + + return response.choices[0].message?.content || '' } public async summaries(messages: Message[], assistant: Assistant): Promise { @@ -134,18 +197,41 @@ export default class ProviderSDK { }) return message.content[0].type === 'text' ? message.content[0].text : null - } else { - // @ts-ignore key is not typed - const response = await this.openaiSdk.chat.completions.create({ + } + + if (this.isGemini) { + const geminiModel = this.geminiSdk.getGenerativeModel({ model: model.id, - messages: [systemMessage, ...userMessages] as ChatCompletionMessageParam[], - stream: false, - max_tokens: 50, - keep_alive: this.keepAliveTime + systemInstruction: systemMessage.content, + generationConfig: { + temperature: assistant?.settings?.temperature + } }) - return removeQuotes(response.choices[0].message?.content || '') + const lastUserMessage = userMessages.pop() + + const chat = await geminiModel.startChat({ + history: userMessages.map((message) => ({ + role: message.role === 'user' ? 'user' : 'model', + parts: [{ text: message.content }] + })) + }) + + const { response } = await chat.sendMessage(lastUserMessage?.content!) + + return response.text() } + + // @ts-ignore key is not typed + const response = await this.openaiSdk.chat.completions.create({ + model: model.id, + messages: [systemMessage, ...userMessages] as ChatCompletionMessageParam[], + stream: false, + max_tokens: 50, + keep_alive: this.keepAliveTime + }) + + return removeQuotes(response.choices[0].message?.content || '') } public async suggestions(messages: Message[], assistant: Assistant): Promise { @@ -172,6 +258,7 @@ export default class ProviderSDK { public async check(): Promise<{ valid: boolean; error: Error | null }> { const model = this.provider.models[0] + const body = { model: model.id, messages: [{ role: 'user', content: 'hi' }], @@ -182,13 +269,32 @@ export default class ProviderSDK { try { if (this.isAnthropic) { const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming) - return { valid: message.content.length > 0, error: null } - } else { - const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) - return { valid: Boolean(response?.choices[0].message), error: null } + return { + valid: message.content.length > 0, + error: null + } + } + + if (this.isGemini) { + const geminiModel = this.geminiSdk.getGenerativeModel({ model: body.model }) + const result = await geminiModel.generateContent(body.messages[0].content) + return { + valid: !isEmpty(result.response.text()), + error: null + } + } + + const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) + + return { + valid: Boolean(response?.choices[0].message), + error: null } } catch (error: any) { - return { valid: false, error } + return { + valid: false, + error + } } } @@ -198,6 +304,22 @@ export default class ProviderSDK { return [] } + if (this.isGemini) { + const api = this.provider.apiHost + '/v1beta/models' + const { data } = await axios.get(api, { params: { key: this.provider.apiKey } }) + return data.models.map( + (m: any) => + ({ + id: m.name.replace('models/', ''), + name: m.displayName, + description: m.description, + object: 'model', + created: Date.now(), + owned_by: 'gemini' + }) as OpenAI.Models.Model + ) + } + const response = await this.openaiSdk.models.list() return response.data } catch (error) { diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index 8084636a..a0766664 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -22,7 +22,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 20, + version: 21, blacklist: ['runtime'], migrate }, diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index b2a848b2..a8ad36cc 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -31,6 +31,15 @@ const initialState: LlmState = { isSystem: true, enabled: true }, + { + id: 'gemini', + name: 'Gemini', + apiKey: '', + apiHost: 'https://generativelanguage.googleapis.com', + models: SYSTEM_MODELS.gemini.filter((m) => m.enabled), + isSystem: true, + enabled: false + }, { id: 'silicon', name: 'Silicon', diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index d49c863f..b678ab6a 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -296,6 +296,26 @@ const migrateConfig = { fontSize: 14 } } + }, + '21': (state: RootState) => { + return { + ...state, + llm: { + ...state.llm, + providers: [ + ...state.llm.providers, + { + id: 'gemini', + name: 'Gemini', + apiKey: '', + apiHost: 'https://generativelanguage.googleapis.com', + models: SYSTEM_MODELS.gemini.filter((m) => m.enabled), + isSystem: true, + enabled: false + } + ] + } + } } } diff --git a/yarn.lock b/yarn.lock index 49b8f8c6..c1b5f01b 100644 --- a/yarn.lock +++ b/yarn.lock @@ -961,6 +961,13 @@ __metadata: languageName: node linkType: hard +"@google/generative-ai@npm:^0.16.0": + version: 0.16.0 + resolution: "@google/generative-ai@npm:0.16.0" + checksum: 10c0/5d561a41cb7be60fc9b49965b66359e15df907bf6679009de7917beff138ba69d4a0772ab2a9d6f0e543d658d72bd19b83e6abdb87a6cdfa402a8764b08eed4c + languageName: node + linkType: hard + "@hello-pangea/dnd@npm:^16.6.0": version: 16.6.0 resolution: "@hello-pangea/dnd@npm:16.6.0" @@ -3099,6 +3106,17 @@ __metadata: languageName: node linkType: hard +"axios@npm:^1.7.3": + version: 1.7.3 + resolution: "axios@npm:1.7.3" + dependencies: + follow-redirects: "npm:^1.15.6" + form-data: "npm:^4.0.0" + proxy-from-env: "npm:^1.1.0" + checksum: 10c0/a18cbe559203efa05fb1fec2d1898e23bf6329bd2575784ee32aa11b5bbe1d54b9f472c49a261294125519cf62aa4fe5ef6e647bb7482eafc15bffe15ab314ce + languageName: node + linkType: hard + "bail@npm:^2.0.0": version: 2.0.2 resolution: "bail@npm:2.0.2" @@ -3446,6 +3464,7 @@ __metadata: "@electron-toolkit/preload": "npm:^3.0.0" "@electron-toolkit/tsconfig": "npm:^1.0.1" "@electron-toolkit/utils": "npm:^3.0.0" + "@google/generative-ai": "npm:^0.16.0" "@hello-pangea/dnd": "npm:^16.6.0" "@kangfenmao/keyv-storage": "npm:^0.1.0" "@reduxjs/toolkit": "npm:^2.2.5" @@ -3457,6 +3476,7 @@ __metadata: "@vitejs/plugin-react": "npm:^4.2.1" ahooks: "npm:^3.8.0" antd: "npm:^5.18.3" + axios: "npm:^1.7.3" browser-image-compression: "npm:^2.0.2" dayjs: "npm:^1.11.11" dotenv-cli: "npm:^7.4.2" @@ -5037,6 +5057,16 @@ __metadata: languageName: node linkType: hard +"follow-redirects@npm:^1.15.6": + version: 1.15.6 + resolution: "follow-redirects@npm:1.15.6" + peerDependenciesMeta: + debug: + optional: true + checksum: 10c0/9ff767f0d7be6aa6870c82ac79cf0368cd73e01bbc00e9eb1c2a16fbb198ec105e3c9b6628bb98e9f3ac66fe29a957b9645bcb9a490bb7aa0d35f908b6b85071 + languageName: node + linkType: hard + "for-each@npm:^0.3.3": version: 0.3.3 resolution: "for-each@npm:0.3.3" @@ -8251,6 +8281,13 @@ __metadata: languageName: node linkType: hard +"proxy-from-env@npm:^1.1.0": + version: 1.1.0 + resolution: "proxy-from-env@npm:1.1.0" + checksum: 10c0/fe7dd8b1bdbbbea18d1459107729c3e4a2243ca870d26d34c2c1bcd3e4425b7bcc5112362df2d93cc7fb9746f6142b5e272fd1cc5c86ddf8580175186f6ad42b + languageName: node + linkType: hard + "pump@npm:^3.0.0": version: 3.0.0 resolution: "pump@npm:3.0.0"