feat(knowledge base): enhance knowledge base with rerank model
This commit is contained in:
parent
359f6e36e9
commit
b50f8a4c11
@ -189,6 +189,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle('knowledge-base:add', KnowledgeService.add)
|
||||
ipcMain.handle('knowledge-base:remove', KnowledgeService.remove)
|
||||
ipcMain.handle('knowledge-base:search', KnowledgeService.search)
|
||||
ipcMain.handle('knowledge-base:rerank', KnowledgeService.rerank)
|
||||
|
||||
// window
|
||||
ipcMain.handle('window:set-minimum-size', (_, width: number, height: number) => {
|
||||
|
||||
20
src/main/reranker/BaseReranker.ts
Normal file
20
src/main/reranker/BaseReranker.ts
Normal file
@ -0,0 +1,20 @@
|
||||
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
export default abstract class BaseReranker {
|
||||
protected base: KnowledgeBaseParams
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
if (!base.rerankModel) {
|
||||
throw new Error('Rerank model is required')
|
||||
}
|
||||
this.base = base
|
||||
}
|
||||
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
|
||||
|
||||
public defaultHeaders() {
|
||||
return {
|
||||
Authorization: `Bearer ${this.base.apiKey}`,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
}
|
||||
}
|
||||
13
src/main/reranker/DefaultReranker.ts
Normal file
13
src/main/reranker/DefaultReranker.ts
Normal file
@ -0,0 +1,13 @@
|
||||
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class DefaultReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
async rerank(): Promise<ExtractChunkData[]> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
}
|
||||
15
src/main/reranker/Reranker.ts
Normal file
15
src/main/reranker/Reranker.ts
Normal file
@ -0,0 +1,15 @@
|
||||
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import RerankerFactory from './RerankerFactory'
|
||||
|
||||
export default class Reranker {
|
||||
private sdk: BaseReranker
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
this.sdk = RerankerFactory.create(base)
|
||||
}
|
||||
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
|
||||
return this.sdk.rerank(query, searchResults)
|
||||
}
|
||||
}
|
||||
14
src/main/reranker/RerankerFactory.ts
Normal file
14
src/main/reranker/RerankerFactory.ts
Normal file
@ -0,0 +1,14 @@
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import DefaultReranker from './DefaultReranker'
|
||||
import SiliconFlowReranker from './SiliconFlowReranker'
|
||||
|
||||
export default class RerankerFactory {
|
||||
static create(base: KnowledgeBaseParams): BaseReranker {
|
||||
if (base.rerankModelProvider === 'silicon') {
|
||||
return new SiliconFlowReranker(base)
|
||||
}
|
||||
return new DefaultReranker(base)
|
||||
}
|
||||
}
|
||||
46
src/main/reranker/SiliconFlowReranker.ts
Normal file
46
src/main/reranker/SiliconFlowReranker.ts
Normal file
@ -0,0 +1,46 @@
|
||||
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
import axios from 'axios'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class SiliconFlowReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const url = `${this.base.baseURL}/rerank`
|
||||
|
||||
const { data } = await axios.post(
|
||||
url,
|
||||
{
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_n: this.base.topN,
|
||||
max_chunks_per_doc: this.base.chunkSize,
|
||||
overlap_tokens: this.base.chunkOverlap
|
||||
},
|
||||
{
|
||||
headers: this.defaultHeaders()
|
||||
}
|
||||
)
|
||||
|
||||
const rerankResults = data.results
|
||||
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
||||
|
||||
return searchResults
|
||||
.map((doc: ExtractChunkData, index: number) => {
|
||||
const score = resultMap.get(index)
|
||||
if (score === undefined) return undefined
|
||||
|
||||
return {
|
||||
...doc,
|
||||
score
|
||||
}
|
||||
})
|
||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||
.sort((a, b) => b.score - a.score)
|
||||
}
|
||||
}
|
||||
@ -23,6 +23,7 @@ import { SitemapLoader } from '@llm-tools/embedjs-loader-sitemap'
|
||||
import { WebLoader } from '@llm-tools/embedjs-loader-web'
|
||||
import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai'
|
||||
import { addFileLoader } from '@main/loader'
|
||||
import Reranker from '@main/reranker/Reranker'
|
||||
import { proxyManager } from '@main/services/ProxyManager'
|
||||
import { windowService } from '@main/services/WindowService'
|
||||
import { getInstanceName } from '@main/utils'
|
||||
@ -482,6 +483,13 @@ class KnowledgeService {
|
||||
const ragApplication = await this.getRagApplication(base)
|
||||
return await ragApplication.search(search)
|
||||
}
|
||||
|
||||
public rerank = async (
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
{ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }
|
||||
): Promise<ExtractChunkData[]> => {
|
||||
return await new Reranker(base).rerank(search, results)
|
||||
}
|
||||
}
|
||||
|
||||
export default new KnowledgeService()
|
||||
|
||||
9
src/preload/index.d.ts
vendored
9
src/preload/index.d.ts
vendored
@ -90,6 +90,15 @@ declare global {
|
||||
base: KnowledgeBaseParams
|
||||
}) => Promise<void>
|
||||
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => Promise<ExtractChunkData[]>
|
||||
rerank: ({
|
||||
search,
|
||||
base,
|
||||
results
|
||||
}: {
|
||||
search: string
|
||||
base: KnowledgeBaseParams
|
||||
results: ExtractChunkData[]
|
||||
}) => Promise<ExtractChunkData[]>
|
||||
}
|
||||
window: {
|
||||
setMinimumSize: (width: number, height: number) => Promise<void>
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import { electronAPI } from '@electron-toolkit/preload'
|
||||
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types'
|
||||
import { contextBridge, ipcRenderer, OpenDialogOptions, shell } from 'electron'
|
||||
|
||||
@ -75,7 +76,9 @@ const api = {
|
||||
remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) =>
|
||||
ipcRenderer.invoke('knowledge-base:remove', { uniqueId, uniqueIds, base }),
|
||||
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) =>
|
||||
ipcRenderer.invoke('knowledge-base:search', { search, base })
|
||||
ipcRenderer.invoke('knowledge-base:search', { search, base }),
|
||||
rerank: ({ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }) =>
|
||||
ipcRenderer.invoke('knowledge-base:rerank', { search, base, results })
|
||||
},
|
||||
window: {
|
||||
setMinimumSize: (width: number, height: number) => ipcRenderer.invoke('window:set-minimum-size', width, height),
|
||||
|
||||
@ -176,6 +176,10 @@ export const REASONING_REGEX =
|
||||
|
||||
// Embedding models
|
||||
export const EMBEDDING_REGEX = /(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings)/i
|
||||
|
||||
// Rerank models
|
||||
export const RERANKING_REGEX = /(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)/i
|
||||
|
||||
export const NOT_SUPPORTED_REGEX = /(?:^tts|rerank|whisper|speech)/i
|
||||
|
||||
// Tool calling models
|
||||
@ -1880,6 +1884,13 @@ export function isEmbeddingModel(model: Model): boolean {
|
||||
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
|
||||
}
|
||||
|
||||
export function isRerankModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
return RERANKING_REGEX.test(model.id) || false
|
||||
}
|
||||
|
||||
export function isVisionModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
|
||||
@ -362,7 +362,11 @@
|
||||
"title": "Knowledge Base",
|
||||
"url_added": "URL added",
|
||||
"url_placeholder": "Enter URL, multiple URLs separated by Enter",
|
||||
"urls": "URLs"
|
||||
"urls": "URLs",
|
||||
"topN": "Number of results returned",
|
||||
"topN_placeholder": "Not set",
|
||||
"topN__too_large_or_small": "The number of results returned cannot be greater than 100 or less than 1.",
|
||||
"topN_tooltip": "The number of matching results returned; the larger the value, the more matching results, but also the more tokens consumed."
|
||||
},
|
||||
"languages": {
|
||||
"arabic": "Arabic",
|
||||
@ -533,7 +537,9 @@
|
||||
"function_calling": "Function Calling"
|
||||
},
|
||||
"vision": "Vision",
|
||||
"websearch": "WebSearch"
|
||||
"websearch": "WebSearch",
|
||||
"rerank_model": "Reordering Model",
|
||||
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add."
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "Expand Dialog",
|
||||
|
||||
@ -362,7 +362,11 @@
|
||||
"title": "ナレッジベース",
|
||||
"url_added": "URLが追加されました",
|
||||
"url_placeholder": "URLを入力, 複数のURLはEnterで区切る",
|
||||
"urls": "URL"
|
||||
"urls": "URL",
|
||||
"topN": "返却される結果の数",
|
||||
"topN_placeholder": "未設定",
|
||||
"topN__too_large_or_small": "結果の数は100より大きくてはならず、1より小さくてはなりません。",
|
||||
"topN_tooltip": "返されるマッチ結果の数は、数値が大きいほどマッチ結果が多くなりますが、消費されるトークンも増えます。"
|
||||
},
|
||||
"languages": {
|
||||
"arabic": "アラビア語",
|
||||
@ -533,7 +537,9 @@
|
||||
"function_calling": "関数呼び出し"
|
||||
},
|
||||
"vision": "画像",
|
||||
"websearch": "ウェブ検索"
|
||||
"websearch": "ウェブ検索",
|
||||
"rerank_model": "再順序付けモデル",
|
||||
"rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "ダイアログを展開",
|
||||
|
||||
@ -362,7 +362,11 @@
|
||||
"title": "База знаний",
|
||||
"url_added": "URL добавлен",
|
||||
"url_placeholder": "Введите URL, несколько URL через Enter",
|
||||
"urls": "URL-адреса"
|
||||
"urls": "URL-адреса",
|
||||
"topN": "Количество возвращаемых результатов",
|
||||
"topN_placeholder": "Не установлено",
|
||||
"topN__too_large_or_small": "Количество возвращаемых результатов не может быть больше 100 или меньше 1.",
|
||||
"topN_tooltip": "Количество возвращаемых совпадений; чем больше значение, тем больше совпадений, но и потребление токенов тоже возрастает."
|
||||
},
|
||||
"languages": {
|
||||
"arabic": "Арабский",
|
||||
@ -539,7 +543,9 @@
|
||||
"function_calling": "Вызов функции"
|
||||
},
|
||||
"vision": "Визуальные",
|
||||
"websearch": "Веб-поисковые"
|
||||
"websearch": "Веб-поисковые",
|
||||
"rerank_model": "Модель переупорядочивания",
|
||||
"rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить."
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "Развернуть диалоговое окно",
|
||||
|
||||
@ -359,6 +359,10 @@
|
||||
"threshold_placeholder": "未设置",
|
||||
"threshold_too_large_or_small": "阈值不能大于1或小于0",
|
||||
"threshold_tooltip": "用于衡量用户问题与知识库内容之间的相关性(0-1)",
|
||||
"topN": "返回结果数量",
|
||||
"topN_placeholder": "未设置",
|
||||
"topN__too_large_or_small": "返回结果数量不能大于100或小于1",
|
||||
"topN_tooltip": "返回的匹配结果数量,数值越大,匹配结果越多,但消耗的 Token 也越多",
|
||||
"title": "知识库",
|
||||
"url_added": "网址已添加",
|
||||
"url_placeholder": "请输入网址, 多个网址用回车分隔",
|
||||
@ -510,6 +514,8 @@
|
||||
"embedding": "嵌入",
|
||||
"embedding_model": "嵌入模型",
|
||||
"embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
||||
"rerank_model": "重排序模型",
|
||||
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
||||
"free": "免费",
|
||||
"no_matches": "无可用模型",
|
||||
"parameter_name": "参数名称",
|
||||
|
||||
@ -362,7 +362,11 @@
|
||||
"title": "知識庫",
|
||||
"url_added": "網址已新增",
|
||||
"url_placeholder": "請輸入網址,多個網址用換行符號分隔",
|
||||
"urls": "網址"
|
||||
"urls": "網址",
|
||||
"topN": "返回結果數量",
|
||||
"topN_placeholder": "未設定",
|
||||
"topN__too_large_or_small": "返回結果數量不能大於100或小於1",
|
||||
"topN_tooltip": "返回的匹配結果數量,數值越大,匹配結果越多,但消耗的 Token 也越多"
|
||||
},
|
||||
"languages": {
|
||||
"arabic": "阿拉伯文",
|
||||
@ -533,7 +537,9 @@
|
||||
"function_calling": "函數調用"
|
||||
},
|
||||
"vision": "視覺",
|
||||
"websearch": "網路搜尋"
|
||||
"websearch": "網路搜尋",
|
||||
"rerank_model": "重排序模型",
|
||||
"rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加"
|
||||
},
|
||||
"navbar": {
|
||||
"expand": "伸縮對話框",
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { isEmbeddingModel } from '@renderer/config/models'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
@ -20,6 +20,7 @@ interface ShowParams {
|
||||
interface FormData {
|
||||
name: string
|
||||
model: string
|
||||
rerankModel: string
|
||||
}
|
||||
|
||||
interface Props extends ShowParams {
|
||||
@ -37,6 +38,11 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
.map((p) => p.models)
|
||||
.flat()
|
||||
.filter((model) => isEmbeddingModel(model))
|
||||
const rerankModels = providers
|
||||
.map((p) => p.models)
|
||||
.flat()
|
||||
.filter((model) => isRerankModel(model))
|
||||
console.log('rerankModels', rerankModels)
|
||||
const nameInputRef = useRef<any>(null)
|
||||
|
||||
const selectOptions = providers
|
||||
@ -53,10 +59,25 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
|
||||
const rerankSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
|
||||
const onOk = async () => {
|
||||
try {
|
||||
const values = await form.validateFields()
|
||||
const selectedModel = find(allModels, JSON.parse(values.model)) as Model
|
||||
const selectedRerankModel = find(rerankModels, JSON.parse(values.rerankModel)) as Model
|
||||
|
||||
if (selectedModel) {
|
||||
setLoading(true)
|
||||
@ -82,6 +103,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
id: nanoid(),
|
||||
name: values.name,
|
||||
model: selectedModel,
|
||||
rerankModel: selectedRerankModel,
|
||||
dimensions,
|
||||
items: [],
|
||||
created_at: Date.now(),
|
||||
@ -134,6 +156,14 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
||||
rules={[{ required: true, message: t('message.error.enter.model') }]}>
|
||||
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="rerankModel"
|
||||
label={t('models.rerank_model')}
|
||||
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
|
||||
rules={[{ required: false, message: t('message.error.enter.model') }]}>
|
||||
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
)
|
||||
|
||||
@ -41,8 +41,16 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
|
||||
search: value,
|
||||
base: getKnowledgeBaseParams(base)
|
||||
})
|
||||
let rerankResult = searchResults
|
||||
if (base.rerankModel) {
|
||||
rerankResult = await window.api.knowledgeBase.rerank({
|
||||
search: value,
|
||||
base: getKnowledgeBaseParams(base),
|
||||
results: searchResults
|
||||
})
|
||||
}
|
||||
const results = await Promise.all(
|
||||
searchResults.map(async (item) => {
|
||||
rerankResult.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
|
||||
@ -2,7 +2,7 @@ import { WarningOutlined } from '@ant-design/icons'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
import { isEmbeddingModel } from '@renderer/config/models'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { useKnowledge } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
@ -23,6 +23,8 @@ interface FormData {
|
||||
chunkSize?: number
|
||||
chunkOverlap?: number
|
||||
threshold?: number
|
||||
rerankModel?: string
|
||||
topN?: number
|
||||
}
|
||||
|
||||
interface Props extends ShowParams {
|
||||
@ -59,6 +61,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
|
||||
const rerankSelectOptions = providers
|
||||
.filter((p) => p.models.length > 0)
|
||||
.map((p) => ({
|
||||
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||
title: p.name,
|
||||
options: sortBy(p.models, 'name')
|
||||
.filter((model) => isRerankModel(model))
|
||||
.map((m) => ({
|
||||
label: m.name,
|
||||
value: getModelUniqId(m)
|
||||
}))
|
||||
}))
|
||||
.filter((group) => group.options.length > 0)
|
||||
|
||||
const onOk = async () => {
|
||||
try {
|
||||
const values = await form.validateFields()
|
||||
@ -68,7 +84,11 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
documentCount: values.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT,
|
||||
chunkSize: values.chunkSize,
|
||||
chunkOverlap: values.chunkOverlap,
|
||||
threshold: values.threshold ?? undefined
|
||||
threshold: values.threshold ?? undefined,
|
||||
rerankModel: values.rerankModel
|
||||
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
|
||||
: undefined,
|
||||
topN: values.topN
|
||||
}
|
||||
updateKnowledgeBase(newBase)
|
||||
setOpen(false)
|
||||
@ -116,6 +136,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} disabled />
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="rerankModel"
|
||||
label={t('models.rerank_model')}
|
||||
initialValue={getModelUniqId(base.rerankModel)}
|
||||
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
|
||||
rules={[{ required: false, message: t('message.error.enter.model') }]}>
|
||||
<Select
|
||||
style={{ width: '100%' }}
|
||||
options={rerankSelectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
allowClear
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
name="documentCount"
|
||||
label={t('knowledge.document_count')}
|
||||
@ -193,6 +227,22 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
]}>
|
||||
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="topN"
|
||||
label={t('knowledge.topN')}
|
||||
initialValue={base.topN}
|
||||
rules={[
|
||||
{
|
||||
validator(_, value) {
|
||||
if (value && (value < 0 || value > 10)) {
|
||||
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
|
||||
}
|
||||
return Promise.resolve()
|
||||
}
|
||||
}
|
||||
]}>
|
||||
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
<Alert message={t('knowledge.chunk_size_change_warning')} type="warning" showIcon icon={<WarningOutlined />} />
|
||||
</Modal>
|
||||
|
||||
@ -39,7 +39,10 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
|
||||
apiVersion: provider.apiVersion,
|
||||
baseURL: host,
|
||||
chunkSize,
|
||||
chunkOverlap: base.chunkOverlap
|
||||
chunkOverlap: base.chunkOverlap,
|
||||
rerankModel: base.rerankModel?.id,
|
||||
rerankModelProvider: base.rerankModel?.provider,
|
||||
topN: base.topN
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,8 +95,17 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
|
||||
})
|
||||
)
|
||||
|
||||
const _searchResults = await Promise.all(
|
||||
searchResults.map(async (item) => {
|
||||
let rerankResults = searchResults
|
||||
if (base.rerankModel) {
|
||||
rerankResults = await window.api.knowledgeBase.rerank({
|
||||
search: message.content,
|
||||
base: getKnowledgeBaseParams(base),
|
||||
results: searchResults
|
||||
})
|
||||
}
|
||||
|
||||
const processdResults = await Promise.all(
|
||||
rerankResults.map(async (item) => {
|
||||
const file = await getFileFromUrl(item.metadata.source)
|
||||
return { ...item, file }
|
||||
})
|
||||
@ -102,7 +114,7 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
|
||||
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
||||
|
||||
const references = await Promise.all(
|
||||
take(_searchResults, documentCount).map(async (item, index) => {
|
||||
take(processdResults, documentCount).map(async (item, index) => {
|
||||
const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
|
||||
return {
|
||||
id: index + 1,
|
||||
|
||||
@ -254,6 +254,8 @@ export interface KnowledgeBase {
|
||||
chunkSize?: number
|
||||
chunkOverlap?: number
|
||||
threshold?: number
|
||||
rerankModel?: Model
|
||||
topN?: number
|
||||
}
|
||||
|
||||
export type KnowledgeBaseParams = {
|
||||
@ -265,6 +267,9 @@ export type KnowledgeBaseParams = {
|
||||
baseURL: string
|
||||
chunkSize?: number
|
||||
chunkOverlap?: number
|
||||
rerankModel?: string
|
||||
rerankModelProvider?: string
|
||||
topN?: number
|
||||
}
|
||||
|
||||
export type GenerateImageParams = {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user