fix(MessageOperations): Improve message pause functionality and error handling

- Update pauseMessage method to handle both askId and messageId
- Add loading state reset when pausing messages
- Enhance error handling in providers with abort error detection
- Modify ApiService to handle aborted requests gracefully
- Add comprehensive isAbortError utility function
This commit is contained in:
MyPrototypeWhat 2025-03-11 23:39:03 +08:00 committed by 亢奋猫
parent 12d40713a9
commit f5d3c07161
7 changed files with 72 additions and 30 deletions

View File

@ -10,6 +10,7 @@ import {
selectTopicLoading,
selectTopicMessages,
setStreamMessage,
setTopicLoading,
updateMessage,
updateMessages
} from '@renderer/store/messages'
@ -155,14 +156,18 @@ export function useMessageOperations(topic: Topic) {
*
*/
const pauseMessage = useCallback(
async (messageId: string) => {
// 存的是用户消息的id也就是助手消息的askId
async (askId: string, messageId: string) => {
// 1. 调用 abort
abortCompletion(messageId)
abortCompletion(askId)
console.log('messageId', messageId)
// 2. 更新消息状态
await editMessage(messageId, { status: 'paused' })
// 3. 清理流式消息
// 3.更改loading状态
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
// 4. 清理流式消息
clearStreamMessageAction(messageId)
},
[editMessage, clearStreamMessageAction]
@ -173,15 +178,13 @@ export function useMessageOperations(topic: Topic) {
const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id]
if (streamMessages) {
// 获取所有流式消息的 askId
const askIds = new Set(
Object.values(streamMessages)
.map((msg) => msg.askId)
.filter(Boolean)
)
const askIds = Object.values(streamMessages)
.map((msg) => [msg.askId, msg.id])
.filter(([askId, id]) => askId && id)
// 对每个 askId 执行暂停
for (const askId of askIds) {
await pauseMessage(askId)
for (const [askId, id] of askIds) {
await pauseMessage(askId, id)
}
}
}, [topic.id, pauseMessage])

View File

@ -208,7 +208,7 @@ export default class AnthropicProvider extends BaseProvider {
const { signal } = abortController
const toolResponses: MCPToolResponse[] = []
const processStream = async (body: MessageCreateParamsNonStreaming) => {
const processStream = (body: MessageCreateParamsNonStreaming) => {
return new Promise<void>((resolve, reject) => {
const toolCalls: ToolUseBlock[] = []
let hasThinkingContent = false
@ -326,7 +326,12 @@ export default class AnthropicProvider extends BaseProvider {
})
}
await processStream(body).finally(cleanup)
await processStream(body)
.catch((error) => {
// 不加这个错误抛不出来
throw error
})
.finally(cleanup)
}
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {

View File

@ -160,13 +160,20 @@ export default abstract class BaseProvider {
addAbortController(messageId, () => abortController.abort())
}
const cleanup = () => {
if (messageId) {
removeAbortController(messageId)
}
}
abortController.signal.addEventListener('abort', () => {
// 兼容
cleanup()
})
return {
abortController,
cleanup: () => {
if (messageId) {
removeAbortController(messageId)
}
}
cleanup
}
}
}

View File

@ -197,9 +197,10 @@ export default class GeminiProvider extends BaseProvider {
const messageContents = await this.getMessageContents(userLastMessage!)
const start_time_millsec = new Date().getTime()
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
const { signal } = abortController
if (!streamOutput) {
const { response } = await chat.sendMessage(messageContents.parts)
const { response } = await chat.sendMessage(messageContents.parts, { signal })
const time_completion_millsec = new Date().getTime() - start_time_millsec
onChunk({
text: response.candidates?.[0].content.parts[0].text,
@ -218,13 +219,8 @@ export default class GeminiProvider extends BaseProvider {
return
}
const lastUserMessage = userMessages.findLast((m) => m.role === 'user')
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
let time_first_token_millsec = 0
const processStream = async (stream: GenerateContentStreamResult) => {
for await (const chunk of stream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
@ -297,7 +293,6 @@ export default class GeminiProvider extends BaseProvider {
})
}
}
await processStream(userMessagesStream).finally(cleanup)
}

View File

@ -3,7 +3,7 @@ import i18n from '@renderer/i18n'
import store from '@renderer/store'
import { setGenerating } from '@renderer/store/runtime'
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
import { formatMessageError } from '@renderer/utils/error'
import { formatMessageError, isAbortError } from '@renderer/utils/error'
import { cloneDeep, findLast, isEmpty } from 'lodash'
import AiProvider from '../providers/AiProvider'
@ -116,12 +116,18 @@ export async function fetchChatCompletion({
// Set metrics.completion_tokens
if (message.metrics && message?.usage?.completion_tokens) {
if (!message.metrics?.completion_tokens) {
message.metrics.completion_tokens = message.usage.completion_tokens
message = {
...message,
metrics: {
...message.metrics,
completion_tokens: message.usage.completion_tokens
}
}
}
}
}
} catch (error: any) {
console.log('error', error)
if (isAbortError(error)) return
message.status = 'error'
message.error = formatMessageError(error)
}

View File

@ -345,7 +345,6 @@ export const sendMessage =
onResponse: async (msg) => {
// 允许在回调外维护一个最新的消息状态每次都更新这个对象但只通过节流函数分发到Redux
const updateMessage = { ...msg, status: msg.status || 'pending', content: msg.content || '' }
// 创建节流函数限制Redux更新频率
// 使用节流函数更新Redux
throttledDispatch(
assistant,

View File

@ -62,3 +62,30 @@ export function formatMessageError(error: any): Record<string, any> {
export function getErrorMessage(error: any): string {
return error?.message || error?.toString() || ''
}
export const isAbortError = (error: any): boolean => {
// 检查错误消息
if (error?.message === 'Request was aborted.') {
return true
}
// 检查是否为 DOMException 类型的中止错误
if (error instanceof DOMException && error.name === 'AbortError') {
return true
}
console.log(
typeof error === 'object',
error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason')
)
// 检查 OpenAI 特定的错误结构
if (
(error &&
typeof error === 'object' &&
(error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason'))) ||
error.stack?.includes('OpenAI.makeRequest')
) {
return true
}
return false
}