From b50f8a4c118ff313cc45c02f69e9bae97cd7ae3d Mon Sep 17 00:00:00 2001 From: eeee0717 Date: Tue, 18 Mar 2025 17:20:01 +0800 Subject: [PATCH] feat(knowledge base): enhance knowledge base with rerank model --- src/main/ipc.ts | 1 + src/main/reranker/BaseReranker.ts | 20 +++++++ src/main/reranker/DefaultReranker.ts | 13 +++++ src/main/reranker/Reranker.ts | 15 ++++++ src/main/reranker/RerankerFactory.ts | 14 +++++ src/main/reranker/SiliconFlowReranker.ts | 46 ++++++++++++++++ src/main/services/KnowledgeService.ts | 8 +++ src/preload/index.d.ts | 9 ++++ src/preload/index.ts | 5 +- src/renderer/src/config/models.ts | 11 ++++ src/renderer/src/i18n/locales/en-us.json | 12 +++-- src/renderer/src/i18n/locales/ja-jp.json | 12 +++-- src/renderer/src/i18n/locales/ru-ru.json | 12 +++-- src/renderer/src/i18n/locales/zh-cn.json | 8 ++- src/renderer/src/i18n/locales/zh-tw.json | 12 +++-- .../components/AddKnowledgePopup.tsx | 32 ++++++++++- .../components/KnowledgeSearchPopup.tsx | 10 +++- .../components/KnowledgeSettingsPopup.tsx | 54 ++++++++++++++++++- src/renderer/src/services/KnowledgeService.ts | 20 +++++-- src/renderer/src/types/index.ts | 5 ++ 20 files changed, 297 insertions(+), 22 deletions(-) create mode 100644 src/main/reranker/BaseReranker.ts create mode 100644 src/main/reranker/DefaultReranker.ts create mode 100644 src/main/reranker/Reranker.ts create mode 100644 src/main/reranker/RerankerFactory.ts create mode 100644 src/main/reranker/SiliconFlowReranker.ts diff --git a/src/main/ipc.ts b/src/main/ipc.ts index de5b9244..196484f5 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -189,6 +189,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle('knowledge-base:add', KnowledgeService.add) ipcMain.handle('knowledge-base:remove', KnowledgeService.remove) ipcMain.handle('knowledge-base:search', KnowledgeService.search) + ipcMain.handle('knowledge-base:rerank', KnowledgeService.rerank) // window ipcMain.handle('window:set-minimum-size', (_, width: number, height: number) => { diff --git a/src/main/reranker/BaseReranker.ts b/src/main/reranker/BaseReranker.ts new file mode 100644 index 00000000..469ed2e6 --- /dev/null +++ b/src/main/reranker/BaseReranker.ts @@ -0,0 +1,20 @@ +import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces' +import { KnowledgeBaseParams } from '@types' + +export default abstract class BaseReranker { + protected base: KnowledgeBaseParams + constructor(base: KnowledgeBaseParams) { + if (!base.rerankModel) { + throw new Error('Rerank model is required') + } + this.base = base + } + abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise + + public defaultHeaders() { + return { + Authorization: `Bearer ${this.base.apiKey}`, + 'Content-Type': 'application/json' + } + } +} diff --git a/src/main/reranker/DefaultReranker.ts b/src/main/reranker/DefaultReranker.ts new file mode 100644 index 00000000..0bb07456 --- /dev/null +++ b/src/main/reranker/DefaultReranker.ts @@ -0,0 +1,13 @@ +import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces' +import { KnowledgeBaseParams } from '@types' + +import BaseReranker from './BaseReranker' + +export default class DefaultReranker extends BaseReranker { + constructor(base: KnowledgeBaseParams) { + super(base) + } + async rerank(): Promise { + throw new Error('Method not implemented.') + } +} diff --git a/src/main/reranker/Reranker.ts b/src/main/reranker/Reranker.ts new file mode 100644 index 00000000..c07b1567 --- /dev/null +++ b/src/main/reranker/Reranker.ts @@ -0,0 +1,15 @@ +import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces' +import { KnowledgeBaseParams } from '@types' + +import BaseReranker from './BaseReranker' +import RerankerFactory from './RerankerFactory' + +export default class Reranker { + private sdk: BaseReranker + constructor(base: KnowledgeBaseParams) { + this.sdk = RerankerFactory.create(base) + } + public async rerank(query: string, searchResults: ExtractChunkData[]): Promise { + return this.sdk.rerank(query, searchResults) + } +} diff --git a/src/main/reranker/RerankerFactory.ts b/src/main/reranker/RerankerFactory.ts new file mode 100644 index 00000000..2c15fe6c --- /dev/null +++ b/src/main/reranker/RerankerFactory.ts @@ -0,0 +1,14 @@ +import { KnowledgeBaseParams } from '@types' + +import BaseReranker from './BaseReranker' +import DefaultReranker from './DefaultReranker' +import SiliconFlowReranker from './SiliconFlowReranker' + +export default class RerankerFactory { + static create(base: KnowledgeBaseParams): BaseReranker { + if (base.rerankModelProvider === 'silicon') { + return new SiliconFlowReranker(base) + } + return new DefaultReranker(base) + } +} diff --git a/src/main/reranker/SiliconFlowReranker.ts b/src/main/reranker/SiliconFlowReranker.ts new file mode 100644 index 00000000..8fa3de35 --- /dev/null +++ b/src/main/reranker/SiliconFlowReranker.ts @@ -0,0 +1,46 @@ +import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces' +import { KnowledgeBaseParams } from '@types' +import axios from 'axios' + +import BaseReranker from './BaseReranker' + +export default class SiliconFlowReranker extends BaseReranker { + constructor(base: KnowledgeBaseParams) { + super(base) + } + + public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { + const url = `${this.base.baseURL}/rerank` + + const { data } = await axios.post( + url, + { + model: this.base.rerankModel, + query, + documents: searchResults.map((doc) => doc.pageContent), + top_n: this.base.topN, + max_chunks_per_doc: this.base.chunkSize, + overlap_tokens: this.base.chunkOverlap + }, + { + headers: this.defaultHeaders() + } + ) + + const rerankResults = data.results + const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0])) + + return searchResults + .map((doc: ExtractChunkData, index: number) => { + const score = resultMap.get(index) + if (score === undefined) return undefined + + return { + ...doc, + score + } + }) + .filter((doc): doc is ExtractChunkData => doc !== undefined) + .sort((a, b) => b.score - a.score) + } +} diff --git a/src/main/services/KnowledgeService.ts b/src/main/services/KnowledgeService.ts index 6e5d53ad..4a65516e 100644 --- a/src/main/services/KnowledgeService.ts +++ b/src/main/services/KnowledgeService.ts @@ -23,6 +23,7 @@ import { SitemapLoader } from '@llm-tools/embedjs-loader-sitemap' import { WebLoader } from '@llm-tools/embedjs-loader-web' import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai' import { addFileLoader } from '@main/loader' +import Reranker from '@main/reranker/Reranker' import { proxyManager } from '@main/services/ProxyManager' import { windowService } from '@main/services/WindowService' import { getInstanceName } from '@main/utils' @@ -482,6 +483,13 @@ class KnowledgeService { const ragApplication = await this.getRagApplication(base) return await ragApplication.search(search) } + + public rerank = async ( + _: Electron.IpcMainInvokeEvent, + { search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] } + ): Promise => { + return await new Reranker(base).rerank(search, results) + } } export default new KnowledgeService() diff --git a/src/preload/index.d.ts b/src/preload/index.d.ts index 2de7c79f..28213713 100644 --- a/src/preload/index.d.ts +++ b/src/preload/index.d.ts @@ -90,6 +90,15 @@ declare global { base: KnowledgeBaseParams }) => Promise search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => Promise + rerank: ({ + search, + base, + results + }: { + search: string + base: KnowledgeBaseParams + results: ExtractChunkData[] + }) => Promise } window: { setMinimumSize: (width: number, height: number) => Promise diff --git a/src/preload/index.ts b/src/preload/index.ts index d4e4e23e..3ef2dcf7 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -1,4 +1,5 @@ import { electronAPI } from '@electron-toolkit/preload' +import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces' import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types' import { contextBridge, ipcRenderer, OpenDialogOptions, shell } from 'electron' @@ -75,7 +76,9 @@ const api = { remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) => ipcRenderer.invoke('knowledge-base:remove', { uniqueId, uniqueIds, base }), search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => - ipcRenderer.invoke('knowledge-base:search', { search, base }) + ipcRenderer.invoke('knowledge-base:search', { search, base }), + rerank: ({ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }) => + ipcRenderer.invoke('knowledge-base:rerank', { search, base, results }) }, window: { setMinimumSize: (width: number, height: number) => ipcRenderer.invoke('window:set-minimum-size', width, height), diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 3c128b45..cda55213 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -176,6 +176,10 @@ export const REASONING_REGEX = // Embedding models export const EMBEDDING_REGEX = /(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings)/i + +// Rerank models +export const RERANKING_REGEX = /(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)/i + export const NOT_SUPPORTED_REGEX = /(?:^tts|rerank|whisper|speech)/i // Tool calling models @@ -1880,6 +1884,13 @@ export function isEmbeddingModel(model: Model): boolean { return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false } +export function isRerankModel(model: Model): boolean { + if (!model) { + return false + } + return RERANKING_REGEX.test(model.id) || false +} + export function isVisionModel(model: Model): boolean { if (!model) { return false diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index bd99fcc2..a4b167bd 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -362,7 +362,11 @@ "title": "Knowledge Base", "url_added": "URL added", "url_placeholder": "Enter URL, multiple URLs separated by Enter", - "urls": "URLs" + "urls": "URLs", + "topN": "Number of results returned", + "topN_placeholder": "Not set", + "topN__too_large_or_small": "The number of results returned cannot be greater than 100 or less than 1.", + "topN_tooltip": "The number of matching results returned; the larger the value, the more matching results, but also the more tokens consumed." }, "languages": { "arabic": "Arabic", @@ -533,7 +537,9 @@ "function_calling": "Function Calling" }, "vision": "Vision", - "websearch": "WebSearch" + "websearch": "WebSearch", + "rerank_model": "Reordering Model", + "rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add." }, "navbar": { "expand": "Expand Dialog", @@ -1083,4 +1089,4 @@ "visualization": "Visualization" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index e4c17297..c214811b 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -362,7 +362,11 @@ "title": "ナレッジベース", "url_added": "URLが追加されました", "url_placeholder": "URLを入力, 複数のURLはEnterで区切る", - "urls": "URL" + "urls": "URL", + "topN": "返却される結果の数", + "topN_placeholder": "未設定", + "topN__too_large_or_small": "結果の数は100より大きくてはならず、1より小さくてはなりません。", + "topN_tooltip": "返されるマッチ結果の数は、数値が大きいほどマッチ結果が多くなりますが、消費されるトークンも増えます。" }, "languages": { "arabic": "アラビア語", @@ -533,7 +537,9 @@ "function_calling": "関数呼び出し" }, "vision": "画像", - "websearch": "ウェブ検索" + "websearch": "ウェブ検索", + "rerank_model": "再順序付けモデル", + "rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。" }, "navbar": { "expand": "ダイアログを展開", @@ -1083,4 +1089,4 @@ "visualization": "可視化" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index fcb49595..c9472887 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -362,7 +362,11 @@ "title": "База знаний", "url_added": "URL добавлен", "url_placeholder": "Введите URL, несколько URL через Enter", - "urls": "URL-адреса" + "urls": "URL-адреса", + "topN": "Количество возвращаемых результатов", + "topN_placeholder": "Не установлено", + "topN__too_large_or_small": "Количество возвращаемых результатов не может быть больше 100 или меньше 1.", + "topN_tooltip": "Количество возвращаемых совпадений; чем больше значение, тем больше совпадений, но и потребление токенов тоже возрастает." }, "languages": { "arabic": "Арабский", @@ -539,7 +543,9 @@ "function_calling": "Вызов функции" }, "vision": "Визуальные", - "websearch": "Веб-поисковые" + "websearch": "Веб-поисковые", + "rerank_model": "Модель переупорядочивания", + "rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить." }, "navbar": { "expand": "Развернуть диалоговое окно", @@ -1083,4 +1089,4 @@ "visualization": "Визуализация" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index f383de5c..3ad77645 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -359,6 +359,10 @@ "threshold_placeholder": "未设置", "threshold_too_large_or_small": "阈值不能大于1或小于0", "threshold_tooltip": "用于衡量用户问题与知识库内容之间的相关性(0-1)", + "topN": "返回结果数量", + "topN_placeholder": "未设置", + "topN__too_large_or_small": "返回结果数量不能大于100或小于1", + "topN_tooltip": "返回的匹配结果数量,数值越大,匹配结果越多,但消耗的 Token 也越多", "title": "知识库", "url_added": "网址已添加", "url_placeholder": "请输入网址, 多个网址用回车分隔", @@ -510,6 +514,8 @@ "embedding": "嵌入", "embedding_model": "嵌入模型", "embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加", + "rerank_model": "重排序模型", + "rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加", "free": "免费", "no_matches": "无可用模型", "parameter_name": "参数名称", @@ -1083,4 +1089,4 @@ "visualization": "可视化" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index 9d5234f4..e91f4fd0 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -362,7 +362,11 @@ "title": "知識庫", "url_added": "網址已新增", "url_placeholder": "請輸入網址,多個網址用換行符號分隔", - "urls": "網址" + "urls": "網址", + "topN": "返回結果數量", + "topN_placeholder": "未設定", + "topN__too_large_or_small": "返回結果數量不能大於100或小於1", + "topN_tooltip": "返回的匹配結果數量,數值越大,匹配結果越多,但消耗的 Token 也越多" }, "languages": { "arabic": "阿拉伯文", @@ -533,7 +537,9 @@ "function_calling": "函數調用" }, "vision": "視覺", - "websearch": "網路搜尋" + "websearch": "網路搜尋", + "rerank_model": "重排序模型", + "rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加" }, "navbar": { "expand": "伸縮對話框", @@ -1083,4 +1089,4 @@ "visualization": "視覺化" } } -} +} \ No newline at end of file diff --git a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx index f9e135fa..eef6b71f 100644 --- a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx +++ b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx @@ -1,5 +1,5 @@ import { TopView } from '@renderer/components/TopView' -import { isEmbeddingModel } from '@renderer/config/models' +import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { useKnowledgeBases } from '@renderer/hooks/useKnowledge' import { useProviders } from '@renderer/hooks/useProvider' import AiProvider from '@renderer/providers/AiProvider' @@ -20,6 +20,7 @@ interface ShowParams { interface FormData { name: string model: string + rerankModel: string } interface Props extends ShowParams { @@ -37,6 +38,11 @@ const PopupContainer: React.FC = ({ title, resolve }) => { .map((p) => p.models) .flat() .filter((model) => isEmbeddingModel(model)) + const rerankModels = providers + .map((p) => p.models) + .flat() + .filter((model) => isRerankModel(model)) + console.log('rerankModels', rerankModels) const nameInputRef = useRef(null) const selectOptions = providers @@ -53,10 +59,25 @@ const PopupContainer: React.FC = ({ title, resolve }) => { })) .filter((group) => group.options.length > 0) + const rerankSelectOptions = providers + .filter((p) => p.models.length > 0) + .map((p) => ({ + label: p.isSystem ? t(`provider.${p.id}`) : p.name, + title: p.name, + options: sortBy(p.models, 'name') + .filter((model) => isRerankModel(model)) + .map((m) => ({ + label: m.name, + value: getModelUniqId(m) + })) + })) + .filter((group) => group.options.length > 0) + const onOk = async () => { try { const values = await form.validateFields() const selectedModel = find(allModels, JSON.parse(values.model)) as Model + const selectedRerankModel = find(rerankModels, JSON.parse(values.rerankModel)) as Model if (selectedModel) { setLoading(true) @@ -82,6 +103,7 @@ const PopupContainer: React.FC = ({ title, resolve }) => { id: nanoid(), name: values.name, model: selectedModel, + rerankModel: selectedRerankModel, dimensions, items: [], created_at: Date.now(), @@ -134,6 +156,14 @@ const PopupContainer: React.FC = ({ title, resolve }) => { rules={[{ required: true, message: t('message.error.enter.model') }]}> + ) diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx index e1ffa3cd..af976419 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSearchPopup.tsx @@ -41,8 +41,16 @@ const PopupContainer: React.FC = ({ base, resolve }) => { search: value, base: getKnowledgeBaseParams(base) }) + let rerankResult = searchResults + if (base.rerankModel) { + rerankResult = await window.api.knowledgeBase.rerank({ + search: value, + base: getKnowledgeBaseParams(base), + results: searchResults + }) + } const results = await Promise.all( - searchResults.map(async (item) => { + rerankResult.map(async (item) => { const file = await getFileFromUrl(item.metadata.source) return { ...item, file } }) diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx index 3a1e2b18..1036bc0f 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx @@ -2,7 +2,7 @@ import { WarningOutlined } from '@ant-design/icons' import { TopView } from '@renderer/components/TopView' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { getEmbeddingMaxContext } from '@renderer/config/embedings' -import { isEmbeddingModel } from '@renderer/config/models' +import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { useKnowledge } from '@renderer/hooks/useKnowledge' import { useProviders } from '@renderer/hooks/useProvider' import { getModelUniqId } from '@renderer/services/ModelService' @@ -23,6 +23,8 @@ interface FormData { chunkSize?: number chunkOverlap?: number threshold?: number + rerankModel?: string + topN?: number } interface Props extends ShowParams { @@ -59,6 +61,20 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { })) .filter((group) => group.options.length > 0) + const rerankSelectOptions = providers + .filter((p) => p.models.length > 0) + .map((p) => ({ + label: p.isSystem ? t(`provider.${p.id}`) : p.name, + title: p.name, + options: sortBy(p.models, 'name') + .filter((model) => isRerankModel(model)) + .map((m) => ({ + label: m.name, + value: getModelUniqId(m) + })) + })) + .filter((group) => group.options.length > 0) + const onOk = async () => { try { const values = await form.validateFields() @@ -68,7 +84,11 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { documentCount: values.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, chunkSize: values.chunkSize, chunkOverlap: values.chunkOverlap, - threshold: values.threshold ?? undefined + threshold: values.threshold ?? undefined, + rerankModel: values.rerankModel + ? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel) + : undefined, + topN: values.topN } updateKnowledgeBase(newBase) setOpen(false) @@ -116,6 +136,20 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { + + = ({ base: _base, resolve }) => { ]}> + 10)) { + return Promise.reject(new Error(t('knowledge.topN_too_large_or_small'))) + } + return Promise.resolve() + } + } + ]}> + + } /> diff --git a/src/renderer/src/services/KnowledgeService.ts b/src/renderer/src/services/KnowledgeService.ts index e112107f..12542cef 100644 --- a/src/renderer/src/services/KnowledgeService.ts +++ b/src/renderer/src/services/KnowledgeService.ts @@ -39,7 +39,10 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams apiVersion: provider.apiVersion, baseURL: host, chunkSize, - chunkOverlap: base.chunkOverlap + chunkOverlap: base.chunkOverlap, + rerankModel: base.rerankModel?.id, + rerankModelProvider: base.rerankModel?.provider, + topN: base.topN } } @@ -92,8 +95,17 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me }) ) - const _searchResults = await Promise.all( - searchResults.map(async (item) => { + let rerankResults = searchResults + if (base.rerankModel) { + rerankResults = await window.api.knowledgeBase.rerank({ + search: message.content, + base: getKnowledgeBaseParams(base), + results: searchResults + }) + } + + const processdResults = await Promise.all( + rerankResults.map(async (item) => { const file = await getFileFromUrl(item.metadata.source) return { ...item, file } }) @@ -102,7 +114,7 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT const references = await Promise.all( - take(_searchResults, documentCount).map(async (item, index) => { + take(processdResults, documentCount).map(async (item, index) => { const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId) return { id: index + 1, diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index e4aba0a2..0e965888 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -254,6 +254,8 @@ export interface KnowledgeBase { chunkSize?: number chunkOverlap?: number threshold?: number + rerankModel?: Model + topN?: number } export type KnowledgeBaseParams = { @@ -265,6 +267,9 @@ export type KnowledgeBaseParams = { baseURL: string chunkSize?: number chunkOverlap?: number + rerankModel?: string + rerankModelProvider?: string + topN?: number } export type GenerateImageParams = {