refactor: Update completions method signatures and enhance documentation

- Reordered parameters in completions methods across AiProvider, AnthropicProvider, GeminiProvider, and OpenAIProvider to improve consistency.
- Added detailed JSDoc comments for methods to clarify parameter usage and functionality.
- Ensured mcpTools parameter is consistently included in completions method signatures.
This commit is contained in:
kangfenmao 2025-03-14 10:44:03 +08:00
parent 18b7618a8d
commit a39ff78758
4 changed files with 154 additions and 14 deletions

View File

@ -19,11 +19,11 @@ export default class AiProvider {
public async completions({
messages,
assistant,
mcpTools,
onChunk,
onFilterMessages,
mcpTools
onFilterMessages
}: CompletionsParams): Promise<void> {
return this.sdk.completions({ messages, assistant, onChunk, onFilterMessages, mcpTools })
return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
}
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void): Promise<string> {

View File

@ -54,6 +54,11 @@ export default class AnthropicProvider extends BaseProvider {
return this.provider.apiHost
}
/**
* Get the message parameter
* @param message - The message
* @returns The message parameter
*/
private async getMessageParam(message: Message): Promise<MessageParam> {
const parts: MessageParam['content'] = [
{
@ -74,6 +79,7 @@ export default class AnthropicProvider extends BaseProvider {
}
})
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
@ -89,18 +95,32 @@ export default class AnthropicProvider extends BaseProvider {
}
}
/**
* Get the temperature
* @param assistant - The assistant
* @param model - The model
* @returns The temperature
*/
private getTemperature(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined
return assistant?.settings?.temperature
return isReasoningModel(model) ? undefined : assistant?.settings?.temperature
}
/**
* Get the top P
* @param assistant - The assistant
* @param model - The model
* @returns The top P
*/
private getTopP(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined
return assistant?.settings?.topP
return isReasoningModel(model) ? undefined : assistant?.settings?.topP
}
/**
* Get the reasoning effort
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getReasoningEffort(assistant: Assistant, model: Model): ReasoningConfig | undefined {
if (!isReasoningModel(model)) {
return undefined
@ -134,7 +154,15 @@ export default class AnthropicProvider extends BaseProvider {
}
}
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
/**
* Generate completions
* @param messages - The messages
* @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
*/
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -192,6 +220,7 @@ export default class AnthropicProvider extends BaseProvider {
text = textBlock.text
}
}
return onChunk({
text,
reasoning_content,
@ -271,6 +300,7 @@ export default class AnthropicProvider extends BaseProvider {
.on('finalMessage', async (message) => {
if (toolCalls.length > 0) {
const toolCallResults: ToolResultBlockParam[] = []
for (const toolCall of toolCalls) {
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
if (mcpTool) {
@ -338,6 +368,13 @@ export default class AnthropicProvider extends BaseProvider {
.finally(cleanup)
}
/**
* Translate a message
* @param message - The message
* @param assistant - The assistant
* @param onResponse - The onResponse callback
* @returns The translated message
*/
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
@ -375,6 +412,12 @@ export default class AnthropicProvider extends BaseProvider {
})
}
/**
* Summarize a message
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
@ -417,6 +460,12 @@ export default class AnthropicProvider extends BaseProvider {
return removeSpecialCharactersForTopicName(content)
}
/**
* Generate text
* @param prompt - The prompt
* @param content - The content
* @returns The generated text
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
@ -436,14 +485,27 @@ export default class AnthropicProvider extends BaseProvider {
return message.content[0].type === 'text' ? message.content[0].text : ''
}
/**
* Generate an image
* @returns The generated image
*/
public async generateImage(): Promise<string[]> {
return []
}
/**
* Generate suggestions
* @returns The suggestions
*/
public async suggestions(): Promise<Suggestion[]> {
return []
}
/**
* Check if the model is valid
* @param model - The model
* @returns The validity of the model
*/
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
if (!model) {
return { valid: false, error: new Error('No model found') }
@ -470,6 +532,10 @@ export default class AnthropicProvider extends BaseProvider {
}
}
/**
* Get the models
* @returns The models
*/
public async models(): Promise<OpenAI.Models.Model[]> {
return []
}

View File

@ -51,6 +51,11 @@ export default class GeminiProvider extends BaseProvider {
return this.provider.apiHost
}
/**
* Handle a PDF file
* @param file - The file
* @returns The part
*/
private async handlePdfFile(file: FileType): Promise<Part> {
const smallFileSize = 20 * 1024 * 1024
const isSmallFile = file.size < smallFileSize
@ -88,6 +93,11 @@ export default class GeminiProvider extends BaseProvider {
} as FileDataPart
}
/**
* Get the message contents
* @param message - The message
* @returns The message contents
*/
private async getMessageContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
@ -123,6 +133,11 @@ export default class GeminiProvider extends BaseProvider {
}
}
/**
* Get the safety settings
* @param modelId - The model ID
* @returns The safety settings
*/
private getSafetySettings(modelId: string): SafetySetting[] {
const safetyThreshold = modelId.includes('gemini-2.0-flash-exp')
? ('OFF' as HarmBlockThreshold)
@ -152,7 +167,15 @@ export default class GeminiProvider extends BaseProvider {
]
}
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
/**
* Generate completions
* @param messages - The messages
* @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
*/
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -167,9 +190,11 @@ export default class GeminiProvider extends BaseProvider {
for (const message of userMessages) {
history.push(await this.getMessageContents(message))
}
mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs)
const tools = mcpToolsToGeminiTools(mcpTools)
const toolResponses: MCPToolResponse[] = []
if (assistant.enableWebSearch && isWebSearchModel(model)) {
tools.push({
// @ts-ignore googleSearch is not a valid tool for Gemini
@ -199,6 +224,7 @@ export default class GeminiProvider extends BaseProvider {
const start_time_millsec = new Date().getTime()
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
const { signal } = abortController
if (!streamOutput) {
const { response } = await chat.sendMessage(messageContents.parts, { signal })
const time_completion_millsec = new Date().getTime() - start_time_millsec
@ -221,15 +247,19 @@ export default class GeminiProvider extends BaseProvider {
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
let time_first_token_millsec = 0
const processStream = async (stream: GenerateContentStreamResult, idx: number) => {
for await (const chunk of stream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
}
const time_completion_millsec = new Date().getTime() - start_time_millsec
const functionCalls = chunk.functionCalls()
if (functionCalls) {
const fcallParts: FunctionCallPart[] = []
const fcRespParts: FunctionResponsePart[] = []
@ -266,6 +296,7 @@ export default class GeminiProvider extends BaseProvider {
)
}
}
if (fcRespParts) {
history.push(messageContents)
history.push({
@ -295,9 +326,17 @@ export default class GeminiProvider extends BaseProvider {
})
}
}
await processStream(userMessagesStream, 0).finally(cleanup)
}
/**
* Translate a message
* @param message - The message
* @param assistant - The assistant
* @param onResponse - The onResponse callback
* @returns The translated message
*/
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant)
@ -332,6 +371,12 @@ export default class GeminiProvider extends BaseProvider {
return text
}
/**
* Summarize a message
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
@ -375,6 +420,12 @@ export default class GeminiProvider extends BaseProvider {
return removeSpecialCharactersForTopicName(response.text())
}
/**
* Generate text
* @param prompt - The prompt
* @param content - The content
* @returns The generated text
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const systemMessage = { role: 'system', content: prompt }
@ -387,14 +438,27 @@ export default class GeminiProvider extends BaseProvider {
return response.text()
}
/**
* Generate suggestions
* @returns The suggestions
*/
public async suggestions(): Promise<Suggestion[]> {
return []
}
/**
* Generate an image
* @returns The generated image
*/
public async generateImage(): Promise<string[]> {
return []
}
/**
* Check if the model is valid
* @param model - The model
* @returns The validity of the model
*/
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
if (!model) {
return { valid: false, error: new Error('No model found') }
@ -422,12 +486,17 @@ export default class GeminiProvider extends BaseProvider {
}
}
/**
* Get the models
* @returns The models
*/
public async models(): Promise<OpenAI.Models.Model[]> {
try {
const api = this.provider.apiHost + '/v1beta/models'
const { data } = await axios.get(api, { params: { key: this.apiKey } })
return data.models.map(
(m: any) =>
(m) =>
({
id: m.name.replace('models/', ''),
name: m.displayName,
@ -442,6 +511,11 @@ export default class GeminiProvider extends BaseProvider {
}
}
/**
* Get the embedding dimensions
* @param model - The model
* @returns The embedding dimensions
*/
public async getEmbeddingDimensions(model: Model): Promise<number> {
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi')
return data.embedding.values.length

View File

@ -288,12 +288,12 @@ export default class OpenAIProvider extends BaseProvider {
* Generate completions for the assistant
* @param messages - The messages
* @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
* @param mcpTools - The MCP tools
* @returns The completions
*/
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> {
async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)