diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index c284fbaa..9edbf52b 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -24,6 +24,7 @@ import FileManager from '@renderer/services/FileManager' import { estimateTextTokens as estimateTxtTokens } from '@renderer/services/TokenService' import { translateText } from '@renderer/services/TranslateService' import store, { useAppDispatch, useAppSelector } from '@renderer/store' +import { abortCompletion } from '@renderer/store/abortController' import { setGenerating, setSearching } from '@renderer/store/runtime' import { Assistant, FileType, KnowledgeBase, Message, Model, Topic } from '@renderer/types' import { classNames, delay, getFileExtension, uuid } from '@renderer/utils' @@ -85,7 +86,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic }) => { const [selectedKnowledgeBases, setSelectedKnowledgeBases] = useState([]) const [mentionModels, setMentionModels] = useState([]) const [isMentionPopupOpen, setIsMentionPopupOpen] = useState(false) - + const currentMessageId = useRef() const isVision = useMemo(() => isVisionModel(model), [model]) const supportExts = useMemo(() => [...textExts, ...documentExts, ...(isVision ? imageExts : [])], [isVision]) @@ -133,7 +134,7 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic }) => { if (mentionModels.length > 0) { message.mentions = mentionModels } - + currentMessageId.current = message.id EventEmitter.emit(EVENT_NAMES.SEND_MESSAGE, message) setText('') @@ -274,6 +275,9 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic }) => { } const onPause = () => { + if (currentMessageId.current) { + abortCompletion(currentMessageId.current) + } window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, true) store.dispatch(setGenerating(false)) } diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index c13b0bd7..d844a19a 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -6,6 +6,7 @@ import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' +import { addAbortController, removeAbortController } from '@renderer/store/abortController' import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import { first, flatten, sum, takeRight } from 'lodash' @@ -13,7 +14,6 @@ import OpenAI from 'openai' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' - export default class AnthropicProvider extends BaseProvider { private sdk: Anthropic @@ -107,10 +107,16 @@ export default class AnthropicProvider extends BaseProvider { } }) } - + const abortController = new AbortController() + const { signal } = abortController + // 获取最后一条用户消息的 ID 作为 askId + const lastUserMessage = _messages.findLast((m) => m.role === 'user') + if (lastUserMessage?.id) { + addAbortController(lastUserMessage.id, () => abortController.abort()) + } return new Promise((resolve, reject) => { const stream = this.sdk.messages - .stream({ ...body, stream: true }) + .stream({ ...body, stream: true }, { signal }) .on('text', (text) => { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { stream.controller.abort() @@ -146,6 +152,10 @@ export default class AnthropicProvider extends BaseProvider { resolve() }) .on('error', (error) => reject(error)) + }).finally(() => { + if (lastUserMessage?.id) { + removeAbortController(lastUserMessage.id) + } }) } diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 179e2da3..b2e38b90 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -16,6 +16,7 @@ import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' +import { addAbortController, removeAbortController } from '@renderer/store/abortController' import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import axios from 'axios' @@ -24,7 +25,6 @@ import OpenAI from 'openai' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' - export default class GeminiProvider extends BaseProvider { private sdk: GoogleGenerativeAI private requestOptions: RequestOptions @@ -204,7 +204,19 @@ export default class GeminiProvider extends BaseProvider { return } - const userMessagesStream = await chat.sendMessageStream(messageContents.parts) + const abortController = new AbortController() + const { signal } = abortController + // 获取最后一条用户消息的 ID 作为 askId + const lastUserMessage = userMessages.findLast((m) => m.role === 'user') + if (lastUserMessage?.id) { + addAbortController(lastUserMessage.id, () => abortController.abort()) + } + + const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(() => { + if (lastUserMessage?.id) { + removeAbortController(lastUserMessage.id) + } + }) let time_first_token_millsec = 0 for await (const chunk of userMessagesStream.stream) { diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index 8ef689ba..76cc4909 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -4,6 +4,7 @@ import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages } from '@renderer/services/MessagesService' +import { addAbortController, removeAbortController } from '@renderer/store/abortController' import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types' import { removeSpecialCharacters } from '@renderer/utils' import { takeRight } from 'lodash' @@ -213,21 +214,40 @@ export default class OpenAIProvider extends BaseProvider { let time_first_token_millsec = 0 let time_first_content_millsec = 0 const start_time_millsec = new Date().getTime() + const abortController = new AbortController() + const { signal } = abortController - // @ts-ignore key is not typed - const stream = await this.sdk.chat.completions.create({ - model: model.id, - messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_tokens: maxTokens, - keep_alive: this.keepAliveTime, - stream: isSupportStreamOutput(), - ...this.getReasoningEffort(assistant, model), - ...getOpenAIWebSearchParams(assistant, model), - ...this.getProviderSpecificParameters(assistant, model), - ...this.getCustomParameters(assistant) - }) + // 获取最后一条用户消息的 ID 作为 askId + const lastUserMessage = _messages.findLast((m) => m.role === 'user') + if (lastUserMessage?.id) { + addAbortController(lastUserMessage.id, () => abortController.abort()) + } + + const stream = await this.sdk.chat.completions + // @ts-ignore key is not typed + .create( + { + model: model.id, + messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[], + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_tokens: maxTokens, + keep_alive: this.keepAliveTime, + stream: isSupportStreamOutput(), + ...this.getReasoningEffort(assistant, model), + ...getOpenAIWebSearchParams(assistant, model), + ...this.getProviderSpecificParameters(assistant, model), + ...this.getCustomParameters(assistant) + }, + { + signal + } + ) + .finally(() => { + if (lastUserMessage?.id) { + removeAbortController(lastUserMessage.id) + } + }) if (!isSupportStreamOutput()) { const time_completion_millsec = new Date().getTime() - start_time_millsec diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index daf7b4fb..88f53051 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -1,5 +1,6 @@ import i18n from '@renderer/i18n' import store from '@renderer/store' +import { addAbortController } from '@renderer/store/abortController' import { setGenerating } from '@renderer/store/runtime' import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types' import { formatMessageError } from '@renderer/utils/error' @@ -16,7 +17,6 @@ import { import { EVENT_NAMES, EventEmitter } from './EventService' import { filterMessages, filterUsefulMessages } from './MessagesService' import { estimateMessagesUsage } from './TokenService' - export async function fetchChatCompletion({ message, messages, @@ -37,18 +37,14 @@ export async function fetchChatCompletion({ onResponse({ ...message }) - // Handle paused state - let paused = false - const timer = setInterval(() => { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - paused = true - message.status = 'paused' - EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message) - store.dispatch(setGenerating(false)) - onResponse({ ...message, status: 'paused' }) - clearInterval(timer) - } - }, 1000) + const pauseFn = (message: Message) => { + message.status = 'paused' + EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message) + store.dispatch(setGenerating(false)) + onResponse({ ...message, status: 'paused' }) + } + + addAbortController(message.askId ?? message.id, pauseFn.bind(null, message)) try { let _messages: Message[] = [] @@ -97,12 +93,6 @@ export async function fetchChatCompletion({ message.error = formatMessageError(error) } - timer && clearInterval(timer) - - if (paused) { - return message - } - // Update message status message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : message.status diff --git a/src/renderer/src/store/abortController.ts b/src/renderer/src/store/abortController.ts new file mode 100644 index 00000000..2978e8f2 --- /dev/null +++ b/src/renderer/src/store/abortController.ts @@ -0,0 +1,25 @@ +export const abortMap = new Map void>() + +export const addAbortController = (messageId: string, abortFn: () => void) => { + let callback = abortFn + const existingCallback = abortMap.get(messageId) + if (existingCallback) { + callback = () => { + existingCallback?.() + abortFn() + } + } + abortMap.set(messageId, callback) +} + +export const removeAbortController = (messageId: string) => { + abortMap.delete(messageId) +} + +export const abortCompletion = (messageId: string) => { + const abortFn = abortMap.get(messageId) + if (abortFn) { + abortFn() + removeAbortController(messageId) + } +}