diff --git a/src/main/services/KnowledgeService.ts b/src/main/services/KnowledgeService.ts index f910f3ac..2ca38a3a 100644 --- a/src/main/services/KnowledgeService.ts +++ b/src/main/services/KnowledgeService.ts @@ -9,7 +9,8 @@ import { DocxLoader, ExcelLoader, PptLoader } from '@llm-tools/embedjs-loader-ms import { PdfLoader } from '@llm-tools/embedjs-loader-pdf' import { SitemapLoader } from '@llm-tools/embedjs-loader-sitemap' import { WebLoader } from '@llm-tools/embedjs-loader-web' -import { OpenAiEmbeddings } from '@llm-tools/embedjs-openai' +import { AzureOpenAiEmbeddings, OpenAiEmbeddings } from '@llm-tools/embedjs-openai' +import { getInstanceName } from '@main/utils' import { FileType, KnowledgeBaseParams, KnowledgeItem } from '@types' import { app } from 'electron' @@ -30,19 +31,29 @@ class KnowledgeService { id, model, apiKey, + apiVersion, baseURL, dimensions }: KnowledgeBaseParams): Promise => { return new RAGApplicationBuilder() .setModel('NO_MODEL') .setEmbeddingModel( - new OpenAiEmbeddings({ - model, - apiKey, - configuration: { baseURL }, - dimensions, - batchSize: 20 - }) + apiVersion + ? new AzureOpenAiEmbeddings({ + azureOpenAIApiKey: apiKey, + azureOpenAIApiVersion: apiVersion, + azureOpenAIApiDeploymentName: model, + azureOpenAIApiInstanceName: getInstanceName(baseURL), + dimensions, + batchSize: 15 + }) + : new OpenAiEmbeddings({ + model, + apiKey, + configuration: { baseURL }, + dimensions, + batchSize: 15 + }) ) .setVectorDatabase(new LibSqlDb({ path: path.join(this.storageDir, id) })) .build() diff --git a/src/main/utils/index.ts b/src/main/utils/index.ts index 07ee7d11..5abef025 100644 --- a/src/main/utils/index.ts +++ b/src/main/utils/index.ts @@ -14,3 +14,11 @@ export function getDataPath() { } return dataPath } + +export function getInstanceName(baseURL: string) { + try { + return new URL(baseURL).host.split('.')[0] + } catch (error) { + return '' + } +} diff --git a/src/renderer/src/queue/KnowledgeQueue.ts b/src/renderer/src/queue/KnowledgeQueue.ts index 7c8d597b..128f9e6a 100644 --- a/src/renderer/src/queue/KnowledgeQueue.ts +++ b/src/renderer/src/queue/KnowledgeQueue.ts @@ -7,44 +7,10 @@ import { KnowledgeItem } from '@renderer/types' class KnowledgeQueue { private processing: Map = new Map() - private pollingInterval: NodeJS.Timeout | null = null - // private readonly POLLING_INTERVAL = 5000 - private readonly MAX_RETRIES = 2 + private readonly MAX_RETRIES = 1 constructor() { this.checkAllBases().catch(console.error) - this.startPolling() - } - - private startPolling(): void { - if (this.pollingInterval) return - - const state = store.getState() - state.knowledge.bases.forEach((base) => { - base.items.forEach((item) => { - if (item.processingStatus === 'processing') { - store.dispatch( - updateItemProcessingStatus({ - baseId: base.id, - itemId: item.id, - status: 'pending', - progress: 0 - }) - ) - } - }) - }) - - // this.pollingInterval = setInterval(() => { - // this.checkAllBases() - // }, this.POLLING_INTERVAL) - } - - private stopPolling(): void { - if (this.pollingInterval) { - clearInterval(this.pollingInterval) - this.pollingInterval = null - } } public async checkAllBases(): Promise { @@ -111,7 +77,6 @@ class KnowledgeQueue { } stopAllProcessing(): void { - this.stopPolling() for (const baseId of this.processing.keys()) { this.processing.set(baseId, false) } diff --git a/src/renderer/src/services/KnowledgeService.ts b/src/renderer/src/services/KnowledgeService.ts index ad565228..31aa8ec9 100644 --- a/src/renderer/src/services/KnowledgeService.ts +++ b/src/renderer/src/services/KnowledgeService.ts @@ -25,6 +25,7 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams model: base.model.id, dimensions: base.dimensions, apiKey: aiProvider.getApiKey(), + apiVersion: provider.apiVersion, baseURL: host } } diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 51496c74..15dc21fe 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -216,5 +216,6 @@ export type KnowledgeBaseParams = { model: string dimensions: number apiKey: string + apiVersion?: string baseURL: string }