refactor(Gemini): migrate generative-ai sdk to genai sdk (#4939)
* refactor(GeminiService): migrate to new Google GenAI SDK and update file handling methods - Updated import statements to use the new Google GenAI SDK. - Refactored file upload, retrieval, and deletion methods to align with the new SDK's API. - Adjusted type definitions and response handling for improved type safety and clarity. - Enhanced file listing and processing logic to utilize async iteration for better performance. * refactor(GeminiProvider): update message handling and integrate abort signal support - Refactored message content handling to align with updated type definitions, ensuring consistent use of Content type. - Enhanced abort signal management for chat requests, allowing for better control over ongoing operations. - Improved message processing logic to streamline user message history handling and response generation. - Adjusted type definitions for message contents to enhance type safety and clarity. * refactor(electron.vite.config): replace direct import of Vite React plugin with dynamic import * fix(Gemini): clean up unused methods and improve property access * fix(typecheck): update color properties to use CSS variables * feat: 修改画图逻辑 * fix: import viteReact --------- Co-authored-by: eeee0717 <chentao020717Work@outlook.com>
This commit is contained in:
parent
35c50b54a8
commit
24ddd69cd5
37698
.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch
vendored
Normal file
37698
.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
|||||||
import react from '@vitejs/plugin-react'
|
import viteReact from '@vitejs/plugin-react'
|
||||||
import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
|
import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
|
||||||
import { resolve } from 'path'
|
import { resolve } from 'path'
|
||||||
import { visualizer } from 'rollup-plugin-visualizer'
|
import { visualizer } from 'rollup-plugin-visualizer'
|
||||||
@ -6,7 +6,7 @@ import { visualizer } from 'rollup-plugin-visualizer'
|
|||||||
const visualizerPlugin = (type: 'renderer' | 'main') => {
|
const visualizerPlugin = (type: 'renderer' | 'main') => {
|
||||||
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
|
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
|
||||||
}
|
}
|
||||||
|
// const viteReact = await import('@vitejs/plugin-react')
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
main: {
|
main: {
|
||||||
plugins: [
|
plugins: [
|
||||||
@ -51,7 +51,7 @@ export default defineConfig({
|
|||||||
},
|
},
|
||||||
renderer: {
|
renderer: {
|
||||||
plugins: [
|
plugins: [
|
||||||
react({
|
viteReact({
|
||||||
babel: {
|
babel: {
|
||||||
plugins: [
|
plugins: [
|
||||||
[
|
[
|
||||||
|
|||||||
@ -64,7 +64,6 @@
|
|||||||
"@cherrystudio/embedjs-openai": "^0.1.28",
|
"@cherrystudio/embedjs-openai": "^0.1.28",
|
||||||
"@electron-toolkit/utils": "^3.0.0",
|
"@electron-toolkit/utils": "^3.0.0",
|
||||||
"@electron/notarize": "^2.5.0",
|
"@electron/notarize": "^2.5.0",
|
||||||
"@google/generative-ai": "^0.24.0",
|
|
||||||
"@langchain/community": "^0.3.36",
|
"@langchain/community": "^0.3.36",
|
||||||
"@mozilla/readability": "^0.6.0",
|
"@mozilla/readability": "^0.6.0",
|
||||||
"@notionhq/client": "^2.2.15",
|
"@notionhq/client": "^2.2.15",
|
||||||
@ -74,6 +73,7 @@
|
|||||||
"@xyflow/react": "^12.4.4",
|
"@xyflow/react": "^12.4.4",
|
||||||
"adm-zip": "^0.5.16",
|
"adm-zip": "^0.5.16",
|
||||||
"async-mutex": "^0.5.0",
|
"async-mutex": "^0.5.0",
|
||||||
|
"bufferutil": "^4.0.9",
|
||||||
"color": "^5.0.0",
|
"color": "^5.0.0",
|
||||||
"diff": "^7.0.0",
|
"diff": "^7.0.0",
|
||||||
"docx": "^9.0.2",
|
"docx": "^9.0.2",
|
||||||
@ -96,6 +96,7 @@
|
|||||||
"turndown-plugin-gfm": "^1.0.2",
|
"turndown-plugin-gfm": "^1.0.2",
|
||||||
"undici": "^7.4.0",
|
"undici": "^7.4.0",
|
||||||
"webdav": "^5.8.0",
|
"webdav": "^5.8.0",
|
||||||
|
"ws": "^8.18.1",
|
||||||
"zipread": "^1.3.3"
|
"zipread": "^1.3.3"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
@ -112,7 +113,7 @@
|
|||||||
"@emotion/is-prop-valid": "^1.3.1",
|
"@emotion/is-prop-valid": "^1.3.1",
|
||||||
"@eslint-react/eslint-plugin": "^1.36.1",
|
"@eslint-react/eslint-plugin": "^1.36.1",
|
||||||
"@eslint/js": "^9.22.0",
|
"@eslint/js": "^9.22.0",
|
||||||
"@google/genai": "^0.4.0",
|
"@google/genai": "patch:@google/genai@npm%3A0.8.0#~/.yarn/patches/@google-genai-npm-0.8.0-450d0d9a7d.patch",
|
||||||
"@hello-pangea/dnd": "^16.6.0",
|
"@hello-pangea/dnd": "^16.6.0",
|
||||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||||
"@modelcontextprotocol/sdk": "^1.9.0",
|
"@modelcontextprotocol/sdk": "^1.9.0",
|
||||||
@ -133,7 +134,8 @@
|
|||||||
"@types/react-dom": "^19.0.4",
|
"@types/react-dom": "^19.0.4",
|
||||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||||
"@types/tinycolor2": "^1",
|
"@types/tinycolor2": "^1",
|
||||||
"@vitejs/plugin-react": "^4.2.1",
|
"@types/ws": "^8",
|
||||||
|
"@vitejs/plugin-react": "4.3.4",
|
||||||
"analytics": "^0.8.16",
|
"analytics": "^0.8.16",
|
||||||
"antd": "^5.22.5",
|
"antd": "^5.22.5",
|
||||||
"applescript": "^1.0.0",
|
"applescript": "^1.0.0",
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { FileMetadataResponse, FileState, GoogleAIFileManager } from '@google/generative-ai/server'
|
import { File, FileState, GoogleGenAI, Pager } from '@google/genai'
|
||||||
import { FileType } from '@types'
|
import { FileType } from '@types'
|
||||||
import fs from 'fs'
|
import fs from 'fs'
|
||||||
|
|
||||||
@ -8,11 +8,15 @@ export class GeminiService {
|
|||||||
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
|
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
|
||||||
private static readonly CACHE_DURATION = 3000
|
private static readonly CACHE_DURATION = 3000
|
||||||
|
|
||||||
static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string) {
|
static async uploadFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File> {
|
||||||
const fileManager = new GoogleAIFileManager(apiKey)
|
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||||
const uploadResult = await fileManager.uploadFile(file.path, {
|
const uploadResult = await sdk.files.upload({
|
||||||
mimeType: 'application/pdf',
|
file: file.path,
|
||||||
displayName: file.origin_name
|
config: {
|
||||||
|
mimeType: 'application/pdf',
|
||||||
|
name: file.id,
|
||||||
|
displayName: file.origin_name
|
||||||
|
}
|
||||||
})
|
})
|
||||||
return uploadResult
|
return uploadResult
|
||||||
}
|
}
|
||||||
@ -24,40 +28,42 @@ export class GeminiService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static async retrieveFile(
|
static async retrieveFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File | undefined> {
|
||||||
_: Electron.IpcMainInvokeEvent,
|
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||||
file: FileType,
|
|
||||||
apiKey: string
|
|
||||||
): Promise<FileMetadataResponse | undefined> {
|
|
||||||
const fileManager = new GoogleAIFileManager(apiKey)
|
|
||||||
|
|
||||||
const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY)
|
const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY)
|
||||||
if (cachedResponse) {
|
if (cachedResponse) {
|
||||||
return GeminiService.processResponse(cachedResponse, file)
|
return GeminiService.processResponse(cachedResponse, file)
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await fileManager.listFiles()
|
const response = await sdk.files.list()
|
||||||
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION)
|
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION)
|
||||||
|
|
||||||
return GeminiService.processResponse(response, file)
|
return GeminiService.processResponse(response, file)
|
||||||
}
|
}
|
||||||
|
|
||||||
private static processResponse(response: any, file: FileType) {
|
private static async processResponse(response: Pager<File>, file: FileType) {
|
||||||
if (response.files) {
|
for await (const f of response) {
|
||||||
return response.files
|
if (f.state === FileState.ACTIVE) {
|
||||||
.filter((file) => file.state === FileState.ACTIVE)
|
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
||||||
.find((i) => i.displayName === file.origin_name && Number(i.sizeBytes) === file.size)
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string) {
|
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string): Promise<File[]> {
|
||||||
const fileManager = new GoogleAIFileManager(apiKey)
|
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||||
return await fileManager.listFiles()
|
const files: File[] = []
|
||||||
|
for await (const f of await sdk.files.list()) {
|
||||||
|
files.push(f)
|
||||||
|
}
|
||||||
|
return files
|
||||||
}
|
}
|
||||||
|
|
||||||
static async deleteFile(_: Electron.IpcMainInvokeEvent, apiKey: string, fileId: string) {
|
static async deleteFile(_: Electron.IpcMainInvokeEvent, fileId: string, apiKey: string) {
|
||||||
const fileManager = new GoogleAIFileManager(apiKey)
|
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||||
await fileManager.deleteFile(fileId)
|
await sdk.files.delete({ name: fileId })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
10
src/preload/index.d.ts
vendored
10
src/preload/index.d.ts
vendored
@ -1,6 +1,6 @@
|
|||||||
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||||
import { ElectronAPI } from '@electron-toolkit/preload'
|
import { ElectronAPI } from '@electron-toolkit/preload'
|
||||||
import type { FileMetadataResponse, ListFilesResponse, UploadFileResponse } from '@google/generative-ai/server'
|
import type { File } from '@google/genai'
|
||||||
import type { GetMCPPromptResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@renderer/types'
|
import type { GetMCPPromptResponse, MCPPrompt, MCPResource, MCPServer, MCPTool } from '@renderer/types'
|
||||||
import { AppInfo, FileType, KnowledgeBaseParams, KnowledgeItem, LanguageVarious, WebDavConfig } from '@renderer/types'
|
import { AppInfo, FileType, KnowledgeBaseParams, KnowledgeItem, LanguageVarious, WebDavConfig } from '@renderer/types'
|
||||||
import type { LoaderReturn } from '@shared/config/types'
|
import type { LoaderReturn } from '@shared/config/types'
|
||||||
@ -119,11 +119,11 @@ declare global {
|
|||||||
resetMinimumSize: () => Promise<void>
|
resetMinimumSize: () => Promise<void>
|
||||||
}
|
}
|
||||||
gemini: {
|
gemini: {
|
||||||
uploadFile: (file: FileType, apiKey: string) => Promise<UploadFileResponse>
|
uploadFile: (file: FileType, apiKey: string) => Promise<File>
|
||||||
retrieveFile: (file: FileType, apiKey: string) => Promise<FileMetadataResponse | undefined>
|
retrieveFile: (file: FileType, apiKey: string) => Promise<File | undefined>
|
||||||
base64File: (file: FileType) => Promise<{ data: string; mimeType: string }>
|
base64File: (file: FileType) => Promise<{ data: string; mimeType: string }>
|
||||||
listFiles: (apiKey: string) => Promise<ListFilesResponse>
|
listFiles: (apiKey: string) => Promise<File[]>
|
||||||
deleteFile: (apiKey: string, fileId: string) => Promise<void>
|
deleteFile: (fileId: string, apiKey: string) => Promise<void>
|
||||||
}
|
}
|
||||||
selectionMenu: {
|
selectionMenu: {
|
||||||
action: (action: string) => Promise<void>
|
action: (action: string) => Promise<void>
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { DeleteOutlined } from '@ant-design/icons'
|
import { DeleteOutlined } from '@ant-design/icons'
|
||||||
import type { FileMetadataResponse } from '@google/generative-ai/server'
|
import type { File } from '@google/genai'
|
||||||
import { useProvider } from '@renderer/hooks/useProvider'
|
import { useProvider } from '@renderer/hooks/useProvider'
|
||||||
import { runAsyncFunction } from '@renderer/utils'
|
import { runAsyncFunction } from '@renderer/utils'
|
||||||
import { MB } from '@shared/config/constant'
|
import { MB } from '@shared/config/constant'
|
||||||
@ -16,11 +16,11 @@ interface GeminiFilesProps {
|
|||||||
|
|
||||||
const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
|
const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
|
||||||
const { provider } = useProvider(id)
|
const { provider } = useProvider(id)
|
||||||
const [files, setFiles] = useState<FileMetadataResponse[]>([])
|
const [files, setFiles] = useState<File[]>([])
|
||||||
const [loading, setLoading] = useState(false)
|
const [loading, setLoading] = useState(false)
|
||||||
|
|
||||||
const fetchFiles = useCallback(async () => {
|
const fetchFiles = useCallback(async () => {
|
||||||
const { files } = await window.api.gemini.listFiles(provider.apiKey)
|
const files = await window.api.gemini.listFiles(provider.apiKey)
|
||||||
files && setFiles(files.filter((file) => file.state === 'ACTIVE'))
|
files && setFiles(files.filter((file) => file.state === 'ACTIVE'))
|
||||||
}, [provider])
|
}, [provider])
|
||||||
|
|
||||||
@ -60,14 +60,14 @@ const GeminiFiles: FC<GeminiFilesProps> = ({ id }) => {
|
|||||||
key={file.name}
|
key={file.name}
|
||||||
fileInfo={{
|
fileInfo={{
|
||||||
name: file.displayName,
|
name: file.displayName,
|
||||||
ext: `.${file.name.split('.').pop()}`,
|
ext: `.${file.name?.split('.').pop()}`,
|
||||||
extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes) / MB).toFixed(2)} MB`,
|
extra: `${dayjs(file.createTime).format('MM-DD HH:mm')} · ${(parseInt(file.sizeBytes || '0') / MB).toFixed(2)} MB`,
|
||||||
actions: (
|
actions: (
|
||||||
<DeleteOutlined
|
<DeleteOutlined
|
||||||
style={{ cursor: 'pointer', color: 'var(--color-error)' }}
|
style={{ cursor: 'pointer', color: 'var(--color-error)' }}
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
setFiles(files.filter((f) => f.name !== file.name))
|
setFiles(files.filter((f) => f.name !== file.name))
|
||||||
window.api.gemini.deleteFile(provider.apiKey, file.name).catch((error) => {
|
window.api.gemini.deleteFile(file.name!, provider.apiKey).catch((error) => {
|
||||||
console.error('Failed to delete file:', error)
|
console.error('Failed to delete file:', error)
|
||||||
setFiles((prev) => [...prev, file])
|
setFiles((prev) => [...prev, file])
|
||||||
})
|
})
|
||||||
|
|||||||
@ -173,13 +173,13 @@ const ServerName = styled.div`
|
|||||||
|
|
||||||
const ServerDescription = styled.div`
|
const ServerDescription = styled.div`
|
||||||
font-size: 0.85rem;
|
font-size: 0.85rem;
|
||||||
color: ${(props) => props.theme.colors?.textSecondary || '#8c8c8c'};
|
color: var(--color-text-2);
|
||||||
margin-bottom: 3px;
|
margin-bottom: 3px;
|
||||||
`
|
`
|
||||||
|
|
||||||
const ServerUrl = styled.div`
|
const ServerUrl = styled.div`
|
||||||
font-size: 0.8rem;
|
font-size: 0.8rem;
|
||||||
color: ${(props) => props.theme.colors?.textTertiary || '#bfbfbf'};
|
color: var(--color-text-3);
|
||||||
white-space: nowrap;
|
white-space: nowrap;
|
||||||
overflow: hidden;
|
overflow: hidden;
|
||||||
text-overflow: ellipsis;
|
text-overflow: ellipsis;
|
||||||
|
|||||||
@ -1,25 +1,18 @@
|
|||||||
import {
|
|
||||||
ContentListUnion,
|
|
||||||
createPartFromBase64,
|
|
||||||
FinishReason,
|
|
||||||
GenerateContentResponse,
|
|
||||||
GoogleGenAI
|
|
||||||
} from '@google/genai'
|
|
||||||
import {
|
import {
|
||||||
Content,
|
Content,
|
||||||
FileDataPart,
|
File,
|
||||||
GenerateContentStreamResult,
|
GenerateContentConfig,
|
||||||
GoogleGenerativeAI,
|
GenerateContentResponse,
|
||||||
|
GoogleGenAI,
|
||||||
HarmBlockThreshold,
|
HarmBlockThreshold,
|
||||||
HarmCategory,
|
HarmCategory,
|
||||||
InlineDataPart,
|
Modality,
|
||||||
Part,
|
Part,
|
||||||
RequestOptions,
|
PartUnion,
|
||||||
SafetySetting,
|
SafetySetting,
|
||||||
TextPart,
|
ToolListUnion
|
||||||
Tool
|
} from '@google/genai'
|
||||||
} from '@google/generative-ai'
|
import { isGemmaModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
|
||||||
import { isGemmaModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
|
|
||||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
@ -39,22 +32,15 @@ import axios from 'axios'
|
|||||||
import { flatten, isEmpty, takeRight } from 'lodash'
|
import { flatten, isEmpty, takeRight } from 'lodash'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
|
|
||||||
import { ChunkCallbackData, CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
|
||||||
export default class GeminiProvider extends BaseProvider {
|
export default class GeminiProvider extends BaseProvider {
|
||||||
private sdk: GoogleGenerativeAI
|
private sdk: GoogleGenAI
|
||||||
private requestOptions: RequestOptions
|
|
||||||
private imageSdk: GoogleGenAI
|
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
this.sdk = new GoogleGenerativeAI(this.apiKey)
|
this.sdk = new GoogleGenAI({ vertexai: false, apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
|
||||||
/// this sdk is experimental
|
|
||||||
this.imageSdk = new GoogleGenAI({ apiKey: this.apiKey, httpOptions: { baseUrl: this.getBaseURL() } })
|
|
||||||
this.requestOptions = {
|
|
||||||
baseUrl: this.getBaseURL()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public getBaseURL(): string {
|
public getBaseURL(): string {
|
||||||
@ -76,31 +62,31 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
inlineData: {
|
inlineData: {
|
||||||
data,
|
data,
|
||||||
mimeType
|
mimeType
|
||||||
}
|
} as Part['inlineData']
|
||||||
} as InlineDataPart
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieve file from Gemini uploaded files
|
// Retrieve file from Gemini uploaded files
|
||||||
const fileMetadata = await window.api.gemini.retrieveFile(file, this.apiKey)
|
const fileMetadata: File | undefined = await window.api.gemini.retrieveFile(file, this.apiKey)
|
||||||
|
|
||||||
if (fileMetadata) {
|
if (fileMetadata) {
|
||||||
return {
|
return {
|
||||||
fileData: {
|
fileData: {
|
||||||
fileUri: fileMetadata.uri,
|
fileUri: fileMetadata.uri,
|
||||||
mimeType: fileMetadata.mimeType
|
mimeType: fileMetadata.mimeType
|
||||||
}
|
} as Part['fileData']
|
||||||
} as FileDataPart
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If file is not found, upload it to Gemini
|
// If file is not found, upload it to Gemini
|
||||||
const uploadResult = await window.api.gemini.uploadFile(file, this.apiKey)
|
const result = await window.api.gemini.uploadFile(file, this.apiKey)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
fileData: {
|
fileData: {
|
||||||
fileUri: uploadResult.file.uri,
|
fileUri: result.uri,
|
||||||
mimeType: uploadResult.file.mimeType
|
mimeType: result.mimeType
|
||||||
}
|
} as Part['fileData']
|
||||||
} as FileDataPart
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -125,8 +111,8 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
inlineData: {
|
inlineData: {
|
||||||
data: base64Data,
|
data: base64Data,
|
||||||
mimeType: mimeType
|
mimeType: mimeType
|
||||||
}
|
} as Part['inlineData']
|
||||||
} as InlineDataPart)
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -139,8 +125,8 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
inlineData: {
|
inlineData: {
|
||||||
data: base64Data.base64,
|
data: base64Data.base64,
|
||||||
mimeType: base64Data.mime
|
mimeType: base64Data.mime
|
||||||
}
|
} as Part['inlineData']
|
||||||
} as InlineDataPart)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if (file.ext === '.pdf') {
|
if (file.ext === '.pdf') {
|
||||||
@ -152,13 +138,13 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||||
parts.push({
|
parts.push({
|
||||||
text: file.origin_name + '\n' + fileContent
|
text: file.origin_name + '\n' + fileContent
|
||||||
} as TextPart)
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
role,
|
role,
|
||||||
parts
|
parts: parts
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,165 +190,215 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
* @param onChunk - The onChunk callback
|
* @param onChunk - The onChunk callback
|
||||||
* @param onFilterMessages - The onFilterMessages callback
|
* @param onFilterMessages - The onFilterMessages callback
|
||||||
*/
|
*/
|
||||||
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
public async completions({
|
||||||
if (assistant.enableGenerateImage) {
|
messages,
|
||||||
await this.generateImageExp({ messages, assistant, onFilterMessages, onChunk })
|
assistant,
|
||||||
} else {
|
mcpTools,
|
||||||
const defaultModel = getDefaultModel()
|
onChunk,
|
||||||
const model = assistant.model || defaultModel
|
onFilterMessages
|
||||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
}: CompletionsParams): Promise<void> {
|
||||||
|
const defaultModel = getDefaultModel()
|
||||||
|
const model = assistant.model || defaultModel
|
||||||
|
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
const userMessages = filterUserRoleStartMessages(
|
const userMessages = filterUserRoleStartMessages(
|
||||||
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||||
)
|
)
|
||||||
onFilterMessages(userMessages)
|
onFilterMessages(userMessages)
|
||||||
|
|
||||||
const userLastMessage = userMessages.pop()
|
const userLastMessage = userMessages.pop()
|
||||||
|
|
||||||
const history: Content[] = []
|
const history: Content[] = []
|
||||||
|
|
||||||
for (const message of userMessages) {
|
for (const message of userMessages) {
|
||||||
history.push(await this.getMessageContents(message))
|
history.push(await this.getMessageContents(message))
|
||||||
}
|
}
|
||||||
|
|
||||||
let systemInstruction = assistant.prompt
|
let systemInstruction = assistant.prompt
|
||||||
|
|
||||||
if (mcpTools && mcpTools.length > 0) {
|
if (mcpTools && mcpTools.length > 0) {
|
||||||
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
|
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
|
||||||
}
|
}
|
||||||
|
|
||||||
// const tools = mcpToolsToGeminiTools(mcpTools)
|
// const tools = mcpToolsToGeminiTools(mcpTools)
|
||||||
const tools: Tool[] = []
|
const tools: ToolListUnion = []
|
||||||
const toolResponses: MCPToolResponse[] = []
|
const toolResponses: MCPToolResponse[] = []
|
||||||
|
|
||||||
if (!WebSearchService.isOverwriteEnabled() && assistant.enableWebSearch && isWebSearchModel(model)) {
|
if (!WebSearchService.isOverwriteEnabled() && assistant.enableWebSearch && isWebSearchModel(model)) {
|
||||||
tools.push({
|
tools.push({
|
||||||
// @ts-ignore googleSearch is not a valid tool for Gemini
|
// @ts-ignore googleSearch is not a valid tool for Gemini
|
||||||
googleSearch: {}
|
googleSearch: {}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
const geminiModel = this.sdk.getGenerativeModel(
|
const generateContentConfig: GenerateContentConfig = {
|
||||||
{
|
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||||
model: model.id,
|
responseMimeType: 'text/plain',
|
||||||
...(isGemmaModel(model) ? {} : { systemInstruction: systemInstruction }),
|
safetySettings: this.getSafetySettings(model.id),
|
||||||
safetySettings: this.getSafetySettings(model.id),
|
// generate image don't need system instruction
|
||||||
tools: tools,
|
systemInstruction: isGemmaModel(model) || isGenerateImageModel(model) ? undefined : systemInstruction,
|
||||||
generationConfig: {
|
temperature: assistant?.settings?.temperature,
|
||||||
maxOutputTokens: maxTokens,
|
topP: assistant?.settings?.topP,
|
||||||
temperature: assistant?.settings?.temperature,
|
maxOutputTokens: maxTokens,
|
||||||
topP: assistant?.settings?.topP,
|
tools: tools,
|
||||||
...this.getCustomParameters(assistant)
|
...this.getCustomParameters(assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageContents: Content = await this.getMessageContents(userLastMessage!)
|
||||||
|
|
||||||
|
const chat = this.sdk.chats.create({
|
||||||
|
model: model.id,
|
||||||
|
config: generateContentConfig,
|
||||||
|
history: history
|
||||||
|
})
|
||||||
|
|
||||||
|
if (isGemmaModel(model) && assistant.prompt) {
|
||||||
|
const isFirstMessage = history.length === 0
|
||||||
|
if (isFirstMessage && messageContents) {
|
||||||
|
const systemMessage = [
|
||||||
|
{
|
||||||
|
text:
|
||||||
|
'<start_of_turn>user\n' +
|
||||||
|
systemInstruction +
|
||||||
|
'<end_of_turn>\n' +
|
||||||
|
'<start_of_turn>user\n' +
|
||||||
|
(messageContents?.parts?.[0] as Part).text +
|
||||||
|
'<end_of_turn>'
|
||||||
}
|
}
|
||||||
},
|
] as Part[]
|
||||||
this.requestOptions
|
if (messageContents && messageContents.parts) {
|
||||||
)
|
messageContents.parts[0] = systemMessage[0]
|
||||||
|
|
||||||
const chat = geminiModel.startChat({ history })
|
|
||||||
const messageContents = await this.getMessageContents(userLastMessage!)
|
|
||||||
|
|
||||||
if (isGemmaModel(model) && assistant.prompt) {
|
|
||||||
const isFirstMessage = history.length === 0
|
|
||||||
if (isFirstMessage) {
|
|
||||||
const systemMessage = {
|
|
||||||
role: 'user',
|
|
||||||
parts: [
|
|
||||||
{
|
|
||||||
text:
|
|
||||||
'<start_of_turn>user\n' +
|
|
||||||
systemInstruction +
|
|
||||||
'<end_of_turn>\n' +
|
|
||||||
'<start_of_turn>user\n' +
|
|
||||||
messageContents.parts[0].text +
|
|
||||||
'<end_of_turn>'
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
messageContents.parts = systemMessage.parts
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const start_time_millsec = new Date().getTime()
|
const start_time_millsec = new Date().getTime()
|
||||||
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
|
|
||||||
const { signal } = abortController
|
const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true)
|
||||||
|
const signalProxy = {
|
||||||
|
_originalSignal: abortController.signal,
|
||||||
|
|
||||||
|
addEventListener: (eventName: string, listener: () => void) => {
|
||||||
|
if (eventName === 'abort') {
|
||||||
|
abortController.signal.addEventListener('abort', listener)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
removeEventListener: (eventName: string, listener: () => void) => {
|
||||||
|
if (eventName === 'abort') {
|
||||||
|
abortController.signal.removeEventListener('abort', listener)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
get aborted() {
|
||||||
|
return abortController.signal.aborted
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!streamOutput) {
|
||||||
|
const response = await chat.sendMessage({
|
||||||
|
message: messageContents as PartUnion,
|
||||||
|
config: {
|
||||||
|
...generateContentConfig,
|
||||||
|
httpOptions: {
|
||||||
|
signal: signalProxy as any
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
onChunk({
|
||||||
|
text: response.text,
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
||||||
|
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
||||||
|
total_tokens: response.usageMetadata?.totalTokenCount || 0
|
||||||
|
},
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: response.usageMetadata?.candidatesTokenCount,
|
||||||
|
time_completion_millsec,
|
||||||
|
time_first_token_millsec: 0
|
||||||
|
},
|
||||||
|
search: response.candidates?.[0]?.groundingMetadata
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const userMessagesStream = await chat.sendMessageStream({
|
||||||
|
message: messageContents as PartUnion,
|
||||||
|
config: {
|
||||||
|
...generateContentConfig,
|
||||||
|
httpOptions: {
|
||||||
|
signal: signalProxy as any
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
let time_first_token_millsec = 0
|
||||||
|
|
||||||
|
const processToolUses = async (content: string, idx: number) => {
|
||||||
|
const toolResults = await parseAndCallTools(
|
||||||
|
content,
|
||||||
|
toolResponses,
|
||||||
|
onChunk,
|
||||||
|
idx,
|
||||||
|
mcpToolCallResponseToGeminiMessage,
|
||||||
|
mcpTools,
|
||||||
|
isVisionModel(model)
|
||||||
|
)
|
||||||
|
if (toolResults && toolResults.length > 0) {
|
||||||
|
history.push(messageContents)
|
||||||
|
const newChat = this.sdk.chats.create({
|
||||||
|
model: model.id,
|
||||||
|
config: generateContentConfig,
|
||||||
|
history: history as Content[]
|
||||||
|
})
|
||||||
|
const newStream = await newChat.sendMessageStream({
|
||||||
|
message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion,
|
||||||
|
config: {
|
||||||
|
...generateContentConfig,
|
||||||
|
httpOptions: {
|
||||||
|
signal: signalProxy as any
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
await processStream(newStream, idx + 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const processStream = async (stream: AsyncGenerator<GenerateContentResponse>, idx: number) => {
|
||||||
|
let content = ''
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||||
|
|
||||||
|
if (time_first_token_millsec == 0) {
|
||||||
|
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
}
|
||||||
|
|
||||||
if (!streamOutput) {
|
|
||||||
const { response } = await chat.sendMessage(messageContents.parts, { signal })
|
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
|
||||||
|
if (chunk.text !== undefined) {
|
||||||
|
content += chunk.text
|
||||||
|
}
|
||||||
|
await processToolUses(content, idx)
|
||||||
|
const generateImage = this.processGeminiImageResponse(chunk)
|
||||||
|
|
||||||
onChunk({
|
onChunk({
|
||||||
text: response.candidates?.[0].content.parts[0].text,
|
text: chunk.text !== undefined ? chunk.text : '',
|
||||||
usage: {
|
usage: {
|
||||||
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||||
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
||||||
total_tokens: response.usageMetadata?.totalTokenCount || 0
|
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||||
},
|
},
|
||||||
metrics: {
|
metrics: {
|
||||||
completion_tokens: response.usageMetadata?.candidatesTokenCount,
|
completion_tokens: chunk.usageMetadata?.candidatesTokenCount,
|
||||||
time_completion_millsec,
|
time_completion_millsec,
|
||||||
time_first_token_millsec: 0
|
time_first_token_millsec
|
||||||
},
|
},
|
||||||
search: response.candidates?.[0]?.groundingMetadata
|
search: chunk.candidates?.[0]?.groundingMetadata,
|
||||||
|
mcpToolResponse: toolResponses,
|
||||||
|
generateImage: generateImage
|
||||||
})
|
})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
|
|
||||||
let time_first_token_millsec = 0
|
|
||||||
|
|
||||||
const processToolUses = async (content: string, idx: number) => {
|
|
||||||
const toolResults = await parseAndCallTools(
|
|
||||||
content,
|
|
||||||
toolResponses,
|
|
||||||
onChunk,
|
|
||||||
idx,
|
|
||||||
mcpToolCallResponseToGeminiMessage,
|
|
||||||
mcpTools,
|
|
||||||
isVisionModel(model)
|
|
||||||
)
|
|
||||||
if (toolResults && toolResults.length > 0) {
|
|
||||||
history.push(messageContents)
|
|
||||||
const newChat = geminiModel.startChat({ history })
|
|
||||||
const newStream = await newChat.sendMessageStream(flatten(toolResults.map((ts) => (ts as Content).parts)), {
|
|
||||||
signal
|
|
||||||
})
|
|
||||||
await processStream(newStream, idx + 1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const processStream = async (stream: GenerateContentStreamResult, idx: number) => {
|
|
||||||
let content = ''
|
|
||||||
for await (const chunk of stream.stream) {
|
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
|
||||||
|
|
||||||
if (time_first_token_millsec == 0) {
|
|
||||||
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
|
||||||
}
|
|
||||||
|
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
|
||||||
|
|
||||||
content += chunk.text()
|
|
||||||
processToolUses(content, idx)
|
|
||||||
|
|
||||||
onChunk({
|
|
||||||
text: chunk.text(),
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
|
||||||
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
|
|
||||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
|
||||||
},
|
|
||||||
metrics: {
|
|
||||||
completion_tokens: chunk.usageMetadata?.candidatesTokenCount,
|
|
||||||
time_completion_millsec,
|
|
||||||
time_first_token_millsec
|
|
||||||
},
|
|
||||||
search: chunk.candidates?.[0]?.groundingMetadata,
|
|
||||||
mcpToolResponse: toolResponses
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await processStream(userMessagesStream, 0).finally(cleanup)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await processStream(userMessagesStream, 0).finally(cleanup)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -372,39 +408,51 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
* @param onResponse - The onResponse callback
|
* @param onResponse - The onResponse callback
|
||||||
* @returns The translated message
|
* @returns The translated message
|
||||||
*/
|
*/
|
||||||
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||||
const defaultModel = getDefaultModel()
|
const defaultModel = getDefaultModel()
|
||||||
const { maxTokens } = getAssistantSettings(assistant)
|
const { maxTokens } = getAssistantSettings(assistant)
|
||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
|
|
||||||
const geminiModel = this.sdk.getGenerativeModel(
|
|
||||||
{
|
|
||||||
model: model.id,
|
|
||||||
...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
|
|
||||||
generationConfig: {
|
|
||||||
maxOutputTokens: maxTokens,
|
|
||||||
temperature: assistant?.settings?.temperature
|
|
||||||
}
|
|
||||||
},
|
|
||||||
this.requestOptions
|
|
||||||
)
|
|
||||||
|
|
||||||
const content =
|
const content =
|
||||||
isGemmaModel(model) && assistant.prompt
|
isGemmaModel(model) && assistant.prompt
|
||||||
? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${message.content}<end_of_turn>`
|
? `<start_of_turn>user\n${assistant.prompt}<end_of_turn>\n<start_of_turn>user\n${message.content}<end_of_turn>`
|
||||||
: message.content
|
: message.content
|
||||||
|
|
||||||
if (!onResponse) {
|
if (!onResponse) {
|
||||||
const { response } = await geminiModel.generateContent(content)
|
const response = await this.sdk.models.generateContent({
|
||||||
return response.text()
|
model: model.id,
|
||||||
|
config: {
|
||||||
|
maxOutputTokens: maxTokens,
|
||||||
|
temperature: assistant?.settings?.temperature,
|
||||||
|
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
|
||||||
|
},
|
||||||
|
contents: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: content }]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
return response.text || ''
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await geminiModel.generateContentStream(content)
|
const response = await this.sdk.models.generateContentStream({
|
||||||
|
model: model.id,
|
||||||
|
config: {
|
||||||
|
maxOutputTokens: maxTokens,
|
||||||
|
temperature: assistant?.settings?.temperature,
|
||||||
|
systemInstruction: isGemmaModel(model) ? undefined : assistant.prompt
|
||||||
|
},
|
||||||
|
contents: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: content }]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
let text = ''
|
let text = ''
|
||||||
|
|
||||||
for await (const chunk of response.stream) {
|
for await (const chunk of response) {
|
||||||
text += chunk.text()
|
text += chunk.text
|
||||||
onResponse(text)
|
onResponse(text)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -442,25 +490,24 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
content: userMessageContent
|
content: userMessageContent
|
||||||
}
|
}
|
||||||
|
|
||||||
const geminiModel = this.sdk.getGenerativeModel(
|
|
||||||
{
|
|
||||||
model: model.id,
|
|
||||||
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content }),
|
|
||||||
generationConfig: {
|
|
||||||
temperature: assistant?.settings?.temperature
|
|
||||||
}
|
|
||||||
},
|
|
||||||
this.requestOptions
|
|
||||||
)
|
|
||||||
|
|
||||||
const chat = await geminiModel.startChat()
|
|
||||||
const content = isGemmaModel(model)
|
const content = isGemmaModel(model)
|
||||||
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
|
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
|
||||||
: userMessage.content
|
: userMessage.content
|
||||||
|
|
||||||
const { response } = await chat.sendMessage(content)
|
const response = await this.sdk.models.generateContent({
|
||||||
|
model: model.id,
|
||||||
|
config: {
|
||||||
|
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content
|
||||||
|
},
|
||||||
|
contents: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: content }]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
return removeSpecialCharactersForTopicName(response.text())
|
return removeSpecialCharactersForTopicName(response.text || '')
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -471,24 +518,23 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
*/
|
*/
|
||||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||||
const model = getDefaultModel()
|
const model = getDefaultModel()
|
||||||
const systemMessage = { role: 'system', content: prompt }
|
const MessageContent = isGemmaModel(model)
|
||||||
|
|
||||||
const geminiModel = this.sdk.getGenerativeModel(
|
|
||||||
{
|
|
||||||
model: model.id,
|
|
||||||
...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content })
|
|
||||||
},
|
|
||||||
this.requestOptions
|
|
||||||
)
|
|
||||||
|
|
||||||
const chat = await geminiModel.startChat()
|
|
||||||
const messageContent = isGemmaModel(model)
|
|
||||||
? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
|
? `<start_of_turn>user\n${prompt}<end_of_turn>\n<start_of_turn>user\n${content}<end_of_turn>`
|
||||||
: content
|
: content
|
||||||
|
const response = await this.sdk.models.generateContent({
|
||||||
|
model: model.id,
|
||||||
|
config: {
|
||||||
|
systemInstruction: isGemmaModel(model) ? undefined : prompt
|
||||||
|
},
|
||||||
|
contents: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: MessageContent }]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
const { response } = await chat.sendMessage(messageContent)
|
return response.text || ''
|
||||||
|
|
||||||
return response.text()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -518,24 +564,28 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
content: messages.map((m) => m.content).join('\n')
|
content: messages.map((m) => m.content).join('\n')
|
||||||
}
|
}
|
||||||
|
|
||||||
const geminiModel = this.sdk.getGenerativeModel(
|
const content = isGemmaModel(model)
|
||||||
{
|
? `<start_of_turn>user\n${systemMessage.content}<end_of_turn>\n<start_of_turn>user\n${userMessage.content}<end_of_turn>`
|
||||||
model: model.id,
|
: userMessage.content
|
||||||
systemInstruction: systemMessage.content,
|
|
||||||
generationConfig: {
|
const response = await this.sdk.models.generateContent({
|
||||||
temperature: assistant?.settings?.temperature
|
model: model.id,
|
||||||
|
config: {
|
||||||
|
systemInstruction: isGemmaModel(model) ? undefined : systemMessage.content,
|
||||||
|
temperature: assistant?.settings?.temperature,
|
||||||
|
httpOptions: {
|
||||||
|
timeout: 20 * 1000
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
contents: [
|
||||||
...this.requestOptions,
|
{
|
||||||
timeout: 20 * 1000
|
role: 'user',
|
||||||
}
|
parts: [{ text: content }]
|
||||||
)
|
}
|
||||||
|
]
|
||||||
|
})
|
||||||
|
|
||||||
const chat = await geminiModel.startChat()
|
return response.text || ''
|
||||||
const { response } = await chat.sendMessage(userMessage.content)
|
|
||||||
|
|
||||||
return response.text()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -546,144 +596,13 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 生成图像
|
|
||||||
* @param messages - 消息列表
|
|
||||||
* @param assistant - 助手配置
|
|
||||||
* @param onChunk - 处理生成块的回调
|
|
||||||
* @param onFilterMessages - 过滤消息的回调
|
|
||||||
* @returns Promise<void>
|
|
||||||
*/
|
|
||||||
private async generateImageExp({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
|
||||||
const defaultModel = getDefaultModel()
|
|
||||||
const model = assistant.model || defaultModel
|
|
||||||
const { contextCount, streamOutput, maxTokens } = getAssistantSettings(assistant)
|
|
||||||
|
|
||||||
const userMessages = filterUserRoleStartMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
|
||||||
onFilterMessages(userMessages)
|
|
||||||
|
|
||||||
const userLastMessage = userMessages.pop()
|
|
||||||
if (!userLastMessage) {
|
|
||||||
throw new Error('No user message found')
|
|
||||||
}
|
|
||||||
|
|
||||||
const history: Content[] = []
|
|
||||||
|
|
||||||
for (const message of userMessages) {
|
|
||||||
history.push(await this.getMessageContents(message))
|
|
||||||
}
|
|
||||||
|
|
||||||
const userLastMessageContent = await this.getMessageContents(userLastMessage)
|
|
||||||
const allContents = [...history, userLastMessageContent]
|
|
||||||
|
|
||||||
let contents: ContentListUnion = allContents.length > 0 ? (allContents as ContentListUnion) : []
|
|
||||||
|
|
||||||
contents = await this.addImageFileToContents(userLastMessage, contents)
|
|
||||||
|
|
||||||
if (!streamOutput) {
|
|
||||||
const response = await this.callGeminiGenerateContent(model.id, contents, maxTokens)
|
|
||||||
|
|
||||||
const { isValid, message } = this.isValidGeminiResponse(response)
|
|
||||||
if (!isValid) {
|
|
||||||
throw new Error(`Gemini API error: ${message}`)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.processGeminiImageResponse(response, onChunk)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
const response = await this.callGeminiGenerateContentStream(model.id, contents, maxTokens)
|
|
||||||
|
|
||||||
for await (const chunk of response) {
|
|
||||||
this.processGeminiImageResponse(chunk, onChunk)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 添加图片文件到内容列表
|
|
||||||
* @param message - 用户消息
|
|
||||||
* @param contents - 内容列表
|
|
||||||
* @returns 更新后的内容列表
|
|
||||||
*/
|
|
||||||
private async addImageFileToContents(message: Message, contents: ContentListUnion): Promise<ContentListUnion> {
|
|
||||||
if (message.files && message.files.length > 0) {
|
|
||||||
const file = message.files[0]
|
|
||||||
const fileContent = await window.api.file.base64Image(file.id + file.ext)
|
|
||||||
|
|
||||||
if (fileContent && fileContent.base64) {
|
|
||||||
const contentsArray = Array.isArray(contents) ? contents : [contents]
|
|
||||||
return [...contentsArray, createPartFromBase64(fileContent.base64, fileContent.mime)]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return contents
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 调用Gemini API生成内容
|
|
||||||
* @param modelId - 模型ID
|
|
||||||
* @param contents - 内容列表
|
|
||||||
* @returns 生成结果
|
|
||||||
*/
|
|
||||||
private async callGeminiGenerateContent(
|
|
||||||
modelId: string,
|
|
||||||
contents: ContentListUnion,
|
|
||||||
maxTokens?: number
|
|
||||||
): Promise<GenerateContentResponse> {
|
|
||||||
try {
|
|
||||||
return await this.imageSdk.models.generateContent({
|
|
||||||
model: modelId,
|
|
||||||
contents: contents,
|
|
||||||
config: {
|
|
||||||
responseModalities: ['Text', 'Image'],
|
|
||||||
responseMimeType: 'text/plain',
|
|
||||||
maxOutputTokens: maxTokens
|
|
||||||
}
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Gemini API error:', error)
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private async callGeminiGenerateContentStream(
|
|
||||||
modelId: string,
|
|
||||||
contents: ContentListUnion,
|
|
||||||
maxTokens?: number
|
|
||||||
): Promise<AsyncGenerator<GenerateContentResponse>> {
|
|
||||||
try {
|
|
||||||
return await this.imageSdk.models.generateContentStream({
|
|
||||||
model: modelId,
|
|
||||||
contents: contents,
|
|
||||||
config: {
|
|
||||||
responseModalities: ['Text', 'Image'],
|
|
||||||
responseMimeType: 'text/plain',
|
|
||||||
maxOutputTokens: maxTokens
|
|
||||||
}
|
|
||||||
})
|
|
||||||
} catch (error) {
|
|
||||||
console.error('Gemini API error:', error)
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 检查Gemini响应是否有效
|
|
||||||
* @param response - Gemini响应
|
|
||||||
* @returns 是否有效
|
|
||||||
*/
|
|
||||||
private isValidGeminiResponse(response: GenerateContentResponse): { isValid: boolean; message: string } {
|
|
||||||
return {
|
|
||||||
isValid: response?.candidates?.[0]?.finishReason === FinishReason.STOP ? true : false,
|
|
||||||
message: response?.candidates?.[0]?.finishReason || ''
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 处理Gemini图像响应
|
* 处理Gemini图像响应
|
||||||
* @param response - Gemini响应
|
* @param response - Gemini响应
|
||||||
* @param onChunk - 处理生成块的回调
|
* @param onChunk - 处理生成块的回调
|
||||||
*/
|
*/
|
||||||
private processGeminiImageResponse(response: any, onChunk: (chunk: ChunkCallbackData) => void): void {
|
private processGeminiImageResponse(chunk: GenerateContentResponse): { type: 'base64'; images: string[] } | undefined {
|
||||||
const parts = response.candidates[0].content.parts
|
const parts = chunk.candidates?.[0]?.content?.parts
|
||||||
if (!parts) {
|
if (!parts) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -695,31 +614,13 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,`
|
const dataPrefix = `data:${part.inlineData.mimeType || 'image/png'};base64,`
|
||||||
return part.inlineData.data.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
|
return part.inlineData.data?.startsWith('data:') ? part.inlineData.data : dataPrefix + part.inlineData.data
|
||||||
})
|
})
|
||||||
|
|
||||||
// 提取文本数据
|
return {
|
||||||
const text = parts
|
type: 'base64',
|
||||||
.filter((part: Part) => part.text !== undefined)
|
images: images.filter((image) => image !== null)
|
||||||
.map((part: Part) => part.text)
|
}
|
||||||
.join('')
|
|
||||||
|
|
||||||
// 返回结果
|
|
||||||
onChunk({
|
|
||||||
text,
|
|
||||||
generateImage: {
|
|
||||||
type: 'base64',
|
|
||||||
images
|
|
||||||
},
|
|
||||||
usage: {
|
|
||||||
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
|
|
||||||
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
|
|
||||||
total_tokens: response.usageMetadata?.totalTokenCount || 0
|
|
||||||
},
|
|
||||||
metrics: {
|
|
||||||
completion_tokens: response.usageMetadata?.candidatesTokenCount
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -732,18 +633,16 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
return { valid: false, error: new Error('No model found') }
|
return { valid: false, error: new Error('No model found') }
|
||||||
}
|
}
|
||||||
|
|
||||||
const body = {
|
|
||||||
model: model.id,
|
|
||||||
messages: [{ role: 'user', content: 'hi' }],
|
|
||||||
max_tokens: 100,
|
|
||||||
stream: false
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const geminiModel = this.sdk.getGenerativeModel({ model: body.model }, this.requestOptions)
|
const result = await this.sdk.models.generateContent({
|
||||||
const result = await geminiModel.generateContent(body.messages[0].content)
|
model: model.id,
|
||||||
|
contents: [{ role: 'user', parts: [{ text: 'hi' }] }],
|
||||||
|
config: {
|
||||||
|
maxOutputTokens: 100
|
||||||
|
}
|
||||||
|
})
|
||||||
return {
|
return {
|
||||||
valid: !isEmpty(result.response.text()),
|
valid: !isEmpty(result.text),
|
||||||
error: null
|
error: null
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
@ -785,7 +684,10 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
* @returns The embedding dimensions
|
* @returns The embedding dimensions
|
||||||
*/
|
*/
|
||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi')
|
const data = await this.sdk.models.embedContent({
|
||||||
return data.embedding.values.length
|
model: model.id,
|
||||||
|
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||||
|
})
|
||||||
|
return data.embeddings?.[0]?.values?.length || 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import type { GroundingMetadata } from '@google/generative-ai'
|
import type { GroundingMetadata } from '@google/genai'
|
||||||
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
|
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
|
||||||
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
|
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
|
||||||
import type {
|
import type {
|
||||||
|
|||||||
@ -282,6 +282,7 @@ export async function fetchChatCompletion({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.log('message', message)
|
||||||
// Emit chat completion event
|
// Emit chat completion event
|
||||||
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
||||||
onResponse(message)
|
onResponse(message)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import { GroundingMetadata } from '@google/generative-ai'
|
import { GroundingMetadata } from '@google/genai'
|
||||||
import OpenAI from 'openai'
|
import OpenAI from 'openai'
|
||||||
import React from 'react'
|
import React from 'react'
|
||||||
import { BuiltinTheme } from 'shiki'
|
import { BuiltinTheme } from 'shiki'
|
||||||
|
|||||||
@ -71,8 +71,8 @@ export function withGeminiGrounding(message: Message) {
|
|||||||
let content = message.content
|
let content = message.content
|
||||||
|
|
||||||
groundingSupports.forEach((support) => {
|
groundingSupports.forEach((support) => {
|
||||||
const text = support?.segment
|
const text = support?.segment?.text
|
||||||
const indices = support?.groundingChunckIndices
|
const indices = support?.groundingChunkIndices
|
||||||
|
|
||||||
if (!text || !indices) return
|
if (!text || !indices) return
|
||||||
|
|
||||||
|
|||||||
@ -1,186 +1,165 @@
|
|||||||
import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
||||||
import { MessageParam } from '@anthropic-ai/sdk/resources'
|
import { MessageParam } from '@anthropic-ai/sdk/resources'
|
||||||
import {
|
import { Content, FunctionCall, Part } from '@google/genai'
|
||||||
ArraySchema,
|
|
||||||
BaseSchema,
|
|
||||||
BooleanSchema,
|
|
||||||
EnumStringSchema,
|
|
||||||
FunctionCall,
|
|
||||||
FunctionDeclaration,
|
|
||||||
FunctionDeclarationSchema,
|
|
||||||
FunctionDeclarationSchemaProperty,
|
|
||||||
IntegerSchema,
|
|
||||||
NumberSchema,
|
|
||||||
ObjectSchema,
|
|
||||||
SchemaType,
|
|
||||||
SimpleStringSchema,
|
|
||||||
Tool as geminiTool
|
|
||||||
} from '@google/generative-ai'
|
|
||||||
import { Content, Part } from '@google/generative-ai'
|
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
|
import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types'
|
||||||
import {
|
import { ChatCompletionContentPart, ChatCompletionMessageParam, ChatCompletionMessageToolCall } from 'openai/resources'
|
||||||
ChatCompletionContentPart,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
ChatCompletionTool
|
|
||||||
} from 'openai/resources'
|
|
||||||
|
|
||||||
import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider'
|
import { ChunkCallbackData, CompletionsParams } from '../providers/AiProvider'
|
||||||
|
|
||||||
const ensureValidSchema = (obj: Record<string, any>): FunctionDeclarationSchemaProperty => {
|
// const ensureValidSchema = (obj: Record<string, any>) => {
|
||||||
// Filter out unsupported keys for Gemini
|
// // Filter out unsupported keys for Gemini
|
||||||
const filteredObj = filterUnsupportedKeys(obj)
|
// const filteredObj = filterUnsupportedKeys(obj)
|
||||||
|
|
||||||
// Handle base schema properties
|
// // Handle base schema properties
|
||||||
const baseSchema = {
|
// const baseSchema = {
|
||||||
description: filteredObj.description,
|
// description: filteredObj.description,
|
||||||
nullable: filteredObj.nullable
|
// nullable: filteredObj.nullable
|
||||||
} as BaseSchema
|
// } as BaseSchema
|
||||||
|
|
||||||
// Handle string type
|
// // Handle string type
|
||||||
if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
|
// if (filteredObj.type?.toLowerCase() === SchemaType.STRING) {
|
||||||
if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
|
// if (filteredObj.enum && Array.isArray(filteredObj.enum)) {
|
||||||
return {
|
// return {
|
||||||
...baseSchema,
|
// ...baseSchema,
|
||||||
type: SchemaType.STRING,
|
// type: SchemaType.STRING,
|
||||||
format: 'enum',
|
// format: 'enum',
|
||||||
enum: filteredObj.enum as string[]
|
// enum: filteredObj.enum as string[]
|
||||||
} as EnumStringSchema
|
// } as EnumStringSchema
|
||||||
}
|
// }
|
||||||
return {
|
// return {
|
||||||
...baseSchema,
|
// ...baseSchema,
|
||||||
type: SchemaType.STRING,
|
// type: SchemaType.STRING,
|
||||||
format: filteredObj.format === 'date-time' ? 'date-time' : undefined
|
// format: filteredObj.format === 'date-time' ? 'date-time' : undefined
|
||||||
} as SimpleStringSchema
|
// } as SimpleStringSchema
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Handle number type
|
// // Handle number type
|
||||||
if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
|
// if (filteredObj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||||
return {
|
// return {
|
||||||
...baseSchema,
|
// ...baseSchema,
|
||||||
type: SchemaType.NUMBER,
|
// type: SchemaType.NUMBER,
|
||||||
format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
|
// format: ['float', 'double'].includes(filteredObj.format) ? (filteredObj.format as 'float' | 'double') : undefined
|
||||||
} as NumberSchema
|
// } as NumberSchema
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Handle integer type
|
// // Handle integer type
|
||||||
if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
|
// if (filteredObj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||||
return {
|
// return {
|
||||||
...baseSchema,
|
// ...baseSchema,
|
||||||
type: SchemaType.INTEGER,
|
// type: SchemaType.INTEGER,
|
||||||
format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
|
// format: ['int32', 'int64'].includes(filteredObj.format) ? (filteredObj.format as 'int32' | 'int64') : undefined
|
||||||
} as IntegerSchema
|
// } as IntegerSchema
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Handle boolean type
|
// // Handle boolean type
|
||||||
if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
// if (filteredObj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||||
return {
|
// return {
|
||||||
...baseSchema,
|
// ...baseSchema,
|
||||||
type: SchemaType.BOOLEAN
|
// type: SchemaType.BOOLEAN
|
||||||
} as BooleanSchema
|
// } as BooleanSchema
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Handle array type
|
// // Handle array type
|
||||||
if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
|
// if (filteredObj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||||
return {
|
// return {
|
||||||
...baseSchema,
|
// ...baseSchema,
|
||||||
type: SchemaType.ARRAY,
|
// type: SchemaType.ARRAY,
|
||||||
items: filteredObj.items
|
// items: filteredObj.items
|
||||||
? ensureValidSchema(filteredObj.items as Record<string, any>)
|
// ? ensureValidSchema(filteredObj.items as Record<string, any>)
|
||||||
: ({ type: SchemaType.STRING } as SimpleStringSchema),
|
// : ({ type: SchemaType.STRING } as SimpleStringSchema),
|
||||||
minItems: filteredObj.minItems,
|
// minItems: filteredObj.minItems,
|
||||||
maxItems: filteredObj.maxItems
|
// maxItems: filteredObj.maxItems
|
||||||
} as ArraySchema
|
// } as ArraySchema
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Handle object type (default)
|
// // Handle object type (default)
|
||||||
const properties = filteredObj.properties
|
// const properties = filteredObj.properties
|
||||||
? Object.fromEntries(
|
// ? Object.fromEntries(
|
||||||
Object.entries(filteredObj.properties).map(([key, value]) => [
|
// Object.entries(filteredObj.properties).map(([key, value]) => [
|
||||||
key,
|
// key,
|
||||||
ensureValidSchema(value as Record<string, any>)
|
// ensureValidSchema(value as Record<string, any>)
|
||||||
])
|
// ])
|
||||||
)
|
// )
|
||||||
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
|
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema } // Ensure properties is never empty
|
||||||
|
|
||||||
return {
|
// return {
|
||||||
...baseSchema,
|
// ...baseSchema,
|
||||||
type: SchemaType.OBJECT,
|
// type: SchemaType.OBJECT,
|
||||||
properties,
|
// properties,
|
||||||
required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
|
// required: Array.isArray(filteredObj.required) ? filteredObj.required : undefined
|
||||||
} as ObjectSchema
|
// } as ObjectSchema
|
||||||
}
|
// }
|
||||||
|
|
||||||
function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
|
// function filterUnsupportedKeys(obj: Record<string, any>): Record<string, any> {
|
||||||
const supportedBaseKeys = ['description', 'nullable']
|
// const supportedBaseKeys = ['description', 'nullable']
|
||||||
const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
|
// const supportedStringKeys = [...supportedBaseKeys, 'type', 'format', 'enum']
|
||||||
const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
|
// const supportedNumberKeys = [...supportedBaseKeys, 'type', 'format']
|
||||||
const supportedBooleanKeys = [...supportedBaseKeys, 'type']
|
// const supportedBooleanKeys = [...supportedBaseKeys, 'type']
|
||||||
const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
|
// const supportedArrayKeys = [...supportedBaseKeys, 'type', 'items', 'minItems', 'maxItems']
|
||||||
const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
|
// const supportedObjectKeys = [...supportedBaseKeys, 'type', 'properties', 'required']
|
||||||
|
|
||||||
const filtered: Record<string, any> = {}
|
// const filtered: Record<string, any> = {}
|
||||||
|
|
||||||
let keysToKeep: string[]
|
// let keysToKeep: string[]
|
||||||
|
|
||||||
if (obj.type?.toLowerCase() === SchemaType.STRING) {
|
// if (obj.type?.toLowerCase() === SchemaType.STRING) {
|
||||||
keysToKeep = supportedStringKeys
|
// keysToKeep = supportedStringKeys
|
||||||
} else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
|
// } else if (obj.type?.toLowerCase() === SchemaType.NUMBER) {
|
||||||
keysToKeep = supportedNumberKeys
|
// keysToKeep = supportedNumberKeys
|
||||||
} else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
|
// } else if (obj.type?.toLowerCase() === SchemaType.INTEGER) {
|
||||||
keysToKeep = supportedNumberKeys
|
// keysToKeep = supportedNumberKeys
|
||||||
} else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
// } else if (obj.type?.toLowerCase() === SchemaType.BOOLEAN) {
|
||||||
keysToKeep = supportedBooleanKeys
|
// keysToKeep = supportedBooleanKeys
|
||||||
} else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
|
// } else if (obj.type?.toLowerCase() === SchemaType.ARRAY) {
|
||||||
keysToKeep = supportedArrayKeys
|
// keysToKeep = supportedArrayKeys
|
||||||
} else {
|
// } else {
|
||||||
// Default to object type
|
// // Default to object type
|
||||||
keysToKeep = supportedObjectKeys
|
// keysToKeep = supportedObjectKeys
|
||||||
}
|
// }
|
||||||
|
|
||||||
// copy supported keys
|
// // copy supported keys
|
||||||
for (const key of keysToKeep) {
|
// for (const key of keysToKeep) {
|
||||||
if (obj[key] !== undefined) {
|
// if (obj[key] !== undefined) {
|
||||||
filtered[key] = obj[key]
|
// filtered[key] = obj[key]
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
return filtered
|
// return filtered
|
||||||
}
|
// }
|
||||||
|
|
||||||
function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
|
// function filterPropertieAttributes(tool: MCPTool, filterNestedObj: boolean = false): Record<string, object> {
|
||||||
const properties = tool.inputSchema.properties
|
// const properties = tool.inputSchema.properties
|
||||||
if (!properties) {
|
// if (!properties) {
|
||||||
return {}
|
// return {}
|
||||||
}
|
// }
|
||||||
|
|
||||||
// For OpenAI, we don't need to validate as strictly
|
// // For OpenAI, we don't need to validate as strictly
|
||||||
if (!filterNestedObj) {
|
// if (!filterNestedObj) {
|
||||||
return properties
|
// return properties
|
||||||
}
|
// }
|
||||||
|
|
||||||
const processedProperties = Object.fromEntries(
|
// const processedProperties = Object.fromEntries(
|
||||||
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||||
)
|
// )
|
||||||
|
|
||||||
return processedProperties
|
// return processedProperties
|
||||||
}
|
// }
|
||||||
|
|
||||||
export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
// export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array<ChatCompletionTool> {
|
||||||
return mcpTools.map((tool) => ({
|
// return mcpTools.map((tool) => ({
|
||||||
type: 'function',
|
// type: 'function',
|
||||||
name: tool.name,
|
// name: tool.name,
|
||||||
function: {
|
// function: {
|
||||||
name: tool.id,
|
// name: tool.id,
|
||||||
description: tool.description,
|
// description: tool.description,
|
||||||
parameters: {
|
// parameters: {
|
||||||
type: 'object',
|
// type: 'object',
|
||||||
properties: filterPropertieAttributes(tool)
|
// properties: filterPropertieAttributes(tool)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}))
|
// }))
|
||||||
}
|
// }
|
||||||
|
|
||||||
export function openAIToolsToMcpTool(
|
export function openAIToolsToMcpTool(
|
||||||
mcpTools: MCPTool[] | undefined,
|
mcpTools: MCPTool[] | undefined,
|
||||||
@ -277,35 +256,35 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU
|
|||||||
return tool
|
return tool
|
||||||
}
|
}
|
||||||
|
|
||||||
export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
|
// export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] {
|
||||||
if (!mcpTools || mcpTools.length === 0) {
|
// if (!mcpTools || mcpTools.length === 0) {
|
||||||
// No tools available
|
// // No tools available
|
||||||
return []
|
// return []
|
||||||
}
|
// }
|
||||||
const functions: FunctionDeclaration[] = []
|
// const functions: FunctionDeclaration[] = []
|
||||||
|
|
||||||
for (const tool of mcpTools) {
|
// for (const tool of mcpTools) {
|
||||||
const properties = filterPropertieAttributes(tool, true)
|
// const properties = filterPropertieAttributes(tool, true)
|
||||||
const functionDeclaration: FunctionDeclaration = {
|
// const functionDeclaration: FunctionDeclaration = {
|
||||||
name: tool.id,
|
// name: tool.id,
|
||||||
description: tool.description,
|
// description: tool.description,
|
||||||
parameters: {
|
// parameters: {
|
||||||
type: SchemaType.OBJECT,
|
// type: SchemaType.OBJECT,
|
||||||
properties:
|
// properties:
|
||||||
Object.keys(properties).length > 0
|
// Object.keys(properties).length > 0
|
||||||
? Object.fromEntries(
|
// ? Object.fromEntries(
|
||||||
Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record<string, any>)])
|
||||||
)
|
// )
|
||||||
: { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
|
// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema }
|
||||||
} as FunctionDeclarationSchema
|
// } as FunctionDeclarationSchema
|
||||||
}
|
// }
|
||||||
functions.push(functionDeclaration)
|
// functions.push(functionDeclaration)
|
||||||
}
|
// }
|
||||||
const tool: geminiTool = {
|
// const tool: geminiTool = {
|
||||||
functionDeclarations: functions
|
// functionDeclarations: functions
|
||||||
}
|
// }
|
||||||
return [tool]
|
// return [tool]
|
||||||
}
|
// }
|
||||||
|
|
||||||
export function geminiFunctionCallToMcpTool(
|
export function geminiFunctionCallToMcpTool(
|
||||||
mcpTools: MCPTool[] | undefined,
|
mcpTools: MCPTool[] | undefined,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user