feat(knowledge base): enhance knowledge base with rerank model

This commit is contained in:
eeee0717 2025-03-18 17:20:01 +08:00 committed by 亢奋猫
parent 359f6e36e9
commit b50f8a4c11
20 changed files with 297 additions and 22 deletions

View File

@ -189,6 +189,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle('knowledge-base:add', KnowledgeService.add)
ipcMain.handle('knowledge-base:remove', KnowledgeService.remove)
ipcMain.handle('knowledge-base:search', KnowledgeService.search)
ipcMain.handle('knowledge-base:rerank', KnowledgeService.rerank)
// window
ipcMain.handle('window:set-minimum-size', (_, width: number, height: number) => {

View File

@ -0,0 +1,20 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
export default abstract class BaseReranker {
protected base: KnowledgeBaseParams
constructor(base: KnowledgeBaseParams) {
if (!base.rerankModel) {
throw new Error('Rerank model is required')
}
this.base = base
}
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
public defaultHeaders() {
return {
Authorization: `Bearer ${this.base.apiKey}`,
'Content-Type': 'application/json'
}
}
}

View File

@ -0,0 +1,13 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class DefaultReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
async rerank(): Promise<ExtractChunkData[]> {
throw new Error('Method not implemented.')
}
}

View File

@ -0,0 +1,15 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import RerankerFactory from './RerankerFactory'
export default class Reranker {
private sdk: BaseReranker
constructor(base: KnowledgeBaseParams) {
this.sdk = RerankerFactory.create(base)
}
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
return this.sdk.rerank(query, searchResults)
}
}

View File

@ -0,0 +1,14 @@
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import DefaultReranker from './DefaultReranker'
import SiliconFlowReranker from './SiliconFlowReranker'
export default class RerankerFactory {
static create(base: KnowledgeBaseParams): BaseReranker {
if (base.rerankModelProvider === 'silicon') {
return new SiliconFlowReranker(base)
}
return new DefaultReranker(base)
}
}

View File

@ -0,0 +1,46 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import axios from 'axios'
import BaseReranker from './BaseReranker'
export default class SiliconFlowReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = `${this.base.baseURL}/rerank`
const { data } = await axios.post(
url,
{
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN,
max_chunks_per_doc: this.base.chunkSize,
overlap_tokens: this.base.chunkOverlap
},
{
headers: this.defaultHeaders()
}
)
const rerankResults = data.results
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.sort((a, b) => b.score - a.score)
}
}

View File

@ -23,6 +23,7 @@ import { SitemapLoader } from '@llm-tools/embedjs-loader-sitemap'
import { WebLoader } from '@llm-tools/embedjs-loader-web'
import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai'
import { addFileLoader } from '@main/loader'
import Reranker from '@main/reranker/Reranker'
import { proxyManager } from '@main/services/ProxyManager'
import { windowService } from '@main/services/WindowService'
import { getInstanceName } from '@main/utils'
@ -482,6 +483,13 @@ class KnowledgeService {
const ragApplication = await this.getRagApplication(base)
return await ragApplication.search(search)
}
public rerank = async (
_: Electron.IpcMainInvokeEvent,
{ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }
): Promise<ExtractChunkData[]> => {
return await new Reranker(base).rerank(search, results)
}
}
export default new KnowledgeService()

View File

@ -90,6 +90,15 @@ declare global {
base: KnowledgeBaseParams
}) => Promise<void>
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => Promise<ExtractChunkData[]>
rerank: ({
search,
base,
results
}: {
search: string
base: KnowledgeBaseParams
results: ExtractChunkData[]
}) => Promise<ExtractChunkData[]>
}
window: {
setMinimumSize: (width: number, height: number) => Promise<void>

View File

@ -1,4 +1,5 @@
import { electronAPI } from '@electron-toolkit/preload'
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types'
import { contextBridge, ipcRenderer, OpenDialogOptions, shell } from 'electron'
@ -75,7 +76,9 @@ const api = {
remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) =>
ipcRenderer.invoke('knowledge-base:remove', { uniqueId, uniqueIds, base }),
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) =>
ipcRenderer.invoke('knowledge-base:search', { search, base })
ipcRenderer.invoke('knowledge-base:search', { search, base }),
rerank: ({ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }) =>
ipcRenderer.invoke('knowledge-base:rerank', { search, base, results })
},
window: {
setMinimumSize: (width: number, height: number) => ipcRenderer.invoke('window:set-minimum-size', width, height),

View File

@ -176,6 +176,10 @@ export const REASONING_REGEX =
// Embedding models
export const EMBEDDING_REGEX = /(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings)/i
// Rerank models
export const RERANKING_REGEX = /(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)/i
export const NOT_SUPPORTED_REGEX = /(?:^tts|rerank|whisper|speech)/i
// Tool calling models
@ -1880,6 +1884,13 @@ export function isEmbeddingModel(model: Model): boolean {
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
}
export function isRerankModel(model: Model): boolean {
if (!model) {
return false
}
return RERANKING_REGEX.test(model.id) || false
}
export function isVisionModel(model: Model): boolean {
if (!model) {
return false

View File

@ -362,7 +362,11 @@
"title": "Knowledge Base",
"url_added": "URL added",
"url_placeholder": "Enter URL, multiple URLs separated by Enter",
"urls": "URLs"
"urls": "URLs",
"topN": "Number of results returned",
"topN_placeholder": "Not set",
"topN__too_large_or_small": "The number of results returned cannot be greater than 100 or less than 1.",
"topN_tooltip": "The number of matching results returned; the larger the value, the more matching results, but also the more tokens consumed."
},
"languages": {
"arabic": "Arabic",
@ -533,7 +537,9 @@
"function_calling": "Function Calling"
},
"vision": "Vision",
"websearch": "WebSearch"
"websearch": "WebSearch",
"rerank_model": "Reordering Model",
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add."
},
"navbar": {
"expand": "Expand Dialog",
@ -1083,4 +1089,4 @@
"visualization": "Visualization"
}
}
}
}

