diff --git a/src/renderer/src/providers/AiProvider.ts b/src/renderer/src/providers/AiProvider.ts index ee4d13a6..02425a13 100644 --- a/src/renderer/src/providers/AiProvider.ts +++ b/src/renderer/src/providers/AiProvider.ts @@ -19,11 +19,11 @@ export default class AiProvider { public async completions({ messages, assistant, + mcpTools, onChunk, - onFilterMessages, - mcpTools + onFilterMessages }: CompletionsParams): Promise { - return this.sdk.completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }) + return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }) } public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void): Promise { diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index f2ffc64c..845a2cd9 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -54,6 +54,11 @@ export default class AnthropicProvider extends BaseProvider { return this.provider.apiHost } + /** + * Get the message parameter + * @param message - The message + * @returns The message parameter + */ private async getMessageParam(message: Message): Promise { const parts: MessageParam['content'] = [ { @@ -74,6 +79,7 @@ export default class AnthropicProvider extends BaseProvider { } }) } + if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() parts.push({ @@ -89,18 +95,32 @@ export default class AnthropicProvider extends BaseProvider { } } + /** + * Get the temperature + * @param assistant - The assistant + * @param model - The model + * @returns The temperature + */ private getTemperature(assistant: Assistant, model: Model) { - if (isReasoningModel(model)) return undefined - - return assistant?.settings?.temperature + return isReasoningModel(model) ? undefined : assistant?.settings?.temperature } + /** + * Get the top P + * @param assistant - The assistant + * @param model - The model + * @returns The top P + */ private getTopP(assistant: Assistant, model: Model) { - if (isReasoningModel(model)) return undefined - - return assistant?.settings?.topP + return isReasoningModel(model) ? undefined : assistant?.settings?.topP } + /** + * Get the reasoning effort + * @param assistant - The assistant + * @param model - The model + * @returns The reasoning effort + */ private getReasoningEffort(assistant: Assistant, model: Model): ReasoningConfig | undefined { if (!isReasoningModel(model)) { return undefined @@ -134,7 +154,15 @@ export default class AnthropicProvider extends BaseProvider { } } - public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) { + /** + * Generate completions + * @param messages - The messages + * @param assistant - The assistant + * @param mcpTools - The MCP tools + * @param onChunk - The onChunk callback + * @param onFilterMessages - The onFilterMessages callback + */ + public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) @@ -192,6 +220,7 @@ export default class AnthropicProvider extends BaseProvider { text = textBlock.text } } + return onChunk({ text, reasoning_content, @@ -271,6 +300,7 @@ export default class AnthropicProvider extends BaseProvider { .on('finalMessage', async (message) => { if (toolCalls.length > 0) { const toolCallResults: ToolResultBlockParam[] = [] + for (const toolCall of toolCalls) { const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) if (mcpTool) { @@ -338,6 +368,13 @@ export default class AnthropicProvider extends BaseProvider { .finally(cleanup) } + /** + * Translate a message + * @param message - The message + * @param assistant - The assistant + * @param onResponse - The onResponse callback + * @returns The translated message + */ public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel @@ -375,6 +412,12 @@ export default class AnthropicProvider extends BaseProvider { }) } + /** + * Summarize a message + * @param messages - The messages + * @param assistant - The assistant + * @returns The summary + */ public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() @@ -417,6 +460,12 @@ export default class AnthropicProvider extends BaseProvider { return removeSpecialCharactersForTopicName(content) } + /** + * Generate text + * @param prompt - The prompt + * @param content - The content + * @returns The generated text + */ public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { const model = getDefaultModel() @@ -436,14 +485,27 @@ export default class AnthropicProvider extends BaseProvider { return message.content[0].type === 'text' ? message.content[0].text : '' } + /** + * Generate an image + * @returns The generated image + */ public async generateImage(): Promise { return [] } + /** + * Generate suggestions + * @returns The suggestions + */ public async suggestions(): Promise { return [] } + /** + * Check if the model is valid + * @param model - The model + * @returns The validity of the model + */ public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { if (!model) { return { valid: false, error: new Error('No model found') } @@ -470,6 +532,10 @@ export default class AnthropicProvider extends BaseProvider { } } + /** + * Get the models + * @returns The models + */ public async models(): Promise { return [] } diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 9c3ec59e..96dcf8b3 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -51,6 +51,11 @@ export default class GeminiProvider extends BaseProvider { return this.provider.apiHost } + /** + * Handle a PDF file + * @param file - The file + * @returns The part + */ private async handlePdfFile(file: FileType): Promise { const smallFileSize = 20 * 1024 * 1024 const isSmallFile = file.size < smallFileSize @@ -88,6 +93,11 @@ export default class GeminiProvider extends BaseProvider { } as FileDataPart } + /** + * Get the message contents + * @param message - The message + * @returns The message contents + */ private async getMessageContents(message: Message): Promise { const role = message.role === 'user' ? 'user' : 'model' @@ -123,6 +133,11 @@ export default class GeminiProvider extends BaseProvider { } } + /** + * Get the safety settings + * @param modelId - The model ID + * @returns The safety settings + */ private getSafetySettings(modelId: string): SafetySetting[] { const safetyThreshold = modelId.includes('gemini-2.0-flash-exp') ? ('OFF' as HarmBlockThreshold) @@ -152,7 +167,15 @@ export default class GeminiProvider extends BaseProvider { ] } - public async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams) { + /** + * Generate completions + * @param messages - The messages + * @param assistant - The assistant + * @param mcpTools - The MCP tools + * @param onChunk - The onChunk callback + * @param onFilterMessages - The onFilterMessages callback + */ + public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) @@ -167,9 +190,11 @@ export default class GeminiProvider extends BaseProvider { for (const message of userMessages) { history.push(await this.getMessageContents(message)) } + mcpTools = filterMCPTools(mcpTools, userLastMessage?.enabledMCPs) const tools = mcpToolsToGeminiTools(mcpTools) const toolResponses: MCPToolResponse[] = [] + if (assistant.enableWebSearch && isWebSearchModel(model)) { tools.push({ // @ts-ignore googleSearch is not a valid tool for Gemini @@ -199,6 +224,7 @@ export default class GeminiProvider extends BaseProvider { const start_time_millsec = new Date().getTime() const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) const { signal } = abortController + if (!streamOutput) { const { response } = await chat.sendMessage(messageContents.parts, { signal }) const time_completion_millsec = new Date().getTime() - start_time_millsec @@ -221,15 +247,19 @@ export default class GeminiProvider extends BaseProvider { const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) let time_first_token_millsec = 0 + const processStream = async (stream: GenerateContentStreamResult, idx: number) => { for await (const chunk of stream.stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break + if (time_first_token_millsec == 0) { time_first_token_millsec = new Date().getTime() - start_time_millsec } + const time_completion_millsec = new Date().getTime() - start_time_millsec const functionCalls = chunk.functionCalls() + if (functionCalls) { const fcallParts: FunctionCallPart[] = [] const fcRespParts: FunctionResponsePart[] = [] @@ -266,6 +296,7 @@ export default class GeminiProvider extends BaseProvider { ) } } + if (fcRespParts) { history.push(messageContents) history.push({ @@ -295,9 +326,17 @@ export default class GeminiProvider extends BaseProvider { }) } } + await processStream(userMessagesStream, 0).finally(cleanup) } + /** + * Translate a message + * @param message - The message + * @param assistant - The assistant + * @param onResponse - The onResponse callback + * @returns The translated message + */ async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { const defaultModel = getDefaultModel() const { maxTokens } = getAssistantSettings(assistant) @@ -332,6 +371,12 @@ export default class GeminiProvider extends BaseProvider { return text } + /** + * Summarize a message + * @param messages - The messages + * @param assistant - The assistant + * @returns The summary + */ public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() @@ -375,6 +420,12 @@ export default class GeminiProvider extends BaseProvider { return removeSpecialCharactersForTopicName(response.text()) } + /** + * Generate text + * @param prompt - The prompt + * @param content - The content + * @returns The generated text + */ public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { const model = getDefaultModel() const systemMessage = { role: 'system', content: prompt } @@ -387,14 +438,27 @@ export default class GeminiProvider extends BaseProvider { return response.text() } + /** + * Generate suggestions + * @returns The suggestions + */ public async suggestions(): Promise { return [] } + /** + * Generate an image + * @returns The generated image + */ public async generateImage(): Promise { return [] } + /** + * Check if the model is valid + * @param model - The model + * @returns The validity of the model + */ public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { if (!model) { return { valid: false, error: new Error('No model found') } @@ -422,12 +486,17 @@ export default class GeminiProvider extends BaseProvider { } } + /** + * Get the models + * @returns The models + */ public async models(): Promise { try { const api = this.provider.apiHost + '/v1beta/models' const { data } = await axios.get(api, { params: { key: this.apiKey } }) + return data.models.map( - (m: any) => + (m) => ({ id: m.name.replace('models/', ''), name: m.displayName, @@ -442,6 +511,11 @@ export default class GeminiProvider extends BaseProvider { } } + /** + * Get the embedding dimensions + * @param model - The model + * @returns The embedding dimensions + */ public async getEmbeddingDimensions(model: Model): Promise { const data = await this.sdk.getGenerativeModel({ model: model.id }, this.requestOptions).embedContent('hi') return data.embedding.values.length diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 00129350..7bd86362 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -288,12 +288,12 @@ export default class OpenAIProvider extends BaseProvider { * Generate completions for the assistant * @param messages - The messages * @param assistant - The assistant + * @param mcpTools - The MCP tools * @param onChunk - The onChunk callback * @param onFilterMessages - The onFilterMessages callback - * @param mcpTools - The MCP tools * @returns The completions */ - async completions({ messages, assistant, onChunk, onFilterMessages, mcpTools }: CompletionsParams): Promise { + async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)