fix: check provider connection use the last model

This commit is contained in:
kangfenmao 2024-12-24 09:33:35 +08:00
parent d558572d97
commit 37477587b6
4 changed files with 22 additions and 9 deletions

View File

@ -1,12 +1,13 @@
import Anthropic from '@anthropic-ai/sdk' import Anthropic from '@anthropic-ai/sdk'
import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources' import { MessageCreateParamsNonStreaming, MessageParam } from '@anthropic-ai/sdk/resources'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { isEmbeddingModel } from '@renderer/config/models'
import { SUMMARIZE_PROMPT } from '@renderer/config/prompts' import { SUMMARIZE_PROMPT } from '@renderer/config/prompts'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService' import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService' import { filterContextMessages } from '@renderer/services/MessagesService'
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
import { first, flatten, sum, takeRight } from 'lodash' import { first, flatten, last, sum, takeRight } from 'lodash'
import OpenAI from 'openai' import OpenAI from 'openai'
import { CompletionsParams } from '.' import { CompletionsParams } from '.'
@ -234,7 +235,11 @@ export default class AnthropicProvider extends BaseProvider {
} }
public async check(): Promise<{ valid: boolean; error: Error | null }> { public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0] const model = last(this.provider.models.filter((m) => !isEmbeddingModel(m)))
if (!model) {
return { valid: false, error: new Error('No model found') }
}
const body = { const body = {
model: model.id, model: model.id,

View File

@ -8,13 +8,14 @@ import {
RequestOptions, RequestOptions,
TextPart TextPart
} from '@google/generative-ai' } from '@google/generative-ai'
import { isEmbeddingModel } from '@renderer/config/models'
import { SUMMARIZE_PROMPT } from '@renderer/config/prompts' import { SUMMARIZE_PROMPT } from '@renderer/config/prompts'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService' import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService' import { filterContextMessages } from '@renderer/services/MessagesService'
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types' import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
import axios from 'axios' import axios from 'axios'
import { first, isEmpty, takeRight } from 'lodash' import { first, isEmpty, last, takeRight } from 'lodash'
import OpenAI from 'openai' import OpenAI from 'openai'
import { CompletionsParams } from '.' import { CompletionsParams } from '.'
@ -240,7 +241,11 @@ export default class GeminiProvider extends BaseProvider {
} }
public async check(): Promise<{ valid: boolean; error: Error | null }> { public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0] const model = last(this.provider.models.filter((m) => !isEmbeddingModel(m)))
if (!model) {
return { valid: false, error: new Error('No model found') }
}
const body = { const body = {
model: model.id, model: model.id,

View File

@ -1,11 +1,11 @@
import { isSupportedModel, isVisionModel } from '@renderer/config/models' import { isEmbeddingModel, isSupportedModel, isVisionModel } from '@renderer/config/models'
import { SUMMARIZE_PROMPT } from '@renderer/config/prompts' import { SUMMARIZE_PROMPT } from '@renderer/config/prompts'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService' import { EVENT_NAMES } from '@renderer/services/EventService'
import { filterContextMessages } from '@renderer/services/MessagesService' import { filterContextMessages } from '@renderer/services/MessagesService'
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
import { removeQuotes } from '@renderer/utils' import { removeQuotes } from '@renderer/utils'
import { takeRight } from 'lodash' import { last, takeRight } from 'lodash'
import OpenAI, { AzureOpenAI } from 'openai' import OpenAI, { AzureOpenAI } from 'openai'
import { import {
ChatCompletionContentPart, ChatCompletionContentPart,
@ -277,7 +277,11 @@ export default class OpenAIProvider extends BaseProvider {
} }
public async check(): Promise<{ valid: boolean; error: Error | null }> { public async check(): Promise<{ valid: boolean; error: Error | null }> {
const model = this.provider.models[0] const model = last(this.provider.models.filter((m) => !isEmbeddingModel(m)))
if (!model) {
return { valid: false, error: new Error('No model found') }
}
const body = { const body = {
model: model.id, model: model.id,

View File

@ -185,7 +185,6 @@ export async function fetchSuggestions({
} }
export async function checkApi(provider: Provider) { export async function checkApi(provider: Provider) {
const model = provider.models[0]
const key = 'api-check' const key = 'api-check'
const style = { marginTop: '3vh' } const style = { marginTop: '3vh' }
@ -201,7 +200,7 @@ export async function checkApi(provider: Provider) {
return false return false
} }
if (!model) { if (isEmpty(provider.models)) {
window.message.error({ content: i18n.t('message.error.enter.model'), key, style }) window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
return false return false
} }