feat(knowledge base): enhance knowledge base with rerank model
This commit is contained in:
parent
359f6e36e9
commit
b50f8a4c11
@ -189,6 +189,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle('knowledge-base:add', KnowledgeService.add)
|
ipcMain.handle('knowledge-base:add', KnowledgeService.add)
|
||||||
ipcMain.handle('knowledge-base:remove', KnowledgeService.remove)
|
ipcMain.handle('knowledge-base:remove', KnowledgeService.remove)
|
||||||
ipcMain.handle('knowledge-base:search', KnowledgeService.search)
|
ipcMain.handle('knowledge-base:search', KnowledgeService.search)
|
||||||
|
ipcMain.handle('knowledge-base:rerank', KnowledgeService.rerank)
|
||||||
|
|
||||||
// window
|
// window
|
||||||
ipcMain.handle('window:set-minimum-size', (_, width: number, height: number) => {
|
ipcMain.handle('window:set-minimum-size', (_, width: number, height: number) => {
|
||||||
|
|||||||
20
src/main/reranker/BaseReranker.ts
Normal file
20
src/main/reranker/BaseReranker.ts
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||||
|
import { KnowledgeBaseParams } from '@types'
|
||||||
|
|
||||||
|
export default abstract class BaseReranker {
|
||||||
|
protected base: KnowledgeBaseParams
|
||||||
|
constructor(base: KnowledgeBaseParams) {
|
||||||
|
if (!base.rerankModel) {
|
||||||
|
throw new Error('Rerank model is required')
|
||||||
|
}
|
||||||
|
this.base = base
|
||||||
|
}
|
||||||
|
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
|
||||||
|
|
||||||
|
public defaultHeaders() {
|
||||||
|
return {
|
||||||
|
Authorization: `Bearer ${this.base.apiKey}`,
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
13
src/main/reranker/DefaultReranker.ts
Normal file
13
src/main/reranker/DefaultReranker.ts
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||||
|
import { KnowledgeBaseParams } from '@types'
|
||||||
|
|
||||||
|
import BaseReranker from './BaseReranker'
|
||||||
|
|
||||||
|
export default class DefaultReranker extends BaseReranker {
|
||||||
|
constructor(base: KnowledgeBaseParams) {
|
||||||
|
super(base)
|
||||||
|
}
|
||||||
|
async rerank(): Promise<ExtractChunkData[]> {
|
||||||
|
throw new Error('Method not implemented.')
|
||||||
|
}
|
||||||
|
}
|
||||||
15
src/main/reranker/Reranker.ts
Normal file
15
src/main/reranker/Reranker.ts
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||||
|
import { KnowledgeBaseParams } from '@types'
|
||||||
|
|
||||||
|
import BaseReranker from './BaseReranker'
|
||||||
|
import RerankerFactory from './RerankerFactory'
|
||||||
|
|
||||||
|
export default class Reranker {
|
||||||
|
private sdk: BaseReranker
|
||||||
|
constructor(base: KnowledgeBaseParams) {
|
||||||
|
this.sdk = RerankerFactory.create(base)
|
||||||
|
}
|
||||||
|
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
|
||||||
|
return this.sdk.rerank(query, searchResults)
|
||||||
|
}
|
||||||
|
}
|
||||||
14
src/main/reranker/RerankerFactory.ts
Normal file
14
src/main/reranker/RerankerFactory.ts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
import { KnowledgeBaseParams } from '@types'
|
||||||
|
|
||||||
|
import BaseReranker from './BaseReranker'
|
||||||
|
import DefaultReranker from './DefaultReranker'
|
||||||
|
import SiliconFlowReranker from './SiliconFlowReranker'
|
||||||
|
|
||||||
|
export default class RerankerFactory {
|
||||||
|
static create(base: KnowledgeBaseParams): BaseReranker {
|
||||||
|
if (base.rerankModelProvider === 'silicon') {
|
||||||
|
return new SiliconFlowReranker(base)
|
||||||
|
}
|
||||||
|
return new DefaultReranker(base)
|
||||||
|
}
|
||||||
|
}
|
||||||
46
src/main/reranker/SiliconFlowReranker.ts
Normal file
46
src/main/reranker/SiliconFlowReranker.ts
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||||
|
import { KnowledgeBaseParams } from '@types'
|
||||||
|
import axios from 'axios'
|
||||||
|
|
||||||
|
import BaseReranker from './BaseReranker'
|
||||||
|
|
||||||
|
export default class SiliconFlowReranker extends BaseReranker {
|
||||||
|
constructor(base: KnowledgeBaseParams) {
|
||||||
|
super(base)
|
||||||
|
}
|
||||||
|
|
||||||
|
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||||
|
const url = `${this.base.baseURL}/rerank`
|
||||||
|
|
||||||
|
const { data } = await axios.post(
|
||||||
|
url,
|
||||||
|
{
|
||||||
|
model: this.base.rerankModel,
|
||||||
|
query,
|
||||||
|
documents: searchResults.map((doc) => doc.pageContent),
|
||||||
|
top_n: this.base.topN,
|
||||||
|
max_chunks_per_doc: this.base.chunkSize,
|
||||||
|
overlap_tokens: this.base.chunkOverlap
|
||||||
|
},
|
||||||
|
{
|
||||||
|
headers: this.defaultHeaders()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const rerankResults = data.results
|
||||||
|
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
||||||
|
|
||||||
|
return searchResults
|
||||||
|
.map((doc: ExtractChunkData, index: number) => {
|
||||||
|
const score = resultMap.get(index)
|
||||||
|
if (score === undefined) return undefined
|
||||||
|
|
||||||
|
return {
|
||||||
|
...doc,
|
||||||
|
score
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||||
|
.sort((a, b) => b.score - a.score)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -23,6 +23,7 @@ import { SitemapLoader } from '@llm-tools/embedjs-loader-sitemap'
|
|||||||
import { WebLoader } from '@llm-tools/embedjs-loader-web'
|
import { WebLoader } from '@llm-tools/embedjs-loader-web'
|
||||||
import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai'
|
import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai'
|
||||||
import { addFileLoader } from '@main/loader'
|
import { addFileLoader } from '@main/loader'
|
||||||
|
import Reranker from '@main/reranker/Reranker'
|
||||||
import { proxyManager } from '@main/services/ProxyManager'
|
import { proxyManager } from '@main/services/ProxyManager'
|
||||||
import { windowService } from '@main/services/WindowService'
|
import { windowService } from '@main/services/WindowService'
|
||||||
import { getInstanceName } from '@main/utils'
|
import { getInstanceName } from '@main/utils'
|
||||||
@ -482,6 +483,13 @@ class KnowledgeService {
|
|||||||
const ragApplication = await this.getRagApplication(base)
|
const ragApplication = await this.getRagApplication(base)
|
||||||
return await ragApplication.search(search)
|
return await ragApplication.search(search)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public rerank = async (
|
||||||
|
_: Electron.IpcMainInvokeEvent,
|
||||||
|
{ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }
|
||||||
|
): Promise<ExtractChunkData[]> => {
|
||||||
|
return await new Reranker(base).rerank(search, results)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export default new KnowledgeService()
|
export default new KnowledgeService()
|
||||||
|
|||||||
9
src/preload/index.d.ts
vendored
9
src/preload/index.d.ts
vendored
@ -90,6 +90,15 @@ declare global {
|
|||||||
base: KnowledgeBaseParams
|
base: KnowledgeBaseParams
|
||||||
}) => Promise<void>
|
}) => Promise<void>
|
||||||
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => Promise<ExtractChunkData[]>
|
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) => Promise<ExtractChunkData[]>
|
||||||
|
rerank: ({
|
||||||
|
search,
|
||||||
|
base,
|
||||||
|
results
|
||||||
|
}: {
|
||||||
|
search: string
|
||||||
|
base: KnowledgeBaseParams
|
||||||
|
results: ExtractChunkData[]
|
||||||
|
}) => Promise<ExtractChunkData[]>
|
||||||
}
|
}
|
||||||
window: {
|
window: {
|
||||||
setMinimumSize: (width: number, height: number) => Promise<void>
|
setMinimumSize: (width: number, height: number) => Promise<void>
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import { electronAPI } from '@electron-toolkit/preload'
|
import { electronAPI } from '@electron-toolkit/preload'
|
||||||
|
import type { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||||
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types'
|
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types'
|
||||||
import { contextBridge, ipcRenderer, OpenDialogOptions, shell } from 'electron'
|
import { contextBridge, ipcRenderer, OpenDialogOptions, shell } from 'electron'
|
||||||
|
|
||||||
@ -75,7 +76,9 @@ const api = {
|
|||||||
remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) =>
|
remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) =>
|
||||||
ipcRenderer.invoke('knowledge-base:remove', { uniqueId, uniqueIds, base }),
|
ipcRenderer.invoke('knowledge-base:remove', { uniqueId, uniqueIds, base }),
|
||||||
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) =>
|
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) =>
|
||||||
ipcRenderer.invoke('knowledge-base:search', { search, base })
|
ipcRenderer.invoke('knowledge-base:search', { search, base }),
|
||||||
|
rerank: ({ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }) =>
|
||||||
|
ipcRenderer.invoke('knowledge-base:rerank', { search, base, results })
|
||||||
},
|
},
|
||||||
window: {
|
window: {
|
||||||
setMinimumSize: (width: number, height: number) => ipcRenderer.invoke('window:set-minimum-size', width, height),
|
setMinimumSize: (width: number, height: number) => ipcRenderer.invoke('window:set-minimum-size', width, height),
|
||||||
|
|||||||
@ -176,6 +176,10 @@ export const REASONING_REGEX =
|
|||||||
|
|
||||||
// Embedding models
|
// Embedding models
|
||||||
export const EMBEDDING_REGEX = /(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings)/i
|
export const EMBEDDING_REGEX = /(?:^text-|embed|bge-|e5-|LLM2Vec|retrieval|uae-|gte-|jina-clip|jina-embeddings)/i
|
||||||
|
|
||||||
|
// Rerank models
|
||||||
|
export const RERANKING_REGEX = /(?:rerank|re-rank|re-ranker|re-ranking|retrieval|retriever)/i
|
||||||
|
|
||||||
export const NOT_SUPPORTED_REGEX = /(?:^tts|rerank|whisper|speech)/i
|
export const NOT_SUPPORTED_REGEX = /(?:^tts|rerank|whisper|speech)/i
|
||||||
|
|
||||||
// Tool calling models
|
// Tool calling models
|
||||||
@ -1880,6 +1884,13 @@ export function isEmbeddingModel(model: Model): boolean {
|
|||||||
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
|
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isRerankModel(model: Model): boolean {
|
||||||
|
if (!model) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return RERANKING_REGEX.test(model.id) || false
|
||||||
|
}
|
||||||
|
|
||||||
export function isVisionModel(model: Model): boolean {
|
export function isVisionModel(model: Model): boolean {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -362,7 +362,11 @@
|
|||||||
"title": "Knowledge Base",
|
"title": "Knowledge Base",
|
||||||
"url_added": "URL added",
|
"url_added": "URL added",
|
||||||
"url_placeholder": "Enter URL, multiple URLs separated by Enter",
|
"url_placeholder": "Enter URL, multiple URLs separated by Enter",
|
||||||
"urls": "URLs"
|
"urls": "URLs",
|
||||||
|
"topN": "Number of results returned",
|
||||||
|
"topN_placeholder": "Not set",
|
||||||
|
"topN__too_large_or_small": "The number of results returned cannot be greater than 100 or less than 1.",
|
||||||
|
"topN_tooltip": "The number of matching results returned; the larger the value, the more matching results, but also the more tokens consumed."
|
||||||
},
|
},
|
||||||
"languages": {
|
"languages": {
|
||||||
"arabic": "Arabic",
|
"arabic": "Arabic",
|
||||||
@ -533,7 +537,9 @@
|
|||||||
"function_calling": "Function Calling"
|
"function_calling": "Function Calling"
|
||||||
},
|
},
|
||||||
"vision": "Vision",
|
"vision": "Vision",
|
||||||
"websearch": "WebSearch"
|
"websearch": "WebSearch",
|
||||||
|
"rerank_model": "Reordering Model",
|
||||||
|
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add."
|
||||||
},
|
},
|
||||||
"navbar": {
|
"navbar": {
|
||||||
"expand": "Expand Dialog",
|
"expand": "Expand Dialog",
|
||||||
|
|||||||
@ -362,7 +362,11 @@
|
|||||||
"title": "ナレッジベース",
|
"title": "ナレッジベース",
|
||||||
"url_added": "URLが追加されました",
|
"url_added": "URLが追加されました",
|
||||||
"url_placeholder": "URLを入力, 複数のURLはEnterで区切る",
|
"url_placeholder": "URLを入力, 複数のURLはEnterで区切る",
|
||||||
"urls": "URL"
|
"urls": "URL",
|
||||||
|
"topN": "返却される結果の数",
|
||||||
|
"topN_placeholder": "未設定",
|
||||||
|
"topN__too_large_or_small": "結果の数は100より大きくてはならず、1より小さくてはなりません。",
|
||||||
|
"topN_tooltip": "返されるマッチ結果の数は、数値が大きいほどマッチ結果が多くなりますが、消費されるトークンも増えます。"
|
||||||
},
|
},
|
||||||
"languages": {
|
"languages": {
|
||||||
"arabic": "アラビア語",
|
"arabic": "アラビア語",
|
||||||
@ -533,7 +537,9 @@
|
|||||||
"function_calling": "関数呼び出し"
|
"function_calling": "関数呼び出し"
|
||||||
},
|
},
|
||||||
"vision": "画像",
|
"vision": "画像",
|
||||||
"websearch": "ウェブ検索"
|
"websearch": "ウェブ検索",
|
||||||
|
"rerank_model": "再順序付けモデル",
|
||||||
|
"rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。"
|
||||||
},
|
},
|
||||||
"navbar": {
|
"navbar": {
|
||||||
"expand": "ダイアログを展開",
|
"expand": "ダイアログを展開",
|
||||||
|
|||||||
@ -362,7 +362,11 @@
|
|||||||
"title": "База знаний",
|
"title": "База знаний",
|
||||||
"url_added": "URL добавлен",
|
"url_added": "URL добавлен",
|
||||||
"url_placeholder": "Введите URL, несколько URL через Enter",
|
"url_placeholder": "Введите URL, несколько URL через Enter",
|
||||||
"urls": "URL-адреса"
|
"urls": "URL-адреса",
|
||||||
|
"topN": "Количество возвращаемых результатов",
|
||||||
|
"topN_placeholder": "Не установлено",
|
||||||
|
"topN__too_large_or_small": "Количество возвращаемых результатов не может быть больше 100 или меньше 1.",
|
||||||
|
"topN_tooltip": "Количество возвращаемых совпадений; чем больше значение, тем больше совпадений, но и потребление токенов тоже возрастает."
|
||||||
},
|
},
|
||||||
"languages": {
|
"languages": {
|
||||||
"arabic": "Арабский",
|
"arabic": "Арабский",
|
||||||
@ -539,7 +543,9 @@
|
|||||||
"function_calling": "Вызов функции"
|
"function_calling": "Вызов функции"
|
||||||
},
|
},
|
||||||
"vision": "Визуальные",
|
"vision": "Визуальные",
|
||||||
"websearch": "Веб-поисковые"
|
"websearch": "Веб-поисковые",
|
||||||
|
"rerank_model": "Модель переупорядочивания",
|
||||||
|
"rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить."
|
||||||
},
|
},
|
||||||
"navbar": {
|
"navbar": {
|
||||||
"expand": "Развернуть диалоговое окно",
|
"expand": "Развернуть диалоговое окно",
|
||||||
|
|||||||
@ -359,6 +359,10 @@
|
|||||||
"threshold_placeholder": "未设置",
|
"threshold_placeholder": "未设置",
|
||||||
"threshold_too_large_or_small": "阈值不能大于1或小于0",
|
"threshold_too_large_or_small": "阈值不能大于1或小于0",
|
||||||
"threshold_tooltip": "用于衡量用户问题与知识库内容之间的相关性(0-1)",
|
"threshold_tooltip": "用于衡量用户问题与知识库内容之间的相关性(0-1)",
|
||||||
|
"topN": "返回结果数量",
|
||||||
|
"topN_placeholder": "未设置",
|
||||||
|
"topN__too_large_or_small": "返回结果数量不能大于100或小于1",
|
||||||
|
"topN_tooltip": "返回的匹配结果数量,数值越大,匹配结果越多,但消耗的 Token 也越多",
|
||||||
"title": "知识库",
|
"title": "知识库",
|
||||||
"url_added": "网址已添加",
|
"url_added": "网址已添加",
|
||||||
"url_placeholder": "请输入网址, 多个网址用回车分隔",
|
"url_placeholder": "请输入网址, 多个网址用回车分隔",
|
||||||
@ -510,6 +514,8 @@
|
|||||||
"embedding": "嵌入",
|
"embedding": "嵌入",
|
||||||
"embedding_model": "嵌入模型",
|
"embedding_model": "嵌入模型",
|
||||||
"embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
"embedding_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
||||||
|
"rerank_model": "重排序模型",
|
||||||
|
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
|
||||||
"free": "免费",
|
"free": "免费",
|
||||||
"no_matches": "无可用模型",
|
"no_matches": "无可用模型",
|
||||||
"parameter_name": "参数名称",
|
"parameter_name": "参数名称",
|
||||||
|
|||||||
@ -362,7 +362,11 @@
|
|||||||
"title": "知識庫",
|
"title": "知識庫",
|
||||||
"url_added": "網址已新增",
|
"url_added": "網址已新增",
|
||||||
"url_placeholder": "請輸入網址,多個網址用換行符號分隔",
|
"url_placeholder": "請輸入網址,多個網址用換行符號分隔",
|
||||||
"urls": "網址"
|
"urls": "網址",
|
||||||
|
"topN": "返回結果數量",
|
||||||
|
"topN_placeholder": "未設定",
|
||||||
|
"topN__too_large_or_small": "返回結果數量不能大於100或小於1",
|
||||||
|
"topN_tooltip": "返回的匹配結果數量,數值越大,匹配結果越多,但消耗的 Token 也越多"
|
||||||
},
|
},
|
||||||
"languages": {
|
"languages": {
|
||||||
"arabic": "阿拉伯文",
|
"arabic": "阿拉伯文",
|
||||||
@ -533,7 +537,9 @@
|
|||||||
"function_calling": "函數調用"
|
"function_calling": "函數調用"
|
||||||
},
|
},
|
||||||
"vision": "視覺",
|
"vision": "視覺",
|
||||||
"websearch": "網路搜尋"
|
"websearch": "網路搜尋",
|
||||||
|
"rerank_model": "重排序模型",
|
||||||
|
"rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加"
|
||||||
},
|
},
|
||||||
"navbar": {
|
"navbar": {
|
||||||
"expand": "伸縮對話框",
|
"expand": "伸縮對話框",
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { TopView } from '@renderer/components/TopView'
|
import { TopView } from '@renderer/components/TopView'
|
||||||
import { isEmbeddingModel } from '@renderer/config/models'
|
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||||
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
||||||
import { useProviders } from '@renderer/hooks/useProvider'
|
import { useProviders } from '@renderer/hooks/useProvider'
|
||||||
import AiProvider from '@renderer/providers/AiProvider'
|
import AiProvider from '@renderer/providers/AiProvider'
|
||||||
@ -20,6 +20,7 @@ interface ShowParams {
|
|||||||
interface FormData {
|
interface FormData {
|
||||||
name: string
|
name: string
|
||||||
model: string
|
model: string
|
||||||
|
rerankModel: string
|
||||||
}
|
}
|
||||||
|
|
||||||
interface Props extends ShowParams {
|
interface Props extends ShowParams {
|
||||||
@ -37,6 +38,11 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
|||||||
.map((p) => p.models)
|
.map((p) => p.models)
|
||||||
.flat()
|
.flat()
|
||||||
.filter((model) => isEmbeddingModel(model))
|
.filter((model) => isEmbeddingModel(model))
|
||||||
|
const rerankModels = providers
|
||||||
|
.map((p) => p.models)
|
||||||
|
.flat()
|
||||||
|
.filter((model) => isRerankModel(model))
|
||||||
|
console.log('rerankModels', rerankModels)
|
||||||
const nameInputRef = useRef<any>(null)
|
const nameInputRef = useRef<any>(null)
|
||||||
|
|
||||||
const selectOptions = providers
|
const selectOptions = providers
|
||||||
@ -53,10 +59,25 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
|||||||
}))
|
}))
|
||||||
.filter((group) => group.options.length > 0)
|
.filter((group) => group.options.length > 0)
|
||||||
|
|
||||||
|
const rerankSelectOptions = providers
|
||||||
|
.filter((p) => p.models.length > 0)
|
||||||
|
.map((p) => ({
|
||||||
|
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||||
|
title: p.name,
|
||||||
|
options: sortBy(p.models, 'name')
|
||||||
|
.filter((model) => isRerankModel(model))
|
||||||
|
.map((m) => ({
|
||||||
|
label: m.name,
|
||||||
|
value: getModelUniqId(m)
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
.filter((group) => group.options.length > 0)
|
||||||
|
|
||||||
const onOk = async () => {
|
const onOk = async () => {
|
||||||
try {
|
try {
|
||||||
const values = await form.validateFields()
|
const values = await form.validateFields()
|
||||||
const selectedModel = find(allModels, JSON.parse(values.model)) as Model
|
const selectedModel = find(allModels, JSON.parse(values.model)) as Model
|
||||||
|
const selectedRerankModel = find(rerankModels, JSON.parse(values.rerankModel)) as Model
|
||||||
|
|
||||||
if (selectedModel) {
|
if (selectedModel) {
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
@ -82,6 +103,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
|||||||
id: nanoid(),
|
id: nanoid(),
|
||||||
name: values.name,
|
name: values.name,
|
||||||
model: selectedModel,
|
model: selectedModel,
|
||||||
|
rerankModel: selectedRerankModel,
|
||||||
dimensions,
|
dimensions,
|
||||||
items: [],
|
items: [],
|
||||||
created_at: Date.now(),
|
created_at: Date.now(),
|
||||||
@ -134,6 +156,14 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
|
|||||||
rules={[{ required: true, message: t('message.error.enter.model') }]}>
|
rules={[{ required: true, message: t('message.error.enter.model') }]}>
|
||||||
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} />
|
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="rerankModel"
|
||||||
|
label={t('models.rerank_model')}
|
||||||
|
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
|
||||||
|
rules={[{ required: false, message: t('message.error.enter.model') }]}>
|
||||||
|
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
|
||||||
|
</Form.Item>
|
||||||
</Form>
|
</Form>
|
||||||
</Modal>
|
</Modal>
|
||||||
)
|
)
|
||||||
|
|||||||
@ -41,8 +41,16 @@ const PopupContainer: React.FC<Props> = ({ base, resolve }) => {
|
|||||||
search: value,
|
search: value,
|
||||||
base: getKnowledgeBaseParams(base)
|
base: getKnowledgeBaseParams(base)
|
||||||
})
|
})
|
||||||
|
let rerankResult = searchResults
|
||||||
|
if (base.rerankModel) {
|
||||||
|
rerankResult = await window.api.knowledgeBase.rerank({
|
||||||
|
search: value,
|
||||||
|
base: getKnowledgeBaseParams(base),
|
||||||
|
results: searchResults
|
||||||
|
})
|
||||||
|
}
|
||||||
const results = await Promise.all(
|
const results = await Promise.all(
|
||||||
searchResults.map(async (item) => {
|
rerankResult.map(async (item) => {
|
||||||
const file = await getFileFromUrl(item.metadata.source)
|
const file = await getFileFromUrl(item.metadata.source)
|
||||||
return { ...item, file }
|
return { ...item, file }
|
||||||
})
|
})
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import { WarningOutlined } from '@ant-design/icons'
|
|||||||
import { TopView } from '@renderer/components/TopView'
|
import { TopView } from '@renderer/components/TopView'
|
||||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||||
import { isEmbeddingModel } from '@renderer/config/models'
|
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||||
import { useKnowledge } from '@renderer/hooks/useKnowledge'
|
import { useKnowledge } from '@renderer/hooks/useKnowledge'
|
||||||
import { useProviders } from '@renderer/hooks/useProvider'
|
import { useProviders } from '@renderer/hooks/useProvider'
|
||||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||||
@ -23,6 +23,8 @@ interface FormData {
|
|||||||
chunkSize?: number
|
chunkSize?: number
|
||||||
chunkOverlap?: number
|
chunkOverlap?: number
|
||||||
threshold?: number
|
threshold?: number
|
||||||
|
rerankModel?: string
|
||||||
|
topN?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
interface Props extends ShowParams {
|
interface Props extends ShowParams {
|
||||||
@ -59,6 +61,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
|||||||
}))
|
}))
|
||||||
.filter((group) => group.options.length > 0)
|
.filter((group) => group.options.length > 0)
|
||||||
|
|
||||||
|
const rerankSelectOptions = providers
|
||||||
|
.filter((p) => p.models.length > 0)
|
||||||
|
.map((p) => ({
|
||||||
|
label: p.isSystem ? t(`provider.${p.id}`) : p.name,
|
||||||
|
title: p.name,
|
||||||
|
options: sortBy(p.models, 'name')
|
||||||
|
.filter((model) => isRerankModel(model))
|
||||||
|
.map((m) => ({
|
||||||
|
label: m.name,
|
||||||
|
value: getModelUniqId(m)
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
.filter((group) => group.options.length > 0)
|
||||||
|
|
||||||
const onOk = async () => {
|
const onOk = async () => {
|
||||||
try {
|
try {
|
||||||
const values = await form.validateFields()
|
const values = await form.validateFields()
|
||||||
@ -68,7 +84,11 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
|||||||
documentCount: values.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT,
|
documentCount: values.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT,
|
||||||
chunkSize: values.chunkSize,
|
chunkSize: values.chunkSize,
|
||||||
chunkOverlap: values.chunkOverlap,
|
chunkOverlap: values.chunkOverlap,
|
||||||
threshold: values.threshold ?? undefined
|
threshold: values.threshold ?? undefined,
|
||||||
|
rerankModel: values.rerankModel
|
||||||
|
? providers.flatMap((p) => p.models).find((m) => getModelUniqId(m) === values.rerankModel)
|
||||||
|
: undefined,
|
||||||
|
topN: values.topN
|
||||||
}
|
}
|
||||||
updateKnowledgeBase(newBase)
|
updateKnowledgeBase(newBase)
|
||||||
setOpen(false)
|
setOpen(false)
|
||||||
@ -116,6 +136,20 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
|||||||
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} disabled />
|
<Select style={{ width: '100%' }} options={selectOptions} placeholder={t('settings.models.empty')} disabled />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
|
<Form.Item
|
||||||
|
name="rerankModel"
|
||||||
|
label={t('models.rerank_model')}
|
||||||
|
initialValue={getModelUniqId(base.rerankModel)}
|
||||||
|
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
|
||||||
|
rules={[{ required: false, message: t('message.error.enter.model') }]}>
|
||||||
|
<Select
|
||||||
|
style={{ width: '100%' }}
|
||||||
|
options={rerankSelectOptions}
|
||||||
|
placeholder={t('settings.models.empty')}
|
||||||
|
allowClear
|
||||||
|
/>
|
||||||
|
</Form.Item>
|
||||||
|
|
||||||
<Form.Item
|
<Form.Item
|
||||||
name="documentCount"
|
name="documentCount"
|
||||||
label={t('knowledge.document_count')}
|
label={t('knowledge.document_count')}
|
||||||
@ -193,6 +227,22 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
|||||||
]}>
|
]}>
|
||||||
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
|
<InputNumber placeholder={t('knowledge.threshold_placeholder')} step={0.1} style={{ width: '100%' }} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
<Form.Item
|
||||||
|
name="topN"
|
||||||
|
label={t('knowledge.topN')}
|
||||||
|
initialValue={base.topN}
|
||||||
|
rules={[
|
||||||
|
{
|
||||||
|
validator(_, value) {
|
||||||
|
if (value && (value < 0 || value > 10)) {
|
||||||
|
return Promise.reject(new Error(t('knowledge.topN_too_large_or_small')))
|
||||||
|
}
|
||||||
|
return Promise.resolve()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]}>
|
||||||
|
<InputNumber placeholder={t('knowledge.topN_placeholder')} style={{ width: '100%' }} />
|
||||||
|
</Form.Item>
|
||||||
</Form>
|
</Form>
|
||||||
<Alert message={t('knowledge.chunk_size_change_warning')} type="warning" showIcon icon={<WarningOutlined />} />
|
<Alert message={t('knowledge.chunk_size_change_warning')} type="warning" showIcon icon={<WarningOutlined />} />
|
||||||
</Modal>
|
</Modal>
|
||||||
|
|||||||
@ -39,7 +39,10 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
|
|||||||
apiVersion: provider.apiVersion,
|
apiVersion: provider.apiVersion,
|
||||||
baseURL: host,
|
baseURL: host,
|
||||||
chunkSize,
|
chunkSize,
|
||||||
chunkOverlap: base.chunkOverlap
|
chunkOverlap: base.chunkOverlap,
|
||||||
|
rerankModel: base.rerankModel?.id,
|
||||||
|
rerankModelProvider: base.rerankModel?.provider,
|
||||||
|
topN: base.topN
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,8 +95,17 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
|
|||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
const _searchResults = await Promise.all(
|
let rerankResults = searchResults
|
||||||
searchResults.map(async (item) => {
|
if (base.rerankModel) {
|
||||||
|
rerankResults = await window.api.knowledgeBase.rerank({
|
||||||
|
search: message.content,
|
||||||
|
base: getKnowledgeBaseParams(base),
|
||||||
|
results: searchResults
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const processdResults = await Promise.all(
|
||||||
|
rerankResults.map(async (item) => {
|
||||||
const file = await getFileFromUrl(item.metadata.source)
|
const file = await getFileFromUrl(item.metadata.source)
|
||||||
return { ...item, file }
|
return { ...item, file }
|
||||||
})
|
})
|
||||||
@ -102,7 +114,7 @@ export const getKnowledgeBaseReference = async (base: KnowledgeBase, message: Me
|
|||||||
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
|
||||||
|
|
||||||
const references = await Promise.all(
|
const references = await Promise.all(
|
||||||
take(_searchResults, documentCount).map(async (item, index) => {
|
take(processdResults, documentCount).map(async (item, index) => {
|
||||||
const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
|
const baseItem = base.items.find((i) => i.uniqueId === item.metadata.uniqueLoaderId)
|
||||||
return {
|
return {
|
||||||
id: index + 1,
|
id: index + 1,
|
||||||
|
|||||||
@ -254,6 +254,8 @@ export interface KnowledgeBase {
|
|||||||
chunkSize?: number
|
chunkSize?: number
|
||||||
chunkOverlap?: number
|
chunkOverlap?: number
|
||||||
threshold?: number
|
threshold?: number
|
||||||
|
rerankModel?: Model
|
||||||
|
topN?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export type KnowledgeBaseParams = {
|
export type KnowledgeBaseParams = {
|
||||||
@ -265,6 +267,9 @@ export type KnowledgeBaseParams = {
|
|||||||
baseURL: string
|
baseURL: string
|
||||||
chunkSize?: number
|
chunkSize?: number
|
||||||
chunkOverlap?: number
|
chunkOverlap?: number
|
||||||
|
rerankModel?: string
|
||||||
|
rerankModelProvider?: string
|
||||||
|
topN?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export type GenerateImageParams = {
|
export type GenerateImageParams = {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user