View File

@ -362,7 +362,11 @@
"title": "ナレッジベース",
"url_added": "URLが追加されました",
"url_placeholder": "URLを入力, 複数のURLはEnterで区切る",
"urls": "URL"
"urls": "URL",
"topN": "返却される結果の数",
"topN_placeholder": "未設定",
"topN__too_large_or_small": "結果の数は100より大きくてはならず、1より小さくてはなりません。",
"topN_tooltip": "返されるマッチ結果の数は、数値が大きいほどマッチ結果が多くなりますが、消費されるトークンも増えます。"
},
"languages": {
"arabic": "アラビア語",
@ -533,7 +537,9 @@
"function_calling": "関数呼び出し"
},
"vision": "画像",
"websearch": "ウェブ検索"
"websearch": "ウェブ検索",
"rerank_model": "再順序付けモデル",
"rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。"
},
"navbar": {
"expand": "ダイアログを展開",
@ -1083,4 +1089,4 @@
"visualization": "可視化"
}
}
}
}

View File

@ -362,7 +362,11 @@
"title": "База знаний",
"url_added": "URL добавлен",
"url_placeholder": "Введите URL, несколько URL через Enter",
"urls": "URL-адреса"
"urls": "URL-адреса",
"topN": "Количество возвращаемых результатов",
"topN_placeholder": "Не установлено",
"topN__too_large_or_small": "Количество возвращаемых результатов не может быть больше 100 или меньше 1.",
"topN_tooltip": "Количество возвращаемых совпадений; чем больше значение, тем больше совпадений, но и потребление токенов тоже возрастает."
},
"languages": {
"arabic": "Арабский",
@ -539,7 +543,9 @@
"function_calling": "Вызов функции"
},
"vision": "Визуальные",
"websearch": "Веб-поисковые"
"websearch": "Веб-поисковые",
"rerank_model": "Модель переупорядочивания",
"rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить."
},
"navbar": {
"expand": "Развернуть диалоговое окно",
@ -1083,4 +1089,4 @@
"visualization": "Визуализация"
}
}
}
}

View File

@ -359,6 +359,10 @@
"threshold_placeholder": "未设置",
"threshold_too_large_or_small": "阈值不能大于1或小于0",
"threshold_tooltip": "用于衡量用户问题与知识库内容之间的相关性0-1",
"topN": "返回结果数量",
"topN_placeholder": "未设置",
"topN__too_large_or_small": "返回结果数量不能大于100或小于1",
"topN_tooltip": "返回的匹配结果数量,数值越大,匹配结果越多,但消耗的 Token 也越多",
"title": "知识库",
"url_added": "网址已添加",
"url_placeholder": "请输入网址, 多个网址用回车分隔",
@ -510,6 +514,8 @@
"embedding": "嵌入",
"embedding_model": "嵌入模型",
"embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"rerank_model": "重排序模型",
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"free": "免费",
"no_matches": "无可用模型",
"parameter_name": "参数名称",
@ -1083,4 +1089,4 @@
"visualization": "可视化"
}
}
}
}

View File

@ -362,7 +362,11 @@
"title": "知識庫",
"url_added": "網址已新增",
"url_placeholder": "請輸入網址,多個網址用換行符號分隔",
"urls": "網址"
"urls": "網址",
"topN": "返回結果數量",
"topN_placeholder": "未設定",
"topN__too_large_or_small": "返回結果數量不能大於100或小於1",
"topN_tooltip": "返回的匹配結果數量,數值越大,匹配結果越多,但消耗的 Token 也越多"
},
"languages": {
"arabic": "阿拉伯文",
@ -533,7 +537,9 @@
"function_calling": "函數調用"
},
"vision": "視覺",
"websearch": "網路搜尋"
"websearch": "網路搜尋",
"rerank_model": "重排序模型",
"rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加"
},
"navbar": {
"expand": "伸縮對話框",
@ -1083,4 +1089,4 @@
"visualization": "視覺化"
}
}
}
}

View File

