refactor(reranker): 重构重排序功能以提高可维护性 (#4539)

* refactor(reranker): 重构重排序功能以提高可维护性

- 将 BaseReranker 类中的公共逻辑提取到受保护的方法中
- 优化了 JinaReranker、SiliconFlowReranker 和 VoyageReranker 的实现
- 新增 getRerankUrl 和 getRerankResult 方法以提高代码复用性
- 简化了重排序结果的处理逻辑

* refactor(reranker): 将 formatErrorMessage 方法的访问权限改为受保护

- 将 formatErrorMessage 方法的访问权限从公共 (public) 改为受保护 (protected)
- 这一更改限制了方法的访问范围,仅允许子类访问该方法
- 有助于提高代码的封装性和安全性
This commit is contained in:
Hamm 2025-04-08 16:53:31 +08:00 committed by GitHub
parent 1fcee6c829
commit b1bd5d0531
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 74 deletions

View File

@ -3,14 +3,60 @@ import { KnowledgeBaseParams } from '@types'
export default abstract class BaseReranker {
protected base: KnowledgeBaseParams
constructor(base: KnowledgeBaseParams) {
if (!base.rerankModel) {
throw new Error('Rerank model is required')
}
this.base = base
}
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
/**
* Get Rerank Request Url
*/
protected getRerankUrl() {
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL
// 必须携带/v1否则会404
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
return `${baseURL}/rerank`
}
/**
* Get Rerank Result
* @param searchResults
* @param rerankResults
* @protected
*/
protected getRerankResult(
searchResults: ExtractChunkData[],
rerankResults: Array<{
index: number
relevance_score: number
}>
) {
const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0]))
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.sort((a, b) => b.score - a.score)
}
public defaultHeaders() {
return {
Authorization: `Bearer ${this.base.rerankApiKey}`,
@ -18,7 +64,7 @@ export default abstract class BaseReranker {
}
}
public formatErrorMessage(url: string, error: any, requestBody: any) {
protected formatErrorMessage(url: string, error: any, requestBody: any) {
const errorDetails = {
url: url,
message: error.message,

View File

@ -10,16 +10,7 @@ export default class JinaReranker extends BaseReranker {
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL
// 必须携带/v1否则会404
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
const url = `${baseURL}/rerank`
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
@ -32,23 +23,9 @@ export default class JinaReranker extends BaseReranker {
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results
console.log(rerankResults)
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.sort((a, b) => b.score - a.score)
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Jina Reranker API Error:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}

View File

@ -10,16 +10,7 @@ export default class SiliconFlowReranker extends BaseReranker {
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL
// 必须携带/v1否则会404
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
const url = `${baseURL}/rerank`
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
@ -34,20 +25,7 @@ export default class SiliconFlowReranker extends BaseReranker {
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.sort((a, b) => b.score - a.score)
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)

View File

@ -10,15 +10,7 @@ export default class VoyageReranker extends BaseReranker {
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
const url = `${baseURL}/rerank`
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
@ -37,21 +29,7 @@ export default class VoyageReranker extends BaseReranker {
})
const rerankResults = data.data
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
return searchResults
.map((doc: ExtractChunkData, index: number) => {
const score = resultMap.get(index)
if (score === undefined) return undefined
return {
...doc,
score
}
})
.filter((doc): doc is ExtractChunkData => doc !== undefined)
.sort((a, b) => b.score - a.score)
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)