feat: gemini reasoning budget support (#5052)

* feat(models): add Gemini 2.5 reasoning model identification and integrate reasoning effort logic in GeminiProvider

* feat(AiProvider): enhance usage tracking by adding thoughts_tokens and updating usage types
This commit is contained in:
SuYao 2025-04-19 01:27:20 +08:00 committed by GitHub
parent 0a28df132d
commit 3360905275
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 77 additions and 13 deletions

View File

@ -2224,7 +2224,8 @@ export function isSupportedReasoningEffortModel(model?: Model): boolean {
model.id.includes('claude-3-7-sonnet') || model.id.includes('claude-3-7-sonnet') ||
model.id.includes('claude-3.7-sonnet') || model.id.includes('claude-3.7-sonnet') ||
isOpenAIoSeries(model) || isOpenAIoSeries(model) ||
isGrokReasoningModel(model) isGrokReasoningModel(model) ||
isGemini25ReasoningModel(model)
) { ) {
return true return true
} }
@ -2251,6 +2252,18 @@ export function isGrokReasoningModel(model?: Model): boolean {
return false return false
} }
export function isGemini25ReasoningModel(model?: Model): boolean {
if (!model) {
return false
}
if (model.id.includes('gemini-2.5')) {
return true
}
return false
}
export function isReasoningModel(model?: Model): boolean { export function isReasoningModel(model?: Model): boolean {
if (!model) { if (!model) {
return false return false
@ -2264,7 +2277,7 @@ export function isReasoningModel(model?: Model): boolean {
return true return true
} }
if (model.id.includes('gemini-2.5')) { if (isGemini25ReasoningModel(model)) {
return true return true
} }

View File

@ -10,9 +10,16 @@ import {
Part, Part,
PartUnion, PartUnion,
SafetySetting, SafetySetting,
ThinkingConfig,
ToolListUnion ToolListUnion
} from '@google/genai' } from '@google/genai'
import { isGemmaModel, isGenerateImageModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' import {
isGemini25ReasoningModel,
isGemmaModel,
isGenerateImageModel,
isVisionModel,
isWebSearchModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings' import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n' import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
@ -35,6 +42,8 @@ import OpenAI from 'openai'
import { CompletionsParams } from '.' import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider' import BaseProvider from './BaseProvider'
type ReasoningEffort = 'low' | 'medium' | 'high'
export default class GeminiProvider extends BaseProvider { export default class GeminiProvider extends BaseProvider {
private sdk: GoogleGenAI private sdk: GoogleGenAI
@ -182,6 +191,41 @@ export default class GeminiProvider extends BaseProvider {
] ]
} }
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getReasoningEffort(assistant: Assistant, model: Model) {
if (isGemini25ReasoningModel(model)) {
const effortRatios: Record<ReasoningEffort, number> = {
high: 1,
medium: 0.5,
low: 0.2
}
const effort = assistant?.settings?.reasoning_effort as ReasoningEffort
const effortRatio = effortRatios[effort]
const maxBudgetToken = 24576 // https://ai.google.dev/gemini-api/docs/thinking
const budgetTokens = Math.max(1024, Math.trunc(maxBudgetToken * effortRatio))
if (!effortRatio) {
return {
thinkingConfig: {
thinkingBudget: 0
} as ThinkingConfig
}
}
return {
thinkingConfig: {
thinkingBudget: budgetTokens,
includeThoughts: true
} as ThinkingConfig
}
}
return {}
}
/** /**
* Generate completions * Generate completions
* @param messages - The messages * @param messages - The messages
@ -241,6 +285,7 @@ export default class GeminiProvider extends BaseProvider {
topP: assistant?.settings?.topP, topP: assistant?.settings?.topP,
maxOutputTokens: maxTokens, maxOutputTokens: maxTokens,
tools: tools, tools: tools,
...this.getReasoningEffort(assistant, model),
...this.getCustomParameters(assistant) ...this.getCustomParameters(assistant)
} }
@ -308,6 +353,7 @@ export default class GeminiProvider extends BaseProvider {
text: response.text, text: response.text,
usage: { usage: {
prompt_tokens: response.usageMetadata?.promptTokenCount || 0, prompt_tokens: response.usageMetadata?.promptTokenCount || 0,
thoughts_tokens: response.usageMetadata?.thoughtsTokenCount || 0,
completion_tokens: response.usageMetadata?.candidatesTokenCount || 0, completion_tokens: response.usageMetadata?.candidatesTokenCount || 0,
total_tokens: response.usageMetadata?.totalTokenCount || 0 total_tokens: response.usageMetadata?.totalTokenCount || 0
}, },
@ -384,6 +430,7 @@ export default class GeminiProvider extends BaseProvider {
usage: { usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0, prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0, completion_tokens: chunk.usageMetadata?.candidatesTokenCount || 0,
thoughts_tokens: chunk.usageMetadata?.thoughtsTokenCount || 0,
total_tokens: chunk.usageMetadata?.totalTokenCount || 0 total_tokens: chunk.usageMetadata?.totalTokenCount || 0
}, },
metrics: { metrics: {

View File

@ -46,7 +46,7 @@ import {
import { CompletionsParams } from '.' import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider' import BaseProvider from './BaseProvider'
type ReasoningEffort = 'high' | 'medium' | 'low' type ReasoningEffort = 'low' | 'medium' | 'high'
export default class OpenAIProvider extends BaseProvider { export default class OpenAIProvider extends BaseProvider {
private sdk: OpenAI private sdk: OpenAI

View File

@ -11,14 +11,15 @@ import type {
Metrics, Metrics,
Model, Model,
Provider, Provider,
Suggestion Suggestion,
Usage
} from '@renderer/types' } from '@renderer/types'
import OpenAI from 'openai' import OpenAI from 'openai'
export interface ChunkCallbackData { export interface ChunkCallbackData {
text?: string text?: string
reasoning_content?: string reasoning_content?: string
usage?: OpenAI.Completions.CompletionUsage usage?: Usage
metrics?: Metrics metrics?: Metrics
// Zhipu web search // Zhipu web search
webSearch?: any[] webSearch?: any[]

View File

@ -1,6 +1,5 @@
import { Assistant, FileType, FileTypes, Message } from '@renderer/types' import { Assistant, FileType, FileTypes, Message, Usage } from '@renderer/types'
import { flatten, takeRight } from 'lodash' import { flatten, takeRight } from 'lodash'
import { CompletionUsage } from 'openai/resources'
import { approximateTokenSize } from 'tokenx' import { approximateTokenSize } from 'tokenx'
import { getAssistantSettings } from './AssistantService' import { getAssistantSettings } from './AssistantService'
@ -52,7 +51,7 @@ export function estimateImageTokens(file: FileType) {
return Math.floor(file.size / 100) return Math.floor(file.size / 100)
} }
export async function estimateMessageUsage(message: Message): Promise<CompletionUsage> { export async function estimateMessageUsage(message: Message): Promise<Usage> {
let imageTokens = 0 let imageTokens = 0
if (message.files) { if (message.files) {
@ -80,17 +79,17 @@ export async function estimateMessagesUsage({
}: { }: {
assistant: Assistant assistant: Assistant
messages: Message[] messages: Message[]
}): Promise<CompletionUsage> { }): Promise<Usage> {
const outputMessage = messages.pop()! const outputMessage = messages.pop()!
const prompt_tokens = await estimateHistoryTokens(assistant, messages) const prompt_tokens = await estimateHistoryTokens(assistant, messages)
const { completion_tokens } = await estimateMessageUsage(outputMessage) const { completion_tokens } = await estimateMessageUsage(outputMessage)
return { return {
prompt_tokens: await estimateHistoryTokens(assistant, messages), prompt_tokens,
completion_tokens, completion_tokens,
total_tokens: prompt_tokens + completion_tokens total_tokens: prompt_tokens + completion_tokens
} as CompletionUsage } as Usage
} }
export async function estimateHistoryTokens(assistant: Assistant, msgs: Message[]) { export async function estimateHistoryTokens(assistant: Assistant, msgs: Message[]) {

View File

@ -63,7 +63,7 @@ export type Message = {
model?: Model model?: Model
files?: FileType[] files?: FileType[]
images?: string[] images?: string[]
usage?: OpenAI.Completions.CompletionUsage usage?: Usage
metrics?: Metrics metrics?: Metrics
knowledgeBaseIds?: string[] knowledgeBaseIds?: string[]
type: 'text' | '@' | 'clear' type: 'text' | '@' | 'clear'
@ -97,6 +97,10 @@ export type Message = {
foldSelected?: boolean foldSelected?: boolean
} }
export type Usage = OpenAI.Completions.CompletionUsage & {
thoughts_tokens?: number
}
export type Metrics = { export type Metrics = {
completion_tokens?: number completion_tokens?: number
time_completion_millsec?: number time_completion_millsec?: number