From ea990e78a5d42dfe33c8f2f895aaffb6f4b150ed Mon Sep 17 00:00:00 2001 From: Chen Tao <70054568+eeee0717@users.noreply.github.com> Date: Thu, 20 Mar 2025 22:32:54 +0800 Subject: [PATCH] feat: support jina reranker (#3658) --- src/main/reranker/JinaReranker.ts | 48 ++++++++++++++++ src/main/reranker/RerankerFactory.ts | 3 + src/main/reranker/SiliconFlowReranker.ts | 55 ++++++++++--------- .../components/KnowledgeSettingsPopup.tsx | 2 +- 4 files changed, 80 insertions(+), 28 deletions(-) create mode 100644 src/main/reranker/JinaReranker.ts diff --git a/src/main/reranker/JinaReranker.ts b/src/main/reranker/JinaReranker.ts new file mode 100644 index 00000000..dbee063c --- /dev/null +++ b/src/main/reranker/JinaReranker.ts @@ -0,0 +1,48 @@ +import { ExtractChunkData } from '@llm-tools/embedjs-interfaces' +import { KnowledgeBaseParams } from '@types' +import axios from 'axios' + +import BaseReranker from './BaseReranker' + +export default class JinaReranker extends BaseReranker { + constructor(base: KnowledgeBaseParams) { + super(base) + } + + public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { + const baseURL = this.base?.rerankBaseURL?.endsWith('/') + ? this.base.rerankBaseURL.slice(0, -1) + : this.base.rerankBaseURL + const url = `${baseURL}/rerank` + + const requestBody = { + model: this.base.rerankModel, + query, + documents: searchResults.map((doc) => doc.pageContent), + top_n: this.base.topN + } + + try { + const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() }) + + const rerankResults = data.results + console.log(rerankResults) + const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0])) + return searchResults + .map((doc: ExtractChunkData, index: number) => { + const score = resultMap.get(index) + if (score === undefined) return undefined + + return { + ...doc, + score + } + }) + .filter((doc): doc is ExtractChunkData => doc !== undefined) + .sort((a, b) => b.score - a.score) + } catch (error) { + console.error('Jina Reranker API 错误:', error) + throw error + } + } +} diff --git a/src/main/reranker/RerankerFactory.ts b/src/main/reranker/RerankerFactory.ts index 2c15fe6c..0c2e8d7d 100644 --- a/src/main/reranker/RerankerFactory.ts +++ b/src/main/reranker/RerankerFactory.ts @@ -2,12 +2,15 @@ import { KnowledgeBaseParams } from '@types' import BaseReranker from './BaseReranker' import DefaultReranker from './DefaultReranker' +import JinaReranker from './JinaReranker' import SiliconFlowReranker from './SiliconFlowReranker' export default class RerankerFactory { static create(base: KnowledgeBaseParams): BaseReranker { if (base.rerankModelProvider === 'silicon') { return new SiliconFlowReranker(base) + } else if (base.rerankModelProvider === 'jina') { + return new JinaReranker(base) } return new DefaultReranker(base) } diff --git a/src/main/reranker/SiliconFlowReranker.ts b/src/main/reranker/SiliconFlowReranker.ts index 1730efda..ee82362e 100644 --- a/src/main/reranker/SiliconFlowReranker.ts +++ b/src/main/reranker/SiliconFlowReranker.ts @@ -15,35 +15,36 @@ export default class SiliconFlowReranker extends BaseReranker { : this.base.rerankBaseURL const url = `${baseURL}/rerank` - const { data } = await axios.post( - url, - { - model: this.base.rerankModel, - query, - documents: searchResults.map((doc) => doc.pageContent), - top_n: this.base.topN, - max_chunks_per_doc: this.base.chunkSize, - overlap_tokens: this.base.chunkOverlap - }, - { - headers: this.defaultHeaders() - } - ) + const requestBody = { + model: this.base.rerankModel, + query, + documents: searchResults.map((doc) => doc.pageContent), + top_n: this.base.topN, + max_chunks_per_doc: this.base.chunkSize, + overlap_tokens: this.base.chunkOverlap + } - const rerankResults = data.results - const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0])) + try { + const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() }) - return searchResults - .map((doc: ExtractChunkData, index: number) => { - const score = resultMap.get(index) - if (score === undefined) return undefined + const rerankResults = data.results + const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0])) - return { - ...doc, - score - } - }) - .filter((doc): doc is ExtractChunkData => doc !== undefined) - .sort((a, b) => b.score - a.score) + return searchResults + .map((doc: ExtractChunkData, index: number) => { + const score = resultMap.get(index) + if (score === undefined) return undefined + + return { + ...doc, + score + } + }) + .filter((doc): doc is ExtractChunkData => doc !== undefined) + .sort((a, b) => b.score - a.score) + } catch (error) { + console.error('SiliconFlow Reranker API 错误:', error) + throw error + } } } diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx index e1c1e8c1..56c00b66 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSettingsPopup.tsx @@ -142,10 +142,10 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { name="rerankModel" label={t('models.rerank_model')} tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }} + initialValue={getModelUniqId(base.rerankModel) || undefined} rules={[{ required: false, message: t('message.error.enter.model') }]}>