diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index e6b5aa6a..e7605e67 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -1992,3 +1992,11 @@ export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Re return {} } + +export function isGemmaModel(model?: Model): boolean { + if (!model) { + return false + } + + return model.id.includes('gemma-') || model.group === 'Gemma' +} diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 96dcf8b3..c6a368bf 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -13,7 +13,7 @@ import { SafetySetting, TextPart } from '@google/generative-ai' -import { isWebSearchModel } from '@renderer/config/models' +import { isGemmaModel, isWebSearchModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' @@ -205,7 +205,7 @@ export default class GeminiProvider extends BaseProvider { const geminiModel = this.sdk.getGenerativeModel( { model: model.id, - systemInstruction: assistant.prompt, + ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }), safetySettings: this.getSafetySettings(model.id), tools: tools, generationConfig: { @@ -221,6 +221,27 @@ export default class GeminiProvider extends BaseProvider { const chat = geminiModel.startChat({ history }) const messageContents = await this.getMessageContents(userLastMessage!) + if (isGemmaModel(model) && assistant.prompt) { + const isFirstMessage = history.length === 0 + if (isFirstMessage) { + const systemMessage = { + role: 'user', + parts: [ + { + text: + 'user\n' + + assistant.prompt + + '\n' + + 'user\n' + + messageContents.parts[0].text + + '' + } + ] + } + messageContents.parts = systemMessage.parts + } + } + const start_time_millsec = new Date().getTime() const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) const { signal } = abortController