From f5d3c07161fa9c4197ab898461bb17ea03ebdbe8 Mon Sep 17 00:00:00 2001 From: MyPrototypeWhat Date: Tue, 11 Mar 2025 23:39:03 +0800 Subject: [PATCH] fix(MessageOperations): Improve message pause functionality and error handling - Update pauseMessage method to handle both askId and messageId - Add loading state reset when pausing messages - Enhance error handling in providers with abort error detection - Modify ApiService to handle aborted requests gracefully - Add comprehensive isAbortError utility function --- .../src/hooks/useMessageOperations.ts | 25 +++++++++-------- .../src/providers/AnthropicProvider.ts | 9 +++++-- src/renderer/src/providers/BaseProvider.ts | 17 ++++++++---- src/renderer/src/providers/GeminiProvider.ts | 11 +++----- src/renderer/src/services/ApiService.ts | 12 ++++++--- src/renderer/src/store/messages.ts | 1 - src/renderer/src/utils/error.ts | 27 +++++++++++++++++++ 7 files changed, 72 insertions(+), 30 deletions(-) diff --git a/src/renderer/src/hooks/useMessageOperations.ts b/src/renderer/src/hooks/useMessageOperations.ts index 4ffe598e..ef7fc2c7 100644 --- a/src/renderer/src/hooks/useMessageOperations.ts +++ b/src/renderer/src/hooks/useMessageOperations.ts @@ -10,6 +10,7 @@ import { selectTopicLoading, selectTopicMessages, setStreamMessage, + setTopicLoading, updateMessage, updateMessages } from '@renderer/store/messages' @@ -155,14 +156,18 @@ export function useMessageOperations(topic: Topic) { * 暂停消息生成 */ const pauseMessage = useCallback( - async (messageId: string) => { + // 存的是用户消息的id,也就是助手消息的askId + async (askId: string, messageId: string) => { // 1. 调用 abort - abortCompletion(messageId) - + abortCompletion(askId) + console.log('messageId', messageId) // 2. 更新消息状态 await editMessage(messageId, { status: 'paused' }) - // 3. 清理流式消息 + // 3.更改loading状态 + dispatch(setTopicLoading({ topicId: topic.id, loading: false })) + + // 4. 清理流式消息 clearStreamMessageAction(messageId) }, [editMessage, clearStreamMessageAction] @@ -173,15 +178,13 @@ export function useMessageOperations(topic: Topic) { const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id] if (streamMessages) { // 获取所有流式消息的 askId - const askIds = new Set( - Object.values(streamMessages) - .map((msg) => msg.askId) - .filter(Boolean) - ) + const askIds = Object.values(streamMessages) + .map((msg) => [msg.askId, msg.id]) + .filter(([askId, id]) => askId && id) // 对每个 askId 执行暂停 - for (const askId of askIds) { - await pauseMessage(askId) + for (const [askId, id] of askIds) { + await pauseMessage(askId, id) } } }, [topic.id, pauseMessage]) diff --git a/src/renderer/src/providers/AnthropicProvider.ts b/src/renderer/src/providers/AnthropicProvider.ts index dc8ffaad..b2e84761 100644 --- a/src/renderer/src/providers/AnthropicProvider.ts +++ b/src/renderer/src/providers/AnthropicProvider.ts @@ -208,7 +208,7 @@ export default class AnthropicProvider extends BaseProvider { const { signal } = abortController const toolResponses: MCPToolResponse[] = [] - const processStream = async (body: MessageCreateParamsNonStreaming) => { + const processStream = (body: MessageCreateParamsNonStreaming) => { return new Promise((resolve, reject) => { const toolCalls: ToolUseBlock[] = [] let hasThinkingContent = false @@ -326,7 +326,12 @@ export default class AnthropicProvider extends BaseProvider { }) } - await processStream(body).finally(cleanup) + await processStream(body) + .catch((error) => { + // 不加这个错误抛不出来 + throw error + }) + .finally(cleanup) } public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { diff --git a/src/renderer/src/providers/BaseProvider.ts b/src/renderer/src/providers/BaseProvider.ts index 71309c60..36137800 100644 --- a/src/renderer/src/providers/BaseProvider.ts +++ b/src/renderer/src/providers/BaseProvider.ts @@ -160,13 +160,20 @@ export default abstract class BaseProvider { addAbortController(messageId, () => abortController.abort()) } + const cleanup = () => { + if (messageId) { + removeAbortController(messageId) + } + } + + abortController.signal.addEventListener('abort', () => { + // 兼容 + cleanup() + }) + return { abortController, - cleanup: () => { - if (messageId) { - removeAbortController(messageId) - } - } + cleanup } } } diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 778f78d3..fd90a4f5 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -197,9 +197,10 @@ export default class GeminiProvider extends BaseProvider { const messageContents = await this.getMessageContents(userLastMessage!) const start_time_millsec = new Date().getTime() - + const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) + const { signal } = abortController if (!streamOutput) { - const { response } = await chat.sendMessage(messageContents.parts) + const { response } = await chat.sendMessage(messageContents.parts, { signal }) const time_completion_millsec = new Date().getTime() - start_time_millsec onChunk({ text: response.candidates?.[0].content.parts[0].text, @@ -218,13 +219,8 @@ export default class GeminiProvider extends BaseProvider { return } - const lastUserMessage = userMessages.findLast((m) => m.role === 'user') - const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id) - const { signal } = abortController - const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }) let time_first_token_millsec = 0 - const processStream = async (stream: GenerateContentStreamResult) => { for await (const chunk of stream.stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break @@ -297,7 +293,6 @@ export default class GeminiProvider extends BaseProvider { }) } } - await processStream(userMessagesStream).finally(cleanup) } diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 2da9dcd7..c1004434 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -3,7 +3,7 @@ import i18n from '@renderer/i18n' import store from '@renderer/store' import { setGenerating } from '@renderer/store/runtime' import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types' -import { formatMessageError } from '@renderer/utils/error' +import { formatMessageError, isAbortError } from '@renderer/utils/error' import { cloneDeep, findLast, isEmpty } from 'lodash' import AiProvider from '../providers/AiProvider' @@ -116,12 +116,18 @@ export async function fetchChatCompletion({ // Set metrics.completion_tokens if (message.metrics && message?.usage?.completion_tokens) { if (!message.metrics?.completion_tokens) { - message.metrics.completion_tokens = message.usage.completion_tokens + message = { + ...message, + metrics: { + ...message.metrics, + completion_tokens: message.usage.completion_tokens + } + } } } } } catch (error: any) { - console.log('error', error) + if (isAbortError(error)) return message.status = 'error' message.error = formatMessageError(error) } diff --git a/src/renderer/src/store/messages.ts b/src/renderer/src/store/messages.ts index 54365d4b..416920cc 100644 --- a/src/renderer/src/store/messages.ts +++ b/src/renderer/src/store/messages.ts @@ -345,7 +345,6 @@ export const sendMessage = onResponse: async (msg) => { // 允许在回调外维护一个最新的消息状态,每次都更新这个对象,但只通过节流函数分发到Redux const updateMessage = { ...msg, status: msg.status || 'pending', content: msg.content || '' } - // 创建节流函数,限制Redux更新频率 // 使用节流函数更新Redux throttledDispatch( assistant, diff --git a/src/renderer/src/utils/error.ts b/src/renderer/src/utils/error.ts index 77d5da4e..ecd3c35d 100644 --- a/src/renderer/src/utils/error.ts +++ b/src/renderer/src/utils/error.ts @@ -62,3 +62,30 @@ export function formatMessageError(error: any): Record { export function getErrorMessage(error: any): string { return error?.message || error?.toString() || '' } + +export const isAbortError = (error: any): boolean => { + // 检查错误消息 + if (error?.message === 'Request was aborted.') { + return true + } + + // 检查是否为 DOMException 类型的中止错误 + if (error instanceof DOMException && error.name === 'AbortError') { + return true + } + console.log( + typeof error === 'object', + error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason') + ) + // 检查 OpenAI 特定的错误结构 + if ( + (error && + typeof error === 'object' && + (error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason'))) || + error.stack?.includes('OpenAI.makeRequest') + ) { + return true + } + + return false +}