feat: support jina reranker (#3658)
This commit is contained in:
parent
6fd5ff991d
commit
ea990e78a5
48
src/main/reranker/JinaReranker.ts
Normal file
48
src/main/reranker/JinaReranker.ts
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import { ExtractChunkData } from '@llm-tools/embedjs-interfaces'
|
||||||
|
import { KnowledgeBaseParams } from '@types'
|
||||||
|
import axios from 'axios'
|
||||||
|
|
||||||
|
import BaseReranker from './BaseReranker'
|
||||||
|
|
||||||
|
export default class JinaReranker extends BaseReranker {
|
||||||
|
constructor(base: KnowledgeBaseParams) {
|
||||||
|
super(base)
|
||||||
|
}
|
||||||
|
|
||||||
|
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
|
||||||
|
const baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||||||
|
? this.base.rerankBaseURL.slice(0, -1)
|
||||||
|
: this.base.rerankBaseURL
|
||||||
|
const url = `${baseURL}/rerank`
|
||||||
|
|
||||||
|
const requestBody = {
|
||||||
|
model: this.base.rerankModel,
|
||||||
|
query,
|
||||||
|
documents: searchResults.map((doc) => doc.pageContent),
|
||||||
|
top_n: this.base.topN
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||||
|
|
||||||
|
const rerankResults = data.results
|
||||||
|
console.log(rerankResults)
|
||||||
|
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
||||||
|
return searchResults
|
||||||
|
.map((doc: ExtractChunkData, index: number) => {
|
||||||
|
const score = resultMap.get(index)
|
||||||
|
if (score === undefined) return undefined
|
||||||
|
|
||||||
|
return {
|
||||||
|
...doc,
|
||||||
|
score
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||||
|
.sort((a, b) => b.score - a.score)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Jina Reranker API 错误:', error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,12 +2,15 @@ import { KnowledgeBaseParams } from '@types'
|
|||||||
|
|
||||||
import BaseReranker from './BaseReranker'
|
import BaseReranker from './BaseReranker'
|
||||||
import DefaultReranker from './DefaultReranker'
|
import DefaultReranker from './DefaultReranker'
|
||||||
|
import JinaReranker from './JinaReranker'
|
||||||
import SiliconFlowReranker from './SiliconFlowReranker'
|
import SiliconFlowReranker from './SiliconFlowReranker'
|
||||||
|
|
||||||
export default class RerankerFactory {
|
export default class RerankerFactory {
|
||||||
static create(base: KnowledgeBaseParams): BaseReranker {
|
static create(base: KnowledgeBaseParams): BaseReranker {
|
||||||
if (base.rerankModelProvider === 'silicon') {
|
if (base.rerankModelProvider === 'silicon') {
|
||||||
return new SiliconFlowReranker(base)
|
return new SiliconFlowReranker(base)
|
||||||
|
} else if (base.rerankModelProvider === 'jina') {
|
||||||
|
return new JinaReranker(base)
|
||||||
}
|
}
|
||||||
return new DefaultReranker(base)
|
return new DefaultReranker(base)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,20 +15,17 @@ export default class SiliconFlowReranker extends BaseReranker {
|
|||||||
: this.base.rerankBaseURL
|
: this.base.rerankBaseURL
|
||||||
const url = `${baseURL}/rerank`
|
const url = `${baseURL}/rerank`
|
||||||
|
|
||||||
const { data } = await axios.post(
|
const requestBody = {
|
||||||
url,
|
|
||||||
{
|
|
||||||
model: this.base.rerankModel,
|
model: this.base.rerankModel,
|
||||||
query,
|
query,
|
||||||
documents: searchResults.map((doc) => doc.pageContent),
|
documents: searchResults.map((doc) => doc.pageContent),
|
||||||
top_n: this.base.topN,
|
top_n: this.base.topN,
|
||||||
max_chunks_per_doc: this.base.chunkSize,
|
max_chunks_per_doc: this.base.chunkSize,
|
||||||
overlap_tokens: this.base.chunkOverlap
|
overlap_tokens: this.base.chunkOverlap
|
||||||
},
|
|
||||||
{
|
|
||||||
headers: this.defaultHeaders()
|
|
||||||
}
|
}
|
||||||
)
|
|
||||||
|
try {
|
||||||
|
const { data } = await axios.post(url, requestBody, { headers: this.defaultHeaders() })
|
||||||
|
|
||||||
const rerankResults = data.results
|
const rerankResults = data.results
|
||||||
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
const resultMap = new Map(rerankResults.map((result: any) => [result.index, result.relevance_score || 0]))
|
||||||
@ -45,5 +42,9 @@ export default class SiliconFlowReranker extends BaseReranker {
|
|||||||
})
|
})
|
||||||
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
.filter((doc): doc is ExtractChunkData => doc !== undefined)
|
||||||
.sort((a, b) => b.score - a.score)
|
.sort((a, b) => b.score - a.score)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('SiliconFlow Reranker API 错误:', error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -142,10 +142,10 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
|
|||||||
name="rerankModel"
|
name="rerankModel"
|
||||||
label={t('models.rerank_model')}
|
label={t('models.rerank_model')}
|
||||||
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
|
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
|
||||||
|
initialValue={getModelUniqId(base.rerankModel) || undefined}
|
||||||
rules={[{ required: false, message: t('message.error.enter.model') }]}>
|
rules={[{ required: false, message: t('message.error.enter.model') }]}>
|
||||||
<Select
|
<Select
|
||||||
style={{ width: '100%' }}
|
style={{ width: '100%' }}
|
||||||
defaultValue={getModelUniqId(base.rerankModel) || undefined}
|
|
||||||
options={rerankSelectOptions}
|
options={rerankSelectOptions}
|
||||||
placeholder={t('settings.models.empty')}
|
placeholder={t('settings.models.empty')}
|
||||||
allowClear
|
allowClear
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user