refactor(Gemini): migrate generative-ai sdk to genai sdk (#4939)
* refactor(GeminiService): migrate to new Google GenAI SDK and update file handling methods - Updated import statements to use the new Google GenAI SDK. - Refactored file upload, retrieval, and deletion methods to align with the new SDK's API. - Adjusted type definitions and response handling for improved type safety and clarity. - Enhanced file listing and processing logic to utilize async iteration for better performance. * refactor(GeminiProvider): update message handling and integrate abort signal support - Refactored message content handling to align with updated type definitions, ensuring consistent use of Content type. - Enhanced abort signal management for chat requests, allowing for better control over ongoing operations. - Improved message processing logic to streamline user message history handling and response generation. - Adjusted type definitions for message contents to enhance type safety and clarity. * refactor(electron.vite.config): replace direct import of Vite React plugin with dynamic import * fix(Gemini): clean up unused methods and improve property access * fix(typecheck): update color properties to use CSS variables * feat: 修改画图逻辑 * fix: import viteReact --------- Co-authored-by: eeee0717 <chentao020717Work@outlook.com>
This commit is contained in:
parent
35c50b54a8
commit
24ddd69cd5
37698
.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch
vendored
Normal file
37698
.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import react from '@vitejs/plugin-react'
|
||||
import viteReact from '@vitejs/plugin-react'
|
||||
import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
|
||||
import { resolve } from 'path'
|
||||
import { visualizer } from 'rollup-plugin-visualizer'
|
||||
@ -6,7 +6,7 @@ import { visualizer } from 'rollup-plugin-visualizer'
|
||||
const visualizerPlugin = (type: 'renderer' | 'main') => {
|
||||
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
|
||||
}
|
||||
|
||||
// const viteReact = await import('@vitejs/plugin-react')
|
||||
export default defineConfig({
|
||||
main: {
|
||||
plugins: [
|
||||
@ -51,7 +51,7 @@ export default defineConfig({
|
||||
},
|
||||
renderer: {
|
||||
plugins: [
|
||||
react({
|
||||
viteReact({
|
||||
babel: {
|
||||
plugins: [
|
||||
[
|
||||
|
||||
@ -64,7 +64,6 @@
|
||||
"@cherrystudio/embedjs-openai": "^0.1.28",
|
||||
"@electron-toolkit/utils": "^3.0.0",
|
||||
"@electron/notarize": "^2.5.0",
|
||||
"@google/generative-ai": "^0.24.0",
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
@ -74,6 +73,7 @@
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"adm-zip": "^0.5.16",
|
||||
"async-mutex": "^0.5.0",
|
||||
"bufferutil": "^4.0.9",
|
||||
"color": "^5.0.0",
|
||||
"diff": "^7.0.0",
|
||||
"docx": "^9.0.2",
|
||||
@ -96,6 +96,7 @@
|
||||
"turndown-plugin-gfm": "^1.0.2",
|
||||
"undici": "^7.4.0",
|
||||
"webdav": "^5.8.0",
|
||||
"ws": "^8.18.1",
|
||||
"zipread": "^1.3.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
@ -112,7 +113,7 @@
|
||||
"@emotion/is-prop-valid": "^1.3.1",
|
||||
"@eslint-react/eslint-plugin": "^1.36.1",
|
||||
"@eslint/js": "^9.22.0",
|
||||
"@google/genai": "^0.4.0",
|
||||
"@google/genai": "patch:@google/genai@npm%3A0.8.0#~/.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch",
|
||||
"@hello-pangea/dnd": "^16.6.0",
|
||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||
"@modelcontextprotocol/sdk": "^1.9.0",
|
||||
@ -133,7 +134,8 @@
|
||||
"@types/react-dom": "^19.0.4",
|
||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||
"@types/tinycolor2": "^1",
|
||||
"@vitejs/plugin-react": "^4.2.1",
|
||||
"@types/ws": "^8",
|
||||
"@vitejs/plugin-react": "4.3.4",
|
||||
"analytics": "^0.8.16",
|
||||
"antd": "^5.22.5",
|
||||
"applescript": "^1.0.0",
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { FileMetadataResponse, FileState, GoogleAIFileManager } from '@google/generative-ai/server'
|
||||
import { File, FileState, GoogleGenAI, Pager } from '@google/genai'
|
||||
import { FileType } from '@types'
|
||||
import fs from 'fs'
|
||||
|
||||
@ -8,11 +8,15 @@ export class GeminiService {
|
||||
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
|
||||
private static readonly CACHE_DURATION = 3000
|
||||
|
||||
static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string) {
|
||||
const fileManager = new GoogleAIFileManager(apiKey)
|
||||
const uploadResult = await fileManager.uploadFile(file.path, {
|
||||
static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File> {
|
||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||
const uploadResult = await sdk.files.upload({
|
||||
file: file.path,
|
||||
config: {
|
||||
mimeType: 'application/pdf',
|
||||
name: file.id,
|
||||
displayName: file.origin_name
|
||||
}
|
||||
})
|
||||
return uploadResult
|
||||
}
|
||||
@ -24,40 +28,42 @@ export class GeminiService {
|
||||
}
|
||||
}
|
||||
|
||||
static async retrieveFile(
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
file: FileType,
|
||||
apiKey: string
|
||||
): Promise<FileMetadataResponse | undefined> {
|
||||
const fileManager = new GoogleAIFileManager(apiKey)
|
||||
|
||||
static async retrieveFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File | undefined> {
|
||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||
const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY)
|
||||
if (cachedResponse) {
|
||||
return GeminiService.processResponse(cachedResponse, file)
|
||||
}
|
||||
|
||||
const response = await fileManager.listFiles()
|
||||
const response = await sdk.files.list()
|
||||
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION)
|
||||
|
||||
return GeminiService.processResponse(response, file)
|
||||
}
|
||||
|
||||
private static processResponse(response: any, file: FileType) {
|
||||
if (response.files) {
|
||||
return response.files
|
||||
.filter((file) => file.state === FileState.ACTIVE)
|
||||
.find((i) => i.displayName === file.origin_name && Number(i.sizeBytes) === file.size)
|
||||
private static async processResponse(response: Pager<File>, file: FileType) {
|
||||
for await (const f of response) {
|
||||
if (f.state === FileState.ACTIVE) {
|
||||
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string) {
|
||||
const fileManager = new GoogleAIFileManager(apiKey)
|
||||
return await fileManager.listFiles()
|
||||
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string): Promise<File[]> {
|
||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||
const files: File[] = []
|
||||
for await (const f of await sdk.files.list()) {
|
||||
files.push(f)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
static async deleteFile(_: Electron.IpcMainInvokeEvent, apiKey: string, fileId: string) {
|
||||
const fileManager = new GoogleAIFileManager(apiKey)
|
||||
await fileManager.deleteFile(fileId)
|
||||
static async deleteFile(_: Electron.IpcMainInvokeEvent, fileId: string, apiKey: string) {
|
||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||
await sdk.files.delete({ name: fileId })
|
||||
}
|
||||
}
|
||||
|
||||
10
src/preload/index.d.ts
vendored
10
src/preload/index.d.ts
vendored
@ -1,6 +1,6 @@
|
||||
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { ElectronAPI } from '@electron-toolkit/preload'
|
||||
import type { FileMetadataResponse, ListFilesResponse, UploadFileResponse } from '@google/generative-ai/server'
|
||||
import type { File } from '@google/genai'
|
||||
import type { GetMCPPromptResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@renderer/types'
|
||||
import { AppInfo, FileType, KnowledgeBaseParams, KnowledgeItem, LanguageVarious, WebDavConfig } from '@renderer/types'
|
||||
import type { LoaderReturn } from '@shared/config/types'
|
||||
@ -119,11 +119,11 @@ declare global {
|
||||
resetMinimumSize: () => Promise<void>
|
||||
}
|
||||
gemini: {
|
||||
uploadFile: (file: FileType, apiKey: string) => Promise<UploadFileResponse>
|
||||
retrieveFile: (file: FileType, apiKey: string) => Promise<FileMetadataResponse | undefined>
|
||||
uploadFile: (file: FileType, apiKey: string) => Promise<File>
|
||||
retrieveFile: (file: FileType, apiKey: string) => Promise<File | undefined>
|
||||
base64File: (file: FileType) => Promise<{ data: string; mimeType: string }>
|
||||
listFiles: (apiKey: string) => Promise<ListFilesResponse>
|
||||
deleteFile: (apiKey: string, fileId: string) => Promise<void>
|
||||
listFiles: (apiKey: string) => Promise<File[]>
|
||||
deleteFile: (fileId: string, apiKey: string) => Promise<void>
|
||||
}
|
||||
selectionMenu: {
|
||||
action: (action: string) => Promise<void>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { DeleteOutlined } from '@ant-design/icons'
|
||||
import type { FileMetadataResponse } from '@google/generative-ai/server'
|
||||
import type { File } from '@google/genai'
|
||||
import { useProvider } from '@renderer/hooks/useProvider'
|
||||
import { runAsyncFunction } from '@renderer/utils'
|
||||
import { MB } from '@shared/config/constant'
|
||||
@ -16,11 +16,11 @@ interface GeminiFilesProps {
|
||||
|
||||
const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
|
||||
const { provider } = useProvider(id)
|
||||
const [files, setFiles] = useState<FileMetadataResponse[]>([])
|
||||
const [files, setFiles] = useState<File[]>([])
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const fetchFiles = useCallback(async () => {
|
||||
const { files } = await window.api.gemini.listFiles(provider.apiKey)
|
||||
const files = await window.api.gemini.listFiles(provider.apiKey)
|
||||
files && setFiles(files.filter((file) => file.state === 'ACTIVE'))
|
||||
}, [provider])
|
||||
|
||||
@ -60,14 +60,14 @@ const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
|
||||
key={file.name}
|
||||
fileInfo={{
|
||||
name: file.displayName,
|
||||
ext: `.${file.name.split('.').pop()}`,
|
||||
extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes) / MB).toFixed(2)} MB`,
|
||||
ext: `.${file.name?.split('.').pop()}`,
|
||||
extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes || '0') / MB).toFixed(2)} MB`,
|
||||
actions: (
|
||||
<DeleteOutlined
|
||||
style={{ cursor: 'pointer', color: 'var(--color-error)' }}
|
||||
onClick={() => {
|
||||
setFiles(files.filter((f) => f.name !== file.name))
|
||||
window.api.gemini.deleteFile(provider.apiKey, file.name).catch((error) => {
|
||||
window.api.gemini.deleteFile(file.name!, provider.apiKey).catch((error) => {
|
||||
console.error('Failed to delete file:', error)
|
||||
setFiles((prev) => [...prev, file])
|
||||
})
|
||||
|
||||
@ -173,13 +173,13 @@ const ServerName = styled.div`
|
||||
|
||||
const ServerDescription = styled.div`
|
||||
font-size: 0.85rem;
|
||||
color: ${(props) => props.theme.colors?.textSecondary || '#8c8c8c'};
|
||||
color: var(--color-text-2);
|
||||
margin-bottom: 3px;
|
||||
`
|
||||
|
||||
const ServerUrl = styled.div`
|
||||
font-size: 0.8rem;
|
||||
color: ${(props) => props.theme.colors?.textTertiary || '#bfbfbf'};
|
||||
color: var(--color-text-3);
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
|
||||
@ -1,25 +1,18 @@
|
||||
import {
|
||||
ContentListUnion,
|
||||
createPartFromBase64,
|
||||
FinishReason,
|
||||
GenerateContentResponse,
|
||||
GoogleGenAI
|
||||
} from '@google/genai'
|
||||
import {
|
||||
Content,
|
||||
FileDataPart,
|
||||
GenerateContentStreamResult,
|
||||
GoogleGenerativeAI,
|
||||
File,
|
||||
GenerateContentConfig,
|
||||
GenerateContentResponse,
|
||||
GoogleGenAI,
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
InlineDataPart,
|
||||
Modality,
|
||||
Part,
|
||||
RequestOptions,
|
||||
PartUnion,
|
||||
SafetySetting,
|
||||
TextPart,
|
||||
Tool
|
||||
} from '@google/generative-ai'
|
||||
import { isGemmaModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
|
||||
ToolListUnion
|
||||
} from '@google/genai'
|
||||
import { isGemmaModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
@ -39,22 +32,15 @@ import axios from 'axios'
|
||||
import { flatten, isEmpty, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { ChunkCallbackData, CompletionsParams } from '.'
|
||||
import { CompletionsParams } from '.'
|
||||
import BaseProvider from './BaseProvider'
|
||||
|
||||
export default class GeminiProvider extends BaseProvider {
|
||||
private sdk: GoogleGenerativeAI
|
||||
private requestOptions: RequestOptions
|
||||
private imageSdk: GoogleGenAI
|
||||
private sdk: GoogleGenAI
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.sdk = new GoogleGenerativeAI(this.apiKey)
|
||||
/// this sdk is experimental
|
||||
this.imageSdk = new GoogleGenAI({ apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
|
||||
this.requestOptions = {
|
||||
baseUrl: this.getBaseURL()
|
||||
}
|
||||
this.sdk = new GoogleGenAI({ vertexai: false, apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
@ -76,31 +62,31 @@ export default class GeminiProvider extends BaseProvider {
|
||||
inlineData: {
|
||||
data,
|
||||
mimeType
|
||||
} as Part['inlineData']
|
||||
}
|
||||
} as InlineDataPart
|
||||
}
|
||||
|
||||
// Retrieve file from Gemini uploaded files
|
||||
const fileMetadata = await window.api.gemini.retrieveFile(file, this.apiKey)
|
||||
const fileMetadata: File | undefined = await window.api.gemini.retrieveFile(file, this.apiKey)
|
||||
|
||||
if (fileMetadata) {
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: fileMetadata.uri,
|
||||
mimeType: fileMetadata.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
} as FileDataPart
|
||||
}
|
||||
|
||||
// If file is not found, upload it to Gemini
|
||||
const uploadResult = await window.api.gemini.uploadFile(file, this.apiKey)
|
||||
const result = await window.api.gemini.uploadFile(file, this.apiKey)
|
||||
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: uploadResult.file.uri,
|
||||
mimeType: uploadResult.file.mimeType
|
||||
fileUri: result.uri,
|
||||
mimeType: result.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
} as FileDataPart
|
||||
}
|
||||
|
||||
/**
|
||||
@ -125,8 +111,8 @@ export default class GeminiProvider extends BaseProvider {
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
}
|
||||
} as InlineDataPart)
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -139,8 +125,8 @@ export default class GeminiProvider extends BaseProvider {
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
}
|
||||
} as InlineDataPart)
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
|
||||
if (file.ext === '.pdf') {
|
||||
@ -152,13 +138,13 @@ export default class GeminiProvider extends BaseProvider {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
} as TextPart)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role,
|
||||
parts
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
@ -204,10 +190,13 @@ export default class GeminiProvider extends BaseProvider {
|
||||
* @param onChunk - The onChunk callback
|
||||
* @param onFilterMessages - The onFilterMessages callback
|
||||
*/
|
||||
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
||||
if (assistant.enableGenerateImage) {
|
||||
await this.generateImageExp({ messages, assistant, onFilterMessages, onChunk })
|
||||
} else {
|
||||
public async completions({
|
||||
messages,
|
||||
assistant,
|
||||
mcpTools,
|
||||
onChunk,
|
||||
onFilterMessages
|
||||
}: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
@ -232,7 +221,7 @@ export default class GeminiProvider extends BaseProvider {
|
||||
}
|
||||
|
||||
// const tools = mcpToolsToGeminiTools(mcpTools)
|
||||
const tools: Tool[] = []
|
||||
const tools: ToolListUnion = []
|
||||
const toolResponses: MCPToolResponse[] = []
|
||||
|
||||
if (!WebSearchService.isOverwriteEnabled() && assistant.enableWebSearch && isWebSearchModel(model)) {
|
||||
@ -242,55 +231,81 @@ export default class GeminiProvider extends BaseProvider {
|
||||
})
|
||||
}
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel(
|
||||
{
|
||||
model: model.id,
|
||||
...(isGemmaModel(model) ? {} : { systemInstruction: systemInstruction }),
|
||||
const generateContentConfig: GenerateContentConfig = {
|
||||
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||
responseMimeType: 'text/plain',
|
||||
safetySettings: this.getSafetySettings(model.id),
|
||||
tools: tools,
|
||||
generationConfig: {
|
||||
maxOutputTokens: maxTokens,
|
||||
// generate image don't need system instruction
|
||||
systemInstruction: isGemmaModel(model) || isGenerateImageModel(model) ? undefined : systemInstruction,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
topP: assistant?.settings?.topP,
|
||||
maxOutputTokens: maxTokens,
|
||||
tools: tools,
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
},
|
||||
this.requestOptions
|
||||
)
|
||||
|
||||
const chat = geminiModel.startChat({ history })
|
||||
const messageContents = await this.getMessageContents(userLastMessage!)
|
||||
const messageContents: Content = await this.getMessageContents(userLastMessage!)
|
||||
|
||||
const chat = this.sdk.chats.create({
|
||||
model: model.id,
|
||||
config: generateContentConfig,
|
||||
history: history
|
||||
})
|
||||
|
||||
if (isGemmaModel(model) && assistant.prompt) {
|
||||
const isFirstMessage = history.length === 0
|
||||
if (isFirstMessage) {
|
||||
const systemMessage = {
|
||||
role: 'user',
|
||||
parts: [
|
||||
if (isFirstMessage && messageContents) {
|
||||
const systemMessage = [
|
||||
{
|
||||
text:
|
||||
'<start_of_turn>user\n' +
|
||||
systemInstruction +
|
||||
'<end_of_turn>\n' +
|
||||
'<start_of_turn>user\n' +
|
||||
messageContents.parts[0].text +
|
||||
(messageContents?.parts?.[0] as Part).text +
|
||||
'<end_of_turn>'
|
||||
}
|
||||
]
|
||||
] as Part[]
|
||||
if (messageContents && messageContents.parts) {
|
||||
messageContents.parts[0] = systemMessage[0]
|
||||
}
|
||||
messageContents.parts = systemMessage.parts
|
||||
}
|
||||
}
|
||||
|
||||
const start_time_millsec = new Date().getTime()
|
||||
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
|
||||
const signalProxy = {
|
||||
_originalSignal: abortController.signal,
|
||||
|
||||
addEventListener: (eventName: string, listener: () => void) => {
|
||||
if (eventName === 'abort') {
|
||||
abortController.signal.addEventListener('abort', listener)
|
||||
}
|
||||
},
|
||||
removeEventListener: (eventName: string, listener: () => void) => {
|
||||
if (eventName === 'abort') {
|
||||
abortController.signal.removeEventListener('abort', listener)
|
||||
}
|
||||
},
|
||||
get aborted() {
|
||||
return abortController.signal.aborted
|
||||
}
|
||||
}
|
||||
|
||||
if (!streamOutput) {
|
||||
const { response } = await chat.sendMessage(messageContents.parts, { signal })
|
||||
const response = await chat.sendMessage({
|
||||
message: messageContents as PartUnion,
|
||||
config: {
|
||||
...generateContentConfig,
|
||||
httpOptions: {
|
||||
signal: signalProxy as any
|
||||
}
|
||||
}
|
||||
})
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
onChunk({
|
||||
text: response.candidates?.[0].content.parts[0].text,
|
||||
text: response.text,
|
||||
usage: {
|
||||
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
||||
@ -306,7 +321,15 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return
|
||||
}
|
||||
|
||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
|
||||
const userMessagesStream = await chat.sendMessageStream({
|
||||
message: messageContents as PartUnion,
|
||||
config: {
|
||||
...generateContentConfig,
|
||||
httpOptions: {
|
||||
signal: signalProxy as any
|
||||
}
|
||||
}
|
||||
})
|
||||
let time_first_token_millsec = 0
|
||||
|
||||
const processToolUses = async (content: string, idx: number) => {
|
||||
@ -321,17 +344,27 @@ export default class GeminiProvider extends BaseProvider {
|
||||
)
|
||||
if (toolResults && toolResults.length > 0) {
|
||||
history.push(messageContents)
|
||||
const newChat = geminiModel.startChat({ history })
|
||||
const newStream = await newChat.sendMessageStream(flatten(toolResults.map((ts) => (ts as Content).parts)), {
|
||||
signal
|
||||
const newChat = this.sdk.chats.create({
|
||||
model: model.id,
|
||||
config: generateContentConfig,
|
||||
history: history as Content[]
|
||||
})
|
||||
const newStream = await newChat.sendMessageStream({
|
||||
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
|
||||
config: {
|
||||
...generateContentConfig,
|
||||
httpOptions: {
|
||||
signal: signalProxy as any
|
||||
}
|
||||
}
|
||||
})
|
||||
await processStream(newStream, idx + 1)
|
||||
}
|
||||
}
|
||||
|
||||
const processStream = async (stream: GenerateContentStreamResult, idx: number) => {
|
||||
const processStream = async (stream: AsyncGenerator<GenerateContentResponse>, idx: number) => {
|
||||
let content = ''
|
||||
for await (const chunk of stream.stream) {
|
||||
for await (const chunk of stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
|
||||
if (time_first_token_millsec == 0) {
|
||||
@ -340,11 +373,14 @@ export default class GeminiProvider extends BaseProvider {
|
||||
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
|
||||
content += chunk.text()
|
||||
processToolUses(content, idx)
|
||||
if (chunk.text !== undefined) {
|
||||
content += chunk.text
|
||||
}
|
||||
await processToolUses(content, idx)
|
||||
const generateImage = this.processGeminiImageResponse(chunk)
|
||||
|
||||
onChunk({
|
||||
text: chunk.text(),
|
||||
text: chunk.text !== undefined ? chunk.text : '',
|
||||
usage: {
|
||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
||||
@ -356,14 +392,14 @@ export default class GeminiProvider extends BaseProvider {
|
||||
time_first_token_millsec
|
||||
},
|
||||
search: chunk.candidates?.[0]?.groundingMetadata,
|
||||
mcpToolResponse: toolResponses
|
||||
mcpToolResponse: toolResponses,
|
||||
generateImage: generateImage
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processStream(userMessagesStream, 0).finally(cleanup)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Translate a message
|
||||
@ -372,39 +408,51 @@ export default class GeminiProvider extends BaseProvider {
|
||||
* @param onResponse - The onResponse callback
|
||||
* @returns The translated message
|
||||
*/
|
||||
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const model = assistant.model || defaultModel
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel(
|
||||
{
|
||||
model: model.id,
|
||||
...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
|
||||
generationConfig: {
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
},
|
||||
this.requestOptions
|
||||
)
|
||||
|
||||
const content =
|
||||
isGemmaModel(model) && assistant.prompt
|
||||
? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${message.content}<end_of_turn>`
|
||||
: message.content
|
||||
|
||||
if (!onResponse) {
|
||||
const { response } = await geminiModel.generateContent(content)
|
||||
return response.text()
|
||||
const response = await this.sdk.models.generateContent({
|
||||
model: model.id,
|
||||
config: {
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
|
||||
},
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: content }]
|
||||
}
|
||||
]
|
||||
})
|
||||
return response.text || ''
|
||||
}
|
||||
|
||||
const response = await geminiModel.generateContentStream(content)
|
||||
|
||||
const response = await this.sdk.models.generateContentStream({
|
||||
model: model.id,
|
||||
config: {
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
|
||||
},
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: content }]
|
||||
}
|
||||
]
|
||||
})
|
||||
let text = ''
|
||||
|
||||
for await (const chunk of response.stream) {
|
||||
text += chunk.text()
|
||||
for await (const chunk of response) {
|
||||
text += chunk.text
|
||||
onResponse(text)
|
||||
}
|
||||
|
||||
@ -442,25 +490,24 @@ export default class GeminiProvider extends BaseProvider {
|
||||
content: userMessageContent
|
||||
}
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel(
|
||||
{
|
||||
model: model.id,
|
||||
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content }),
|
||||
generationConfig: {
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
},
|
||||
this.requestOptions
|
||||
)
|
||||
|
||||
const chat = await geminiModel.startChat()
|
||||
const content = isGemmaModel(model)
|
||||
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
|
||||
: userMessage.content
|
||||
|
||||
const { response } = await chat.sendMessage(content)
|
||||
const response = await this.sdk.models.generateContent({
|
||||
model: model.id,
|
||||
config: {
|
||||
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content
|
||||
},
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: content }]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
return removeSpecialCharactersForTopicName(response.text())
|
||||
return removeSpecialCharactersForTopicName(response.text || '')
|
||||
}
|
||||
|
||||
/**
|
||||
@ -471,24 +518,23 @@ export default class GeminiProvider extends BaseProvider {
|
||||
*/
|
||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
const model = getDefaultModel()
|
||||
const systemMessage = { role: 'system', content: prompt }
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel(
|
||||
{
|
||||
model: model.id,
|
||||
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content })
|
||||
},
|
||||
this.requestOptions
|
||||
)
|
||||
|
||||
const chat = await geminiModel.startChat()
|
||||
const messageContent = isGemmaModel(model)
|
||||
const MessageContent = isGemmaModel(model)
|
||||
? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
|
||||
: content
|
||||
const response = await this.sdk.models.generateContent({
|
||||
model: model.id,
|
||||
config: {
|
||||
systemInstruction: isGemmaModel(model) ? undefined : prompt
|
||||
},
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: MessageContent }]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const { response } = await chat.sendMessage(messageContent)
|
||||
|
||||
return response.text()
|
||||
return response.text || ''
|
||||
}
|
||||
|
||||
/**
|
||||
@ -518,24 +564,28 @@ export default class GeminiProvider extends BaseProvider {
|
||||
content: messages.map((m) => m.content).join('\n')
|
||||
}
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel(
|
||||
{
|
||||
const content = isGemmaModel(model)
|
||||
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
|
||||
: userMessage.content
|
||||
|
||||
const response = await this.sdk.models.generateContent({
|
||||
model: model.id,
|
||||
systemInstruction: systemMessage.content,
|
||||
generationConfig: {
|
||||
temperature: assistant?.settings?.temperature
|
||||
}
|
||||
},
|
||||
{
|
||||
...this.requestOptions,
|
||||
config: {
|
||||
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
httpOptions: {
|
||||
timeout: 20 * 1000
|
||||
}
|
||||
)
|
||||
},
|
||||
contents: [
|
||||
{
|
||||
role: 'user',
|
||||
parts: [{ text: content }]
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const chat = await geminiModel.startChat()
|
||||
const { response } = await chat.sendMessage(userMessage.content)
|
||||
|
||||
return response.text()
|
||||
return response.text || ''
|
||||
}
|
||||
|
||||
/**
|
||||
@ -546,144 +596,13 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像
|
||||
* @param messages - 消息列表
|
||||
* @param assistant - 助手配置
|
||||
* @param onChunk - 处理生成块的回调
|
||||
* @param onFilterMessages - 过滤消息的回调
|
||||
* @returns Promise<void>
|
||||
*/
|
||||
private async generateImageExp({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, streamOutput, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||
onFilterMessages(userMessages)
|
||||
|
||||
const userLastMessage = userMessages.pop()
|
||||
if (!userLastMessage) {
|
||||
throw new Error('No user message found')
|
||||
}
|
||||
|
||||
const history: Content[] = []
|
||||
|
||||
for (const message of userMessages) {
|
||||
history.push(await this.getMessageContents(message))
|
||||
}
|
||||
|
||||
const userLastMessageContent = await this.getMessageContents(userLastMessage)
|
||||
const allContents = [...history, userLastMessageContent]
|
||||
|
||||
let contents: ContentListUnion = allContents.length > 0 ? (allContents as ContentListUnion) : []
|
||||
|
||||
contents = await this.addImageFileToContents(userLastMessage, contents)
|
||||
|
||||
if (!streamOutput) {
|
||||
const response = await this.callGeminiGenerateContent(model.id, contents, maxTokens)
|
||||
|
||||
const { isValid, message } = this.isValidGeminiResponse(response)
|
||||
if (!isValid) {
|
||||
throw new Error(`Gemini API error: ${message}`)
|
||||
}
|
||||
|
||||
this.processGeminiImageResponse(response, onChunk)
|
||||
return
|
||||
}
|
||||
const response = await this.callGeminiGenerateContentStream(model.id, contents, maxTokens)
|
||||
|
||||
for await (const chunk of response) {
|
||||
this.processGeminiImageResponse(chunk, onChunk)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加图片文件到内容列表
|
||||
* @param message - 用户消息
|
||||
* @param contents - 内容列表
|
||||
* @returns 更新后的内容列表
|
||||
*/
|
||||
private async addImageFileToContents(message: Message, contents: ContentListUnion): Promise<ContentListUnion> {
|
||||
if (message.files && message.files.length > 0) {
|
||||
const file = message.files[0]
|
||||
const fileContent = await window.api.file.base64Image(file.id + file.ext)
|
||||
|
||||
if (fileContent && fileContent.base64) {
|
||||
const contentsArray = Array.isArray(contents) ? contents : [contents]
|
||||
return [...contentsArray, createPartFromBase64(fileContent.base64, fileContent.mime)]
|
||||
}
|
||||
}
|
||||
return contents
|
||||
}
|
||||
|
||||
/**
|
||||
* 调用Gemini API生成内容
|
||||
* @param modelId - 模型ID
|
||||
* @param contents - 内容列表
|
||||
* @returns 生成结果
|
||||
*/
|
||||
private async callGeminiGenerateContent(
|
||||
modelId: string,
|
||||
contents: ContentListUnion,
|
||||
maxTokens?: number
|
||||
): Promise<GenerateContentResponse> {
|
||||
try {
|
||||
return await this.imageSdk.models.generateContent({
|
||||
model: modelId,
|
||||
contents: contents,
|
||||
config: {
|
||||
responseModalities: ['Text', 'Image'],
|
||||
responseMimeType: 'text/plain',
|
||||
maxOutputTokens: maxTokens
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Gemini API error:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private async callGeminiGenerateContentStream(
|
||||
modelId: string,
|
||||
contents: ContentListUnion,
|
||||
maxTokens?: number
|
||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
||||
try {
|
||||
return await this.imageSdk.models.generateContentStream({
|
||||
model: modelId,
|
||||
contents: contents,
|
||||
config: {
|
||||
responseModalities: ['Text', 'Image'],
|
||||
responseMimeType: 'text/plain',
|
||||
maxOutputTokens: maxTokens
|
||||
}
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('Gemini API error:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查Gemini响应是否有效
|
||||
* @param response - Gemini响应
|
||||
* @returns 是否有效
|
||||
*/
|
||||
private isValidGeminiResponse(response: GenerateContentResponse): { isValid: boolean; message: string } {
|
||||
return {
|
||||
isValid: response?.candidates?.[0]?.finishReason === FinishReason.STOP ? true : false,
|
||||
message: response?.candidates?.[0]?.finishReason || ''
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理Gemini图像响应
|
||||
* @param response - Gemini响应
|
||||
* @param onChunk - 处理生成块的回调
|
||||
*/
|
||||
private processGeminiImageResponse(response: any, onChunk: (chunk: ChunkCallbackData) => void): void {
|
||||
const parts = response.candidates[0].content.parts
|
||||
private processGeminiImageResponse(chunk: GenerateContentResponse): { type: 'base64'; images: string[] } | undefined {
|
||||
const parts = chunk.candidates?.[0]?.content?.parts
|
||||
if (!parts) {
|
||||
return
|
||||
}
|
||||
@ -695,31 +614,13 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return null
|
||||
}
|
||||
const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,`
|
||||
return part.inlineData.data.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
|
||||
return part.inlineData.data?.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
|
||||
})
|
||||
|
||||
// 提取文本数据
|
||||
const text = parts
|
||||
.filter((part: Part) => part.text !== undefined)
|
||||
.map((part: Part) => part.text)
|
||||
.join('')
|
||||
|
||||
// 返回结果
|
||||
onChunk({
|
||||
text,
|
||||
generateImage: {
|
||||
return {
|
||||
type: 'base64',
|
||||
images
|
||||
},
|
||||
usage: {
|
||||
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
||||
total_tokens: response.usageMetadata?.totalTokenCount || 0
|
||||
},
|
||||
metrics: {
|
||||
completion_tokens: response.usageMetadata?.candidatesTokenCount
|
||||
images: images.filter((image) => image !== null)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@ -732,18 +633,16 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return { valid: false, error: new Error('No model found') }
|
||||
}
|
||||
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user', content: 'hi' }],
|
||||
max_tokens: 100,
|
||||
stream: false
|
||||
}
|
||||
|
||||
try {
|
||||
const geminiModel = this.sdk.getGenerativeModel({ model: body.model }, this.requestOptions)
|
||||
const result = await geminiModel.generateContent(body.messages[0].content)
|
||||
const result = await this.sdk.models.generateContent({
|
||||
model: model.id,
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
|
||||
config: {
|
||||
maxOutputTokens: 100
|
||||
}
|
||||
})
|
||||
return {
|
||||
valid: !isEmpty(result.response.text()),
|
||||
valid: !isEmpty(result.text),
|
||||
error: null
|
||||
}
|
||||
} catch (error: any) {
|
||||
@ -785,7 +684,10 @@ export default class GeminiProvider extends BaseProvider {
|
||||
* @returns The embedding dimensions
|
||||
*/
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi')
|
||||
return data.embedding.values.length
|
||||
const data = await this.sdk.models.embedContent({
|
||||
model: model.id,
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||
})
|
||||
return data.embeddings?.[0]?.values?.length || 0
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import type { GroundingMetadata } from '@google/generative-ai'
|
||||
import type { GroundingMetadata } from '@google/genai'
|
||||
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
|
||||
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
|
||||
import type {
|
||||
|
||||
@ -282,6 +282,7 @@ export async function fetchChatCompletion({
|
||||
}
|
||||
}
|
||||
|
||||
console.log('message', message)
|
||||
// Emit chat completion event
|
||||
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
||||
onResponse(message)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { GroundingMetadata } from '@google/generative-ai'
|
||||
import { GroundingMetadata } from '@google/genai'
|
||||
import OpenAI from 'openai'
|
||||
import React from 'react'
|
||||
import { BuiltinTheme } from 'shiki'
|
||||
|
||||
@ -71,8 +71,8 @@ export function withGeminiGrounding(message: Message) {
|
||||
let content = message.content
|
||||
|
||||
groundingSupports.forEach((support) => {
|
||||
const text = support?.segment
|
||||
const indices = support?.groundingChunckIndices
|
||||
const text = support?.segment?.text
|
||||
const indices = support?.groundingChunkIndices
|
||||
|
||||
if (!text || !indices) return
|
||||
|
||||
|
||||
@ -1,186 +1,165 @@
|
||||
import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
||||
import { MessageParam } from '@anthropic-ai/sdk/resources'
|
||||
import {
|
||||
ArraySchema,
|
||||
BaseSchema,
|
||||
BooleanSchema,
|
||||
EnumStringSchema,
|
||||
FunctionCall,
|
||||
FunctionDeclaration,
|
||||
FunctionDeclarationSchema,
|
||||
FunctionDeclarationSchemaProperty,
|
||||
IntegerSchema,
|
||||
NumberSchema,
|
||||
ObjectSchema,
|
||||
SchemaType,
|
||||
SimpleStringSchema,
|
||||
Tool as geminiTool
|
||||
} from '@google/generative-ai'
|
||||
import { Content, Part } from '@google/generative-ai'
|
||||
import { Content, FunctionCall, Part } from '@google/genai'
|
||||
import store from '@renderer/store'
|
||||
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import {
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionMessageToolCall,
|
||||
ChatCompletionTool
|
||||
} from 'openai/resources'
|
||||
import { ChatCompletionContentPart, ChatCompletionMessageParam, ChatCompletionMessageToolCall } from 'openai/resources'
|
||||
|
||||
import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider'
|
||||
|
||||
const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => {
|
||||
// Filter out unsupported keys for Gemini
|
||||
const filteredObj = filterUnsupportedKeys(obj)
|
||||
// const ensureValidSchema = (obj: Record<string, any>) => {
|
||||
// // Filter out unsupported keys for Gemini
|
||||
// const filteredObj = filterUnsupportedKeys(obj)
|
||||
|
||||
// Handle base schema properties
|
||||
const baseSchema = {
|
||||
description: filteredObj.description,
|
||||
nullable: filteredObj.nullable
|
||||
} as BaseSchema
|
||||
// // Handle base schema properties
|
||||
// const baseSchema = {
|
||||
// description: filteredObj.description,
|
||||
// nullable: filteredObj.nullable
|
||||
// } as BaseSchema
|
||||
|
||||
// Handle string type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
|
||||
if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.STRING,
|
||||
format: 'enum',
|
||||
enum: filteredObj.enum as string[]
|
||||
} as EnumStringSchema
|
||||
}
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.STRING,
|
||||
format: filteredObj.format === 'date-time' ? 'date-time' : undefined
|
||||
} as SimpleStringSchema
|
||||
}
|
||||
// // Handle string type
|
||||
// if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
|
||||
// if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
|
||||
// return {
|
||||
// ...baseSchema,
|
||||
// type: SchemaType.STRING,
|
||||
// format: 'enum',
|
||||
// enum: filteredObj.enum as string[]
|
||||
// } as EnumStringSchema
|
||||
// }
|
||||
// return {
|
||||
// ...baseSchema,
|
||||
// type: SchemaType.STRING,
|
||||
// format: filteredObj.format === 'date-time' ? 'date-time' : undefined
|
||||
// } as SimpleStringSchema
|
||||
// }
|
||||
|
||||
// Handle number type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.NUMBER,
|
||||
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
|
||||
} as NumberSchema
|
||||
}
|
||||
// // Handle number type
|
||||
// if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||
// return {
|
||||
// ...baseSchema,
|
||||
// type: SchemaType.NUMBER,
|
||||
// format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
|
||||
// } as NumberSchema
|
||||
// }
|
||||
|
||||
// Handle integer type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.INTEGER,
|
||||
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
|
||||
} as IntegerSchema
|
||||
}
|
||||
// // Handle integer type
|
||||
// if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||
// return {
|
||||
// ...baseSchema,
|
||||
// type: SchemaType.INTEGER,
|
||||
// format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
|
||||
// } as IntegerSchema
|
||||
// }
|
||||
|
||||
// Handle boolean type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.BOOLEAN
|
||||
} as BooleanSchema
|
||||
}
|
||||
// // Handle boolean type
|
||||
// if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||
// return {
|
||||
// ...baseSchema,
|
||||
// type: SchemaType.BOOLEAN
|
||||
// } as BooleanSchema
|
||||
// }
|
||||
|
||||
// Handle array type
|
||||
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.ARRAY,
|
||||
items: filteredObj.items
|
||||
? ensureValidSchema(filteredObj.items as Record<string, any>)
|
||||
: ({ type: SchemaType.STRING } as SimpleStringSchema),
|
||||
minItems: filteredObj.minItems,
|
||||
maxItems: filteredObj.maxItems
|
||||
} as ArraySchema
|
||||
}
|
||||
// // Handle array type
|
||||
// if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||
// return {
|
||||
// ...baseSchema,
|
||||
// type: SchemaType.ARRAY,
|
||||
// items: filteredObj.items
|
||||
// ? ensureValidSchema(filteredObj.items as Record<string, any>)
|
||||
// : ({ type: SchemaType.STRING } as SimpleStringSchema),
|
||||
// minItems: filteredObj.minItems,
|
||||
// maxItems: filteredObj.maxItems
|
||||
// } as ArraySchema
|
||||
// }
|
||||
|
||||
// Handle object type (default)
|
||||
const properties = filteredObj.properties
|
||||
? Object.fromEntries(
|
||||
Object.entries(filteredObj.properties).map(([key, value]) => [
|
||||
key,
|
||||
ensureValidSchema(value as Record<string, any>)
|
||||
])
|
||||
)
|
||||
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
|
||||
// // Handle object type (default)
|
||||
// const properties = filteredObj.properties
|
||||
// ? Object.fromEntries(
|
||||
// Object.entries(filteredObj.properties).map(([key, value]) => [
|
||||
// key,
|
||||
// ensureValidSchema(value as Record<string, any>)
|
||||
// ])
|
||||
// )
|
||||
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
|
||||
|
||||
return {
|
||||
...baseSchema,
|
||||
type: SchemaType.OBJECT,
|
||||
properties,
|
||||
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
|
||||
} as ObjectSchema
|
||||
}
|
||||
// return {
|
||||
// ...baseSchema,
|
||||
// type: SchemaType.OBJECT,
|
||||
// properties,
|
||||
// required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
|
||||
// } as ObjectSchema
|
||||
// }
|
||||
|
||||
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
|
||||
const supportedBaseKeys = ['description', 'nullable']
|
||||
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
|
||||
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
|
||||
const supportedBooleanKeys = [...supportedBaseKeys, 'type']
|
||||
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
|
||||
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
|
||||
// function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
|
||||
// const supportedBaseKeys = ['description', 'nullable']
|
||||
// const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
|
||||
// const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
|
||||
// const supportedBooleanKeys = [...supportedBaseKeys, 'type']
|
||||
// const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
|
||||
// const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
|
||||
|
||||
const filtered: Record<string, any> = {}
|
||||
// const filtered: Record<string, any> = {}
|
||||
|
||||
let keysToKeep: string[]
|
||||
// let keysToKeep: string[]
|
||||
|
||||
if (obj.type?.toLowerCase() === SchemaType.STRING) {
|
||||
keysToKeep = supportedStringKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||
keysToKeep = supportedNumberKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||
keysToKeep = supportedNumberKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||
keysToKeep = supportedBooleanKeys
|
||||
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||
keysToKeep = supportedArrayKeys
|
||||
} else {
|
||||
// Default to object type
|
||||
keysToKeep = supportedObjectKeys
|
||||
}
|
||||
// if (obj.type?.toLowerCase() === SchemaType.STRING) {
|
||||
// keysToKeep = supportedStringKeys
|
||||
// } else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||
// keysToKeep = supportedNumberKeys
|
||||
// } else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||
// keysToKeep = supportedNumberKeys
|
||||
// } else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||
// keysToKeep = supportedBooleanKeys
|
||||
// } else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||
// keysToKeep = supportedArrayKeys
|
||||
// } else {
|
||||
// // Default to object type
|
||||
// keysToKeep = supportedObjectKeys
|
||||
// }
|
||||
|
||||
// copy supported keys
|
||||
for (const key of keysToKeep) {
|
||||
if (obj[key] !== undefined) {
|
||||
filtered[key] = obj[key]
|
||||
}
|
||||
}
|
||||
// // copy supported keys
|
||||
// for (const key of keysToKeep) {
|
||||
// if (obj[key] !== undefined) {
|
||||
// filtered[key] = obj[key]
|
||||
// }
|
||||
// }
|
||||
|
||||
return filtered
|
||||
}
|
||||
// return filtered
|
||||
// }
|
||||
|
||||
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
|
||||
const properties = tool.inputSchema.properties
|
||||
if (!properties) {
|
||||
return {}
|
||||
}
|
||||
// function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
|
||||
// const properties = tool.inputSchema.properties
|
||||
// if (!properties) {
|
||||
// return {}
|
||||
// }
|
||||
|
||||
// For OpenAI, we don't need to validate as strictly
|
||||
if (!filterNestedObj) {
|
||||
return properties
|
||||
}
|
||||
// // For OpenAI, we don't need to validate as strictly
|
||||
// if (!filterNestedObj) {
|
||||
// return properties
|
||||
// }
|
||||
|
||||
const processedProperties = Object.fromEntries(
|
||||
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||
)
|
||||
// const processedProperties = Object.fromEntries(
|
||||
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||
// )
|
||||
|
||||
return processedProperties
|
||||
}
|
||||
// return processedProperties
|
||||
// }
|
||||
|
||||
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
||||
return mcpTools.map((tool) => ({
|
||||
type: 'function',
|
||||
name: tool.name,
|
||||
function: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: filterPropertieAttributes(tool)
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
// export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
||||
// return mcpTools.map((tool) => ({
|
||||
// type: 'function',
|
||||
// name: tool.name,
|
||||
// function: {
|
||||
// name: tool.id,
|
||||
// description: tool.description,
|
||||
// parameters: {
|
||||
// type: 'object',
|
||||
// properties: filterPropertieAttributes(tool)
|
||||
// }
|
||||
// }
|
||||
// }))
|
||||
// }
|
||||
|
||||
export function openAIToolsToMcpTool(
|
||||
mcpTools: MCPTool[] | undefined,
|
||||
@ -277,35 +256,35 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
|
||||
return tool
|
||||
}
|
||||
|
||||
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
|
||||
if (!mcpTools || mcpTools.length === 0) {
|
||||
// No tools available
|
||||
return []
|
||||
}
|
||||
const functions: FunctionDeclaration[] = []
|
||||
// export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
|
||||
// if (!mcpTools || mcpTools.length === 0) {
|
||||
// // No tools available
|
||||
// return []
|
||||
// }
|
||||
// const functions: FunctionDeclaration[] = []
|
||||
|
||||
for (const tool of mcpTools) {
|
||||
const properties = filterPropertieAttributes(tool, true)
|
||||
const functionDeclaration: FunctionDeclaration = {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
parameters: {
|
||||
type: SchemaType.OBJECT,
|
||||
properties:
|
||||
Object.keys(properties).length > 0
|
||||
? Object.fromEntries(
|
||||
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||
)
|
||||
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
|
||||
} as FunctionDeclarationSchema
|
||||
}
|
||||
functions.push(functionDeclaration)
|
||||
}
|
||||
const tool: geminiTool = {
|
||||
functionDeclarations: functions
|
||||
}
|
||||
return [tool]
|
||||
}
|
||||
// for (const tool of mcpTools) {
|
||||
// const properties = filterPropertieAttributes(tool, true)
|
||||
// const functionDeclaration: FunctionDeclaration = {
|
||||
// name: tool.id,
|
||||
// description: tool.description,
|
||||
// parameters: {
|
||||
// type: SchemaType.OBJECT,
|
||||
// properties:
|
||||
// Object.keys(properties).length > 0
|
||||
// ? Object.fromEntries(
|
||||
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||
// )
|
||||
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
|
||||
// } as FunctionDeclarationSchema
|
||||
// }
|
||||
// functions.push(functionDeclaration)
|
||||
// }
|
||||
// const tool: geminiTool = {
|
||||
// functionDeclarations: functions
|
||||
// }
|
||||
// return [tool]
|
||||
// }
|
||||
|
||||
export function geminiFunctionCallToMcpTool(
|
||||
mcpTools: MCPTool[] | undefined,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user