refactor(reranker): 重构重排序功能以提高可维护性 (#4539)
* refactor(reranker): 重构重排序功能以提高可维护性 - 将 BaseReranker 类中的公共逻辑提取到受保护的方法中 - 优化了 JinaReranker、SiliconFlowReranker 和 VoyageReranker 的实现 - 新增 getRerankUrl 和 getRerankResult 方法以提高代码复用性 - 简化了重排序结果的处理逻辑 * refactor(reranker): 将 formatErrorMessage 方法的访问权限改为受保护 - 将 formatErrorMessage 方法的访问权限从公共 (public) 改为受保护 (protected) - 这一更改限制了方法的访问范围,仅允许子类访问该方法 - 有助于提高代码的封装性和安全性
This commit is contained in:
parent
1fcee6c829
commit
b1bd5d0531
@ -3,14 +3,60 @@ import { KnowledgeBaseParams } from '@types'
|
|||||||
|
|
||||||
export default abstract class BaseReranker {
|
export default abstract class BaseReranker {
|
||||||
protected base: KnowledgeBaseParams
|
protected base: KnowledgeBaseParams
|
||||||
|
|
||||||
constructor(base: KnowledgeBaseParams) {
|
constructor(base: KnowledgeBaseParams) {
|
||||||
if (!base.rerankModel) {
|
if (!base.rerankModel) {
|
||||||
throw new Error('Rerank model is required')
|
throw new Error('Rerank model is required')
|
||||||
}
|
}
|
||||||
this.base = base
|
this.base = base
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
|
abstract rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get Rerank Request Url
|
||||||
|
*/
|
||||||
|
protected getRerankUrl() {
|
||||||
|
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||||||
|
? this.base.rerankBaseURL.slice(0, -1)
|
||||||
|
: this.base.rerankBaseURL
|
||||||
|
// 必须携带/v1,否则会404
|
||||||
|
if (baseURL && !baseURL.endsWith('/v1')) {
|
||||||
|
baseURL = `${baseURL}/v1`
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${baseURL}/rerank`
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get Rerank Result
|
||||||
|
* @param searchResults
|
||||||
|
* @param rerankResults
|
||||||
|
* @protected
|
||||||
|
*/
|
||||||
|
protected getRerankResult(
|
||||||
|
searchResults: ExtractChunkData[],
|
||||||
|
rerankResults: Array<{
|
||||||
|
index: number
|
||||||
|
relevance_score: number
|
||||||
|
}>
|
||||||
|
) {
|
||||||
|
const resultMap = new Map(rerankResults.map((result) => [result.index, result.relevance_score || 0]))
|
||||||
|
|
||||||
|
return searchResults
|
||||||
|
.map((doc: ExtractChunkData, index: number) => {
|
||||||
|
const score = resultMap.get(index)
|
||||||
|
if (score === undefined) return undefined
|
||||||
|
|
||||||
|
return {
|
||||||
|
...doc,
|
||||||
|
score
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||||
|
.sort((a, b) => b.score - a.score)
|
||||||
|
}
|
||||||
|
|
||||||
public defaultHeaders() {
|
public defaultHeaders() {
|
||||||
return {
|
return {
|
||||||
Authorization: `Bearer ${this.base.rerankApiKey}`,
|
Authorization: `Bearer ${this.base.rerankApiKey}`,
|
||||||
@ -18,7 +64,7 @@ export default abstract class BaseReranker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public formatErrorMessage(url: string, error: any, requestBody: any) {
|
protected formatErrorMessage(url: string, error: any, requestBody: any) {
|
||||||
const errorDetails = {
|
const errorDetails = {
|
||||||
url: url,
|
url: url,
|
||||||
message: error.message,
|
message: error.message,
|
||||||
|
|||||||
@ -10,16 +10,7 @@ export default class JinaReranker extends BaseReranker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
const url = this.getRerankUrl()
|
||||||
? this.base.rerankBaseURL.slice(0, -1)
|
|
||||||
: this.base.rerankBaseURL
|
|
||||||
|
|
||||||
// 必须携带/v1,否则会404
|
|
||||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
|
||||||
baseURL = `${baseURL}/v1`
|
|
||||||
}
|
|
||||||
|
|
||||||
const url = `${baseURL}/rerank`
|
|
||||||
|
|
||||||
const requestBody = {
|
const requestBody = {
|
||||||
model: this.base.rerankModel,
|
model: this.base.rerankModel,
|
||||||
@ -32,23 +23,9 @@ export default class JinaReranker extends BaseReranker {
|
|||||||
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||||
|
|
||||||
const rerankResults = data.results
|
const rerankResults = data.results
|
||||||
console.log(rerankResults)
|
return this.getRerankResult(searchResults, rerankResults)
|
||||||
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
|
||||||
return searchResults
|
|
||||||
.map((doc: ExtractChunkData, index: number) => {
|
|
||||||
const score = resultMap.get(index)
|
|
||||||
if (score === undefined) return undefined
|
|
||||||
|
|
||||||
return {
|
|
||||||
...doc,
|
|
||||||
score
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
|
||||||
.sort((a, b) => b.score - a.score)
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||||
|
|
||||||
console.error('Jina Reranker API Error:', errorDetails)
|
console.error('Jina Reranker API Error:', errorDetails)
|
||||||
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -10,16 +10,7 @@ export default class SiliconFlowReranker extends BaseReranker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
const url = this.getRerankUrl()
|
||||||
? this.base.rerankBaseURL.slice(0, -1)
|
|
||||||
: this.base.rerankBaseURL
|
|
||||||
|
|
||||||
// 必须携带/v1,否则会404
|
|
||||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
|
||||||
baseURL = `${baseURL}/v1`
|
|
||||||
}
|
|
||||||
|
|
||||||
const url = `${baseURL}/rerank`
|
|
||||||
|
|
||||||
const requestBody = {
|
const requestBody = {
|
||||||
model: this.base.rerankModel,
|
model: this.base.rerankModel,
|
||||||
@ -34,20 +25,7 @@ export default class SiliconFlowReranker extends BaseReranker {
|
|||||||
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||||
|
|
||||||
const rerankResults = data.results
|
const rerankResults = data.results
|
||||||
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
return this.getRerankResult(searchResults, rerankResults)
|
||||||
|
|
||||||
return searchResults
|
|
||||||
.map((doc: ExtractChunkData, index: number) => {
|
|
||||||
const score = resultMap.get(index)
|
|
||||||
if (score === undefined) return undefined
|
|
||||||
|
|
||||||
return {
|
|
||||||
...doc,
|
|
||||||
score
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
|
||||||
.sort((a, b) => b.score - a.score)
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||||
|
|
||||||
|
|||||||
@ -10,15 +10,7 @@ export default class VoyageReranker extends BaseReranker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
const url = this.getRerankUrl()
|
||||||
? this.base.rerankBaseURL.slice(0, -1)
|
|
||||||
: this.base.rerankBaseURL
|
|
||||||
|
|
||||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
|
||||||
baseURL = `${baseURL}/v1`
|
|
||||||
}
|
|
||||||
|
|
||||||
const url = `${baseURL}/rerank`
|
|
||||||
|
|
||||||
const requestBody = {
|
const requestBody = {
|
||||||
model: this.base.rerankModel,
|
model: this.base.rerankModel,
|
||||||
@ -37,21 +29,7 @@ export default class VoyageReranker extends BaseReranker {
|
|||||||
})
|
})
|
||||||
|
|
||||||
const rerankResults = data.data
|
const rerankResults = data.data
|
||||||
|
return this.getRerankResult(searchResults, rerankResults)
|
||||||
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
|
||||||
|
|
||||||
return searchResults
|
|
||||||
.map((doc: ExtractChunkData, index: number) => {
|
|
||||||
const score = resultMap.get(index)
|
|
||||||
if (score === undefined) return undefined
|
|
||||||
|
|
||||||
return {
|
|
||||||
...doc,
|
|
||||||
score
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
|
||||||
.sort((a, b) => b.score - a.score)
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
const errorDetails = this.formatErrorMessage(url, error, requestBody)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user