feat(knowledge base): enhance knowledge base with rerank model

This commit is contained in:
eeee0717 2025-03-18 17:20:01 +08:00 committed by 亢奋猫
parent 359f6e36e9
commit b50f8a4c11
20 changed files with 297 additions and 22 deletions

View File

@ -189,6 +189,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle('knowledge-base:add', KnowledgeService.add) ipcMain.handle('knowledge-base:add', KnowledgeService.add)
ipcMain.handle('knowledge-base:remove', KnowledgeService.remove) ipcMain.handle('knowledge-base:remove', KnowledgeService.remove)
ipcMain.handle('knowledge-base:search', KnowledgeService.search) ipcMain.handle('knowledge-base:search', KnowledgeService.search)
ipcMain.handle('knowledge-base:rerank', KnowledgeService.rerank)
// window // window
ipcMain.handle('window:set-minimum-size', (_, width: number, height: number) => { ipcMain.handle('window:set-minimum-size', (_, width: number, height: number) => {

View File

@ -0,0 +1,20 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
export default abstract class BaseReranker {
protected base: KnowledgeBaseParams
constructor(base: KnowledgeBaseParams) {
if (!base.rerankModel) {
throw new Error('Rerank model is required')
}
this.base = base
}
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
public defaultHeaders() {
return {
Authorization: `Bearer ${this.base.apiKey}`,
'Content-Type': 'application/json'
}
}
}

View File

@ -0,0 +1,13 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class DefaultReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
async rerank(): Promise<ExtractChunkData[]> {
throw new Error('Method not implemented.')
}
}

View File

@ -0,0 +1,15 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import RerankerFactory from './RerankerFactory'
export default class Reranker {
private sdk: BaseReranker
constructor(base: KnowledgeBaseParams) {
this.sdk = RerankerFactory.create(base)
}
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
return this.sdk.rerank(query, searchResults)
}
}

View File

@ -0,0 +1,14 @@
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import DefaultReranker from './DefaultReranker'
import SiliconFlowReranker from './SiliconFlowReranker'
export default class RerankerFactory {
static create(base: KnowledgeBaseParams): BaseReranker {
if (base.rerankModelProvider === 'silicon') {
return new SiliconFlowReranker(base)
}
return new DefaultReranker(base)
}
}

View File

@ -0,0 +1,46 @@
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import axios from 'axios'
import BaseReranker from './BaseReranker'
export default class SiliconFlowReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = `${this.base.baseURL}/rerank`
const { data } = await axios.post(
url,
{
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN,
max_chunks_per_doc: this.base.chunkSize,
overlap_tokens: this.base.chunkOverlap
},
{
headers: this.defaultHeaders()
}
)
const rerankResults = data.results
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.sort((a, b) => b.score - a.score)
}
}

View File

@ -23,6 +23,7 @@ import { SitemapLoader } from '@llm-tools/embedjs-loader-sitemap'
import { WebLoader } from '@llm-tools/embedjs-loader-web' import { WebLoader } from '@llm-tools/embedjs-loader-web'
import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai' import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai'
import { addFileLoader } from '@main/loader' import { addFileLoader } from '@main/loader'
import Reranker from '@main/reranker/Reranker'
import { proxyManager } from '@main/services/ProxyManager' import { proxyManager } from '@main/services/ProxyManager'
import { windowService } from '@main/services/WindowService' import { windowService } from '@main/services/WindowService'
import { getInstanceName } from '@main/utils' import { getInstanceName } from '@main/utils'
@ -482,6 +483,13 @@ class KnowledgeService {
const ragApplication = await this.getRagApplication(base) const ragApplication = await this.getRagApplication(base)
return await ragApplication.search(search) return await ragApplication.search(search)
} }
public rerank = async (
_: Electron.IpcMainInvokeEvent,
{ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }
): Promise<ExtractChunkData[]> => {
return await new Reranker(base).rerank(search, results)
}
} }
export default new KnowledgeService() export default new KnowledgeService()

View File

