diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index 4cb7ea60..676b6b26 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -1,12 +1,13 @@ import Anthropic from '@anthropic-ai/sdk' import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' +import { isEmbeddingModel } from '@renderer/config/models' import { SUMMARIZE_PROMPT } from '@renderer/config/prompts' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' -import { first, flatten, sum, takeRight } from 'lodash' +import { first, flatten, last, sum, takeRight } from 'lodash' import OpenAI from 'openai' import { CompletionsParams } from '.' @@ -234,7 +235,11 @@ export default class AnthropicProvider extends BaseProvider { } public async check(): Promise<{ valid: boolean; error: Error | null }> { - const model = this.provider.models[0] + const model = last(this.provider.models.filter((m) => !isEmbeddingModel(m))) + + if (!model) { + return { valid: false, error: new Error('No model found') } + } const body = { model: model.id, diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index f030efdf..7beb7086 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -8,13 +8,14 @@ import { RequestOptions, TextPart } from '@google/generative-ai' +import { isEmbeddingModel } from '@renderer/config/models' import { SUMMARIZE_PROMPT } from '@renderer/config/prompts' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' import axios from 'axios' -import { first, isEmpty, takeRight } from 'lodash' +import { first, isEmpty, last, takeRight } from 'lodash' import OpenAI from 'openai' import { CompletionsParams } from '.' @@ -240,7 +241,11 @@ export default class GeminiProvider extends BaseProvider { } public async check(): Promise<{ valid: boolean; error: Error | null }> { - const model = this.provider.models[0] + const model = last(this.provider.models.filter((m) => !isEmbeddingModel(m))) + + if (!model) { + return { valid: false, error: new Error('No model found') } + } const body = { model: model.id, diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index f5ceed3b..c57e0d2e 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -1,11 +1,11 @@ -import { isSupportedModel, isVisionModel } from '@renderer/config/models' +import { isEmbeddingModel, isSupportedModel, isVisionModel } from '@renderer/config/models' import { SUMMARIZE_PROMPT } from '@renderer/config/prompts' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeQuotes } from '@renderer/utils' -import { takeRight } from 'lodash' +import { last, takeRight } from 'lodash' import OpenAI, { AzureOpenAI } from 'openai' import { ChatCompletionContentPart, @@ -277,7 +277,11 @@ export default class OpenAIProvider extends BaseProvider { } public async check(): Promise<{ valid: boolean; error: Error | null }> { - const model = this.provider.models[0] + const model = last(this.provider.models.filter((m) => !isEmbeddingModel(m))) + + if (!model) { + return { valid: false, error: new Error('No model found') } + } const body = { model: model.id, diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 3c24c003..56f08493 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -185,7 +185,6 @@ export async function fetchSuggestions({ } export async function checkApi(provider: Provider) { - const model = provider.models[0] const key = 'api-check' const style = { marginTop: '3vh' } @@ -201,7 +200,7 @@ export async function checkApi(provider: Provider) { return false } - if (!model) { + if (isEmpty(provider.models)) { window.message.error({ content: i18n.t('message.error.enter.model'), key, style }) return false }