@ -1,5 +1,5 @@
import { TopView } from '@renderer/components/TopView'
import { isEmbeddingModel } from '@renderer/config/models'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider'
import AiProvider from '@renderer/providers/AiProvider'
@ -20,6 +20,7 @@ interface ShowParams {
interface FormData {
name: string
model: string
rerankModel: string
}
interface Props extends ShowParams {
@ -37,6 +38,11 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
.map((p) => p.models)
.flat()
.filter((model) => isEmbeddingModel(model))
const rerankModels = providers
.map((p) => p.models)
.flat()
.filter((model) => isRerankModel(model))
console.log('rerankModels', rerankModels)
const nameInputRef = useRef<any>(null)
const selectOptions = providers
@ -53,10 +59,25 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
}))
.filter((group) => group.options.length > 0)
const rerankSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const onOk = async () => {
try {
const values = await form.validateFields()
const selectedModel = find(allModels, JSON.parse(values.model)) as Model
const selectedRerankModel = find(rerankModels, JSON.parse(values.rerankModel)) as Model
if (selectedModel) {
setLoading(true)
@ -82,6 +103,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
id: nanoid(),
name: values.name,
model: selectedModel,
rerankModel: selectedRerankModel,
dimensions,
items: [],
created_at: Date.now(),
@ -134,6 +156,14 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
rules={[{ required: true, message: t('message.error.enter.model') }]}>
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} />
</Form.Item>
<Form.Item
name="rerankModel"
label={t('models.rerank_model')}
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
rules={[{ required: false, message: t('message.error.enter.model') }]}>
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
</Form.Item>
</Form>
</Modal>
)

View File

@ -41,8 +41,16 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
search: value,
base: getKnowledgeBaseParams(base)
})
let rerankResult = searchResults
if (base.rerankModel) {
rerankResult = await window.api.knowledgeBase.rerank({
search: value,
base: getKnowledgeBaseParams(base),
results: searchResults
})
}
const results = await Promise.all(
searchResults.map(async (item) => {
rerankResult.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})

View File

@ -2,7 +2,7 @@ import { WarningOutlined } from '@ant-design/icons'
import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { isEmbeddingModel } from '@renderer/config/models'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useKnowledge } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
@ -23,6 +23,8 @@ interface FormData {
chunkSize?: number
chunkOverlap?: number
threshold?: number
rerankModel?: string
topN?: number
}
interface Props extends ShowParams {
@ -59,6 +61,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
}))
.filter((group) => group.options.length > 0)
const rerankSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const onOk = async () => {
try {
const values = await form.validateFields()
@ -68,7 +84,11 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
documentCount: values.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT,
chunkSize: values.chunkSize,
chunkOverlap: values.chunkOverlap,
threshold: values.threshold ?? undefined
threshold: values.threshold ?? undefined,
rerankModel: values.rerankModel
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
: undefined,
topN: values.topN
}
updateKnowledgeBase(newBase)
setOpen(false)
@ -116,6 +136,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} disabled />
</Form.Item>
<Form.Item
name="rerankModel"
label={t('models.rerank_model')}
initialValue={getModelUniqId(base.rerankModel)}
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
rules={[{ required: false, message: t('message.error.enter.model') }]}>
<Select
style={{ width: '100%' }}
options={rerankSelectOptions}
placeholder={t('settings.models.empty')}
allowClear
/>
</Form.Item>
<Form.Item
name="documentCount"
label={t('knowledge.document_count')}
@ -193,6 +227,22 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
]}>
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
</Form.Item>
<Form.Item
name="topN"
label={t('knowledge.topN')}
initialValue={base.topN}
rules={[
{
validator(_, value) {
if (value && (value < 0 || value > 10)) {
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
}
return Promise.resolve()
}
}
]}>
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
</Form.Item>
</Form>
<Alert message={t('knowledge.chunk_size_change_warning')} type="warning" showIcon icon={<WarningOutlined />} />
</Modal>

View File

@ -39,7 +39,10 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
apiVersion: provider.apiVersion,
baseURL: host,
chunkSize,
chunkOverlap: base.chunkOverlap
chunkOverlap: base.chunkOverlap,
rerankModel: base.rerankModel?.id,
rerankModelProvider: base.rerankModel?.provider,
topN: base.topN
}
}
@ -92,8 +95,17 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
})
)
const _searchResults = await Promise.all(
searchResults.map(async (item) => {
let rerankResults = searchResults
if (base.rerankModel) {
rerankResults = await window.api.knowledgeBase.rerank({
search: message.content,
base: getKnowledgeBaseParams(base),
results: searchResults
})
}
const processdResults = await Promise.all(
rerankResults.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
return { ...item, file }
})
@ -102,7 +114,7 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
const references = await Promise.all(
take(_searchResults, documentCount).map(async (item, index) => {
take(processdResults, documentCount).map(async (item, index) => {
const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
return {
id: index + 1,

View File

@ -254,6 +254,8 @@ export interface KnowledgeBase {
chunkSize?: number
chunkOverlap?: number
threshold?: number
rerankModel?: Model
topN?: number
}
export type KnowledgeBaseParams = {
@ -265,6 +267,9 @@ export type KnowledgeBaseParams = {
baseURL: string
chunkSize?: number
chunkOverlap?: number
rerankModel?: string
rerankModelProvider?: string
topN?: number
}
export type GenerateImageParams = {