feat: add configurable request options to gemini provider

This commit is contained in:
kangfenmao 2024-11-26 13:15:25 +08:00
parent 09e86b35a5
commit 8de1197557

View File

@ -5,6 +5,7 @@ import {
HarmCategory,
InlineDataPart,
Part,
RequestOptions,
TextPart
} from '@google/generative-ai'
import { SUMMARIZE_PROMPT } from '@renderer/config/prompts'
@ -20,10 +21,14 @@ import BaseProvider from './BaseProvider'
export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenerativeAI
private requestOptions: RequestOptions
constructor(provider: Provider) {
super(provider)
this.sdk = new GoogleGenerativeAI(this.apiKey)
this.requestOptions = {
baseUrl: this.provider.apiHost
}
}
private async getMessageContents(message: Message): Promise<Content> {
@ -75,23 +80,26 @@ export default class GeminiProvider extends BaseProvider {
history.push(await this.getMessageContents(message))
}
const geminiModel = this.sdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
},
safetySettings: [
{ category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold: HarmBlockThreshold.BLOCK_NONE },
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: HarmBlockThreshold.BLOCK_NONE
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
},
{ category: HarmCategory.HARM_CATEGORY_HARASSMENT, threshold: HarmBlockThreshold.BLOCK_NONE },
{ category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_NONE }
]
})
safetySettings: [
{ category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold: HarmBlockThreshold.BLOCK_NONE },
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: HarmBlockThreshold.BLOCK_NONE
},
{ category: HarmCategory.HARM_CATEGORY_HARASSMENT, threshold: HarmBlockThreshold.BLOCK_NONE },
{ category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold: HarmBlockThreshold.BLOCK_NONE }
]
},
this.requestOptions
)
const chat = geminiModel.startChat({ history })
const messageContents = await this.getMessageContents(userLastMessage!)
@ -129,14 +137,17 @@ export default class GeminiProvider extends BaseProvider {
const { maxTokens } = getAssistantSettings(assistant)
const model = assistant.model || defaultModel
const geminiModel = this.sdk.getGenerativeModel({
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
})
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
systemInstruction: assistant.prompt,
generationConfig: {
maxOutputTokens: maxTokens,
temperature: assistant?.settings?.temperature
}
},
this.requestOptions
)
const { response } = await geminiModel.generateContent(message.content)
@ -168,13 +179,16 @@ export default class GeminiProvider extends BaseProvider {
content: userMessageContent
}
const geminiModel = this.sdk.getGenerativeModel({
model: model.id,
systemInstruction: systemMessage.content,
generationConfig: {
temperature: assistant?.settings?.temperature
}
})
const geminiModel = this.sdk.getGenerativeModel(
{
model: model.id,
systemInstruction: systemMessage.content,
generationConfig: {
temperature: assistant?.settings?.temperature
}
},
this.requestOptions
)
const chat = await geminiModel.startChat()
@ -187,7 +201,7 @@ export default class GeminiProvider extends BaseProvider {
const model = getDefaultModel()
const systemMessage = { role: 'system', content: prompt }
const geminiModel = this.sdk.getGenerativeModel({ model: model.id })
const geminiModel = this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions)
const chat = await geminiModel.startChat({ systemInstruction: systemMessage.content })
const { response } = await chat.sendMessage(content)
@ -214,7 +228,7 @@ export default class GeminiProvider extends BaseProvider {
}
try {
const geminiModel = this.sdk.getGenerativeModel({ model: body.model })
const geminiModel = this.sdk.getGenerativeModel({ model: body.model }, this.requestOptions)
const result = await geminiModel.generateContent(body.messages[0].content)
return {
valid: !isEmpty(result.response.text()),