fix: azure openai embedding

This commit is contained in:
kangfenmao 2024-12-27 14:02:53 +08:00
parent 4ac608052c
commit c409256ae9
5 changed files with 30 additions and 44 deletions

View File

@ -9,7 +9,8 @@ import { DocxLoader, ExcelLoader, PptLoader } from '@llm-tools/embedjs-loader-ms
import { PdfLoader } from '@llm-tools/embedjs-loader-pdf'
import { SitemapLoader } from '@llm-tools/embedjs-loader-sitemap'
import { WebLoader } from '@llm-tools/embedjs-loader-web'
import { OpenAiEmbeddings } from '@llm-tools/embedjs-openai'
import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai'
import { getInstanceName } from '@main/utils'
import { FileType, KnowledgeBaseParams, KnowledgeItem } from '@types'
import { app } from 'electron'
@ -30,18 +31,28 @@ class KnowledgeService {
id,
model,
apiKey,
apiVersion,
baseURL,
dimensions
}: KnowledgeBaseParams): Promise<RAGApplication> => {
return new RAGApplicationBuilder()
.setModel('NO_MODEL')
.setEmbeddingModel(
new OpenAiEmbeddings({
apiVersion
? new AzureOpenAiEmbeddings({
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: apiVersion,
azureOpenAIApiDeploymentName: model,
azureOpenAIApiInstanceName: getInstanceName(baseURL),
dimensions,
batchSize: 15
})
: new OpenAiEmbeddings({
model,
apiKey,
configuration: { baseURL },
dimensions,
batchSize: 20
batchSize: 15
})
)
.setVectorDatabase(new LibSqlDb({ path: path.join(this.storageDir, id) }))

View File

@ -14,3 +14,11 @@ export function getDataPath() {
}
return dataPath
}
export function getInstanceName(baseURL: string) {
try {
return new URL(baseURL).host.split('.')[0]
} catch (error) {
return ''
}
}

View File

@ -7,44 +7,10 @@ import { KnowledgeItem } from '@renderer/types'
class KnowledgeQueue {
private processing: Map<string, boolean> = new Map()
private pollingInterval: NodeJS.Timeout | null = null
// private readonly POLLING_INTERVAL = 5000
private readonly MAX_RETRIES = 2
private readonly MAX_RETRIES = 1
constructor() {
this.checkAllBases().catch(console.error)
this.startPolling()
}
private startPolling(): void {
if (this.pollingInterval) return
const state = store.getState()
state.knowledge.bases.forEach((base) => {
base.items.forEach((item) => {
if (item.processingStatus === 'processing') {
store.dispatch(
updateItemProcessingStatus({
baseId: base.id,
itemId: item.id,
status: 'pending',
progress: 0
})
)
}
})
})
// this.pollingInterval = setInterval(() => {
// this.checkAllBases()
// }, this.POLLING_INTERVAL)
}
private stopPolling(): void {
if (this.pollingInterval) {
clearInterval(this.pollingInterval)
this.pollingInterval = null
}
}
public async checkAllBases(): Promise<void> {
@ -111,7 +77,6 @@ class KnowledgeQueue {
}
stopAllProcessing(): void {
this.stopPolling()
for (const baseId of this.processing.keys()) {
this.processing.set(baseId, false)
}

View File

@ -25,6 +25,7 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
model: base.model.id,
dimensions: base.dimensions,
apiKey: aiProvider.getApiKey(),
apiVersion: provider.apiVersion,
baseURL: host
}
}

View File

@ -216,5 +216,6 @@ export type KnowledgeBaseParams = {
model: string
dimensions: number
apiKey: string
apiVersion?: string
baseURL: string
}