diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index c4561bab..bd30dbdc 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -94,10 +94,12 @@ import WenxinModelLogoDark from '@renderer/assets/images/models/wenxin_dark.png' import YiModelLogo from '@renderer/assets/images/models/yi.png' import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png' import { Model } from '@renderer/types' +import OpenAI from 'openai' const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-turbo|dall|cogview/i const VISION_REGEX = /llava|moondream|minicpm|gemini-1.5|claude-3|vision|glm-4v|gpt-4|qwen-vl/i -const EMBEDDING_REGEX = /embedding/i +const EMBEDDING_REGEX = /embed|rerank/i +const NOT_SUPPORTED_REGEX = /embed|tts|rerank|whisper|speech/i export function getModelLogo(modelId: string) { const isLight = true @@ -665,3 +667,7 @@ export function isEmbeddingModel(model: Model): boolean { export function isVisionModel(model: Model): boolean { return VISION_REGEX.test(model.id) } + +export function isSupportedModel(model: OpenAI.Models.Model): boolean { + return !NOT_SUPPORTED_REGEX.test(model.id) +} diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 6cce6d90..8bfd947c 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -1,5 +1,5 @@ import { isLocalAi } from '@renderer/config/env' -import { isVisionModel } from '@renderer/config/models' +import { isSupportedModel, isVisionModel } from '@renderer/config/models' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant' import { EVENT_NAMES } from '@renderer/services/event' import { filterContextMessages } from '@renderer/services/messages' @@ -263,18 +263,29 @@ export default class OpenAIProvider extends BaseProvider { public async models(): Promise { try { - if (this.provider.id === 'github') { - // @ts-ignore key is not typed - return response.body.map((model) => ({ - id: model.name, - description: model.summary, - object: 'model', - owned_by: model.publisher - })) + const query: Record = {} + + if (this.provider.id === 'silicon') { + query.type = 'text' } - const response = await this.sdk.models.list() - return response.data + const response = await this.sdk.models.list({ query }) + + if (this.provider.id === 'github') { + // @ts-ignore key is not typed + return response.body + .map((model) => ({ + id: model.name, + description: model.summary, + object: 'model', + owned_by: model.publisher + })) + .filter(isSupportedModel) + } + + const models = response?.data || [] + + return models.filter(isSupportedModel) } catch (error) { return [] }