@ -90,6 +90,15 @@ declare global {
base: KnowledgeBaseParams base: KnowledgeBaseParams
}) => Promise<void> }) => Promise<void>
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => Promise<ExtractChunkData[]> search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => Promise<ExtractChunkData[]>
rerank: ({
search,
base,
results
}: {
search: string
base: KnowledgeBaseParams
results: ExtractChunkData[]
}) => Promise<ExtractChunkData[]>
} }
window: { window: {
setMinimumSize: (width: number, height: number) => Promise<void> setMinimumSize: (width: number, height: number) => Promise<void>

View File

@ -1,4 +1,5 @@
import { electronAPI } from '@electron-toolkit/preload' import { electronAPI } from '@electron-toolkit/preload'
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types' import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types'
import { contextBridge, ipcRenderer, OpenDialogOptions, shell } from 'electron' import { contextBridge, ipcRenderer, OpenDialogOptions, shell } from 'electron'
@ -75,7 +76,9 @@ const api = {
remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) => remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) =>
ipcRenderer.invoke('knowledge-base:remove', { uniqueId, uniqueIds, base }), ipcRenderer.invoke('knowledge-base:remove', { uniqueId, uniqueIds, base }),
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) =>
ipcRenderer.invoke('knowledge-base:search', { search, base }) ipcRenderer.invoke('knowledge-base:search', { search, base }),
rerank: ({ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }) =>
ipcRenderer.invoke('knowledge-base:rerank', { search, base, results })
}, },
window: { window: {
setMinimumSize: (width: number, height: number) => ipcRenderer.invoke('window:set-minimum-size', width, height), setMinimumSize: (width: number, height: number) => ipcRenderer.invoke('window:set-minimum-size', width, height),

View File

@ -176,6 +176,10 @@ export const REASONING_REGEX =
// Embedding models // Embedding models
export const EMBEDDING_REGEX = /(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings)/i export const EMBEDDING_REGEX = /(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings)/i
// Rerank models
export const RERANKING_REGEX = /(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)/i
export const NOT_SUPPORTED_REGEX = /(?:^tts|rerank|whisper|speech)/i export const NOT_SUPPORTED_REGEX = /(?:^tts|rerank|whisper|speech)/i
// Tool calling models // Tool calling models
@ -1880,6 +1884,13 @@ export function isEmbeddingModel(model: Model): boolean {
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
} }
export function isRerankModel(model: Model): boolean {
if (!model) {
return false
}
return RERANKING_REGEX.test(model.id) || false
}
export function isVisionModel(model: Model): boolean { export function isVisionModel(model: Model): boolean {
if (!model) { if (!model) {
return false return false

View File

@ -362,7 +362,11 @@
"title": "Knowledge Base", "title": "Knowledge Base",
"url_added": "URL added", "url_added": "URL added",
"url_placeholder": "Enter URL, multiple URLs separated by Enter", "url_placeholder": "Enter URL, multiple URLs separated by Enter",
"urls": "URLs" "urls": "URLs",
"topN": "Number of results returned",
"topN_placeholder": "Not set",
"topN__too_large_or_small": "The number of results returned cannot be greater than 100 or less than 1.",
"topN_tooltip": "The number of matching results returned; the larger the value, the more matching results, but also the more tokens consumed."
}, },
"languages": { "languages": {
"arabic": "Arabic", "arabic": "Arabic",
@ -533,7 +537,9 @@
"function_calling": "Function Calling" "function_calling": "Function Calling"
}, },
"vision": "Vision", "vision": "Vision",
"websearch": "WebSearch" "websearch": "WebSearch",
"rerank_model": "Reordering Model",
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add."
}, },
"navbar": { "navbar": {
"expand": "Expand Dialog", "expand": "Expand Dialog",
@ -1083,4 +1089,4 @@
"visualization": "Visualization" "visualization": "Visualization"
} }
} }
} }

View File

