From 38b52a2ee62f5bc3fc3ffd9b7b67ec8fe00f354f Mon Sep 17 00:00:00 2001 From: ousugo Date: Thu, 13 Mar 2025 02:22:32 +0800 Subject: [PATCH] refactor(GeminiProvider): Enhance message handling for Gemma models --- src/renderer/src/providers/GeminiProvider.ts | 34 +++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index c6a368bf..f0740ff0 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -366,7 +366,7 @@ export default class GeminiProvider extends BaseProvider { const geminiModel = this.sdk.getGenerativeModel( { model: model.id, - systemInstruction: assistant.prompt, + ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }), generationConfig: { maxOutputTokens: maxTokens, temperature: assistant?.settings?.temperature @@ -375,12 +375,17 @@ export default class GeminiProvider extends BaseProvider { this.requestOptions ) + const content = + isGemmaModel(model) && assistant.prompt + ? `user\n${assistant.prompt}\nuser\n${message.content}` + : message.content + if (!onResponse) { - const { response } = await geminiModel.generateContent(message.content) + const { response } = await geminiModel.generateContent(content) return response.text() } - const response = await geminiModel.generateContentStream(message.content) + const response = await geminiModel.generateContentStream(content) let text = '' @@ -426,7 +431,7 @@ export default class GeminiProvider extends BaseProvider { const geminiModel = this.sdk.getGenerativeModel( { model: model.id, - systemInstruction: systemMessage.content, + ...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content }), generationConfig: { temperature: assistant?.settings?.temperature } @@ -435,8 +440,11 @@ export default class GeminiProvider extends BaseProvider { ) const chat = await geminiModel.startChat() + const content = isGemmaModel(model) + ? `user\n${systemMessage.content}\nuser\n${userMessage.content}` + : userMessage.content - const { response } = await chat.sendMessage(userMessage.content) + const { response } = await chat.sendMessage(content) return removeSpecialCharactersForTopicName(response.text()) } @@ -451,10 +459,20 @@ export default class GeminiProvider extends BaseProvider { const model = getDefaultModel() const systemMessage = { role: 'system', content: prompt } - const geminiModel = this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions) + const geminiModel = this.sdk.getGenerativeModel( + { + model: model.id, + ...(isGemmaModel(model) ? {} : { systemInstruction: systemMessage.content }) + }, + this.requestOptions + ) - const chat = await geminiModel.startChat({ systemInstruction: systemMessage.content }) - const { response } = await chat.sendMessage(content) + const chat = await geminiModel.startChat() + const messageContent = isGemmaModel(model) + ? `user\n${prompt}\nuser\n${content}` + : content + + const { response } = await chat.sendMessage(messageContent) return response.text() }