refactor(Gemini): migrate generative-ai sdk to genai sdk (#4939)

* refactor(GeminiService): migrate to new Google GenAI SDK and update file handling methods

- Updated import statements to use the new Google GenAI SDK.
- Refactored file upload, retrieval, and deletion methods to align with the new SDK's API.
- Adjusted type definitions and response handling for improved type safety and clarity.
- Enhanced file listing and processing logic to utilize async iteration for better performance.

* refactor(GeminiProvider): update message handling and integrate abort signal support

- Refactored message content handling to align with updated type definitions, ensuring consistent use of Content type.
- Enhanced abort signal management for chat requests, allowing for better control over ongoing operations.
- Improved message processing logic to streamline user message history handling and response generation.
- Adjusted type definitions for message contents to enhance type safety and clarity.

* refactor(electron.vite.config): replace direct import of Vite React plugin with dynamic import

* fix(Gemini): clean up unused methods and improve property access

* fix(typecheck): update color properties to use CSS variables

* feat: 修改画图逻辑

* fix: import viteReact

---------

Co-authored-by: eeee0717 <chentao020717Work@outlook.com>
This commit is contained in:
SuYao 2025-04-16 23:13:22 +08:00 committed by GitHub
parent 35c50b54a8
commit 24ddd69cd5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 39025 additions and 1518 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import react from '@vitejs/plugin-react' import viteReact from '@vitejs/plugin-react'
import { defineConfig, externalizeDepsPlugin } from 'electron-vite' import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
import { resolve } from 'path' import { resolve } from 'path'
import { visualizer } from 'rollup-plugin-visualizer' import { visualizer } from 'rollup-plugin-visualizer'
@ -6,7 +6,7 @@ import { visualizer } from 'rollup-plugin-visualizer'
const visualizerPlugin = (type: 'renderer' | 'main') => { const visualizerPlugin = (type: 'renderer' | 'main') => {
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : [] return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
} }
// const viteReact = await import('@vitejs/plugin-react')
export default defineConfig({ export default defineConfig({
main: { main: {
plugins: [ plugins: [
@ -51,7 +51,7 @@ export default defineConfig({
}, },
renderer: { renderer: {
plugins: [ plugins: [
react({ viteReact({
babel: { babel: {
plugins: [ plugins: [
[ [

View File

@ -64,7 +64,6 @@
"@cherrystudio/embedjs-openai": "^0.1.28", "@cherrystudio/embedjs-openai": "^0.1.28",
"@electron-toolkit/utils": "^3.0.0", "@electron-toolkit/utils": "^3.0.0",
"@electron/notarize": "^2.5.0", "@electron/notarize": "^2.5.0",
"@google/generative-ai": "^0.24.0",
"@langchain/community": "^0.3.36", "@langchain/community": "^0.3.36",
"@mozilla/readability": "^0.6.0", "@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15", "@notionhq/client": "^2.2.15",
@ -74,6 +73,7 @@
"@xyflow/react": "^12.4.4", "@xyflow/react": "^12.4.4",
"adm-zip": "^0.5.16", "adm-zip": "^0.5.16",
"async-mutex": "^0.5.0", "async-mutex": "^0.5.0",
"bufferutil": "^4.0.9",
"color": "^5.0.0", "color": "^5.0.0",
"diff": "^7.0.0", "diff": "^7.0.0",
"docx": "^9.0.2", "docx": "^9.0.2",
@ -96,6 +96,7 @@
"turndown-plugin-gfm": "^1.0.2", "turndown-plugin-gfm": "^1.0.2",
"undici": "^7.4.0", "undici": "^7.4.0",
"webdav": "^5.8.0", "webdav": "^5.8.0",
"ws": "^8.18.1",
"zipread": "^1.3.3" "zipread": "^1.3.3"
}, },
"devDependencies": { "devDependencies": {
@ -112,7 +113,7 @@
"@emotion/is-prop-valid": "^1.3.1", "@emotion/is-prop-valid": "^1.3.1",
"@eslint-react/eslint-plugin": "^1.36.1", "@eslint-react/eslint-plugin": "^1.36.1",
"@eslint/js": "^9.22.0", "@eslint/js": "^9.22.0",
"@google/genai": "^0.4.0", "@google/genai": "patch:@google/genai@npm%3A0.8.0#~/.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch",
"@hello-pangea/dnd": "^16.6.0", "@hello-pangea/dnd": "^16.6.0",
"@kangfenmao/keyv-storage": "^0.1.0", "@kangfenmao/keyv-storage": "^0.1.0",
"@modelcontextprotocol/sdk": "^1.9.0", "@modelcontextprotocol/sdk": "^1.9.0",
@ -133,7 +134,8 @@
"@types/react-dom": "^19.0.4", "@types/react-dom": "^19.0.4",
"@types/react-infinite-scroll-component": "^5.0.0", "@types/react-infinite-scroll-component": "^5.0.0",
"@types/tinycolor2": "^1", "@types/tinycolor2": "^1",
"@vitejs/plugin-react": "^4.2.1", "@types/ws": "^8",
"@vitejs/plugin-react": "4.3.4",
"analytics": "^0.8.16", "analytics": "^0.8.16",
"antd": "^5.22.5", "antd": "^5.22.5",
"applescript": "^1.0.0", "applescript": "^1.0.0",

View File

@ -1,4 +1,4 @@
import { FileMetadataResponse, FileState, GoogleAIFileManager } from '@google/generative-ai/server' import { File, FileState, GoogleGenAI, Pager } from '@google/genai'
import { FileType } from '@types' import { FileType } from '@types'
import fs from 'fs' import fs from 'fs'
@ -8,11 +8,15 @@ export class GeminiService {
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list' private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
private static readonly CACHE_DURATION = 3000 private static readonly CACHE_DURATION = 3000
static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string) { static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File> {
const fileManager = new GoogleAIFileManager(apiKey) const sdk = new GoogleGenAI({ vertexai: false, apiKey })
const uploadResult = await fileManager.uploadFile(file.path, { const uploadResult = await sdk.files.upload({
file: file.path,
config: {
mimeType: 'application/pdf', mimeType: 'application/pdf',
name: file.id,
displayName: file.origin_name displayName: file.origin_name
}
}) })
return uploadResult return uploadResult
} }
@ -24,40 +28,42 @@ export class GeminiService {
} }
} }
static async retrieveFile( static async retrieveFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File | undefined> {
_: Electron.IpcMainInvokeEvent, const sdk = new GoogleGenAI({ vertexai: false, apiKey })
file: FileType,
apiKey: string
): Promise<FileMetadataResponse | undefined> {
const fileManager = new GoogleAIFileManager(apiKey)
const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY) const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY)
if (cachedResponse) { if (cachedResponse) {
return GeminiService.processResponse(cachedResponse, file) return GeminiService.processResponse(cachedResponse, file)
} }
const response = await fileManager.listFiles() const response = await sdk.files.list()
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION) CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION)
return GeminiService.processResponse(response, file) return GeminiService.processResponse(response, file)
} }
private static processResponse(response: any, file: FileType) { private static async processResponse(response: Pager<File>, file: FileType) {
if (response.files) { for await (const f of response) {
return response.files if (f.state === FileState.ACTIVE) {
.filter((file) => file.state === FileState.ACTIVE) if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
.find((i) => i.displayName === file.origin_name && Number(i.sizeBytes) === file.size) return f
} }
}
}
return undefined return undefined
} }
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string) { static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string): Promise<File[]> {
const fileManager = new GoogleAIFileManager(apiKey) const sdk = new GoogleGenAI({ vertexai: false, apiKey })
return await fileManager.listFiles() const files: File[] = []
for await (const f of await sdk.files.list()) {
files.push(f)
}
return files
} }
static async deleteFile(_: Electron.IpcMainInvokeEvent, apiKey: string, fileId: string) { static async deleteFile(_: Electron.IpcMainInvokeEvent, fileId: string, apiKey: string) {
const fileManager = new GoogleAIFileManager(apiKey) const sdk = new GoogleGenAI({ vertexai: false, apiKey })
await fileManager.deleteFile(fileId) await sdk.files.delete({ name: fileId })
} }
} }

View File

@ -1,6 +1,6 @@
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { ElectronAPI } from '@electron-toolkit/preload' import { ElectronAPI } from '@electron-toolkit/preload'
import type { FileMetadataResponse, ListFilesResponse, UploadFileResponse } from '@google/generative-ai/server' import type { File } from '@google/genai'
import type { GetMCPPromptResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@renderer/types' import type { GetMCPPromptResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@renderer/types'
import { AppInfo, FileType, KnowledgeBaseParams, KnowledgeItem, LanguageVarious, WebDavConfig } from '@renderer/types' import { AppInfo, FileType, KnowledgeBaseParams, KnowledgeItem, LanguageVarious, WebDavConfig } from '@renderer/types'
import type { LoaderReturn } from '@shared/config/types' import type { LoaderReturn } from '@shared/config/types'
@ -119,11 +119,11 @@ declare global {
resetMinimumSize: () => Promise<void> resetMinimumSize: () => Promise<void>
} }
gemini: { gemini: {
uploadFile: (file: FileType, apiKey: string) => Promise<UploadFileResponse> uploadFile: (file: FileType, apiKey: string) => Promise<File>
retrieveFile: (file: FileType, apiKey: string) => Promise<FileMetadataResponse | undefined> retrieveFile: (file: FileType, apiKey: string) => Promise<File | undefined>
base64File: (file: FileType) => Promise<{ data: string; mimeType: string }> base64File: (file: FileType) => Promise<{ data: string; mimeType: string }>
listFiles: (apiKey: string) => Promise<ListFilesResponse> listFiles: (apiKey: string) => Promise<File[]>
deleteFile: (apiKey: string, fileId: string) => Promise<void> deleteFile: (fileId: string, apiKey: string) => Promise<void>
} }
selectionMenu: { selectionMenu: {
action: (action: string) => Promise<void> action: (action: string) => Promise<void>

View File

@ -1,5 +1,5 @@
import { DeleteOutlined } from '@ant-design/icons' import { DeleteOutlined } from '@ant-design/icons'
import type { FileMetadataResponse } from '@google/generative-ai/server' import type { File } from '@google/genai'
import { useProvider } from '@renderer/hooks/useProvider' import { useProvider } from '@renderer/hooks/useProvider'
import { runAsyncFunction } from '@renderer/utils' import { runAsyncFunction } from '@renderer/utils'
import { MB } from '@shared/config/constant' import { MB } from '@shared/config/constant'
@ -16,11 +16,11 @@ interface GeminiFilesProps {
const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => { const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
const { provider } = useProvider(id) const { provider } = useProvider(id)
const [files, setFiles] = useState<FileMetadataResponse[]>([]) const [files, setFiles] = useState<File[]>([])
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const fetchFiles = useCallback(async () => { const fetchFiles = useCallback(async () => {
const { files } = await window.api.gemini.listFiles(provider.apiKey) const files = await window.api.gemini.listFiles(provider.apiKey)
files && setFiles(files.filter((file) => file.state === 'ACTIVE')) files && setFiles(files.filter((file) => file.state === 'ACTIVE'))
}, [provider]) }, [provider])
@ -60,14 +60,14 @@ const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
key={file.name} key={file.name}
fileInfo={{ fileInfo={{
name: file.displayName, name: file.displayName,
ext: `.${file.name.split('.').pop()}`, ext: `.${file.name?.split('.').pop()}`,
extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes) / MB).toFixed(2)} MB`, extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes || '0') / MB).toFixed(2)} MB`,
actions: ( actions: (
<DeleteOutlined <DeleteOutlined
style={{ cursor: 'pointer', color: 'var(--color-error)' }} style={{ cursor: 'pointer', color: 'var(--color-error)' }}
onClick={() => { onClick={() => {
setFiles(files.filter((f) => f.name !== file.name)) setFiles(files.filter((f) => f.name !== file.name))
window.api.gemini.deleteFile(provider.apiKey, file.name).catch((error) => { window.api.gemini.deleteFile(file.name!, provider.apiKey).catch((error) => {
console.error('Failed to delete file:', error) console.error('Failed to delete file:', error)
setFiles((prev) => [...prev, file]) setFiles((prev) => [...prev, file])
}) })

View File

@ -173,13 +173,13 @@ const ServerName = styled.div`
const ServerDescription = styled.div` const ServerDescription = styled.div`
font-size: 0.85rem; font-size: 0.85rem;
color: ${(props) => props.theme.colors?.textSecondary || '#8c8c8c'}; color: var(--color-text-2);
margin-bottom: 3px; margin-bottom: 3px;
` `
const ServerUrl = styled.div` const ServerUrl = styled.div`
font-size: 0.8rem; font-size: 0.8rem;
color: ${(props) => props.theme.colors?.textTertiary || '#bfbfbf'}; color: var(--color-text-3);
white-space: nowrap; white-space: nowrap;
overflow: hidden; overflow: hidden;
text-overflow: ellipsis; text-overflow: ellipsis;

View File

@ -1,25 +1,18 @@
import {
ContentListUnion,
createPartFromBase64,
FinishReason,
GenerateContentResponse,
GoogleGenAI
} from '@google/genai'
import { import {
Content, Content,
FileDataPart, File,
GenerateContentStreamResult, GenerateContentConfig,
GoogleGenerativeAI, GenerateContentResponse,
GoogleGenAI,
HarmBlockThreshold, HarmBlockThreshold,
HarmCategory, HarmCategory,
InlineDataPart, Modality,
Part, Part,
RequestOptions, PartUnion,
SafetySetting, SafetySetting,
TextPart, ToolListUnion
Tool } from '@google/genai'
} from '@google/generative-ai' import { isGemmaModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
import { isGemmaModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings' import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n' import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
@ -39,22 +32,15 @@ import axios from 'axios'
import { flatten, isEmpty, takeRight } from 'lodash' import { flatten, isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai' import OpenAI from 'openai'
import { ChunkCallbackData, CompletionsParams } from '.' import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider' import BaseProvider from './BaseProvider'
export default class GeminiProvider extends BaseProvider { export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenerativeAI private sdk: GoogleGenAI
private requestOptions: RequestOptions
private imageSdk: GoogleGenAI
constructor(provider: Provider) { constructor(provider: Provider) {
super(provider) super(provider)
this.sdk = new GoogleGenerativeAI(this.apiKey) this.sdk = new GoogleGenAI({ vertexai: false, apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
/// this sdk is experimental
this.imageSdk = new GoogleGenAI({ apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
this.requestOptions = {
baseUrl: this.getBaseURL()
}
} }
public getBaseURL(): string { public getBaseURL(): string {
@ -76,31 +62,31 @@ export default class GeminiProvider extends BaseProvider {
inlineData: { inlineData: {
data, data,
mimeType mimeType
} as Part['inlineData']
} }
} as InlineDataPart
} }
// Retrieve file from Gemini uploaded files // Retrieve file from Gemini uploaded files
const fileMetadata = await window.api.gemini.retrieveFile(file, this.apiKey) const fileMetadata: File | undefined = await window.api.gemini.retrieveFile(file, this.apiKey)
if (fileMetadata) { if (fileMetadata) {
return { return {
fileData: { fileData: {
fileUri: fileMetadata.uri, fileUri: fileMetadata.uri,
mimeType: fileMetadata.mimeType mimeType: fileMetadata.mimeType
} as Part['fileData']
} }
} as FileDataPart
} }
// If file is not found, upload it to Gemini // If file is not found, upload it to Gemini
const uploadResult = await window.api.gemini.uploadFile(file, this.apiKey) const result = await window.api.gemini.uploadFile(file, this.apiKey)
return { return {
fileData: { fileData: {
fileUri: uploadResult.file.uri, fileUri: result.uri,
mimeType: uploadResult.file.mimeType mimeType: result.mimeType
} as Part['fileData']
} }
} as FileDataPart
} }
/** /**
@ -125,8 +111,8 @@ export default class GeminiProvider extends BaseProvider {
inlineData: { inlineData: {
data: base64Data, data: base64Data,
mimeType: mimeType mimeType: mimeType
} } as Part['inlineData']
} as InlineDataPart) })
} }
} }
} }
@ -139,8 +125,8 @@ export default class GeminiProvider extends BaseProvider {
inlineData: { inlineData: {
data: base64Data.base64, data: base64Data.base64,
mimeType: base64Data.mime mimeType: base64Data.mime
} } as Part['inlineData']
} as InlineDataPart) })
} }
if (file.ext === '.pdf') { if (file.ext === '.pdf') {
@ -152,13 +138,13 @@ export default class GeminiProvider extends BaseProvider {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({ parts.push({
text: file.origin_name + '\n' + fileContent text: file.origin_name + '\n' + fileContent
} as TextPart) })
} }
} }
return { return {
role, role,
parts parts: parts
} }
} }
@ -204,10 +190,13 @@ export default class GeminiProvider extends BaseProvider {
* @param onChunk - The onChunk callback * @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback * @param onFilterMessages - The onFilterMessages callback
*/ */
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) { public async completions({
if (assistant.enableGenerateImage) { messages,
await this.generateImageExp({ messages, assistant, onFilterMessages, onChunk }) assistant,
} else { mcpTools,
onChunk,
onFilterMessages
}: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -232,7 +221,7 @@ export default class GeminiProvider extends BaseProvider {
} }
// const tools = mcpToolsToGeminiTools(mcpTools) // const tools = mcpToolsToGeminiTools(mcpTools)
const tools: Tool[] = [] const tools: ToolListUnion = []
const toolResponses: MCPToolResponse[] = [] const toolResponses: MCPToolResponse[] = []
if (!WebSearchService.isOverwriteEnabled() && assistant.enableWebSearch && isWebSearchModel(model)) { if (!WebSearchService.isOverwriteEnabled() && assistant.enableWebSearch && isWebSearchModel(model)) {
@ -242,55 +231,81 @@ export default class GeminiProvider extends BaseProvider {
}) })
} }
const geminiModel = this.sdk.getGenerativeModel( const generateContentConfig: GenerateContentConfig = {
{ responseModalities: [Modality.TEXT, Modality.IMAGE],
model: model.id, responseMimeType: 'text/plain',
...(isGemmaModel(model) ? {} : { systemInstruction: systemInstruction }),
safetySettings: this.getSafetySettings(model.id), safetySettings: this.getSafetySettings(model.id),
tools: tools, // generate image don't need system instruction
generationConfig: { systemInstruction: isGemmaModel(model) || isGenerateImageModel(model) ? undefined : systemInstruction,
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature, temperature: assistant?.settings?.temperature,
topP: assistant?.settings?.topP, topP: assistant?.settings?.topP,
maxOutputTokens: maxTokens,
tools: tools,
...this.getCustomParameters(assistant) ...this.getCustomParameters(assistant)
} }
},
this.requestOptions
)
const chat = geminiModel.startChat({ history }) const messageContents: Content = await this.getMessageContents(userLastMessage!)
const messageContents = await this.getMessageContents(userLastMessage!)
const chat = this.sdk.chats.create({
model: model.id,
config: generateContentConfig,
history: history
})
if (isGemmaModel(model) && assistant.prompt) { if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0 const isFirstMessage = history.length === 0
if (isFirstMessage) { if (isFirstMessage && messageContents) {
const systemMessage = { const systemMessage = [
role: 'user',
parts: [
{ {
text: text:
'<start_of_turn>user\n' + '<start_of_turn>user\n' +
systemInstruction + systemInstruction +
'<end_of_turn>\n' + '<end_of_turn>\n' +
'<start_of_turn>user\n' + '<start_of_turn>user\n' +
messageContents.parts[0].text + (messageContents?.parts?.[0] as Part).text +
'<end_of_turn>' '<end_of_turn>'
} }
] ] as Part[]
if (messageContents && messageContents.parts) {
messageContents.parts[0] = systemMessage[0]
} }
messageContents.parts = systemMessage.parts
} }
} }
const start_time_millsec = new Date().getTime() const start_time_millsec = new Date().getTime()
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
const { signal } = abortController const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
const signalProxy = {
_originalSignal: abortController.signal,
addEventListener: (eventName: string, listener: () => void) => {
if (eventName === 'abort') {
abortController.signal.addEventListener('abort', listener)
}
},
removeEventListener: (eventName: string, listener: () => void) => {
if (eventName === 'abort') {
abortController.signal.removeEventListener('abort', listener)
}
},
get aborted() {
return abortController.signal.aborted
}
}
if (!streamOutput) { if (!streamOutput) {
const { response } = await chat.sendMessage(messageContents.parts, { signal }) const response = await chat.sendMessage({
message: messageContents as PartUnion,
config: {
...generateContentConfig,
httpOptions: {
signal: signalProxy as any
}
}
})
const time_completion_millsec = new Date().getTime() - start_time_millsec const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({ onChunk({
text: response.candidates?.[0].content.parts[0].text, text: response.text,
usage: { usage: {
prompt_tokens: response.usageMetadata?.promptTokenCount || 0, prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0, completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
@ -306,7 +321,15 @@ export default class GeminiProvider extends BaseProvider {
return return
} }
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) const userMessagesStream = await chat.sendMessageStream({
message: messageContents as PartUnion,
config: {
...generateContentConfig,
httpOptions: {
signal: signalProxy as any
}
}
})
let time_first_token_millsec = 0 let time_first_token_millsec = 0
const processToolUses = async (content: string, idx: number) => { const processToolUses = async (content: string, idx: number) => {
@ -321,17 +344,27 @@ export default class GeminiProvider extends BaseProvider {
) )
if (toolResults && toolResults.length > 0) { if (toolResults && toolResults.length > 0) {
history.push(messageContents) history.push(messageContents)
const newChat = geminiModel.startChat({ history }) const newChat = this.sdk.chats.create({
const newStream = await newChat.sendMessageStream(flatten(toolResults.map((ts) => (ts as Content).parts)), { model: model.id,
signal config: generateContentConfig,
history: history as Content[]
})
const newStream = await newChat.sendMessageStream({
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
config: {
...generateContentConfig,
httpOptions: {
signal: signalProxy as any
}
}
}) })
await processStream(newStream, idx + 1) await processStream(newStream, idx + 1)
} }
} }
const processStream = async (stream: GenerateContentStreamResult, idx: number) => { const processStream = async (stream: AsyncGenerator<GenerateContentResponse>, idx: number) => {
let content = '' let content = ''
for await (const chunk of stream.stream) { for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
if (time_first_token_millsec == 0) { if (time_first_token_millsec == 0) {
@ -340,11 +373,14 @@ export default class GeminiProvider extends BaseProvider {
const time_completion_millsec = new Date().getTime() - start_time_millsec const time_completion_millsec = new Date().getTime() - start_time_millsec
content += chunk.text() if (chunk.text !== undefined) {
processToolUses(content, idx) content += chunk.text
}
await processToolUses(content, idx)
const generateImage = this.processGeminiImageResponse(chunk)
onChunk({ onChunk({
text: chunk.text(), text: chunk.text !== undefined ? chunk.text : '',
usage: { usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
@ -356,14 +392,14 @@ export default class GeminiProvider extends BaseProvider {
time_first_token_millsec time_first_token_millsec
}, },
search: chunk.candidates?.[0]?.groundingMetadata, search: chunk.candidates?.[0]?.groundingMetadata,
mcpToolResponse: toolResponses mcpToolResponse: toolResponses,
generateImage: generateImage
}) })
} }
} }
await processStream(userMessagesStream, 0).finally(cleanup) await processStream(userMessagesStream, 0).finally(cleanup)
} }
}
/** /**
* Translate a message * Translate a message
@ -372,39 +408,51 @@ export default class GeminiProvider extends BaseProvider {
* @param onResponse - The onResponse callback * @param onResponse - The onResponse callback
* @returns The translated message * @returns The translated message
*/ */
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant) const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
},
this.requestOptions
)
const content = const content =
isGemmaModel(model) && assistant.prompt isGemmaModel(model) && assistant.prompt
? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${message.content}<end_of_turn>` ? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${message.content}<end_of_turn>`
: message.content : message.content
if (!onResponse) { if (!onResponse) {
const { response } = await geminiModel.generateContent(content) const response = await this.sdk.models.generateContent({
return response.text() model: model.id,
config: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature,
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
return response.text || ''
} }
const response = await geminiModel.generateContentStream(content) const response = await this.sdk.models.generateContentStream({
model: model.id,
config: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature,
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
let text = '' let text = ''
for await (const chunk of response.stream) { for await (const chunk of response) {
text += chunk.text() text += chunk.text
onResponse(text) onResponse(text)
} }
@ -442,25 +490,24 @@ export default class GeminiProvider extends BaseProvider {
content: userMessageContent content: userMessageContent
} }
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content }),
generationConfig: {
temperature: assistant?.settings?.temperature
}
},
this.requestOptions
)
const chat = await geminiModel.startChat()
const content = isGemmaModel(model) const content = isGemmaModel(model)
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>` ? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
: userMessage.content : userMessage.content
const { response } = await chat.sendMessage(content) const response = await this.sdk.models.generateContent({
model: model.id,
config: {
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content
},
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
return removeSpecialCharactersForTopicName(response.text()) return removeSpecialCharactersForTopicName(response.text || '')
} }
/** /**
@ -471,24 +518,23 @@ export default class GeminiProvider extends BaseProvider {
*/ */
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> { public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel() const model = getDefaultModel()
const systemMessage = { role: 'system', content: prompt } const MessageContent = isGemmaModel(model)
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content })
},
this.requestOptions
)
const chat = await geminiModel.startChat()
const messageContent = isGemmaModel(model)
? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>` ? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
: content : content
const response = await this.sdk.models.generateContent({
model: model.id,
config: {
systemInstruction: isGemmaModel(model) ? undefined : prompt
},
contents: [
{
role: 'user',
parts: [{ text: MessageContent }]
}
]
})
const { response } = await chat.sendMessage(messageContent) return response.text || ''
return response.text()
} }
/** /**
@ -518,24 +564,28 @@ export default class GeminiProvider extends BaseProvider {
content: messages.map((m) => m.content).join('\n') content: messages.map((m) => m.content).join('\n')
} }
const geminiModel = this.sdk.getGenerativeModel( const content = isGemmaModel(model)
{ ? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
: userMessage.content
const response = await this.sdk.models.generateContent({
model: model.id, model: model.id,
systemInstruction: systemMessage.content, config: {
generationConfig: { systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content,
temperature: assistant?.settings?.temperature temperature: assistant?.settings?.temperature,
} httpOptions: {
},
{
...this.requestOptions,
timeout: 20 * 1000 timeout: 20 * 1000
} }
) },
contents: [
{
role: 'user',
parts: [{ text: content }]
}
]
})
const chat = await geminiModel.startChat() return response.text || ''
const { response } = await chat.sendMessage(userMessage.content)
return response.text()
} }
/** /**
@ -546,144 +596,13 @@ export default class GeminiProvider extends BaseProvider {
return [] return []
} }
/**
*
* @param messages -
* @param assistant -
* @param onChunk -
* @param onFilterMessages -
* @returns Promise<void>
*/
private async generateImageExp({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, streamOutput, maxTokens } = getAssistantSettings(assistant)
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
onFilterMessages(userMessages)
const userLastMessage = userMessages.pop()
if (!userLastMessage) {
throw new Error('No user message found')
}
const history: Content[] = []
for (const message of userMessages) {
history.push(await this.getMessageContents(message))
}
const userLastMessageContent = await this.getMessageContents(userLastMessage)
const allContents = [...history, userLastMessageContent]
let contents: ContentListUnion = allContents.length > 0 ? (allContents as ContentListUnion) : []
contents = await this.addImageFileToContents(userLastMessage, contents)
if (!streamOutput) {
const response = await this.callGeminiGenerateContent(model.id, contents, maxTokens)
const { isValid, message } = this.isValidGeminiResponse(response)
if (!isValid) {
throw new Error(`Gemini API error: ${message}`)
}
this.processGeminiImageResponse(response, onChunk)
return
}
const response = await this.callGeminiGenerateContentStream(model.id, contents, maxTokens)
for await (const chunk of response) {
this.processGeminiImageResponse(chunk, onChunk)
}
}
/**
*
* @param message -
* @param contents -
* @returns
*/
private async addImageFileToContents(message: Message, contents: ContentListUnion): Promise<ContentListUnion> {
if (message.files && message.files.length > 0) {
const file = message.files[0]
const fileContent = await window.api.file.base64Image(file.id + file.ext)
if (fileContent && fileContent.base64) {
const contentsArray = Array.isArray(contents) ? contents : [contents]
return [...contentsArray, createPartFromBase64(fileContent.base64, fileContent.mime)]
}
}
return contents
}
/**
* Gemini API生成内容
* @param modelId - ID
* @param contents -
* @returns
*/
private async callGeminiGenerateContent(
modelId: string,
contents: ContentListUnion,
maxTokens?: number
): Promise<GenerateContentResponse> {
try {
return await this.imageSdk.models.generateContent({
model: modelId,
contents: contents,
config: {
responseModalities: ['Text', 'Image'],
responseMimeType: 'text/plain',
maxOutputTokens: maxTokens
}
})
} catch (error) {
console.error('Gemini API error:', error)
throw error
}
}
private async callGeminiGenerateContentStream(
modelId: string,
contents: ContentListUnion,
maxTokens?: number
): Promise<AsyncGenerator<GenerateContentResponse>> {
try {
return await this.imageSdk.models.generateContentStream({
model: modelId,
contents: contents,
config: {
responseModalities: ['Text', 'Image'],
responseMimeType: 'text/plain',
maxOutputTokens: maxTokens
}
})
} catch (error) {
console.error('Gemini API error:', error)
throw error
}
}
/**
* Gemini响应是否有效
* @param response - Gemini响应
* @returns
*/
private isValidGeminiResponse(response: GenerateContentResponse): { isValid: boolean; message: string } {
return {
isValid: response?.candidates?.[0]?.finishReason === FinishReason.STOP ? true : false,
message: response?.candidates?.[0]?.finishReason || ''
}
}
/** /**
* Gemini图像响应 * Gemini图像响应
* @param response - Gemini响应 * @param response - Gemini响应
* @param onChunk - * @param onChunk -
*/ */
private processGeminiImageResponse(response: any, onChunk: (chunk: ChunkCallbackData) => void): void { private processGeminiImageResponse(chunk: GenerateContentResponse): { type: 'base64'; images: string[] } | undefined {
const parts = response.candidates[0].content.parts const parts = chunk.candidates?.[0]?.content?.parts
if (!parts) { if (!parts) {
return return
} }
@ -695,31 +614,13 @@ export default class GeminiProvider extends BaseProvider {
return null return null
} }
const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,` const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,`
return part.inlineData.data.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data return part.inlineData.data?.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
}) })
// 提取文本数据 return {
const text = parts
.filter((part: Part) => part.text !== undefined)
.map((part: Part) => part.text)
.join('')
// 返回结果
onChunk({
text,
generateImage: {
type: 'base64', type: 'base64',
images images: images.filter((image) => image !== null)
},
usage: {
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
total_tokens: response.usageMetadata?.totalTokenCount || 0
},
metrics: {
completion_tokens: response.usageMetadata?.candidatesTokenCount
} }
})
} }
/** /**
@ -732,18 +633,16 @@ export default class GeminiProvider extends BaseProvider {
return { valid: false, error: new Error('No model found') } return { valid: false, error: new Error('No model found') }
} }
const body = {
model: model.id,
messages: [{ role: 'user', content: 'hi' }],
max_tokens: 100,
stream: false
}
try { try {
const geminiModel = this.sdk.getGenerativeModel({ model: body.model }, this.requestOptions) const result = await this.sdk.models.generateContent({
const result = await geminiModel.generateContent(body.messages[0].content) model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
config: {
maxOutputTokens: 100
}
})
return { return {
valid: !isEmpty(result.response.text()), valid: !isEmpty(result.text),
error: null error: null
} }
} catch (error: any) { } catch (error: any) {
@ -785,7 +684,10 @@ export default class GeminiProvider extends BaseProvider {
* @returns The embedding dimensions * @returns The embedding dimensions
*/ */
public async getEmbeddingDimensions(model: Model): Promise<number> { public async getEmbeddingDimensions(model: Model): Promise<number> {
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi') const data = await this.sdk.models.embedContent({
return data.embedding.values.length model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
} }
} }

View File

@ -1,4 +1,4 @@
import type { GroundingMetadata } from '@google/generative-ai' import type { GroundingMetadata } from '@google/genai'
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider' import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory' import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
import type { import type {

View File

@ -282,6 +282,7 @@ export async function fetchChatCompletion({
} }
} }
console.log('message', message)
// Emit chat completion event // Emit chat completion event
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message) EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
onResponse(message) onResponse(message)

View File

@ -1,4 +1,4 @@
import { GroundingMetadata } from '@google/generative-ai' import { GroundingMetadata } from '@google/genai'
import OpenAI from 'openai' import OpenAI from 'openai'
import React from 'react' import React from 'react'
import { BuiltinTheme } from 'shiki' import { BuiltinTheme } from 'shiki'

View File

@ -71,8 +71,8 @@ export function withGeminiGrounding(message: Message) {
let content = message.content let content = message.content
groundingSupports.forEach((support) => { groundingSupports.forEach((support) => {
const text = support?.segment const text = support?.segment?.text
const indices = support?.groundingChunckIndices const indices = support?.groundingChunkIndices
if (!text || !indices) return if (!text || !indices) return

View File

@ -1,186 +1,165 @@
import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
import { MessageParam } from '@anthropic-ai/sdk/resources' import { MessageParam } from '@anthropic-ai/sdk/resources'
import { import { Content, FunctionCall, Part } from '@google/genai'
ArraySchema,
BaseSchema,
BooleanSchema,
EnumStringSchema,
FunctionCall,
FunctionDeclaration,
FunctionDeclarationSchema,
FunctionDeclarationSchemaProperty,
IntegerSchema,
NumberSchema,
ObjectSchema,
SchemaType,
SimpleStringSchema,
Tool as geminiTool
} from '@google/generative-ai'
import { Content, Part } from '@google/generative-ai'
import store from '@renderer/store' import store from '@renderer/store'
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types' import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
import { import { ChatCompletionContentPart, ChatCompletionMessageParam, ChatCompletionMessageToolCall } from 'openai/resources'
ChatCompletionContentPart,
ChatCompletionMessageParam,
ChatCompletionMessageToolCall,
ChatCompletionTool
} from 'openai/resources'
import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider' import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider'
const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => { // const ensureValidSchema = (obj: Record<string, any>) => {
// Filter out unsupported keys for Gemini // // Filter out unsupported keys for Gemini
const filteredObj = filterUnsupportedKeys(obj) // const filteredObj = filterUnsupportedKeys(obj)
// Handle base schema properties // // Handle base schema properties
const baseSchema = { // const baseSchema = {
description: filteredObj.description, // description: filteredObj.description,
nullable: filteredObj.nullable // nullable: filteredObj.nullable
} as BaseSchema // } as BaseSchema
// Handle string type // // Handle string type
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) { // if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
if (filteredObj.enum && Array.isArray(filteredObj.enum)) { // if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
return { // return {
...baseSchema, // ...baseSchema,
type: SchemaType.STRING, // type: SchemaType.STRING,
format: 'enum', // format: 'enum',
enum: filteredObj.enum as string[] // enum: filteredObj.enum as string[]
} as EnumStringSchema // } as EnumStringSchema
} // }
return { // return {
...baseSchema, // ...baseSchema,
type: SchemaType.STRING, // type: SchemaType.STRING,
format: filteredObj.format === 'date-time' ? 'date-time' : undefined // format: filteredObj.format === 'date-time' ? 'date-time' : undefined
} as SimpleStringSchema // } as SimpleStringSchema
} // }
// Handle number type // // Handle number type
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) { // if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
return { // return {
...baseSchema, // ...baseSchema,
type: SchemaType.NUMBER, // type: SchemaType.NUMBER,
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined // format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
} as NumberSchema // } as NumberSchema
} // }
// Handle integer type // // Handle integer type
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) { // if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
return { // return {
...baseSchema, // ...baseSchema,
type: SchemaType.INTEGER, // type: SchemaType.INTEGER,
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined // format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
} as IntegerSchema // } as IntegerSchema
} // }
// Handle boolean type // // Handle boolean type
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) { // if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
return { // return {
...baseSchema, // ...baseSchema,
type: SchemaType.BOOLEAN // type: SchemaType.BOOLEAN
} as BooleanSchema // } as BooleanSchema
} // }
// Handle array type // // Handle array type
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) { // if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
return { // return {
...baseSchema, // ...baseSchema,
type: SchemaType.ARRAY, // type: SchemaType.ARRAY,
items: filteredObj.items // items: filteredObj.items
? ensureValidSchema(filteredObj.items as Record<string, any>) // ? ensureValidSchema(filteredObj.items as Record<string, any>)
: ({ type: SchemaType.STRING } as SimpleStringSchema), // : ({ type: SchemaType.STRING } as SimpleStringSchema),
minItems: filteredObj.minItems, // minItems: filteredObj.minItems,
maxItems: filteredObj.maxItems // maxItems: filteredObj.maxItems
} as ArraySchema // } as ArraySchema
} // }
// Handle object type (default) // // Handle object type (default)
const properties = filteredObj.properties // const properties = filteredObj.properties
? Object.fromEntries( // ? Object.fromEntries(
Object.entries(filteredObj.properties).map(([key, value]) => [ // Object.entries(filteredObj.properties).map(([key, value]) => [
key, // key,
ensureValidSchema(value as Record<string, any>) // ensureValidSchema(value as Record<string, any>)
]) // ])
) // )
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty // : { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
return { // return {
...baseSchema, // ...baseSchema,
type: SchemaType.OBJECT, // type: SchemaType.OBJECT,
properties, // properties,
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined // required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
} as ObjectSchema // } as ObjectSchema
} // }
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> { // function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
const supportedBaseKeys = ['description', 'nullable'] // const supportedBaseKeys = ['description', 'nullable']
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum'] // const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format'] // const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
const supportedBooleanKeys = [...supportedBaseKeys, 'type'] // const supportedBooleanKeys = [...supportedBaseKeys, 'type']
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems'] // const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required'] // const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
const filtered: Record<string, any> = {} // const filtered: Record<string, any> = {}
let keysToKeep: string[] // let keysToKeep: string[]
if (obj.type?.toLowerCase() === SchemaType.STRING) { // if (obj.type?.toLowerCase() === SchemaType.STRING) {
keysToKeep = supportedStringKeys // keysToKeep = supportedStringKeys
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) { // } else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
keysToKeep = supportedNumberKeys // keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) { // } else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
keysToKeep = supportedNumberKeys // keysToKeep = supportedNumberKeys
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) { // } else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
keysToKeep = supportedBooleanKeys // keysToKeep = supportedBooleanKeys
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) { // } else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
keysToKeep = supportedArrayKeys // keysToKeep = supportedArrayKeys
} else { // } else {
// Default to object type // // Default to object type
keysToKeep = supportedObjectKeys // keysToKeep = supportedObjectKeys
} // }
// copy supported keys // // copy supported keys
for (const key of keysToKeep) { // for (const key of keysToKeep) {
if (obj[key] !== undefined) { // if (obj[key] !== undefined) {
filtered[key] = obj[key] // filtered[key] = obj[key]
} // }
} // }
return filtered // return filtered
} // }
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> { // function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
const properties = tool.inputSchema.properties // const properties = tool.inputSchema.properties
if (!properties) { // if (!properties) {
return {} // return {}
} // }
// For OpenAI, we don't need to validate as strictly // // For OpenAI, we don't need to validate as strictly
if (!filterNestedObj) { // if (!filterNestedObj) {
return properties // return properties
} // }
const processedProperties = Object.fromEntries( // const processedProperties = Object.fromEntries(
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)]) // Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
) // )
return processedProperties // return processedProperties
} // }
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> { // export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
return mcpTools.map((tool) => ({ // return mcpTools.map((tool) => ({
type: 'function', // type: 'function',
name: tool.name, // name: tool.name,
function: { // function: {
name: tool.id, // name: tool.id,
description: tool.description, // description: tool.description,
parameters: { // parameters: {
type: 'object', // type: 'object',
properties: filterPropertieAttributes(tool) // properties: filterPropertieAttributes(tool)
} // }
} // }
})) // }))
} // }
export function openAIToolsToMcpTool( export function openAIToolsToMcpTool(
mcpTools: MCPTool[] | undefined, mcpTools: MCPTool[] | undefined,
@ -277,35 +256,35 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
return tool return tool
} }
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] { // export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
if (!mcpTools || mcpTools.length === 0) { // if (!mcpTools || mcpTools.length === 0) {
// No tools available // // No tools available
return [] // return []
} // }
const functions: FunctionDeclaration[] = [] // const functions: FunctionDeclaration[] = []
for (const tool of mcpTools) { // for (const tool of mcpTools) {
const properties = filterPropertieAttributes(tool, true) // const properties = filterPropertieAttributes(tool, true)
const functionDeclaration: FunctionDeclaration = { // const functionDeclaration: FunctionDeclaration = {
name: tool.id, // name: tool.id,
description: tool.description, // description: tool.description,
parameters: { // parameters: {
type: SchemaType.OBJECT, // type: SchemaType.OBJECT,
properties: // properties:
Object.keys(properties).length > 0 // Object.keys(properties).length > 0
? Object.fromEntries( // ? Object.fromEntries(
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)]) // Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
) // )
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // : { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
} as FunctionDeclarationSchema // } as FunctionDeclarationSchema
} // }
functions.push(functionDeclaration) // functions.push(functionDeclaration)
} // }
const tool: geminiTool = { // const tool: geminiTool = {
functionDeclarations: functions // functionDeclarations: functions
} // }
return [tool] // return [tool]
} // }
export function geminiFunctionCallToMcpTool( export function geminiFunctionCallToMcpTool(
mcpTools: MCPTool[] | undefined, mcpTools: MCPTool[] | undefined,

1671
yarn.lock

File diff suppressed because it is too large Load Diff