feat: support jina reranker (#3658)
This commit is contained in:
parent
6fd5ff991d
commit
ea990e78a5
48
src/main/reranker/JinaReranker.ts
Normal file
48
src/main/reranker/JinaReranker.ts
Normal file
@ -0,0 +1,48 @@
|
||||
import { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
import axios from 'axios'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
|
||||
export default class JinaReranker extends BaseReranker {
|
||||
constructor(base: KnowledgeBaseParams) {
|
||||
super(base)
|
||||
}
|
||||
|
||||
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||
const baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||||
? this.base.rerankBaseURL.slice(0, -1)
|
||||
: this.base.rerankBaseURL
|
||||
const url = `${baseURL}/rerank`
|
||||
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_n: this.base.topN
|
||||
}
|
||||
|
||||
try {
|
||||
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||
|
||||
const rerankResults = data.results
|
||||
console.log(rerankResults)
|
||||
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
||||
return searchResults
|
||||
.map((doc: ExtractChunkData, index: number) => {
|
||||
const score = resultMap.get(index)
|
||||
if (score === undefined) return undefined
|
||||
|
||||
return {
|
||||
...doc,
|
||||
score
|
||||
}
|
||||
})
|
||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||
.sort((a, b) => b.score - a.score)
|
||||
} catch (error) {
|
||||
console.error('Jina Reranker API 错误:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2,12 +2,15 @@ import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import BaseReranker from './BaseReranker'
|
||||
import DefaultReranker from './DefaultReranker'
|
||||
import JinaReranker from './JinaReranker'
|
||||
import SiliconFlowReranker from './SiliconFlowReranker'
|
||||
|
||||
export default class RerankerFactory {
|
||||
static create(base: KnowledgeBaseParams): BaseReranker {
|
||||
if (base.rerankModelProvider === 'silicon') {
|
||||
return new SiliconFlowReranker(base)
|
||||
} else if (base.rerankModelProvider === 'jina') {
|
||||
return new JinaReranker(base)
|
||||
}
|
||||
return new DefaultReranker(base)
|
||||
}
|
||||
|
||||
@ -15,20 +15,17 @@ export default class SiliconFlowReranker extends BaseReranker {
|
||||
: this.base.rerankBaseURL
|
||||
const url = `${baseURL}/rerank`
|
||||
|
||||
const { data } = await axios.post(
|
||||
url,
|
||||
{
|
||||
const requestBody = {
|
||||
model: this.base.rerankModel,
|
||||
query,
|
||||
documents: searchResults.map((doc) => doc.pageContent),
|
||||
top_n: this.base.topN,
|
||||
max_chunks_per_doc: this.base.chunkSize,
|
||||
overlap_tokens: this.base.chunkOverlap
|
||||
},
|
||||
{
|
||||
headers: this.defaultHeaders()
|
||||
}
|
||||
)
|
||||
|
||||
try {
|
||||
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||
|
||||
const rerankResults = data.results
|
||||
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
||||
@ -45,5 +42,9 @@ export default class SiliconFlowReranker extends BaseReranker {
|
||||
})
|
||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||
.sort((a, b) => b.score - a.score)
|
||||
} catch (error) {
|
||||
console.error('SiliconFlow Reranker API 错误:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -142,10 +142,10 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
||||
name="rerankModel"
|
||||
label={t('models.rerank_model')}
|
||||
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
|
||||
initialValue={getModelUniqId(base.rerankModel) || undefined}
|
||||
rules={[{ required: false, message: t('message.error.enter.model') }]}>
|
||||
<Select
|
||||
style={{ width: '100%' }}
|
||||
defaultValue={getModelUniqId(base.rerankModel) || undefined}
|
||||
options={rerankSelectOptions}
|
||||
placeholder={t('settings.models.empty')}
|
||||
allowClear
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user