diff --git a/src/main/reranker/BaseReranker.ts b/src/main/reranker/BaseReranker.ts index 58a64269..4109f539 100644 --- a/src/main/reranker/BaseReranker.ts +++ b/src/main/reranker/BaseReranker.ts @@ -3,14 +3,60 @@ import { KnowledgeBaseParams } from '@types' export default abstract class BaseReranker { protected base: KnowledgeBaseParams + constructor(base: KnowledgeBaseParams) { if (!base.rerankModel) { throw new Error('Rerank model is required') } this.base = base } + abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise + /** + * Get Rerank Request Url + */ + protected getRerankUrl() { + let baseURL = this.base?.rerankBaseURL?.endsWith('/') + ? this.base.rerankBaseURL.slice(0, -1) + : this.base.rerankBaseURL + // 必须携带/v1,否则会404 + if (baseURL && !baseURL.endsWith('/v1')) { + baseURL = `${baseURL}/v1` + } + + return `${baseURL}/rerank` + } + + /** + * Get Rerank Result + * @param searchResults + * @param rerankResults + * @protected + */ + protected getRerankResult( + searchResults: ExtractChunkData[], + rerankResults: Array<{ + index: number + relevance_score: number + }> + ) { + const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0])) + + return searchResults + .map((doc: ExtractChunkData, index: number) => { + const score = resultMap.get(index) + if (score === undefined) return undefined + + return { + ...doc, + score + } + }) + .filter((doc): doc is ExtractChunkData => doc !== undefined) + .sort((a, b) => b.score - a.score) + } + public defaultHeaders() { return { Authorization: `Bearer ${this.base.rerankApiKey}`, @@ -18,7 +64,7 @@ export default abstract class BaseReranker { } } - public formatErrorMessage(url: string, error: any, requestBody: any) { + protected formatErrorMessage(url: string, error: any, requestBody: any) { const errorDetails = { url: url, message: error.message, diff --git a/src/main/reranker/JinaReranker.ts b/src/main/reranker/JinaReranker.ts index 718774ee..207ddcb9 100644 --- a/src/main/reranker/JinaReranker.ts +++ b/src/main/reranker/JinaReranker.ts @@ -10,16 +10,7 @@ export default class JinaReranker extends BaseReranker { } public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { - let baseURL = this.base?.rerankBaseURL?.endsWith('/') - ? this.base.rerankBaseURL.slice(0, -1) - : this.base.rerankBaseURL - - // 必须携带/v1,否则会404 - if (baseURL && !baseURL.endsWith('/v1')) { - baseURL = `${baseURL}/v1` - } - - const url = `${baseURL}/rerank` + const url = this.getRerankUrl() const requestBody = { model: this.base.rerankModel, @@ -32,23 +23,9 @@ export default class JinaReranker extends BaseReranker { const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() }) const rerankResults = data.results - console.log(rerankResults) - const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0])) - return searchResults - .map((doc: ExtractChunkData, index: number) => { - const score = resultMap.get(index) - if (score === undefined) return undefined - - return { - ...doc, - score - } - }) - .filter((doc): doc is ExtractChunkData => doc !== undefined) - .sort((a, b) => b.score - a.score) + return this.getRerankResult(searchResults, rerankResults) } catch (error: any) { const errorDetails = this.formatErrorMessage(url, error, requestBody) - console.error('Jina Reranker API Error:', errorDetails) throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`) } diff --git a/src/main/reranker/SiliconFlowReranker.ts b/src/main/reranker/SiliconFlowReranker.ts index d37f547b..0a27cf7e 100644 --- a/src/main/reranker/SiliconFlowReranker.ts +++ b/src/main/reranker/SiliconFlowReranker.ts @@ -10,16 +10,7 @@ export default class SiliconFlowReranker extends BaseReranker { } public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { - let baseURL = this.base?.rerankBaseURL?.endsWith('/') - ? this.base.rerankBaseURL.slice(0, -1) - : this.base.rerankBaseURL - - // 必须携带/v1,否则会404 - if (baseURL && !baseURL.endsWith('/v1')) { - baseURL = `${baseURL}/v1` - } - - const url = `${baseURL}/rerank` + const url = this.getRerankUrl() const requestBody = { model: this.base.rerankModel, @@ -34,20 +25,7 @@ export default class SiliconFlowReranker extends BaseReranker { const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() }) const rerankResults = data.results - const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0])) - - return searchResults - .map((doc: ExtractChunkData, index: number) => { - const score = resultMap.get(index) - if (score === undefined) return undefined - - return { - ...doc, - score - } - }) - .filter((doc): doc is ExtractChunkData => doc !== undefined) - .sort((a, b) => b.score - a.score) + return this.getRerankResult(searchResults, rerankResults) } catch (error: any) { const errorDetails = this.formatErrorMessage(url, error, requestBody) diff --git a/src/main/reranker/VoyageReranker.ts b/src/main/reranker/VoyageReranker.ts index 0cfc024e..a2c0f5f8 100644 --- a/src/main/reranker/VoyageReranker.ts +++ b/src/main/reranker/VoyageReranker.ts @@ -10,15 +10,7 @@ export default class VoyageReranker extends BaseReranker { } public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise => { - let baseURL = this.base?.rerankBaseURL?.endsWith('/') - ? this.base.rerankBaseURL.slice(0, -1) - : this.base.rerankBaseURL - - if (baseURL && !baseURL.endsWith('/v1')) { - baseURL = `${baseURL}/v1` - } - - const url = `${baseURL}/rerank` + const url = this.getRerankUrl() const requestBody = { model: this.base.rerankModel, @@ -37,21 +29,7 @@ export default class VoyageReranker extends BaseReranker { }) const rerankResults = data.data - - const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0])) - - return searchResults - .map((doc: ExtractChunkData, index: number) => { - const score = resultMap.get(index) - if (score === undefined) return undefined - - return { - ...doc, - score - } - }) - .filter((doc): doc is ExtractChunkData => doc !== undefined) - .sort((a, b) => b.score - a.score) + return this.getRerankResult(searchResults, rerankResults) } catch (error: any) { const errorDetails = this.formatErrorMessage(url, error, requestBody)