feat: gemini reasoning budget support (#5052)

* feat(models): add Gemini 2.5 reasoning model identification and integrate reasoning effort logic in GeminiProvider

* feat(AiProvider): enhance usage tracking by adding thoughts_tokens and updating usage types
This commit is contained in:
SuYao 2025-04-19 01:27:20 +08:00 committed by GitHub
parent 0a28df132d
commit 3360905275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 77 additions and 13 deletions

View File

@ -2224,7 +2224,8 @@ export function isSupportedReasoningEffortModel(model?: Model): boolean {
model.id.includes('claude-3-7-sonnet') ||
model.id.includes('claude-3.7-sonnet') ||
isOpenAIoSeries(model) ||
isGrokReasoningModel(model)
isGrokReasoningModel(model) ||
isGemini25ReasoningModel(model)
) {
return true
}
@ -2251,6 +2252,18 @@ export function isGrokReasoningModel(model?: Model): boolean {
return false
}
export function isGemini25ReasoningModel(model?: Model): boolean {
if (!model) {
return false
}
if (model.id.includes('gemini-2.5')) {
return true
}
return false
}
export function isReasoningModel(model?: Model): boolean {
if (!model) {
return false
@ -2264,7 +2277,7 @@ export function isReasoningModel(model?: Model): boolean {
return true
}
if (model.id.includes('gemini-2.5')) {
if (isGemini25ReasoningModel(model)) {
return true
}

View File

@ -10,9 +10,16 @@ import {
Part,
PartUnion,
SafetySetting,
ThinkingConfig,
ToolListUnion
} from '@google/genai'
import { isGemmaModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models'
import {
isGemini25ReasoningModel,
isGemmaModel,
isGenerateImageModel,
isVisionModel,
isWebSearchModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
@ -35,6 +42,8 @@ import OpenAI from 'openai'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
type ReasoningEffort = 'low' | 'medium' | 'high'
export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenAI
@ -182,6 +191,41 @@ export default class GeminiProvider extends BaseProvider {
]
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getReasoningEffort(assistant: Assistant, model: Model) {
if (isGemini25ReasoningModel(model)) {
const effortRatios: Record<ReasoningEffort, number> = {
high: 1,
medium: 0.5,
low: 0.2
}
const effort = assistant?.settings?.reasoning_effort as ReasoningEffort
const effortRatio = effortRatios[effort]
const maxBudgetToken = 24576 // https://ai.google.dev/gemini-api/docs/thinking
const budgetTokens = Math.max(1024, Math.trunc(maxBudgetToken * effortRatio))
if (!effortRatio) {
return {
thinkingConfig: {
thinkingBudget: 0
} as ThinkingConfig
}
}
return {
thinkingConfig: {
thinkingBudget: budgetTokens,
includeThoughts: true
} as ThinkingConfig
}
}
return {}
}
/**
* Generate completions
* @param messages - The messages
@ -241,6 +285,7 @@ export default class GeminiProvider extends BaseProvider {
topP: assistant?.settings?.topP,
maxOutputTokens: maxTokens,
tools: tools,
...this.getReasoningEffort(assistant, model),
...this.getCustomParameters(assistant)
}
@ -308,6 +353,7 @@ export default class GeminiProvider extends BaseProvider {
text: response.text,
usage: {
prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
thoughts_tokens: response.usageMetadata?.thoughtsTokenCount || 0,
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
total_tokens: response.usageMetadata?.totalTokenCount || 0
},
@ -384,6 +430,7 @@ export default class GeminiProvider extends BaseProvider {
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
thoughts_tokens: chunk.usageMetadata?.thoughtsTokenCount || 0,
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
},
metrics: {

View File

@ -46,7 +46,7 @@ import {
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
type ReasoningEffort = 'high' | 'medium' | 'low'
type ReasoningEffort = 'low' | 'medium' | 'high'
export default class OpenAIProvider extends BaseProvider {
private sdk: OpenAI

View File

@ -11,14 +11,15 @@ import type {
Metrics,
Model,
Provider,
Suggestion
Suggestion,
Usage
} from '@renderer/types'
import OpenAI from 'openai'
export interface ChunkCallbackData {
text?: string
reasoning_content?: string
usage?: OpenAI.Completions.CompletionUsage
usage?: Usage
metrics?: Metrics
// Zhipu web search
webSearch?: any[]

View File

@ -1,6 +1,5 @@
import { Assistant, FileType, FileTypes, Message } from '@renderer/types'
import { Assistant, FileType, FileTypes, Message, Usage } from '@renderer/types'
import { flatten, takeRight } from 'lodash'
import { CompletionUsage } from 'openai/resources'
import { approximateTokenSize } from 'tokenx'
import { getAssistantSettings } from './AssistantService'
@ -52,7 +51,7 @@ export function estimateImageTokens(file: FileType) {
return Math.floor(file.size / 100)
}
export async function estimateMessageUsage(message: Message): Promise<CompletionUsage> {
export async function estimateMessageUsage(message: Message): Promise<Usage> {
let imageTokens = 0
if (message.files) {
@ -80,17 +79,17 @@ export async function estimateMessagesUsage({
}: {
assistant: Assistant
messages: Message[]
}): Promise<CompletionUsage> {
}): Promise<Usage> {
const outputMessage = messages.pop()!
const prompt_tokens = await estimateHistoryTokens(assistant, messages)
const { completion_tokens } = await estimateMessageUsage(outputMessage)
return {
prompt_tokens: await estimateHistoryTokens(assistant, messages),
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens
} as CompletionUsage
} as Usage
}
export async function estimateHistoryTokens(assistant: Assistant, msgs: Message[]) {

View File

@ -63,7 +63,7 @@ export type Message = {
model?: Model
files?: FileType[]
images?: string[]
usage?: OpenAI.Completions.CompletionUsage
usage?: Usage
metrics?: Metrics
knowledgeBaseIds?: string[]
type: 'text' | '@' | 'clear'
@ -97,6 +97,10 @@ export type Message = {
foldSelected?: boolean
}
export type Usage = OpenAI.Completions.CompletionUsage & {
thoughts_tokens?: number
}
export type Metrics = {
completion_tokens?: number
time_completion_millsec?: number