feat(GeminiProvider): Add isGemmaModel function and update model handling

Introduce isGemmaModel function to identify Gemma models and adjust system instruction handling in GeminiProvider based on model type. Ensure proper message formatting for Gemma models during chat initialization.
This commit is contained in:
ousugo 2025-03-13 02:11:28 +08:00 committed by Asurada
parent 762c901074
commit 442ef89ce0
2 changed files with 31 additions and 2 deletions

View File

@ -1992,3 +1992,11 @@ export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Re
return {} return {}
} }
export function isGemmaModel(model?: Model): boolean {
if (!model) {
return false
}
return model.id.includes('gemma-') || model.group === 'Gemma'
}

View File

@ -13,7 +13,7 @@ import {
SafetySetting, SafetySetting,
TextPart TextPart
} from '@google/generative-ai' } from '@google/generative-ai'
import { isWebSearchModel } from '@renderer/config/models' import { isGemmaModel, isWebSearchModel } from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings' import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n' import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
@ -205,7 +205,7 @@ export default class GeminiProvider extends BaseProvider {
const geminiModel = this.sdk.getGenerativeModel( const geminiModel = this.sdk.getGenerativeModel(
{ {
model: model.id, model: model.id,
systemInstruction: assistant.prompt, ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }),
safetySettings: this.getSafetySettings(model.id), safetySettings: this.getSafetySettings(model.id),
tools: tools, tools: tools,
generationConfig: { generationConfig: {
@ -221,6 +221,27 @@ export default class GeminiProvider extends BaseProvider {
const chat = geminiModel.startChat({ history }) const chat = geminiModel.startChat({ history })
const messageContents = await this.getMessageContents(userLastMessage!) const messageContents = await this.getMessageContents(userLastMessage!)
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage) {
const systemMessage = {
role: 'user',
parts: [
{
text:
'<start_of_turn>user\n' +
assistant.prompt +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
messageContents.parts[0].text +
'<end_of_turn>'
}
]
}
messageContents.parts = systemMessage.parts
}
}
const start_time_millsec = new Date().getTime() const start_time_millsec = new Date().getTime()
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
const { signal } = abortController const { signal } = abortController