diff --git a/package.json b/package.json
index c3430d2d..5d444bb5 100644
--- a/package.json
+++ b/package.json
@@ -37,6 +37,7 @@
"@electron-toolkit/eslint-config-prettier": "^2.0.0",
"@electron-toolkit/eslint-config-ts": "^1.0.1",
"@electron-toolkit/tsconfig": "^1.0.1",
+ "@google/generative-ai": "^0.16.0",
"@hello-pangea/dnd": "^16.6.0",
"@kangfenmao/keyv-storage": "^0.1.0",
"@reduxjs/toolkit": "^2.2.5",
@@ -47,6 +48,7 @@
"@vitejs/plugin-react": "^4.2.1",
"ahooks": "^3.8.0",
"antd": "^5.18.3",
+ "axios": "^1.7.3",
"browser-image-compression": "^2.0.2",
"dayjs": "^1.11.11",
"dotenv-cli": "^7.4.2",
diff --git a/src/renderer/src/assets/images/models/embedding.png b/src/renderer/src/assets/images/models/embedding.png
new file mode 100644
index 00000000..00e49036
Binary files /dev/null and b/src/renderer/src/assets/images/models/embedding.png differ
diff --git a/src/renderer/src/assets/images/models/gemini.png b/src/renderer/src/assets/images/models/gemini.png
new file mode 100644
index 00000000..05bce507
Binary files /dev/null and b/src/renderer/src/assets/images/models/gemini.png differ
diff --git a/src/renderer/src/assets/images/models/palm.svg b/src/renderer/src/assets/images/models/palm.svg
new file mode 100644
index 00000000..5c345fe1
--- /dev/null
+++ b/src/renderer/src/assets/images/models/palm.svg
@@ -0,0 +1,67 @@
+
+
+
diff --git a/src/renderer/src/assets/images/providers/gemini.png b/src/renderer/src/assets/images/providers/gemini.png
new file mode 100644
index 00000000..05bce507
Binary files /dev/null and b/src/renderer/src/assets/images/providers/gemini.png differ
diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts
index ac46e15a..c8569bf8 100644
--- a/src/renderer/src/config/models.ts
+++ b/src/renderer/src/config/models.ts
@@ -33,6 +33,22 @@ export const SYSTEM_MODELS: Record = {
enabled: true
}
],
+ gemini: [
+ {
+ id: 'gemini-1.5-flash',
+ provider: 'gemini',
+ name: 'Gemini 1.5 Flash',
+ group: 'Gemini 1.5',
+ enabled: true
+ },
+ {
+ id: 'gemini-1.5-pro-exp-0801',
+ provider: 'gemini',
+ name: 'Gemini 1.5 Pro Experimental 0801',
+ group: 'Gemini 1.5',
+ enabled: true
+ }
+ ],
silicon: [
{
id: 'Qwen/Qwen2-7B-Instruct',
diff --git a/src/renderer/src/config/provider.ts b/src/renderer/src/config/provider.ts
index c01e276f..6d912179 100644
--- a/src/renderer/src/config/provider.ts
+++ b/src/renderer/src/config/provider.ts
@@ -3,10 +3,13 @@ import ChatGLMModelLogo from '@renderer/assets/images/models/chatglm.jpeg'
import ChatGPTModelLogo from '@renderer/assets/images/models/chatgpt.jpeg'
import ClaudeModelLogo from '@renderer/assets/images/models/claude.png'
import DeepSeekModelLogo from '@renderer/assets/images/models/deepseek.png'
+import EmbeddingModelLogo from '@renderer/assets/images/models/embedding.png'
+import GeminiModelLogo from '@renderer/assets/images/models/gemini.png'
import GemmaModelLogo from '@renderer/assets/images/models/gemma.jpeg'
import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg'
import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png'
import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg'
+import PalmModelLogo from '@renderer/assets/images/models/palm.svg'
import QwenModelLogo from '@renderer/assets/images/models/qwen.png'
import YiModelLogo from '@renderer/assets/images/models/yi.svg'
import AiHubMixProviderLogo from '@renderer/assets/images/providers/aihubmix.jpg'
@@ -14,6 +17,7 @@ import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.j
import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png'
import DashScopeProviderLogo from '@renderer/assets/images/providers/dashscope.png'
import DeepSeekProviderLogo from '@renderer/assets/images/providers/deepseek.png'
+import GeminiProviderLogo from '@renderer/assets/images/providers/gemini.png'
import GroqProviderLogo from '@renderer/assets/images/providers/groq.png'
import MoonshotProviderLogo from '@renderer/assets/images/providers/moonshot.jpeg'
import MoonshotModelLogo from '@renderer/assets/images/providers/moonshot.jpeg'
@@ -52,6 +56,8 @@ export function getProviderLogo(providerId: string) {
return AnthropicProviderLogo
case 'aihubmix':
return AiHubMixProviderLogo
+ case 'gemini':
+ return GeminiProviderLogo
default:
return undefined
}
@@ -75,7 +81,11 @@ export function getModelLogo(modelId: string) {
moonshot: MoonshotModelLogo,
phi: MicrosoftModelLogo,
baichuan: BaichuanModelLogo,
- claude: ClaudeModelLogo
+ claude: ClaudeModelLogo,
+ gemini: GeminiModelLogo,
+ embedding: EmbeddingModelLogo,
+ bison: PalmModelLogo,
+ palm: PalmModelLogo
}
for (const key in logoMap) {
@@ -242,5 +252,17 @@ export const PROVIDER_CONFIG = {
docs: 'https://doc.aihubmix.com/',
models: 'https://aihubmix.com/models'
}
+ },
+ gemini: {
+ api: {
+ url: 'https://generativelanguage.googleapis.com',
+ editable: false
+ },
+ websites: {
+ official: 'https://gemini.google.com/',
+ apiKey: 'https://aistudio.google.com/app/apikey',
+ docs: 'https://ai.google.dev/gemini-api/docs',
+ models: 'https://ai.google.dev/gemini-api/docs/models/gemini'
+ }
}
}
diff --git a/src/renderer/src/i18n/index.ts b/src/renderer/src/i18n/index.ts
index 0d0a7641..fe22c8c0 100644
--- a/src/renderer/src/i18n/index.ts
+++ b/src/renderer/src/i18n/index.ts
@@ -106,6 +106,7 @@ const resources = {
},
provider: {
openai: 'OpenAI',
+ gemini: 'Gemini',
deepseek: 'DeepSeek',
moonshot: 'Moonshot',
silicon: 'SiliconFlow',
@@ -323,6 +324,7 @@ const resources = {
},
provider: {
openai: 'OpenAI',
+ gemini: 'Gemini',
deepseek: '深度求索',
moonshot: '月之暗面',
silicon: '硅基流动',
diff --git a/src/renderer/src/services/ProviderSDK.ts b/src/renderer/src/services/ProviderSDK.ts
index fe17ae41..76df0640 100644
--- a/src/renderer/src/services/ProviderSDK.ts
+++ b/src/renderer/src/services/ProviderSDK.ts
@@ -1,10 +1,12 @@
import Anthropic from '@anthropic-ai/sdk'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
+import { GoogleGenerativeAI } from '@google/generative-ai'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { getOllamaKeepAliveTime } from '@renderer/hooks/useOllama'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { removeQuotes } from '@renderer/utils'
-import { sum, takeRight } from 'lodash'
+import axios from 'axios'
+import { isEmpty, sum, takeRight } from 'lodash'
import OpenAI from 'openai'
import { ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources'
@@ -15,6 +17,7 @@ export default class ProviderSDK {
provider: Provider
openaiSdk: OpenAI
anthropicSdk: Anthropic
+ geminiSdk: GoogleGenerativeAI
constructor(provider: Provider) {
this.provider = provider
@@ -22,12 +25,17 @@ export default class ProviderSDK {
const baseURL = host.endsWith('/') ? host : `${provider.apiHost}/v1/`
this.anthropicSdk = new Anthropic({ apiKey: provider.apiKey, baseURL })
this.openaiSdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: provider.apiKey, baseURL })
+ this.geminiSdk = new GoogleGenerativeAI(provider.apiKey)
}
private get isAnthropic() {
return this.provider.id === 'anthropic'
}
+ private get isGemini() {
+ return this.provider.id === 'gemini'
+ }
+
private get keepAliveTime() {
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
}
@@ -42,7 +50,6 @@ export default class ProviderSDK {
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
-
const userMessages = takeRight(messages, contextCount + 1).map((message) => ({
role: message.role,
content: message.content
@@ -66,25 +73,64 @@ export default class ProviderSDK {
}
})
)
- } else {
- // @ts-ignore key is not typed
- const stream = await this.openaiSdk.chat.completions.create({
+ return
+ }
+
+ if (this.isGemini) {
+ const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
- messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
- stream: true,
- temperature: assistant?.settings?.temperature,
- max_tokens: maxTokens,
- keep_alive: this.keepAliveTime
+ systemInstruction: assistant.prompt,
+ generationConfig: {
+ maxOutputTokens: maxTokens,
+ temperature: assistant?.settings?.temperature
+ }
})
- for await (const chunk of stream) {
+
+ const userLastMessage = userMessages.pop()
+
+ const chat = geminiModel.startChat({
+ history: userMessages.map((message) => ({
+ role: message.role === 'user' ? 'user' : 'model',
+ parts: [{ text: message.content }]
+ }))
+ })
+
+ const userMessagesStream = await chat.sendMessageStream(userLastMessage?.content!)
+
+ for await (const chunk of userMessagesStream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
- onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
+ onChunk({
+ text: chunk.text(),
+ usage: {
+ prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
+ completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
+ total_tokens: chunk.usageMetadata?.totalTokenCount || 0
+ }
+ })
}
+
+ return
+ }
+
+ // @ts-ignore key is not typed
+ const stream = await this.openaiSdk.chat.completions.create({
+ model: model.id,
+ messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
+ stream: true,
+ temperature: assistant?.settings?.temperature,
+ max_tokens: maxTokens,
+ keep_alive: this.keepAliveTime
+ })
+
+ for await (const chunk of stream) {
+ if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
+ onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
}
}
public async translate(message: Message, assistant: Assistant) {
const defaultModel = getDefaultModel()
+ const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel
const messages = [
{ role: 'system', content: assistant.prompt },
@@ -99,17 +145,34 @@ export default class ProviderSDK {
temperature: assistant?.settings?.temperature,
stream: false
})
+
return response.content[0].type === 'text' ? response.content[0].text : ''
- } else {
- // @ts-ignore key is not typed
- const response = await this.openaiSdk.chat.completions.create({
- model: model.id,
- messages: messages as ChatCompletionMessageParam[],
- stream: false,
- keep_alive: this.keepAliveTime
- })
- return response.choices[0].message?.content || ''
}
+
+ if (this.isGemini) {
+ const geminiModel = this.geminiSdk.getGenerativeModel({
+ model: model.id,
+ systemInstruction: assistant.prompt,
+ generationConfig: {
+ maxOutputTokens: maxTokens,
+ temperature: assistant?.settings?.temperature
+ }
+ })
+
+ const { response } = await geminiModel.generateContent(message.content)
+
+ return response.text()
+ }
+
+ // @ts-ignore key is not typed
+ const response = await this.openaiSdk.chat.completions.create({
+ model: model.id,
+ messages: messages as ChatCompletionMessageParam[],
+ stream: false,
+ keep_alive: this.keepAliveTime
+ })
+
+ return response.choices[0].message?.content || ''
}
public async summaries(messages: Message[], assistant: Assistant): Promise {
@@ -134,18 +197,41 @@ export default class ProviderSDK {
})
return message.content[0].type === 'text' ? message.content[0].text : null
- } else {
- // @ts-ignore key is not typed
- const response = await this.openaiSdk.chat.completions.create({
+ }
+
+ if (this.isGemini) {
+ const geminiModel = this.geminiSdk.getGenerativeModel({
model: model.id,
- messages: [systemMessage, ...userMessages] as ChatCompletionMessageParam[],
- stream: false,
- max_tokens: 50,
- keep_alive: this.keepAliveTime
+ systemInstruction: systemMessage.content,
+ generationConfig: {
+ temperature: assistant?.settings?.temperature
+ }
})
- return removeQuotes(response.choices[0].message?.content || '')
+ const lastUserMessage = userMessages.pop()
+
+ const chat = await geminiModel.startChat({
+ history: userMessages.map((message) => ({
+ role: message.role === 'user' ? 'user' : 'model',
+ parts: [{ text: message.content }]
+ }))
+ })
+
+ const { response } = await chat.sendMessage(lastUserMessage?.content!)
+
+ return response.text()
}
+
+ // @ts-ignore key is not typed
+ const response = await this.openaiSdk.chat.completions.create({
+ model: model.id,
+ messages: [systemMessage, ...userMessages] as ChatCompletionMessageParam[],
+ stream: false,
+ max_tokens: 50,
+ keep_alive: this.keepAliveTime
+ })
+
+ return removeQuotes(response.choices[0].message?.content || '')
}
public async suggestions(messages: Message[], assistant: Assistant): Promise {
@@ -172,6 +258,7 @@ export default class ProviderSDK {
public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0]
+
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
@@ -182,13 +269,32 @@ export default class ProviderSDK {
try {
if (this.isAnthropic) {
const message = await this.anthropicSdk.messages.create(body as MessageCreateParamsNonStreaming)
- return { valid: message.content.length > 0, error: null }
- } else {
- const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
- return { valid: Boolean(response?.choices[0].message), error: null }
+ return {
+ valid: message.content.length > 0,
+ error: null
+ }
+ }
+
+ if (this.isGemini) {
+ const geminiModel = this.geminiSdk.getGenerativeModel({ model: body.model })
+ const result = await geminiModel.generateContent(body.messages[0].content)
+ return {
+ valid: !isEmpty(result.response.text()),
+ error: null
+ }
+ }
+
+ const response = await this.openaiSdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming)
+
+ return {
+ valid: Boolean(response?.choices[0].message),
+ error: null
}
} catch (error: any) {
- return { valid: false, error }
+ return {
+ valid: false,
+ error
+ }
}
}
@@ -198,6 +304,22 @@ export default class ProviderSDK {
return []
}
+ if (this.isGemini) {
+ const api = this.provider.apiHost + '/v1beta/models'
+ const { data } = await axios.get(api, { params: { key: this.provider.apiKey } })
+ return data.models.map(
+ (m: any) =>
+ ({
+ id: m.name.replace('models/', ''),
+ name: m.displayName,
+ description: m.description,
+ object: 'model',
+ created: Date.now(),
+ owned_by: 'gemini'
+ }) as OpenAI.Models.Model
+ )
+ }
+
const response = await this.openaiSdk.models.list()
return response.data
} catch (error) {
diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts
index 8084636a..a0766664 100644
--- a/src/renderer/src/store/index.ts
+++ b/src/renderer/src/store/index.ts
@@ -22,7 +22,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
- version: 20,
+ version: 21,
blacklist: ['runtime'],
migrate
},
diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts
index b2a848b2..a8ad36cc 100644
--- a/src/renderer/src/store/llm.ts
+++ b/src/renderer/src/store/llm.ts
@@ -31,6 +31,15 @@ const initialState: LlmState = {
isSystem: true,
enabled: true
},
+ {
+ id: 'gemini',
+ name: 'Gemini',
+ apiKey: '',
+ apiHost: 'https://generativelanguage.googleapis.com',
+ models: SYSTEM_MODELS.gemini.filter((m) => m.enabled),
+ isSystem: true,
+ enabled: false
+ },
{
id: 'silicon',
name: 'Silicon',
diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts
index d49c863f..b678ab6a 100644
--- a/src/renderer/src/store/migrate.ts
+++ b/src/renderer/src/store/migrate.ts
@@ -296,6 +296,26 @@ const migrateConfig = {
fontSize: 14
}
}
+ },
+ '21': (state: RootState) => {
+ return {
+ ...state,
+ llm: {
+ ...state.llm,
+ providers: [
+ ...state.llm.providers,
+ {
+ id: 'gemini',
+ name: 'Gemini',
+ apiKey: '',
+ apiHost: 'https://generativelanguage.googleapis.com',
+ models: SYSTEM_MODELS.gemini.filter((m) => m.enabled),
+ isSystem: true,
+ enabled: false
+ }
+ ]
+ }
+ }
}
}
diff --git a/yarn.lock b/yarn.lock
index 49b8f8c6..c1b5f01b 100644
--- a/yarn.lock
+++ b/yarn.lock
@@ -961,6 +961,13 @@ __metadata:
languageName: node
linkType: hard
+"@google/generative-ai@npm:^0.16.0":
+ version: 0.16.0
+ resolution: "@google/generative-ai@npm:0.16.0"
+ checksum: 10c0/5d561a41cb7be60fc9b49965b66359e15df907bf6679009de7917beff138ba69d4a0772ab2a9d6f0e543d658d72bd19b83e6abdb87a6cdfa402a8764b08eed4c
+ languageName: node
+ linkType: hard
+
"@hello-pangea/dnd@npm:^16.6.0":
version: 16.6.0
resolution: "@hello-pangea/dnd@npm:16.6.0"
@@ -3099,6 +3106,17 @@ __metadata:
languageName: node
linkType: hard
+"axios@npm:^1.7.3":
+ version: 1.7.3
+ resolution: "axios@npm:1.7.3"
+ dependencies:
+ follow-redirects: "npm:^1.15.6"
+ form-data: "npm:^4.0.0"
+ proxy-from-env: "npm:^1.1.0"
+ checksum: 10c0/a18cbe559203efa05fb1fec2d1898e23bf6329bd2575784ee32aa11b5bbe1d54b9f472c49a261294125519cf62aa4fe5ef6e647bb7482eafc15bffe15ab314ce
+ languageName: node
+ linkType: hard
+
"bail@npm:^2.0.0":
version: 2.0.2
resolution: "bail@npm:2.0.2"
@@ -3446,6 +3464,7 @@ __metadata:
"@electron-toolkit/preload": "npm:^3.0.0"
"@electron-toolkit/tsconfig": "npm:^1.0.1"
"@electron-toolkit/utils": "npm:^3.0.0"
+ "@google/generative-ai": "npm:^0.16.0"
"@hello-pangea/dnd": "npm:^16.6.0"
"@kangfenmao/keyv-storage": "npm:^0.1.0"
"@reduxjs/toolkit": "npm:^2.2.5"
@@ -3457,6 +3476,7 @@ __metadata:
"@vitejs/plugin-react": "npm:^4.2.1"
ahooks: "npm:^3.8.0"
antd: "npm:^5.18.3"
+ axios: "npm:^1.7.3"
browser-image-compression: "npm:^2.0.2"
dayjs: "npm:^1.11.11"
dotenv-cli: "npm:^7.4.2"
@@ -5037,6 +5057,16 @@ __metadata:
languageName: node
linkType: hard
+"follow-redirects@npm:^1.15.6":
+ version: 1.15.6
+ resolution: "follow-redirects@npm:1.15.6"
+ peerDependenciesMeta:
+ debug:
+ optional: true
+ checksum: 10c0/9ff767f0d7be6aa6870c82ac79cf0368cd73e01bbc00e9eb1c2a16fbb198ec105e3c9b6628bb98e9f3ac66fe29a957b9645bcb9a490bb7aa0d35f908b6b85071
+ languageName: node
+ linkType: hard
+
"for-each@npm:^0.3.3":
version: 0.3.3
resolution: "for-each@npm:0.3.3"
@@ -8251,6 +8281,13 @@ __metadata:
languageName: node
linkType: hard
+"proxy-from-env@npm:^1.1.0":
+ version: 1.1.0
+ resolution: "proxy-from-env@npm:1.1.0"
+ checksum: 10c0/fe7dd8b1bdbbbea18d1459107729c3e4a2243ca870d26d34c2c1bcd3e4425b7bcc5112362df2d93cc7fb9746f6142b5e272fd1cc5c86ddf8580175186f6ad42b
+ languageName: node
+ linkType: hard
+
"pump@npm:^3.0.0":
version: 3.0.0
resolution: "pump@npm:3.0.0"