From 8faececa4c5c8e7fe70206bf5dd27c2e149494de Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat <43230886+MyPrototypeWhat@users.noreply.github.com> Date: Fri, 14 Mar 2025 17:57:33 +0800 Subject: [PATCH] fix: messages pause bug (#3343) * refactor: Simplify message resend logic and enhance abort controller handling - Updated MessageMenubar to streamline message resend functionality. - Improved abort controller management in BaseProvider and related services. - Adjusted sendMessage to handle both single and multiple assistant messages. - Enhanced logging for better debugging and tracking of message flow. * feat: Enhance message handling and queue management - Updated Inputbar to include mentions in dispatched messages. - Introduced appendMessage action to manage message insertion at specific positions in the state. - Improved sendMessage logic to handle mentions and maintain message order. - Refactored getTopicQueue to accept options for better queue configuration. * refactor: Improve abort handling and message operations - Refactored useMessageOperations to streamline message pausing logic. - Enhanced abort controller in BaseProvider to handle abort events more effectively. - Updated OpenAIProvider to utilize new abort handling mechanism. - Adjusted fetchChatCompletion to set message status based on abort conditions. - Improved message dispatching in sendMessage for better queue management. * refactor: Enhance signal promise handling in BaseProvider and OpenAIProvider - Updated signal handling in BaseProvider to use a structured signalPromise object for better clarity and management. - Adjusted error handling in OpenAIProvider to correctly catch and throw errors from the signalPromise. - Improved overall abort handling logic to ensure robust message operations. * fix:lint --- .../src/hooks/useMessageOperations.ts | 43 +++--- .../src/pages/home/Inputbar/Inputbar.tsx | 6 +- .../pages/home/Messages/MessageMenubar.tsx | 8 +- .../src/providers/AnthropicProvider.ts | 21 ++- src/renderer/src/providers/BaseProvider.ts | 37 ++++-- src/renderer/src/providers/GeminiProvider.ts | 2 +- src/renderer/src/providers/OpenAIProvider.ts | 8 +- src/renderer/src/services/ApiService.ts | 16 +-- src/renderer/src/services/MessagesService.ts | 2 +- src/renderer/src/store/messages.ts | 122 ++++++++++++++---- src/renderer/src/utils/abortController.ts | 29 ++--- src/renderer/src/utils/queue.ts | 8 +- 12 files changed, 194 insertions(+), 108 deletions(-) diff --git a/src/renderer/src/hooks/useMessageOperations.ts b/src/renderer/src/hooks/useMessageOperations.ts index caa02bda..7c6b2142 100644 --- a/src/renderer/src/hooks/useMessageOperations.ts +++ b/src/renderer/src/hooks/useMessageOperations.ts @@ -158,34 +158,35 @@ export function useMessageOperations(topic: Topic) { /** * 暂停消息生成 */ - const pauseMessage = useCallback( - // 存的是用户消息的id,也就是助手消息的askId - async (message: Message) => { - // 1. 调用 abort - message.askId && abortCompletion(message.askId) + // const pauseMessage = useCallback( + // // 存的是用户消息的id,也就是助手消息的askId + // async (message: Message) => { + // // 1. 调用 abort - // 2. 更新消息状态 - await editMessage(message.id, { status: 'paused', content: message.content }) + // // 2. 更新消息状态, + // // await editMessage(message.id, { status: 'paused', content: message.content }) - // 3.更改loading状态 - dispatch(setTopicLoading({ topicId: message.topicId, loading: false })) + // // 3.更改loading状态 + // dispatch(setTopicLoading({ topicId: message.topicId, loading: false })) - // 4. 清理流式消息 - clearStreamMessageAction(message.id) - }, - [editMessage, dispatch, clearStreamMessageAction] - ) + // // 4. 清理流式消息 + // // clearStreamMessageAction(message.id) + // }, + // [editMessage, dispatch, clearStreamMessageAction] + // ) const pauseMessages = useCallback(async () => { + // 暂停的消息不需要在这更改status,通过catch判断abort错误之后设置message.status const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id] + if (!streamMessages) return + // 不需要重复暂停 + const askIds = [...new Set(Object.values(streamMessages).map((m) => m?.askId))] - if (streamMessages) { - const streamMessagesList = Object.values(streamMessages).filter((msg) => msg?.askId && msg?.id) - for (const message of streamMessagesList) { - message && (await pauseMessage(message)) - } + for (const askId of askIds) { + askId && abortCompletion(askId) } - }, [pauseMessage, topic.id]) + dispatch(setTopicLoading({ topicId: topic.id, loading: false })) + }, [topic.id, dispatch]) /** * 恢复/重发消息 @@ -213,7 +214,7 @@ export function useMessageOperations(topic: Topic) { clearStreamMessage: clearStreamMessageAction, createNewContext, clearTopicMessages: clearTopicMessagesAction, - pauseMessage, + // pauseMessage, pauseMessages, resumeMessage } diff --git a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx index 9a7b4e91..e3c70fb9 100644 --- a/src/renderer/src/pages/home/Inputbar/Inputbar.tsx +++ b/src/renderer/src/pages/home/Inputbar/Inputbar.tsx @@ -175,7 +175,11 @@ const Inputbar: FC = ({ assistant: _assistant, setActiveTopic, topic }) = userMessage.usage = await estimateMessageUsage(userMessage) currentMessageId.current = userMessage.id - dispatch(_sendMessage(userMessage, assistant, topic)) + dispatch( + _sendMessage(userMessage, assistant, topic, { + mentions: mentionModels + }) + ) // Clear input setText('') diff --git a/src/renderer/src/pages/home/Messages/MessageMenubar.tsx b/src/renderer/src/pages/home/Messages/MessageMenubar.tsx index d1cb1b25..b35b2942 100644 --- a/src/renderer/src/pages/home/Messages/MessageMenubar.tsx +++ b/src/renderer/src/pages/home/Messages/MessageMenubar.tsx @@ -92,10 +92,10 @@ const MessageMenubar: FC = (props) => { // Resend all grouped messages if (!isEmpty(groupdMessages)) { - for (const assistantMessage of groupdMessages) { - const _model = assistantMessage.model || assistantModel - await resendMessage({ ...assistantMessage, model: _model }, assistant) - } + // for (const assistantMessage of groupdMessages) { + // const _model = assistantMessage.model || assistantModel + await resendMessage(message, assistant) + // } return } diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index 845a2cd9..b98d6bce 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -10,7 +10,6 @@ import { isReasoningModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' -import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages, filterEmptyMessages, @@ -241,13 +240,13 @@ export default class AnthropicProvider extends BaseProvider { return new Promise((resolve, reject) => { const toolCalls: ToolUseBlock[] = [] let hasThinkingContent = false - const stream = this.sdk.messages + this.sdk.messages .stream({ ...body, stream: true }, { signal }) .on('text', (text) => { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - stream.controller.abort() - return resolve() - } + // if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { + // stream.controller.abort() + // return resolve() + // } if (time_first_token_millsec == 0) { time_first_token_millsec = new Date().getTime() - start_time_millsec @@ -357,15 +356,13 @@ export default class AnthropicProvider extends BaseProvider { resolve() }) .on('error', (error) => reject(error)) + .on('abort', () => { + reject(new Error('Request was aborted.')) + }) }) } - await processStream(body, 0) - .catch((error) => { - // 不加这个错误抛不出来 - throw error - }) - .finally(cleanup) + await processStream(body, 0).finally(cleanup) } /** diff --git a/src/renderer/src/providers/BaseProvider.ts b/src/renderer/src/providers/BaseProvider.ts index c9ce9a21..379f775c 100644 --- a/src/renderer/src/providers/BaseProvider.ts +++ b/src/renderer/src/providers/BaseProvider.ts @@ -160,24 +160,45 @@ export default abstract class BaseProvider { ) } - protected createAbortController(messageId?: string) { + protected createAbortController(messageId?: string, isAddEventListener?: boolean) { const abortController = new AbortController() + const abortFn = () => abortController.abort() if (messageId) { - addAbortController(messageId, () => abortController.abort()) + addAbortController(messageId, abortFn) } const cleanup = () => { if (messageId) { - removeAbortController(messageId) + signalPromise.resolve?.(undefined) + removeAbortController(messageId, abortFn) } } + const signalPromise: { + resolve: (value: unknown) => void + promise: Promise + } = { + resolve: () => {}, + promise: Promise.resolve() + } - abortController.signal.addEventListener('abort', () => { - // 兼容 - cleanup() - }) - + if (isAddEventListener) { + signalPromise.promise = new Promise((resolve, reject) => { + signalPromise.resolve = resolve + if (abortController.signal.aborted) { + reject(new Error('Request was aborted.')) + } + // 捕获abort事件,有些abort事件必须 + abortController.signal.addEventListener('abort', () => { + reject(new Error('Request was aborted.')) + }) + }) + return { + abortController, + cleanup, + signalPromise + } + } return { abortController, cleanup diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index f0740ff0..7817f795 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -326,7 +326,7 @@ export default class GeminiProvider extends BaseProvider { }) const newChat = geminiModel.startChat({ history }) const newStream = await newChat.sendMessageStream(fcRespParts, { signal }) - await processStream(newStream, idx + 1).finally(cleanup) + await processStream(newStream, idx + 1) } } diff --git a/src/renderer/src/providers/OpenAIProvider.ts b/src/renderer/src/providers/OpenAIProvider.ts index b31741ae..45689465 100644 --- a/src/renderer/src/providers/OpenAIProvider.ts +++ b/src/renderer/src/providers/OpenAIProvider.ts @@ -357,7 +357,6 @@ export default class OpenAIProvider extends BaseProvider { } const userMessages: ChatCompletionMessageParam[] = [] - const _messages = filterUserRoleStartMessages( filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1))) ) @@ -414,7 +413,7 @@ export default class OpenAIProvider extends BaseProvider { let time_first_content_millsec = 0 const start_time_millsec = new Date().getTime() const lastUserMessage = _messages.findLast((m) => m.role === 'user') - const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) + const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true) const { signal } = abortController mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs) @@ -425,7 +424,6 @@ export default class OpenAIProvider extends BaseProvider { ) as ChatCompletionMessageParam[] const toolResponses: MCPToolResponse[] = [] - const processStream = async (stream: any, idx: number) => { if (!isSupportStreamOutput()) { const time_completion_millsec = new Date().getTime() - start_time_millsec @@ -593,6 +591,10 @@ export default class OpenAIProvider extends BaseProvider { ) await processStream(stream, 0).finally(cleanup) + // 捕获signal的错误 + await signalPromise?.promise?.catch((error) => { + throw error + }) } /** diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index c1ce13c2..db6f5ce2 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -34,12 +34,6 @@ export async function fetchChatCompletion({ const webSearchProvider = WebSearchService.getWebSearchProvider() const AI = new AiProvider(provider) - // store.dispatch(setGenerating(true)) - - // onResponse({ ...message }) - - // addAbortController(message.askId ?? message.id) - try { let _messages: Message[] = [] let isFirstChunk = true @@ -70,7 +64,6 @@ export async function fetchChatCompletion({ } const allMCPTools = await window.api.mcp.listTools() - await AI.completions({ messages: filterUsefulMessages(messages), assistant, @@ -127,9 +120,12 @@ export async function fetchChatCompletion({ } } } catch (error: any) { - if (isAbortError(error)) return - message.status = 'error' - message.error = formatMessageError(error) + if (isAbortError(error)) { + message.status = 'paused' + } else { + message.status = 'error' + message.error = formatMessageError(error) + } } // Emit chat completion event diff --git a/src/renderer/src/services/MessagesService.ts b/src/renderer/src/services/MessagesService.ts index 1f3a18d9..c3a2b86a 100644 --- a/src/renderer/src/services/MessagesService.ts +++ b/src/renderer/src/services/MessagesService.ts @@ -53,7 +53,7 @@ export function filterEmptyMessages(messages: Message[]): Message[] { } export function filterUsefulMessages(messages: Message[]): Message[] { - const _messages = messages + const _messages = [...messages] const groupedMessages = getGroupedMessages(messages) Object.entries(groupedMessages).forEach(([key, messages]) => { diff --git a/src/renderer/src/store/messages.ts b/src/renderer/src/store/messages.ts index 75096253..3c55efe6 100644 --- a/src/renderer/src/store/messages.ts +++ b/src/renderer/src/store/messages.ts @@ -6,6 +6,7 @@ import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService' import { getAssistantMessage, resetAssistantMessage } from '@renderer/services/MessagesService' import type { AppDispatch, RootState } from '@renderer/store' import type { Assistant, Message, Topic } from '@renderer/types' +import { Model } from '@renderer/types' import { clearTopicQueue, getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue' import { throttle } from 'lodash' @@ -105,6 +106,29 @@ const messagesSlice = createSlice({ state.messagesByTopic[topicId].push(messages) } }, + appendMessage: ( + state, + action: PayloadAction<{ topicId: string; messages: Message | Message[]; position?: number }> + ) => { + const { topicId, messages, position } = action.payload + if (!state.messagesByTopic[topicId]) { + state.messagesByTopic[topicId] = [] + } + + // 确保消息数组存在并且拿到引用 + const messagesList = state.messagesByTopic[topicId] + + // 要插入的消息 + const messagesToInsert = Array.isArray(messages) ? messages : [messages] + + if (position !== undefined && position >= 0 && position <= messagesList.length) { + // 如果指定了位置,在特定位置插入消息 + messagesList.splice(position, 0, ...messagesToInsert) + } else { + // 否则默认添加到末尾 + messagesList.push(...messagesToInsert) + } + }, updateMessage: ( state, action: PayloadAction<{ topicId: string; messageId: string; updates: Partial }> @@ -233,8 +257,9 @@ export const sendMessage = assistant: Assistant, topic: Topic, options?: { - resendAssistantMessage?: Message + resendAssistantMessage?: Message | Message[] isMentionModel?: boolean + mentions?: Model[] } ) => async (dispatch: AppDispatch, getState: () => RootState) => { @@ -255,17 +280,27 @@ export const sendMessage = if (options?.resendAssistantMessage) { // 直接使用传入的助手消息,进行重置 const messageToReset = options.resendAssistantMessage - const { model, id } = messageToReset - const resetMessage = resetAssistantMessage(messageToReset, model) - // 更新状态 - dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage })) - // 使用重置后的消息 - assistantMessages.push(resetMessage) + if (Array.isArray(messageToReset)) { + assistantMessages = messageToReset.map((m) => { + const { model, id } = m + const resetMessage = resetAssistantMessage(m, model) + // 更新状态 + dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage })) + // 使用重置后的消息 + return resetMessage + }) + } else { + const { model, id } = messageToReset + const resetMessage = resetAssistantMessage(messageToReset, model) + // 更新状态 + dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage })) + // 使用重置后的消息 + assistantMessages.push(resetMessage) + } } else { - // 不是重发情况 - if (userMessage.mentions?.length) { - // 为每个被 mention 的模型创建一个助手消息 - assistantMessages = userMessage.mentions.map((m) => { + // 为每个被 mention 的模型创建一个助手消息 + if (options?.mentions?.length) { + assistantMessages = options?.mentions.map((m) => { const assistantMessage = getAssistantMessage({ assistant: { ...assistant, model: m }, topic }) assistantMessage.model = m assistantMessage.askId = userMessage.id @@ -280,19 +315,36 @@ export const sendMessage = assistantMessages.push(assistantMessage) } + // 获取当前消息列表 + const currentMessages = getState().messages.messagesByTopic[topic.id] + + // 最后一个具有相同askId的助手消息,在其后插入 + let position: number | undefined + if (options?.isMentionModel) { + const lastAssistantIndex = currentMessages.findLastIndex( + (m) => m.role === 'assistant' && m.askId === userMessage.id + ) + if (lastAssistantIndex !== -1) { + position = lastAssistantIndex + 1 + } + } + dispatch( - addMessage({ + appendMessage({ topicId: topic.id, - messages: !options?.isMentionModel ? [userMessage, ...assistantMessages] : assistantMessages + messages: !options?.isMentionModel ? [userMessage, ...assistantMessages] : assistantMessages, + position }) ) } - + for (const assistantMessage of assistantMessages) { + // for of会收到await 影响,在暂停的时候会因为异步的原因有概率拿不到数据 + dispatch(setStreamMessage({ topicId: topic.id, message: assistantMessage })) + } const queue = getTopicQueue(topic.id) for (const assistantMessage of assistantMessages) { // Set as stream message instead of adding to messages - dispatch(setStreamMessage({ topicId: topic.id, message: assistantMessage })) // Sync user message with database const state = getState() @@ -303,7 +355,7 @@ export const sendMessage = } // 保证请求有序,防止请求静态,限制并发数量 - await queue.add(async () => { + queue.add(async () => { try { const messages = getState().messages.messagesByTopic[topic.id] if (!messages) { @@ -324,13 +376,30 @@ export const sendMessage = // 节流 const throttledDispatch = throttle(handleResponseMessageUpdate, 100, { trailing: true }) // 100ms的节流时间应足够平衡用户体验和性能 + // 寻找当前正在处理的消息在消息列表中的位置 + // const messageIndex = messages.findIndex((m) => m.id === assistantMessage.id) + const handleMessages = (): Message[] => { + // 找到对应的用户消息位置 + const userMessageIndex = messages.findIndex((m) => m.id === assistantMessage.askId) - const messageIndex = messages.findIndex((m) => m.id === assistantMessage.id) + if (userMessageIndex !== -1) { + // 先截取到用户消息为止的所有消息,再进行过滤 + const messagesUpToUser = messages.slice(0, userMessageIndex + 1) + return messagesUpToUser.filter((m) => !m.status?.includes('ing')) + } + + // 如果找不到对应的用户消息,使用原有逻辑 + // 按理说不会找不到 先注释掉看看 + // if (messageIndex !== -1) { + // const messagesUpToAssistant = messages.slice(0, messageIndex) + // return messagesUpToAssistant.filter((m) => !m.status?.includes('ing')) + // } + // 没有找到消息索引的情况,过滤所有消息 + return messages.filter((m) => !m.status?.includes('ing')) + } await fetchChatCompletion({ message: { ...assistantMessage }, - messages: messages - .filter((m) => !m.status?.includes('ing')) - .slice(0, messageIndex !== -1 ? messageIndex : undefined), + messages: handleMessages(), assistant: assistantWithModel, onResponse: async (msg) => { // 允许在回调外维护一个最新的消息状态,每次都更新这个对象,但只通过节流函数分发到Redux @@ -362,13 +431,14 @@ export const sendMessage = } }) } - // 等待所有请求完成,设置loading - await queue.onIdle() - dispatch(setTopicLoading({ topicId: topic.id, loading: false })) } catch (error: any) { console.error('Error in sendMessage:', error) dispatch(setError(error.message)) dispatch(setTopicLoading({ topicId: topic.id, loading: false })) + } finally { + // 等待所有请求完成,设置loading + await waitForTopicQueue(topic.id) + dispatch(setTopicLoading({ topicId: topic.id, loading: false })) } } @@ -385,7 +455,7 @@ export const resendMessage = // 如果是用户消息,直接重发 if (message.role === 'user') { // 查找此用户消息对应的助手消息 - const assistantMessage = topicMessages.find((m) => m.role === 'assistant' && m.askId === message.id) + const assistantMessage = topicMessages.filter((m) => m.role === 'assistant' && m.askId === message.id) return dispatch( sendMessage(message, assistant, topic, { resendAssistantMessage: assistantMessage, @@ -408,6 +478,7 @@ export const resendMessage = return dispatch(sendMessage(userMessage, assistant, topic, { isMentionModel })) } + console.log('assistantMessage', message) dispatch( sendMessage(userMessage, assistant, topic, { resendAssistantMessage: message @@ -521,7 +592,8 @@ export const { loadTopicMessages, setStreamMessage, commitStreamMessage, - clearStreamMessage + clearStreamMessage, + appendMessage } = messagesSlice.actions export default messagesSlice.reducer diff --git a/src/renderer/src/utils/abortController.ts b/src/renderer/src/utils/abortController.ts index 98195d80..8d3f59bd 100644 --- a/src/renderer/src/utils/abortController.ts +++ b/src/renderer/src/utils/abortController.ts @@ -1,25 +1,22 @@ -export const abortMap = new Map void>() +export const abortMap = new Map void)[]>() export const addAbortController = (id: string, abortFn: () => void) => { - let callback = abortFn - const existingCallback = abortMap.get(id) - if (existingCallback) { - callback = () => { - existingCallback?.() - abortFn() - } - } - abortMap.set(id, callback) + abortMap.set(id, [...(abortMap.get(id) || []), abortFn]) } -export const removeAbortController = (id: string) => { - abortMap.delete(id) +export const removeAbortController = (id: string, abortFn: () => void) => { + const callbackArr = abortMap.get(id) + if (abortFn) { + callbackArr?.splice(callbackArr?.indexOf(abortFn), 1) + } else abortMap.delete(id) } export const abortCompletion = (id: string) => { - const abortFn = abortMap.get(id) - if (abortFn) { - abortFn() - removeAbortController(id) + const abortFns = abortMap.get(id) + if (abortFns?.length) { + for (const fn of [...abortFns]) { + fn() + removeAbortController(id, fn) + } } } diff --git a/src/renderer/src/utils/queue.ts b/src/renderer/src/utils/queue.ts index 49144294..62e0af8d 100644 --- a/src/renderer/src/utils/queue.ts +++ b/src/renderer/src/utils/queue.ts @@ -8,13 +8,9 @@ const requestQueues: { [topicId: string]: PQueue } = {} * @param topicId The ID of the topic * @returns A PQueue instance for the topic */ -export const getTopicQueue = (topicId: string): PQueue => { +export const getTopicQueue = (topicId: string, options = {}): PQueue => { if (!requestQueues[topicId]) { - requestQueues[topicId] = new PQueue({ - concurrency: 4, - timeout: 1000 * 60 * 5, // 5 minutes - throwOnTimeout: false - }) + requestQueues[topicId] = new PQueue(options) } return requestQueues[topicId] }