refactor: Update completions method signatures and enhance documentation
- Reordered parameters in completions methods across AiProvider, AnthropicProvider, GeminiProvider, and OpenAIProvider to improve consistency. - Added detailed JSDoc comments for methods to clarify parameter usage and functionality. - Ensured mcpTools parameter is consistently included in completions method signatures.
This commit is contained in:
parent
18b7618a8d
commit
a39ff78758
@ -19,11 +19,11 @@ export default class AiProvider {
|
||||
public async completions({
|
||||
messages,
|
||||
assistant,
|
||||
mcpTools,
|
||||
onChunk,
|
||||
onFilterMessages,
|
||||
mcpTools
|
||||
onFilterMessages
|
||||
}: CompletionsParams): Promise<void> {
|
||||
return this.sdk.completions({ messages, assistant, onChunk, onFilterMessages, mcpTools })
|
||||
return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
|
||||
}
|
||||
|
||||
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void): Promise<string> {
|
||||
|
||||
@ -54,6 +54,11 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @returns The message parameter
|
||||
*/
|
||||
private async getMessageParam(message: Message): Promise<MessageParam> {
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
@ -74,6 +79,7 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
@ -89,18 +95,32 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the temperature
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The temperature
|
||||
*/
|
||||
private getTemperature(assistant: Assistant, model: Model) {
|
||||
if (isReasoningModel(model)) return undefined
|
||||
|
||||
return assistant?.settings?.temperature
|
||||
return isReasoningModel(model) ? undefined : assistant?.settings?.temperature
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the top P
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The top P
|
||||
*/
|
||||
private getTopP(assistant: Assistant, model: Model) {
|
||||
if (isReasoningModel(model)) return undefined
|
||||
|
||||
return assistant?.settings?.topP
|
||||
return isReasoningModel(model) ? undefined : assistant?.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getReasoningEffort(assistant: Assistant, model: Model): ReasoningConfig | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
@ -134,7 +154,15 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
|
||||
/**
|
||||
* Generate completions
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @param mcpTools - The MCP tools
|
||||
* @param onChunk - The onChunk callback
|
||||
* @param onFilterMessages - The onFilterMessages callback
|
||||
*/
|
||||
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
@ -192,6 +220,7 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
text = textBlock.text
|
||||
}
|
||||
}
|
||||
|
||||
return onChunk({
|
||||
text,
|
||||
reasoning_content,
|
||||
@ -271,6 +300,7 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
.on('finalMessage', async (message) => {
|
||||
if (toolCalls.length > 0) {
|
||||
const toolCallResults: ToolResultBlockParam[] = []
|
||||
|
||||
for (const toolCall of toolCalls) {
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
if (mcpTool) {
|
||||
@ -338,6 +368,13 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
.finally(cleanup)
|
||||
}
|
||||
|
||||
/**
|
||||
* Translate a message
|
||||
* @param message - The message
|
||||
* @param assistant - The assistant
|
||||
* @param onResponse - The onResponse callback
|
||||
* @returns The translated message
|
||||
*/
|
||||
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
@ -375,6 +412,12 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize a message
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @returns The summary
|
||||
*/
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
@ -417,6 +460,12 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
return removeSpecialCharactersForTopicName(content)
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate text
|
||||
* @param prompt - The prompt
|
||||
* @param content - The content
|
||||
* @returns The generated text
|
||||
*/
|
||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
const model = getDefaultModel()
|
||||
|
||||
@ -436,14 +485,27 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
return message.content[0].type === 'text' ? message.content[0].text : ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an image
|
||||
* @returns The generated image
|
||||
*/
|
||||
public async generateImage(): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate suggestions
|
||||
* @returns The suggestions
|
||||
*/
|
||||
public async suggestions(): Promise<Suggestion[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model is valid
|
||||
* @param model - The model
|
||||
* @returns The validity of the model
|
||||
*/
|
||||
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
|
||||
if (!model) {
|
||||
return { valid: false, error: new Error('No model found') }
|
||||
@ -470,6 +532,10 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the models
|
||||
* @returns The models
|
||||
*/
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
@ -51,6 +51,11 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a PDF file
|
||||
* @param file - The file
|
||||
* @returns The part
|
||||
*/
|
||||
private async handlePdfFile(file: FileType): Promise<Part> {
|
||||
const smallFileSize = 20 * 1024 * 1024
|
||||
const isSmallFile = file.size < smallFileSize
|
||||
@ -88,6 +93,11 @@ export default class GeminiProvider extends BaseProvider {
|
||||
} as FileDataPart
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message contents
|
||||
* @param message - The message
|
||||
* @returns The message contents
|
||||
*/
|
||||
private async getMessageContents(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
|
||||
@ -123,6 +133,11 @@ export default class GeminiProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the safety settings
|
||||
* @param modelId - The model ID
|
||||
* @returns The safety settings
|
||||
*/
|
||||
private getSafetySettings(modelId: string): SafetySetting[] {
|
||||
const safetyThreshold = modelId.includes('gemini-2.0-flash-exp')
|
||||
? ('OFF' as HarmBlockThreshold)
|
||||
@ -152,7 +167,15 @@ export default class GeminiProvider extends BaseProvider {
|
||||
]
|
||||
}
|
||||
|
||||
public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) {
|
||||
/**
|
||||
* Generate completions
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @param mcpTools - The MCP tools
|
||||
* @param onChunk - The onChunk callback
|
||||
* @param onFilterMessages - The onFilterMessages callback
|
||||
*/
|
||||
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
@ -167,9 +190,11 @@ export default class GeminiProvider extends BaseProvider {
|
||||
for (const message of userMessages) {
|
||||
history.push(await this.getMessageContents(message))
|
||||
}
|
||||
|
||||
mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs)
|
||||
const tools = mcpToolsToGeminiTools(mcpTools)
|
||||
const toolResponses: MCPToolResponse[] = []
|
||||
|
||||
if (assistant.enableWebSearch && isWebSearchModel(model)) {
|
||||
tools.push({
|
||||
// @ts-ignore googleSearch is not a valid tool for Gemini
|
||||
@ -199,6 +224,7 @@ export default class GeminiProvider extends BaseProvider {
|
||||
const start_time_millsec = new Date().getTime()
|
||||
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
if (!streamOutput) {
|
||||
const { response } = await chat.sendMessage(messageContents.parts, { signal })
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
@ -221,15 +247,19 @@ export default class GeminiProvider extends BaseProvider {
|
||||
|
||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
|
||||
let time_first_token_millsec = 0
|
||||
|
||||
const processStream = async (stream: GenerateContentStreamResult, idx: number) => {
|
||||
for await (const chunk of stream.stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
|
||||
if (time_first_token_millsec == 0) {
|
||||
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||
}
|
||||
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
|
||||
const functionCalls = chunk.functionCalls()
|
||||
|
||||
if (functionCalls) {
|
||||
const fcallParts: FunctionCallPart[] = []
|
||||
const fcRespParts: FunctionResponsePart[] = []
|
||||
@ -266,6 +296,7 @@ export default class GeminiProvider extends BaseProvider {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if (fcRespParts) {
|
||||
history.push(messageContents)
|
||||
history.push({
|
||||
@ -295,9 +326,17 @@ export default class GeminiProvider extends BaseProvider {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processStream(userMessagesStream, 0).finally(cleanup)
|
||||
}
|
||||
|
||||
/**
|
||||
* Translate a message
|
||||
* @param message - The message
|
||||
* @param assistant - The assistant
|
||||
* @param onResponse - The onResponse callback
|
||||
* @returns The translated message
|
||||
*/
|
||||
async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
@ -332,6 +371,12 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return text
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize a message
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @returns The summary
|
||||
*/
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
@ -375,6 +420,12 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return removeSpecialCharactersForTopicName(response.text())
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate text
|
||||
* @param prompt - The prompt
|
||||
* @param content - The content
|
||||
* @returns The generated text
|
||||
*/
|
||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
const model = getDefaultModel()
|
||||
const systemMessage = { role: 'system', content: prompt }
|
||||
@ -387,14 +438,27 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return response.text()
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate suggestions
|
||||
* @returns The suggestions
|
||||
*/
|
||||
public async suggestions(): Promise<Suggestion[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an image
|
||||
* @returns The generated image
|
||||
*/
|
||||
public async generateImage(): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model is valid
|
||||
* @param model - The model
|
||||
* @returns The validity of the model
|
||||
*/
|
||||
public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> {
|
||||
if (!model) {
|
||||
return { valid: false, error: new Error('No model found') }
|
||||
@ -422,12 +486,17 @@ export default class GeminiProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the models
|
||||
* @returns The models
|
||||
*/
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const api = this.provider.apiHost + '/v1beta/models'
|
||||
const { data } = await axios.get(api, { params: { key: this.apiKey } })
|
||||
|
||||
return data.models.map(
|
||||
(m: any) =>
|
||||
(m) =>
|
||||
({
|
||||
id: m.name.replace('models/', ''),
|
||||
name: m.displayName,
|
||||
@ -442,6 +511,11 @@ export default class GeminiProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the embedding dimensions
|
||||
* @param model - The model
|
||||
* @returns The embedding dimensions
|
||||
*/
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi')
|
||||
return data.embedding.values.length
|
||||
|
||||
@ -288,12 +288,12 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
* Generate completions for the assistant
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @param mcpTools - The MCP tools
|
||||
* @param onChunk - The onChunk callback
|
||||
* @param onFilterMessages - The onFilterMessages callback
|
||||
* @param mcpTools - The MCP tools
|
||||
* @returns The completions
|
||||
*/
|
||||
async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise<void> {
|
||||
async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user