From 442ef89ce0fcf85a3e5ef31d5be9e2e5291e1572 Mon Sep 17 00:00:00 2001 From: ousugo Date: Thu, 13 Mar 2025 02:11:28 +0800 Subject: [PATCH] feat(GeminiProvider): Add isGemmaModel function and update model handling Introduce isGemmaModel function to identify Gemma models and adjust system instruction handling in GeminiProvider based on model type. Ensure proper message formatting for Gemma models during chat initialization. --- src/renderer/src/config/models.ts | 8 +++++++ src/renderer/src/providers/GeminiProvider.ts | 25 ++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index e6b5aa6a..e7605e67 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -1992,3 +1992,11 @@ export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Re return {} } + +export function isGemmaModel(model?: Model): boolean { + if (!model) { + return false + } + + return model.id.includes('gemma-') || model.group === 'Gemma' +} diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 96dcf8b3..c6a368bf 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -13,7 +13,7 @@ import { SafetySetting, TextPart } from '@google/generative-ai' -import { isWebSearchModel } from '@renderer/config/models' +import { isGemmaModel, isWebSearchModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' @@ -205,7 +205,7 @@ export default class GeminiProvider extends BaseProvider { const geminiModel = this.sdk.getGenerativeModel( { model: model.id, - systemInstruction: assistant.prompt, + ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }), safetySettings: this.getSafetySettings(model.id), tools: tools, generationConfig: { @@ -221,6 +221,27 @@ export default class GeminiProvider extends BaseProvider { const chat = geminiModel.startChat({ history }) const messageContents = await this.getMessageContents(userLastMessage!) + if (isGemmaModel(model) && assistant.prompt) { + const isFirstMessage = history.length === 0 + if (isFirstMessage) { + const systemMessage = { + role: 'user', + parts: [ + { + text: + 'user\n' + + assistant.prompt + + '\n' + + 'user\n' + + messageContents.parts[0].text + + '' + } + ] + } + messageContents.parts = systemMessage.parts + } + } + const start_time_millsec = new Date().getTime() const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) const { signal } = abortController