refactor(Gemini): migrate generative-ai sdk to genai sdk (#4939)

* refactor(GeminiService): migrate to new Google GenAI SDK and update file handling methods

- Updated import statements to use the new Google GenAI SDK.
- Refactored file upload, retrieval, and deletion methods to align with the new SDK's API.
- Adjusted type definitions and response handling for improved type safety and clarity.
- Enhanced file listing and processing logic to utilize async iteration for better performance.

* refactor(GeminiProvider): update message handling and integrate abort signal support

- Refactored message content handling to align with updated type definitions, ensuring consistent use of Content type.
- Enhanced abort signal management for chat requests, allowing for better control over ongoing operations.
- Improved message processing logic to streamline user message history handling and response generation.
- Adjusted type definitions for message contents to enhance type safety and clarity.

* refactor(electron.vite.config): replace direct import of Vite React plugin with dynamic import

* fix(Gemini): clean up unused methods and improve property access

* fix(typecheck): update color properties to use CSS variables

* feat: 修改画图逻辑

* fix: import viteReact

---------

Co-authored-by: eeee0717 <chentao020717Work@outlook.com>
This commit is contained in:
SuYao 2025-04-16 23:13:22 +08:00 committed by GitHub
parent 35c50b54a8
commit 24ddd69cd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 39025 additions and 1518 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import react from '@vitejs/plugin-react'
import viteReact from '@vitejs/plugin-react'
import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
import { resolve } from 'path'
import { visualizer } from 'rollup-plugin-visualizer'
@ -6,7 +6,7 @@ import { visualizer } from 'rollup-plugin-visualizer'
const visualizerPlugin = (type: 'renderer' | 'main') => {
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
}
// const viteReact = await import('@vitejs/plugin-react')
export default defineConfig({
main: {
plugins: [
@ -51,7 +51,7 @@ export default defineConfig({
},
renderer: {
plugins: [
react({
viteReact({
babel: {
plugins: [
[

View File

@ -64,7 +64,6 @@
"@cherrystudio/embedjs-openai": "^0.1.28",
"@electron-toolkit/utils": "^3.0.0",
"@electron/notarize": "^2.5.0",
"@google/generative-ai": "^0.24.0",
"@langchain/community": "^0.3.36",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
@ -74,6 +73,7 @@
"@xyflow/react": "^12.4.4",
"adm-zip": "^0.5.16",
"async-mutex": "^0.5.0",
"bufferutil": "^4.0.9",
"color": "^5.0.0",
"diff": "^7.0.0",
"docx": "^9.0.2",
@ -96,6 +96,7 @@
"turndown-plugin-gfm": "^1.0.2",
"undici": "^7.4.0",
"webdav": "^5.8.0",
"ws": "^8.18.1",
"zipread": "^1.3.3"
},
"devDependencies": {
@ -112,7 +113,7 @@
"@emotion/is-prop-valid": "^1.3.1",
"@eslint-react/eslint-plugin": "^1.36.1",
"@eslint/js": "^9.22.0",
"@google/genai": "^0.4.0",
"@google/genai": "patch:@google/genai@npm%3A0.8.0#~/.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch",
"@hello-pangea/dnd": "^16.6.0",
"@kangfenmao/keyv-storage": "^0.1.0",
"@modelcontextprotocol/sdk": "^1.9.0",
@ -133,7 +134,8 @@
"@types/react-dom": "^19.0.4",
"@types/react-infinite-scroll-component": "^5.0.0",
"@types/tinycolor2": "^1",
"@vitejs/plugin-react": "^4.2.1",
"@types/ws": "^8",
"@vitejs/plugin-react": "4.3.4",
"analytics": "^0.8.16",
"antd": "^5.22.5",
"applescript": "^1.0.0",

View File

@ -1,4 +1,4 @@
import { FileMetadataResponse, FileState, GoogleAIFileManager } from '@google/generative-ai/server'
import { File, FileState, GoogleGenAI, Pager } from '@google/genai'
import { FileType } from '@types'
import fs from 'fs'
@ -8,11 +8,15 @@ export class GeminiService {
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
private static readonly CACHE_DURATION = 3000
static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string) {
const fileManager = new GoogleAIFileManager(apiKey)
const uploadResult = await fileManager.uploadFile(file.path, {
static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File> {
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
const uploadResult = await sdk.files.upload({
file: file.path,
config: {
mimeType: 'application/pdf',
name: file.id,
displayName: file.origin_name
}
})
return uploadResult
}
@ -24,40 +28,42 @@ export class GeminiService {
}
}
static async retrieveFile(
_: Electron.IpcMainInvokeEvent,
file: FileType,
apiKey: string
): Promise<FileMetadataResponse | undefined> {
const fileManager = new GoogleAIFileManager(apiKey)
static async retrieveFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File | undefined> {
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY)
if (cachedResponse) {
return GeminiService.processResponse(cachedResponse, file)
}
const response = await fileManager.listFiles()
const response = await sdk.files.list()
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION)
return GeminiService.processResponse(response, file)
}
private static processResponse(response: any, file: FileType) {
if (response.files) {
return response.files
.filter((file) => file.state === FileState.ACTIVE)
.find((i) => i.displayName === file.origin_name && Number(i.sizeBytes) === file.size)
private static async processResponse(response: Pager<File>, file: FileType) {
for await (const f of response) {
if (f.state === FileState.ACTIVE) {
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
return f
}
}
}
return undefined
}
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string) {
const fileManager = new GoogleAIFileManager(apiKey)
return await fileManager.listFiles()
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string): Promise<File[]> {
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
const files: File[] = []
for await (const f of await sdk.files.list()) {
files.push(f)
}
return files
}
static async deleteFile(_: Electron.IpcMainInvokeEvent, apiKey: string, fileId: string) {
const fileManager = new GoogleAIFileManager(apiKey)
await fileManager.deleteFile(fileId)
static async deleteFile(_: Electron.IpcMainInvokeEvent, fileId: string, apiKey: string) {
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
await sdk.files.delete({ name: fileId })
}
}

View File

@ -1,6 +1,6 @@
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { ElectronAPI } from '@electron-toolkit/preload'
import type { FileMetadataResponse, ListFilesResponse, UploadFileResponse } from '@google/generative-ai/server'
import type { File } from '@google/genai'
import type { GetMCPPromptResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@renderer/types'
import { AppInfo, FileType, KnowledgeBaseParams, KnowledgeItem, LanguageVarious, WebDavConfig } from '@renderer/types'
import type { LoaderReturn } from '@shared/config/types'
@ -119,11 +119,11 @@ declare global {
resetMinimumSize: () => Promise<void>
}
gemini: {
uploadFile: (file: FileType, apiKey: string) => Promise<UploadFileResponse>
retrieveFile: (file: FileType, apiKey: string) => Promise<FileMetadataResponse | undefined>
uploadFile: (file: FileType, apiKey: string) => Promise<File>
retrieveFile: (file: FileType, apiKey: string) => Promise<File | undefined>
base64File: (file: FileType) => Promise<{ data: string; mimeType: string }>
listFiles: (apiKey: string) => Promise<ListFilesResponse>
deleteFile: (apiKey: string, fileId: string) => Promise<void>
listFiles: (apiKey: string) => Promise<File[]>
deleteFile: (fileId: string, apiKey: string) => Promise<void>
}
selectionMenu: {
action: (action: string) => Promise<void>

View File

@ -1,5 +1,5 @@
import { DeleteOutlined } from '@ant-design/icons'
import type { FileMetadataResponse } from '@google/generative-ai/server'
import type { File } from '@google/genai'
import { useProvider } from '@renderer/hooks/useProvider'
import { runAsyncFunction } from '@renderer/utils'
import { MB } from '@shared/config/constant'
@ -16,11 +16,11 @@ interface GeminiFilesProps {
const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
const { provider } = useProvider(id)
const [files, setFiles] = useState<FileMetadataResponse[]>([])
const [files, setFiles] = useState<File[]>([])
const [loading, setLoading] = useState(false)
const fetchFiles = useCallback(async () => {
const { files } = await window.api.gemini.listFiles(provider.apiKey)
const files = await window.api.gemini.listFiles(provider.apiKey)
files && setFiles(files.filter((file) => file.state === 'ACTIVE'))
}, [provider])
@ -60,14 +60,14 @@ const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
key={file.name}
fileInfo={{
name: file.displayName,
ext: `.${file.name.split('.').pop()}`,
extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes) / MB).toFixed(2)} MB`,
ext: `.${file.name?.split('.').pop()}`,
extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes || '0') / MB).toFixed(2)} MB`,
actions: (
<DeleteOutlined
style={{ cursor: 'pointer', color: 'var(--color-error)' }}
onClick={() => {
setFiles(files.filter((f) => f.name !== file.name))
window.api.gemini.deleteFile(provider.apiKey, file.name).catch((error) => {
window.api.gemini.deleteFile(file.name!, provider.apiKey).catch((error) => {
console.error('Failed to delete file:', error)
setFiles((prev) => [...prev, file])
})

View File

@ -173,13 +173,13 @@ const ServerName = styled.div`
const ServerDescription = styled.div`
font-size: 0.85rem;
color: ${(props) => props.theme.colors?.textSecondary || '#8c8c8c'};
color: var(--color-text-2);
margin-bottom: 3px;
`
const ServerUrl = styled.div`
font-size: 0.8rem;
color: ${(props) => props.theme.colors?.textTertiary || '#bfbfbf'};
color: var(--color-text-3);
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;

View File

@ -1,25 +1,18 @@
import {
ContentListUnion,
createPartFromBase64,
FinishReason,
GenerateContentResponse,
GoogleGenAI
} from '@google/genai'
import {
Content,
FileDataPart,
GenerateContentStreamResult,
GoogleGenerativeAI,
File,
GenerateContentConfig,
GenerateContentResponse,
GoogleGenAI,
HarmBlockThreshold,
HarmCategory,
InlineDataPart,
Modality,
Part,
RequestOptions,
PartUnion,
SafetySetting,
TextPart,
Tool
} from '@google/generative-ai'
import { isGemmaModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
ToolListUnion
} from '@google/genai'
import { isGemmaModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
@ -39,22 +32,15 @@ import axios from 'axios'
import { flatten, isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai'
import { ChunkCallbackData, CompletionsParams } from '.'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenerativeAI
private requestOptions: RequestOptions
private imageSdk: GoogleGenAI
private sdk: GoogleGenAI
constructor(provider: Provider) {
super(provider)
this.sdk = new GoogleGenerativeAI(this.apiKey)
/// this sdk is experimental
this.imageSdk = new GoogleGenAI({ apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
this.requestOptions = {
baseUrl: this.getBaseURL()
}
this.sdk = new GoogleGenAI({ vertexai: false, apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
}
public getBaseURL(): string {
@ -76,31 +62,31 @@ export default class GeminiProvider extends BaseProvider {
inlineData: {
data,
mimeType
} as Part['inlineData']
}
} as InlineDataPart
}
// Retrieve file from Gemini uploaded files
const fileMetadata = await window.api.gemini.retrieveFile(file, this.apiKey)
const fileMetadata: File | undefined = await window.api.gemini.retrieveFile(file, this.apiKey)
if (fileMetadata) {
return {
fileData: {
fileUri: fileMetadata.uri,
mimeType: fileMetadata.mimeType
} as Part['fileData']
}
} as FileDataPart
}
// If file is not found, upload it to Gemini
const uploadResult = await window.api.gemini.uploadFile(file, this.apiKey)
const result = await window.api.gemini.uploadFile(file, this.apiKey)
return {
fileData: {
fileUri: uploadResult.file.uri,
mimeType: uploadResult.file.mimeType
fileUri: result.uri,
mimeType: result.mimeType
} as Part['fileData']
}
} as FileDataPart
}
/**
@ -125,8 +111,8 @@ export default class GeminiProvider extends BaseProvider {
inlineData: {
data: base64Data,
mimeType: mimeType
}
} as InlineDataPart)
} as Part['inlineData']
})
}
}
}
@ -139,8 +125,8 @@ export default class GeminiProvider extends BaseProvider {
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
}
} as InlineDataPart)
} as Part['inlineData']
})
}
if (file.ext === '.pdf') {
@ -152,13 +138,13 @@ export default class GeminiProvider extends BaseProvider {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
text: file.origin_name + '\n' + fileContent
} as TextPart)
})
}
}
return {
role,
parts
parts: parts
}
}
@ -204,10 +190,13 @@ export default class GeminiProvider extends BaseProvider {
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
*/
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
if (assistant.enableGenerateImage) {
await this.generateImageExp({ messages, assistant, onFilterMessages, onChunk })
} else {
public async completions({
messages,
assistant,
mcpTools,
onChunk,
onFilterMessages
}: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -232,7 +221,7 @@ export default class GeminiProvider extends BaseProvider {
}
// const tools = mcpToolsToGeminiTools(mcpTools)
const tools: Tool[] = []
const tools: ToolListUnion = []
const toolResponses: MCPToolResponse[] = []
if (!WebSearchService.isOverwriteEnabled() && assistant.enableWebSearch && isWebSearchModel(model)) {
@ -242,55 +231,81 @@ export default class GeminiProvider extends BaseProvider {
})
}
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
...(isGemmaModel(model) ? {} : { systemInstruction: systemInstruction }),
const generateContentConfig: GenerateContentConfig = {
responseModalities: [Modality.TEXT, Modality.IMAGE],
responseMimeType: 'text/plain',
safetySettings: this.getSafetySettings(model.id),
tools: tools,
generationConfig: {
maxOutputTokens: maxTokens,
// generate image don't need system instruction
systemInstruction: isGemmaModel(model) || isGenerateImageModel(model) ? undefined : systemInstruction,
temperature: assistant?.settings?.temperature,
topP: assistant?.settings?.topP,
maxOutputTokens: maxTokens,
tools: tools,
...this.getCustomParameters(assistant)
}
},
this.requestOptions
)
const chat = geminiModel.startChat({ history })
const messageContents = await this.getMessageContents(userLastMessage!)
const messageContents: Content = await this.getMessageContents(userLastMessage!)
const chat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history
})
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage) {
const systemMessage = {
role: 'user',
parts: [
if (isFirstMessage && messageContents) {
const systemMessage = [
{
text:
'<start_of_turn>user\n' +
systemInstruction +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
messageContents.parts[0].text +
(messageContents?.parts?.[0] as Part).text +
'<end_of_turn>'
}
]
] as Part[]
if (messageContents && messageContents.parts) {
messageContents.parts[0] = systemMessage[0]
}
messageContents.parts = systemMessage.parts
}
}
const start_time_millsec = new Date().getTime()
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
const { signal } = abortController
const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
const signalProxy = {
_originalSignal: abortController.signal,
addEventListener: (eventName: string, listener: () => void) => {
if (eventName === 'abort') {
abortController.signal.addEventListener('abort', listener)
}
},
removeEventListener: (eventName: string, listener: () => void) => {
if (eventName === 'abort') {
abortController.signal.removeEventListener('abort', listener)
}
},
get aborted() {
return abortController.signal.aborted
}
}
if (!streamOutput) {
const { response } = await chat.sendMessage(messageContents.parts, { signal })
const response = await chat.sendMessage({
message: messageContents as PartUnion,
config: {
...generateContentConfig,
httpOptions: {
signal: signalProxy as any
}
}
})
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
text: response.candidates?.[0].content.parts[0].text,
text: response.text,
usage: {
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
@ -306,7 +321,15 @@ export default class GeminiProvider extends BaseProvider {
return
}
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
const userMessagesStream = await chat.sendMessageStream({
message: messageContents as PartUnion,
config: {
...generateContentConfig,
httpOptions: {
signal: signalProxy as any
}
}
})
let time_first_token_millsec = 0
const processToolUses = async (content: string, idx: number) => {
@ -321,17 +344,27 @@ export default class GeminiProvider extends BaseProvider {
)
if (toolResults && toolResults.length > 0) {
history.push(messageContents)
const newChat = geminiModel.startChat({ history })
const newStream = await newChat.sendMessageStream(flatten(toolResults.map((ts) => (ts as Content).parts)), {
signal
const newChat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history as Content[]
})
const newStream = await newChat.sendMessageStream({
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
config: {
...generateContentConfig,
httpOptions: {
signal: signalProxy as any
}
}
})
await processStream(newStream, idx + 1)
}
}
const processStream = async (stream: GenerateContentStreamResult, idx: number) => {
const processStream = async (stream: AsyncGenerator<GenerateContentResponse>, idx: number) => {
let content = ''
for await (const chunk of stream.stream) {
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
if (time_first_token_millsec == 0) {
@ -340,11 +373,14 @@ export default class GeminiProvider extends BaseProvider {
const time_completion_millsec = new Date().getTime() - start_time_millsec
content += chunk.text()
processToolUses(content, idx)
if (chunk.text !== undefined) {
content += chunk.text
}
await processToolUses(content, idx)
const generateImage = this.processGeminiImageResponse(chunk)
onChunk({
text: chunk.text(),
text: chunk.text !== undefined ? chunk.text : '',
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
@ -356,14 +392,14 @@ export default class GeminiProvider extends BaseProvider {
time_first_token_millsec
},
search: chunk.candidates?.[0]?.groundingMetadata,
mcpToolResponse: toolResponses
mcpToolResponse: toolResponses,
generateImage: generateImage
})
}
}
await processStream(userMessagesStream, 0).finally(cleanup)
}
}
/**
* Translate a message
@ -372,39 +408,51 @@ export default class GeminiProvider extends BaseProvider {
* @param onResponse - The onResponse callback
* @returns The translated message
*/
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
},
this.requestOptions
)
const content =
isGemmaModel(model) && assistant.prompt
? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${message.content}<end_of_turn>`
: message.content
if (!onResponse) {
const { response } = await geminiModel.generateContent(content)
return response.text()
const response = await this.sdk.models.generateContent({
model: model.id,
config: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature,
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
return response.text || ''
}
const response = await geminiModel.generateContentStream(content)
const response = await this.sdk.models.generateContentStream({
model: model.id,
config: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature,
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
let text = ''
for await (const chunk of response.stream) {
text += chunk.text()
for await (const chunk of response) {
text += chunk.text
onResponse(text)
}
@ -442,25 +490,24 @@ export default class GeminiProvider extends BaseProvider {
content: userMessageContent
}
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content }),
generationConfig: {
temperature: assistant?.settings?.temperature
}
},
this.requestOptions
)
const chat = await geminiModel.startChat()
const content = isGemmaModel(model)
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
: userMessage.content
const { response } = await chat.sendMessage(content)
const response = await this.sdk.models.generateContent({
model: model.id,
config: {
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
return removeSpecialCharactersForTopicName(response.text())
return removeSpecialCharactersForTopicName(response.text || '')
}
/**
@ -471,24 +518,23 @@ export default class GeminiProvider extends BaseProvider {
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const systemMessage = { role: 'system', content: prompt }
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content })
},
this.requestOptions
)
const chat = await geminiModel.startChat()
const messageContent = isGemmaModel(model)
const MessageContent = isGemmaModel(model)
? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
: content
const response = await this.sdk.models.generateContent({
model: model.id,
config: {
systemInstruction: isGemmaModel(model) ? undefined : prompt
},
contents: [
{
role: 'user',
parts: [{ text: MessageContent }]
}
]
})
const { response } = await chat.sendMessage(messageContent)
return response.text()
return response.text || ''
}
/**
@ -518,24 +564,28 @@ export default class GeminiProvider extends BaseProvider {
content: messages.map((m) => m.content).join('\n')
}
const geminiModel = this.sdk.getGenerativeModel(
{
const content = isGemmaModel(model)
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
: userMessage.content
const response = await this.sdk.models.generateContent({
model: model.id,
systemInstruction: systemMessage.content,
generationConfig: {
temperature: assistant?.settings?.temperature
}
},
{
...this.requestOptions,
config: {
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content,
temperature: assistant?.settings?.temperature,
httpOptions: {
timeout: 20 * 1000
}
)
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
const chat = await geminiModel.startChat()
const { response } = await chat.sendMessage(userMessage.content)
return response.text()
return response.text || ''
}
/**
@ -546,144 +596,13 @@ export default class GeminiProvider extends BaseProvider {
return []
}
/**
*
* @param messages -
* @param assistant -
* @param onChunk -
* @param onFilterMessages -
* @returns Promise<void>
*/
private async generateImageExp({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, streamOutput, maxTokens } = getAssistantSettings(assistant)
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
onFilterMessages(userMessages)
const userLastMessage = userMessages.pop()
if (!userLastMessage) {
throw new Error('No user message found')
}
const history: Content[] = []
for (const message of userMessages) {
history.push(await this.getMessageContents(message))
}
const userLastMessageContent = await this.getMessageContents(userLastMessage)
const allContents = [...history, userLastMessageContent]
let contents: ContentListUnion = allContents.length > 0 ? (allContents as ContentListUnion) : []
contents = await this.addImageFileToContents(userLastMessage, contents)
if (!streamOutput) {
const response = await this.callGeminiGenerateContent(model.id, contents, maxTokens)
const { isValid, message } = this.isValidGeminiResponse(response)
if (!isValid) {
throw new Error(`Gemini API error: ${message}`)
}
this.processGeminiImageResponse(response, onChunk)
return
}
const response = await this.callGeminiGenerateContentStream(model.id, contents, maxTokens)
for await (const chunk of response) {
this.processGeminiImageResponse(chunk, onChunk)
}
}
/**
*
* @param message -
* @param contents -
* @returns
*/
private async addImageFileToContents(message: Message, contents: ContentListUnion): Promise<ContentListUnion> {
if (message.files && message.files.length > 0) {
const file = message.files[0]
const fileContent = await window.api.file.base64Image(file.id + file.ext)
if (fileContent && fileContent.base64) {
const contentsArray = Array.isArray(contents) ? contents : [contents]
return [...contentsArray, createPartFromBase64(fileContent.base64, fileContent.mime)]
}
}
return contents
}
/**
* Gemini API生成内容
* @param modelId - ID
* @param contents -
* @returns
*/
private async callGeminiGenerateContent(
modelId: string,
contents: ContentListUnion,
maxTokens?: number
): Promise<GenerateContentResponse> {
try {
return await this.imageSdk.models.generateContent({
model: modelId,
contents: contents,
config: {
responseModalities: ['Text', 'Image'],
responseMimeType: 'text/plain',
maxOutputTokens: maxTokens
}
})
} catch (error) {
console.error('Gemini API error:', error)
throw error
}
}
private async callGeminiGenerateContentStream(
modelId: string,
contents: ContentListUnion,
maxTokens?: number
): Promise<AsyncGenerator<GenerateContentResponse>> {
try {
return await this.imageSdk.models.generateContentStream({
model: modelId,
contents: contents,
config: {
responseModalities: ['Text', 'Image'],
responseMimeType: 'text/plain',
maxOutputTokens: maxTokens
}
})
} catch (error) {
console.error('Gemini API error:', error)
throw error
}
}
/**
* Gemini响应是否有效
* @param response - Gemini响应
* @returns
*/
private isValidGeminiResponse(response: GenerateContentResponse): { isValid: boolean; message: string } {
return {
isValid: response?.candidates?.[0]?.finishReason === FinishReason.STOP ? true : false,
message: response?.candidates?.[0]?.finishReason || ''
}
}
/**
* Gemini图像响应
* @param response - Gemini响应
* @param onChunk -
*/
private processGeminiImageResponse(response: any, onChunk: (chunk: ChunkCallbackData) => void): void {
const parts = response.candidates[0].content.parts
private processGeminiImageResponse(chunk: GenerateContentResponse): { type: 'base64'; images: string[] } | undefined {
const parts = chunk.candidates?.[0]?.content?.parts
if (!parts) {
return
}
@ -695,31 +614,13 @@ export default class GeminiProvider extends BaseProvider {
return null
}
const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,`
return part.inlineData.data.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
return part.inlineData.data?.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
})
// 提取文本数据
const text = parts
.filter((part: Part) => part.text !== undefined)
.map((part: Part) => part.text)
.join('')
// 返回结果
onChunk({
text,
generateImage: {
return {
type: 'base64',
images
},
usage: {
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
total_tokens: response.usageMetadata?.totalTokenCount || 0
},
metrics: {
completion_tokens: response.usageMetadata?.candidatesTokenCount
images: images.filter((image) => image !== null)
}
})
}
/**
@ -732,18 +633,16 @@ export default class GeminiProvider extends BaseProvider {
return { valid: false, error: new Error('No model found') }
}
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try {
const geminiModel = this.sdk.getGenerativeModel({ model: body.model }, this.requestOptions)
const result = await geminiModel.generateContent(body.messages[0].content)
const result = await this.sdk.models.generateContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
config: {
maxOutputTokens: 100
}
})
return {
valid: !isEmpty(result.response.text()),
valid: !isEmpty(result.text),
error: null
}
} catch (error: any) {
@ -785,7 +684,10 @@ export default class GeminiProvider extends BaseProvider {
* @returns The embedding dimensions
*/
public async getEmbeddingDimensions(model: Model): Promise<number> {
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi')
return data.embedding.values.length
const data = await this.sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
}
}

View File

@ -1,4 +1,4 @@
import type { GroundingMetadata } from '@google/generative-ai'
import type { GroundingMetadata } from '@google/genai'
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
import type {

View File

@ -282,6 +282,7 @@ export async function fetchChatCompletion({
}
}
console.log('message', message)
// Emit chat completion event
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
onResponse(message)

View File

@ -1,4 +1,4 @@
import { GroundingMetadata } from '@google/generative-ai'
import { GroundingMetadata } from '@google/genai'
import OpenAI from 'openai'
import React from 'react'
import { BuiltinTheme } from 'shiki'

View File

@ -71,8 +71,8 @@ export function withGeminiGrounding(message: Message) {
let content = message.content
groundingSupports.forEach((support) => {
const text = support?.segment
const indices = support?.groundingChunckIndices
const text = support?.segment?.text
const indices = support?.groundingChunkIndices
if (!text || !indices) return

View File

@ -1,186 +1,165 @@
import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
import { MessageParam } from '@anthropic-ai/sdk/resources'
import {
ArraySchema,
BaseSchema,
BooleanSchema,
EnumStringSchema,
FunctionCall,
FunctionDeclaration,
FunctionDeclarationSchema,
FunctionDeclarationSchemaProperty,
IntegerSchema,
NumberSchema,
ObjectSchema,
SchemaType,
SimpleStringSchema,
Tool as geminiTool
} from '@google/generative-ai'
import { Content, Part } from '@google/generative-ai'
import { Content, FunctionCall, Part } from '@google/genai'
import store from '@renderer/store'
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
import {
ChatCompletionContentPart,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionTool
} from 'openai/resources'
import { ChatCompletionContentPart, ChatCompletionMessageParam, ChatCompletionMessageToolCall } from 'openai/resources'
import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider'
const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => {
// Filter out unsupported keys for Gemini
const filteredObj = filterUnsupportedKeys(obj)
// const ensureValidSchema = (obj: Record<string, any>) => {
// // Filter out unsupported keys for Gemini
// const filteredObj = filterUnsupportedKeys(obj)
// Handle base schema properties
const baseSchema = {
description: filteredObj.description,
nullable: filteredObj.nullable
} as BaseSchema
// // Handle base schema properties
// const baseSchema = {
// description: filteredObj.description,
// nullable: filteredObj.nullable
// } as BaseSchema
// Handle string type
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
return {
...baseSchema,
type: SchemaType.STRING,
format: 'enum',
enum: filteredObj.enum as string[]
} as EnumStringSchema
}
return {
...baseSchema,
type: SchemaType.STRING,
format: filteredObj.format === 'date-time' ? 'date-time' : undefined
} as SimpleStringSchema
}
// // Handle string type
// if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
// if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
// return {
// ...baseSchema,
// type: SchemaType.STRING,
// format: 'enum',
// enum: filteredObj.enum as string[]
// } as EnumStringSchema
// }
// return {
// ...baseSchema,
// type: SchemaType.STRING,
// format: filteredObj.format === 'date-time' ? 'date-time' : undefined
// } as SimpleStringSchema
// }
// Handle number type
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
return {
...baseSchema,
type: SchemaType.NUMBER,
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
} as NumberSchema
}
// // Handle number type
// if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
// return {
// ...baseSchema,
// type: SchemaType.NUMBER,
// format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
// } as NumberSchema
// }
// Handle integer type
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
return {
...baseSchema,
type: SchemaType.INTEGER,
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
} as IntegerSchema
}
// // Handle integer type
// if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
// return {
// ...baseSchema,
// type: SchemaType.INTEGER,
// format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
// } as IntegerSchema
// }
// Handle boolean type
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
return {
...baseSchema,
type: SchemaType.BOOLEAN
} as BooleanSchema
}
// // Handle boolean type
// if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
// return {
// ...baseSchema,
// type: SchemaType.BOOLEAN
// } as BooleanSchema
// }
// Handle array type
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
return {
...baseSchema,
type: SchemaType.ARRAY,
items: filteredObj.items
? ensureValidSchema(filteredObj.items as Record<string, any>)
: ({ type: SchemaType.STRING } as SimpleStringSchema),
minItems: filteredObj.minItems,
maxItems: filteredObj.maxItems
} as ArraySchema
}
// // Handle array type
// if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
// return {
// ...baseSchema,
// type: SchemaType.ARRAY,
// items: filteredObj.items
// ? ensureValidSchema(filteredObj.items as Record<string, any>)
// : ({ type: SchemaType.STRING } as SimpleStringSchema),
// minItems: filteredObj.minItems,
// maxItems: filteredObj.maxItems
// } as ArraySchema
// }
// Handle object type (default)
const properties = filteredObj.properties
? Object.fromEntries(
Object.entries(filteredObj.properties).map(([key, value]) => [
key,
ensureValidSchema(value as Record<string, any>)
])
)
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
// // Handle object type (default)
// const properties = filteredObj.properties
// ? Object.fromEntries(
// Object.entries(filteredObj.properties).map(([key, value]) => [
// key,
// ensureValidSchema(value as Record<string, any>)
// ])
// )
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
return {
...baseSchema,
type: SchemaType.OBJECT,
properties,
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
} as ObjectSchema
}
// return {
// ...baseSchema,
// type: SchemaType.OBJECT,
// properties,
// required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
// } as ObjectSchema
// }
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
const supportedBaseKeys = ['description', 'nullable']
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
const supportedBooleanKeys = [...supportedBaseKeys, 'type']
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
// function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
// const supportedBaseKeys = ['description', 'nullable']
// const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
// const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
// const supportedBooleanKeys = [...supportedBaseKeys, 'type']
// const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
// const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
const filtered: Record<string, any> = {}
// const filtered: Record<string, any> = {}
let keysToKeep: string[]
// let keysToKeep: string[]
if (obj.type?.toLowerCase() === SchemaType.STRING) {
keysToKeep = supportedStringKeys
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
keysToKeep = supportedBooleanKeys
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
keysToKeep = supportedArrayKeys
} else {
// Default to object type
keysToKeep = supportedObjectKeys
}
// if (obj.type?.toLowerCase() === SchemaType.STRING) {
// keysToKeep = supportedStringKeys
// } else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
// keysToKeep = supportedNumberKeys
// } else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
// keysToKeep = supportedNumberKeys
// } else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
// keysToKeep = supportedBooleanKeys
// } else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
// keysToKeep = supportedArrayKeys
// } else {
// // Default to object type
// keysToKeep = supportedObjectKeys
// }
// copy supported keys
for (const key of keysToKeep) {
if (obj[key] !== undefined) {
filtered[key] = obj[key]
}
}
// // copy supported keys
// for (const key of keysToKeep) {
// if (obj[key] !== undefined) {
// filtered[key] = obj[key]
// }
// }
return filtered
}
// return filtered
// }
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
const properties = tool.inputSchema.properties
if (!properties) {
return {}
}
// function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
// const properties = tool.inputSchema.properties
// if (!properties) {
// return {}
// }
// For OpenAI, we don't need to validate as strictly
if (!filterNestedObj) {
return properties
}
// // For OpenAI, we don't need to validate as strictly
// if (!filterNestedObj) {
// return properties
// }
const processedProperties = Object.fromEntries(
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
)
// const processedProperties = Object.fromEntries(
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
// )
return processedProperties
}
// return processedProperties
// }
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
return mcpTools.map((tool) => ({
type: 'function',
name: tool.name,
function: {
name: tool.id,
description: tool.description,
parameters: {
type: 'object',
properties: filterPropertieAttributes(tool)
}
}
}))
}
// export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
// return mcpTools.map((tool) => ({
// type: 'function',
// name: tool.name,
// function: {
// name: tool.id,
// description: tool.description,
// parameters: {
// type: 'object',
// properties: filterPropertieAttributes(tool)
// }
// }
// }))
// }
export function openAIToolsToMcpTool(
mcpTools: MCPTool[] | undefined,
@ -277,35 +256,35 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
return tool
}
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
if (!mcpTools || mcpTools.length === 0) {
// No tools available
return []
}
const functions: FunctionDeclaration[] = []
// export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
// if (!mcpTools || mcpTools.length === 0) {
// // No tools available
// return []
// }
// const functions: FunctionDeclaration[] = []
for (const tool of mcpTools) {
const properties = filterPropertieAttributes(tool, true)
const functionDeclaration: FunctionDeclaration = {
name: tool.id,
description: tool.description,
parameters: {
type: SchemaType.OBJECT,
properties:
Object.keys(properties).length > 0
? Object.fromEntries(
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
)
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
} as FunctionDeclarationSchema
}
functions.push(functionDeclaration)
}
const tool: geminiTool = {
functionDeclarations: functions
}
return [tool]
}
// for (const tool of mcpTools) {
// const properties = filterPropertieAttributes(tool, true)
// const functionDeclaration: FunctionDeclaration = {
// name: tool.id,
// description: tool.description,
// parameters: {
// type: SchemaType.OBJECT,
// properties:
// Object.keys(properties).length > 0
// ? Object.fromEntries(
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
// )
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
// } as FunctionDeclarationSchema
// }
// functions.push(functionDeclaration)
// }
// const tool: geminiTool = {
// functionDeclarations: functions
// }
// return [tool]
// }
export function geminiFunctionCallToMcpTool(
mcpTools: MCPTool[] | undefined,

1671
yarn.lock

File diff suppressed because it is too large Load Diff