@ -362,7 +362,11 @@
"title": "ナレッジベース", "title": "ナレッジベース",
"url_added": "URLが追加されました", "url_added": "URLが追加されました",
"url_placeholder": "URLを入力, 複数のURLはEnterで区切る", "url_placeholder": "URLを入力, 複数のURLはEnterで区切る",
"urls": "URL" "urls": "URL",
"topN": "返却される結果の数",
"topN_placeholder": "未設定",
"topN__too_large_or_small": "結果の数は100より大きくてはならず、1より小さくてはなりません。",
"topN_tooltip": "返されるマッチ結果の数は、数値が大きいほどマッチ結果が多くなりますが、消費されるトークンも増えます。"
}, },
"languages": { "languages": {
"arabic": "アラビア語", "arabic": "アラビア語",
@ -533,7 +537,9 @@
"function_calling": "関数呼び出し" "function_calling": "関数呼び出し"
}, },
"vision": "画像", "vision": "画像",
"websearch": "ウェブ検索" "websearch": "ウェブ検索",
"rerank_model": "再順序付けモデル",
"rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。"
}, },
"navbar": { "navbar": {
"expand": "ダイアログを展開", "expand": "ダイアログを展開",
@ -1083,4 +1089,4 @@
"visualization": "可視化" "visualization": "可視化"
} }
} }
} }

View File

@ -362,7 +362,11 @@
"title": "База знаний", "title": "База знаний",
"url_added": "URL добавлен", "url_added": "URL добавлен",
"url_placeholder": "Введите URL, несколько URL через Enter", "url_placeholder": "Введите URL, несколько URL через Enter",
"urls": "URL-адреса" "urls": "URL-адреса",
"topN": "Количество возвращаемых результатов",
"topN_placeholder": "Не установлено",
"topN__too_large_or_small": "Количество возвращаемых результатов не может быть больше 100 или меньше 1.",
"topN_tooltip": "Количество возвращаемых совпадений; чем больше значение, тем больше совпадений, но и потребление токенов тоже возрастает."
}, },
"languages": { "languages": {
"arabic": "Арабский", "arabic": "Арабский",
@ -539,7 +543,9 @@
"function_calling": "Вызов функции" "function_calling": "Вызов функции"
}, },
"vision": "Визуальные", "vision": "Визуальные",
"websearch": "Веб-поисковые" "websearch": "Веб-поисковые",
"rerank_model": "Модель переупорядочивания",
"rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить."
}, },
"navbar": { "navbar": {
"expand": "Развернуть диалоговое окно", "expand": "Развернуть диалоговое окно",
@ -1083,4 +1089,4 @@
"visualization": "Визуализация" "visualization": "Визуализация"
} }
} }
} }

View File

@ -359,6 +359,10 @@
"threshold_placeholder": "未设置", "threshold_placeholder": "未设置",
"threshold_too_large_or_small": "阈值不能大于1或小于0", "threshold_too_large_or_small": "阈值不能大于1或小于0",
"threshold_tooltip": "用于衡量用户问题与知识库内容之间的相关性0-1", "threshold_tooltip": "用于衡量用户问题与知识库内容之间的相关性0-1",
"topN": "返回结果数量",
"topN_placeholder": "未设置",
"topN__too_large_or_small": "返回结果数量不能大于100或小于1",
"topN_tooltip": "返回的匹配结果数量,数值越大,匹配结果越多,但消耗的 Token 也越多",
"title": "知识库", "title": "知识库",
"url_added": "网址已添加", "url_added": "网址已添加",
"url_placeholder": "请输入网址, 多个网址用回车分隔", "url_placeholder": "请输入网址, 多个网址用回车分隔",
@ -510,6 +514,8 @@
"embedding": "嵌入", "embedding": "嵌入",
"embedding_model": "嵌入模型", "embedding_model": "嵌入模型",
"embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加", "embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"rerank_model": "重排序模型",
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"free": "免费", "free": "免费",
"no_matches": "无可用模型", "no_matches": "无可用模型",
"parameter_name": "参数名称", "parameter_name": "参数名称",
@ -1083,4 +1089,4 @@
"visualization": "可视化" "visualization": "可视化"
} }
} }
} }

View File

