refactor: Update completions method signatures and enhance documentation

- Reordered parameters in completions methods across AiProvider, AnthropicProvider, GeminiProvider, and OpenAIProvider to improve consistency.
- Added detailed JSDoc comments for methods to clarify parameter usage and functionality.
- Ensured mcpTools parameter is consistently included in completions method signatures.
This commit is contained in:
kangfenmao 2025-03-14 10:44:03 +08:00
parent 18b7618a8d
commit a39ff78758
4 changed files with 154 additions and 14 deletions

View File

@ -19,11 +19,11 @@ export default class AiProvider {
public async completions({ public async completions({
messages, messages,
assistant, assistant,
mcpTools,
onChunk, onChunk,
onFilterMessages, onFilterMessages
mcpTools
}: CompletionsParams): Promise<void> { }: CompletionsParams): Promise<void> {
return this.sdk.completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }) return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
} }
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void): Promise<string> { public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void): Promise<string> {

View File

@ -54,6 +54,11 @@ export default class AnthropicProvider extends BaseProvider {
return this.provider.apiHost return this.provider.apiHost
} }
/**
* Get the message parameter
* @param message - The message
* @returns The message parameter
*/
private async getMessageParam(message: Message): Promise<MessageParam> { private async getMessageParam(message: Message): Promise<MessageParam> {
const parts: MessageParam['content'] = [ const parts: MessageParam['content'] = [
{ {
@ -74,6 +79,7 @@ export default class AnthropicProvider extends BaseProvider {
} }
}) })
} }
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({ parts.push({
@ -89,18 +95,32 @@ export default class AnthropicProvider extends BaseProvider {
} }
} }
/**
* Get the temperature
* @param assistant - The assistant
* @param model - The model
* @returns The temperature
*/
private getTemperature(assistant: Assistant, model: Model) { private getTemperature(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined return isReasoningModel(model) ? undefined : assistant?.settings?.temperature
return assistant?.settings?.temperature
} }
/**
* Get the top P
* @param assistant - The assistant
* @param model - The model
* @returns The top P
*/
private getTopP(assistant: Assistant, model: Model) { private getTopP(assistant: Assistant, model: Model) {
if (isReasoningModel(model)) return undefined return isReasoningModel(model) ? undefined : assistant?.settings?.topP
return assistant?.settings?.topP
} }
/**
* Get the reasoning effort
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getReasoningEffort(assistant: Assistant, model: Model): ReasoningConfig | undefined { private getReasoningEffort(assistant: Assistant, model: Model): ReasoningConfig | undefined {
if (!isReasoningModel(model)) { if (!isReasoningModel(model)) {
return undefined return undefined
@ -134,7 +154,15 @@ export default class AnthropicProvider extends BaseProvider {
} }
} }
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) { /**
* Generate completions
* @param messages - The messages
* @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
*/
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -192,6 +220,7 @@ export default class AnthropicProvider extends BaseProvider {
text = textBlock.text text = textBlock.text
} }
} }
return onChunk({ return onChunk({
text, text,
reasoning_content, reasoning_content,
@ -271,6 +300,7 @@ export default class AnthropicProvider extends BaseProvider {
.on('finalMessage', async (message) => { .on('finalMessage', async (message) => {
if (toolCalls.length > 0) { if (toolCalls.length > 0) {
const toolCallResults: ToolResultBlockParam[] = [] const toolCallResults: ToolResultBlockParam[] = []
for (const toolCall of toolCalls) { for (const toolCall of toolCalls) {
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
if (mcpTool) { if (mcpTool) {
@ -338,6 +368,13 @@ export default class AnthropicProvider extends BaseProvider {
.finally(cleanup) .finally(cleanup)
} }
/**
* Translate a message
* @param message - The message
* @param assistant - The assistant
* @param onResponse - The onResponse callback
* @returns The translated message
*/
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
@ -375,6 +412,12 @@ export default class AnthropicProvider extends BaseProvider {
}) })
} }
/**
* Summarize a message
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaries(messages: Message[], assistant: Assistant): Promise<string> { public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel() const model = getTopNamingModel() || assistant.model || getDefaultModel()
@ -417,6 +460,12 @@ export default class AnthropicProvider extends BaseProvider {
return removeSpecialCharactersForTopicName(content) return removeSpecialCharactersForTopicName(content)
} }
/**
* Generate text
* @param prompt - The prompt
* @param content - The content
* @returns The generated text
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> { public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel() const model = getDefaultModel()
@ -436,14 +485,27 @@ export default class AnthropicProvider extends BaseProvider {
return message.content[0].type === 'text' ? message.content[0].text : '' return message.content[0].type === 'text' ? message.content[0].text : ''
} }
/**
* Generate an image
* @returns The generated image
*/
public async generateImage(): Promise<string[]> { public async generateImage(): Promise<string[]> {
return [] return []
} }
/**
* Generate suggestions
* @returns The suggestions
*/
public async suggestions(): Promise<Suggestion[]> { public async suggestions(): Promise<Suggestion[]> {
return [] return []
} }
/**
* Check if the model is valid
* @param model - The model
* @returns The validity of the model
*/
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
if (!model) { if (!model) {
return { valid: false, error: new Error('No model found') } return { valid: false, error: new Error('No model found') }
@ -470,6 +532,10 @@ export default class AnthropicProvider extends BaseProvider {
} }
} }
/**
* Get the models
* @returns The models
*/
public async models(): Promise<OpenAI.Models.Model[]> { public async models(): Promise<OpenAI.Models.Model[]> {
return [] return []
} }

View File

@ -51,6 +51,11 @@ export default class GeminiProvider extends BaseProvider {
return this.provider.apiHost return this.provider.apiHost
} }
/**
* Handle a PDF file
* @param file - The file
* @returns The part
*/
private async handlePdfFile(file: FileType): Promise<Part> { private async handlePdfFile(file: FileType): Promise<Part> {
const smallFileSize = 20 * 1024 * 1024 const smallFileSize = 20 * 1024 * 1024
const isSmallFile = file.size < smallFileSize const isSmallFile = file.size < smallFileSize
@ -88,6 +93,11 @@ export default class GeminiProvider extends BaseProvider {
} as FileDataPart } as FileDataPart
} }
/**
* Get the message contents
* @param message - The message
* @returns The message contents
*/
private async getMessageContents(message: Message): Promise<Content> { private async getMessageContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model' const role = message.role === 'user' ? 'user' : 'model'
@ -123,6 +133,11 @@ export default class GeminiProvider extends BaseProvider {
} }
} }
/**
* Get the safety settings
* @param modelId - The model ID
* @returns The safety settings
*/
private getSafetySettings(modelId: string): SafetySetting[] { private getSafetySettings(modelId: string): SafetySetting[] {
const safetyThreshold = modelId.includes('gemini-2.0-flash-exp') const safetyThreshold = modelId.includes('gemini-2.0-flash-exp')
? ('OFF' as HarmBlockThreshold) ? ('OFF' as HarmBlockThreshold)
@ -152,7 +167,15 @@ export default class GeminiProvider extends BaseProvider {
] ]
} }
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) { /**
* Generate completions
* @param messages - The messages
* @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
*/
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
@ -167,9 +190,11 @@ export default class GeminiProvider extends BaseProvider {
for (const message of userMessages) { for (const message of userMessages) {
history.push(await this.getMessageContents(message)) history.push(await this.getMessageContents(message))
} }
mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs) mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs)
const tools = mcpToolsToGeminiTools(mcpTools) const tools = mcpToolsToGeminiTools(mcpTools)
const toolResponses: MCPToolResponse[] = [] const toolResponses: MCPToolResponse[] = []
if (assistant.enableWebSearch && isWebSearchModel(model)) { if (assistant.enableWebSearch && isWebSearchModel(model)) {
tools.push({ tools.push({
// @ts-ignore googleSearch is not a valid tool for Gemini // @ts-ignore googleSearch is not a valid tool for Gemini
@ -199,6 +224,7 @@ export default class GeminiProvider extends BaseProvider {
const start_time_millsec = new Date().getTime() const start_time_millsec = new Date().getTime()
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
const { signal } = abortController const { signal } = abortController
if (!streamOutput) { if (!streamOutput) {
const { response } = await chat.sendMessage(messageContents.parts, { signal }) const { response } = await chat.sendMessage(messageContents.parts, { signal })
const time_completion_millsec = new Date().getTime() - start_time_millsec const time_completion_millsec = new Date().getTime() - start_time_millsec
@ -221,15 +247,19 @@ export default class GeminiProvider extends BaseProvider {
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
let time_first_token_millsec = 0 let time_first_token_millsec = 0
const processStream = async (stream: GenerateContentStreamResult, idx: number) => { const processStream = async (stream: GenerateContentStreamResult, idx: number) => {
for await (const chunk of stream.stream) { for await (const chunk of stream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
if (time_first_token_millsec == 0) { if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec time_first_token_millsec = new Date().getTime() - start_time_millsec
} }
const time_completion_millsec = new Date().getTime() - start_time_millsec const time_completion_millsec = new Date().getTime() - start_time_millsec
const functionCalls = chunk.functionCalls() const functionCalls = chunk.functionCalls()
if (functionCalls) { if (functionCalls) {
const fcallParts: FunctionCallPart[] = [] const fcallParts: FunctionCallPart[] = []
const fcRespParts: FunctionResponsePart[] = [] const fcRespParts: FunctionResponsePart[] = []
@ -266,6 +296,7 @@ export default class GeminiProvider extends BaseProvider {
) )
} }
} }
if (fcRespParts) { if (fcRespParts) {
history.push(messageContents) history.push(messageContents)
history.push({ history.push({
@ -295,9 +326,17 @@ export default class GeminiProvider extends BaseProvider {
}) })
} }
} }
await processStream(userMessagesStream, 0).finally(cleanup) await processStream(userMessagesStream, 0).finally(cleanup)
} }
/**
* Translate a message
* @param message - The message
* @param assistant - The assistant
* @param onResponse - The onResponse callback
* @returns The translated message
*/
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const { maxTokens } = getAssistantSettings(assistant) const { maxTokens } = getAssistantSettings(assistant)
@ -332,6 +371,12 @@ export default class GeminiProvider extends BaseProvider {
return text return text
} }
/**
* Summarize a message
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaries(messages: Message[], assistant: Assistant): Promise<string> { public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel() const model = getTopNamingModel() || assistant.model || getDefaultModel()
@ -375,6 +420,12 @@ export default class GeminiProvider extends BaseProvider {
return removeSpecialCharactersForTopicName(response.text()) return removeSpecialCharactersForTopicName(response.text())
} }
/**
* Generate text
* @param prompt - The prompt
* @param content - The content
* @returns The generated text
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> { public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel() const model = getDefaultModel()
const systemMessage = { role: 'system', content: prompt } const systemMessage = { role: 'system', content: prompt }
@ -387,14 +438,27 @@ export default class GeminiProvider extends BaseProvider {
return response.text() return response.text()
} }
/**
* Generate suggestions
* @returns The suggestions
*/
public async suggestions(): Promise<Suggestion[]> { public async suggestions(): Promise<Suggestion[]> {
return [] return []
} }
/**
* Generate an image
* @returns The generated image
*/
public async generateImage(): Promise<string[]> { public async generateImage(): Promise<string[]> {
return [] return []
} }
/**
* Check if the model is valid
* @param model - The model
* @returns The validity of the model
*/
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
if (!model) { if (!model) {
return { valid: false, error: new Error('No model found') } return { valid: false, error: new Error('No model found') }
@ -422,12 +486,17 @@ export default class GeminiProvider extends BaseProvider {
} }
} }
/**
* Get the models
* @returns The models
*/
public async models(): Promise<OpenAI.Models.Model[]> { public async models(): Promise<OpenAI.Models.Model[]> {
try { try {
const api = this.provider.apiHost + '/v1beta/models' const api = this.provider.apiHost + '/v1beta/models'
const { data } = await axios.get(api, { params: { key: this.apiKey } }) const { data } = await axios.get(api, { params: { key: this.apiKey } })
return data.models.map( return data.models.map(
(m: any) => (m) =>
({ ({
id: m.name.replace('models/', ''), id: m.name.replace('models/', ''),
name: m.displayName, name: m.displayName,
@ -442,6 +511,11 @@ export default class GeminiProvider extends BaseProvider {
} }
} }
/**
* Get the embedding dimensions
* @param model - The model
* @returns The embedding dimensions
*/
public async getEmbeddingDimensions(model: Model): Promise<number> { public async getEmbeddingDimensions(model: Model): Promise<number> {
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi') const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi')
return data.embedding.values.length return data.embedding.values.length

View File

@ -288,12 +288,12 @@ export default class OpenAIProvider extends BaseProvider {
* Generate completions for the assistant * Generate completions for the assistant
* @param messages - The messages * @param messages - The messages
* @param assistant - The assistant * @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback * @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback * @param onFilterMessages - The onFilterMessages callback
* @param mcpTools - The MCP tools
* @returns The completions * @returns The completions
*/ */
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> { async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel() const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)