@ -362,7 +362,11 @@
"title": "知識庫", "title": "知識庫",
"url_added": "網址已新增", "url_added": "網址已新增",
"url_placeholder": "請輸入網址,多個網址用換行符號分隔", "url_placeholder": "請輸入網址,多個網址用換行符號分隔",
"urls": "網址" "urls": "網址",
"topN": "返回結果數量",
"topN_placeholder": "未設定",
"topN__too_large_or_small": "返回結果數量不能大於100或小於1",
"topN_tooltip": "返回的匹配結果數量,數值越大,匹配結果越多,但消耗的 Token 也越多"
}, },
"languages": { "languages": {
"arabic": "阿拉伯文", "arabic": "阿拉伯文",
@ -533,7 +537,9 @@
"function_calling": "函數調用" "function_calling": "函數調用"
}, },
"vision": "視覺", "vision": "視覺",
"websearch": "網路搜尋" "websearch": "網路搜尋",
"rerank_model": "重排序模型",
"rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加"
}, },
"navbar": { "navbar": {
"expand": "伸縮對話框", "expand": "伸縮對話框",
@ -1083,4 +1089,4 @@
"visualization": "視覺化" "visualization": "視覺化"
} }
} }
} }

View File

@ -1,5 +1,5 @@
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { isEmbeddingModel } from '@renderer/config/models' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge' import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import AiProvider from '@renderer/providers/AiProvider' import AiProvider from '@renderer/providers/AiProvider'
@ -20,6 +20,7 @@ interface ShowParams {
interface FormData { interface FormData {
name: string name: string
model: string model: string
rerankModel: string
} }
interface Props extends ShowParams { interface Props extends ShowParams {
@ -37,6 +38,11 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
.map((p) => p.models) .map((p) => p.models)
.flat() .flat()
.filter((model) => isEmbeddingModel(model)) .filter((model) => isEmbeddingModel(model))
const rerankModels = providers
.map((p) => p.models)
.flat()
.filter((model) => isRerankModel(model))
console.log('rerankModels', rerankModels)
const nameInputRef = useRef<any>(null) const nameInputRef = useRef<any>(null)
const selectOptions = providers const selectOptions = providers
@ -53,10 +59,25 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
})) }))
.filter((group) => group.options.length > 0) .filter((group) => group.options.length > 0)
const rerankSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const onOk = async () => { const onOk = async () => {
try { try {
const values = await form.validateFields() const values = await form.validateFields()
const selectedModel = find(allModels, JSON.parse(values.model)) as Model const selectedModel = find(allModels, JSON.parse(values.model)) as Model
const selectedRerankModel = find(rerankModels, JSON.parse(values.rerankModel)) as Model
if (selectedModel) { if (selectedModel) {
setLoading(true) setLoading(true)
@ -82,6 +103,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
id: nanoid(), id: nanoid(),
name: values.name, name: values.name,
model: selectedModel, model: selectedModel,
rerankModel: selectedRerankModel,
dimensions, dimensions,
items: [], items: [],
created_at: Date.now(), created_at: Date.now(),
@ -134,6 +156,14 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
rules={[{ required: true, message: t('message.error.enter.model') }]}> rules={[{ required: true, message: t('message.error.enter.model') }]}>
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} /> <Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} />
</Form.Item> </Form.Item>
<Form.Item
name="rerankModel"
label={t('models.rerank_model')}
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
rules={[{ required: false, message: t('message.error.enter.model') }]}>
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
</Form.Item>
</Form> </Form>
</Modal> </Modal>
) )

View File

@ -41,8 +41,16 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
search: value, search: value,
base: getKnowledgeBaseParams(base) base: getKnowledgeBaseParams(base)
}) })
let rerankResult = searchResults
if (base.rerankModel) {
rerankResult = await window.api.knowledgeBase.rerank({
search: value,
base: getKnowledgeBaseParams(base),
results: searchResults
})
}
const results = await Promise.all( const results = await Promise.all(
searchResults.map(async (item) => { rerankResult.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source) const file = await getFileFromUrl(item.metadata.source)
return { ...item, file } return { ...item, file }
}) })

View File

@ -2,7 +2,7 @@ import { WarningOutlined } from '@ant-design/icons'
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings' import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { isEmbeddingModel } from '@renderer/config/models' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useKnowledge } from '@renderer/hooks/useKnowledge' import { useKnowledge } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService' import { getModelUniqId } from '@renderer/services/ModelService'
@ -23,6 +23,8 @@ interface FormData {
chunkSize?: number chunkSize?: number
chunkOverlap?: number chunkOverlap?: number
threshold?: number threshold?: number
rerankModel?: string
topN?: number
} }
interface Props extends ShowParams { interface Props extends ShowParams {
@ -59,6 +61,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
})) }))
.filter((group) => group.options.length > 0) .filter((group) => group.options.length > 0)
const rerankSelectOptions = providers
.filter((p) => p.models.length > 0)
.map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name,
options: sortBy(p.models, 'name')
.filter((model) => isRerankModel(model))
.map((m) => ({
label: m.name,
value: getModelUniqId(m)
}))
}))
.filter((group) => group.options.length > 0)
const onOk = async () => { const onOk = async () => {
try { try {
const values = await form.validateFields() const values = await form.validateFields()
@ -68,7 +84,11 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
documentCount: values.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, documentCount: values.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT,
chunkSize: values.chunkSize, chunkSize: values.chunkSize,
chunkOverlap: values.chunkOverlap, chunkOverlap: values.chunkOverlap,
threshold: values.threshold ?? undefined threshold: values.threshold ?? undefined,
rerankModel: values.rerankModel
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
: undefined,
topN: values.topN
} }
updateKnowledgeBase(newBase) updateKnowledgeBase(newBase)
setOpen(false) setOpen(false)
@ -116,6 +136,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} disabled /> <Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} disabled />
</Form.Item> </Form.Item>
<Form.Item
name="rerankModel"
label={t('models.rerank_model')}
initialValue={getModelUniqId(base.rerankModel)}
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
rules={[{ required: false, message: t('message.error.enter.model') }]}>
<Select
style={{ width: '100%' }}
options={rerankSelectOptions}
placeholder={t('settings.models.empty')}
allowClear
/>
</Form.Item>
<Form.Item <Form.Item
name="documentCount" name="documentCount"
label={t('knowledge.document_count')} label={t('knowledge.document_count')}
@ -193,6 +227,22 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
]}> ]}>
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} /> <InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
</Form.Item> </Form.Item>
<Form.Item
name="topN"
label={t('knowledge.topN')}
initialValue={base.topN}
rules={[
{
validator(_, value) {
if (value && (value < 0 || value > 10)) {
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
}
return Promise.resolve()
}
}
]}>
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
</Form.Item>
</Form> </Form>
<Alert message={t('knowledge.chunk_size_change_warning')} type="warning" showIcon icon={<WarningOutlined />} /> <Alert message={t('knowledge.chunk_size_change_warning')} type="warning" showIcon icon={<WarningOutlined />} />
</Modal> </Modal>

View File

@ -39,7 +39,10 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
apiVersion: provider.apiVersion, apiVersion: provider.apiVersion,
baseURL: host, baseURL: host,
chunkSize, chunkSize,
chunkOverlap: base.chunkOverlap chunkOverlap: base.chunkOverlap,
rerankModel: base.rerankModel?.id,
rerankModelProvider: base.rerankModel?.provider,
topN: base.topN
} }
} }
@ -92,8 +95,17 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
}) })
) )
const _searchResults = await Promise.all( let rerankResults = searchResults
searchResults.map(async (item) => { if (base.rerankModel) {
rerankResults = await window.api.knowledgeBase.rerank({
search: message.content,
base: getKnowledgeBaseParams(base),
results: searchResults
})
}
const processdResults = await Promise.all(
rerankResults.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source) const file = await getFileFromUrl(item.metadata.source)
return { ...item, file } return { ...item, file }
}) })
@ -102,7 +114,7 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
const references = await Promise.all( const references = await Promise.all(
take(_searchResults, documentCount).map(async (item, index) => { take(processdResults, documentCount).map(async (item, index) => {
const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId) const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
return { return {
id: index + 1, id: index + 1,

View File

@ -254,6 +254,8 @@ export interface KnowledgeBase {
chunkSize?: number chunkSize?: number
chunkOverlap?: number chunkOverlap?: number
threshold?: number threshold?: number
rerankModel?: Model
topN?: number
} }
export type KnowledgeBaseParams = { export type KnowledgeBaseParams = {
@ -265,6 +267,9 @@ export type KnowledgeBaseParams = {
baseURL: string baseURL: string
chunkSize?: number chunkSize?: number
chunkOverlap?: number chunkOverlap?: number
rerankModel?: string
rerankModelProvider?: string
topN?: number
} }
export type GenerateImageParams = { export type GenerateImageParams = {