fix(MessageOperations): Improve message pause functionality and error handling
- Update pauseMessage method to handle both askId and messageId - Add loading state reset when pausing messages - Enhance error handling in providers with abort error detection - Modify ApiService to handle aborted requests gracefully - Add comprehensive isAbortError utility function
This commit is contained in:
parent
12d40713a9
commit
f5d3c07161
@ -10,6 +10,7 @@ import {
|
||||
selectTopicLoading,
|
||||
selectTopicMessages,
|
||||
setStreamMessage,
|
||||
setTopicLoading,
|
||||
updateMessage,
|
||||
updateMessages
|
||||
} from '@renderer/store/messages'
|
||||
@ -155,14 +156,18 @@ export function useMessageOperations(topic: Topic) {
|
||||
* 暂停消息生成
|
||||
*/
|
||||
const pauseMessage = useCallback(
|
||||
async (messageId: string) => {
|
||||
// 存的是用户消息的id,也就是助手消息的askId
|
||||
async (askId: string, messageId: string) => {
|
||||
// 1. 调用 abort
|
||||
abortCompletion(messageId)
|
||||
|
||||
abortCompletion(askId)
|
||||
console.log('messageId', messageId)
|
||||
// 2. 更新消息状态
|
||||
await editMessage(messageId, { status: 'paused' })
|
||||
|
||||
// 3. 清理流式消息
|
||||
// 3.更改loading状态
|
||||
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
|
||||
|
||||
// 4. 清理流式消息
|
||||
clearStreamMessageAction(messageId)
|
||||
},
|
||||
[editMessage, clearStreamMessageAction]
|
||||
@ -173,15 +178,13 @@ export function useMessageOperations(topic: Topic) {
|
||||
const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id]
|
||||
if (streamMessages) {
|
||||
// 获取所有流式消息的 askId
|
||||
const askIds = new Set(
|
||||
Object.values(streamMessages)
|
||||
.map((msg) => msg.askId)
|
||||
.filter(Boolean)
|
||||
)
|
||||
const askIds = Object.values(streamMessages)
|
||||
.map((msg) => [msg.askId, msg.id])
|
||||
.filter(([askId, id]) => askId && id)
|
||||
|
||||
// 对每个 askId 执行暂停
|
||||
for (const askId of askIds) {
|
||||
await pauseMessage(askId)
|
||||
for (const [askId, id] of askIds) {
|
||||
await pauseMessage(askId, id)
|
||||
}
|
||||
}
|
||||
}, [topic.id, pauseMessage])
|
||||
|
||||
@ -208,7 +208,7 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
const { signal } = abortController
|
||||
const toolResponses: MCPToolResponse[] = []
|
||||
|
||||
const processStream = async (body: MessageCreateParamsNonStreaming) => {
|
||||
const processStream = (body: MessageCreateParamsNonStreaming) => {
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
const toolCalls: ToolUseBlock[] = []
|
||||
let hasThinkingContent = false
|
||||
@ -326,7 +326,12 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
})
|
||||
}
|
||||
|
||||
await processStream(body).finally(cleanup)
|
||||
await processStream(body)
|
||||
.catch((error) => {
|
||||
// 不加这个错误抛不出来
|
||||
throw error
|
||||
})
|
||||
.finally(cleanup)
|
||||
}
|
||||
|
||||
public async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) {
|
||||
|
||||
@ -160,13 +160,20 @@ export default abstract class BaseProvider {
|
||||
addAbortController(messageId, () => abortController.abort())
|
||||
}
|
||||
|
||||
const cleanup = () => {
|
||||
if (messageId) {
|
||||
removeAbortController(messageId)
|
||||
}
|
||||
}
|
||||
|
||||
abortController.signal.addEventListener('abort', () => {
|
||||
// 兼容
|
||||
cleanup()
|
||||
})
|
||||
|
||||
return {
|
||||
abortController,
|
||||
cleanup: () => {
|
||||
if (messageId) {
|
||||
removeAbortController(messageId)
|
||||
}
|
||||
}
|
||||
cleanup
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -197,9 +197,10 @@ export default class GeminiProvider extends BaseProvider {
|
||||
const messageContents = await this.getMessageContents(userLastMessage!)
|
||||
|
||||
const start_time_millsec = new Date().getTime()
|
||||
|
||||
const { abortController, cleanup } = this.createAbortController(userLastMessage?.id)
|
||||
const { signal } = abortController
|
||||
if (!streamOutput) {
|
||||
const { response } = await chat.sendMessage(messageContents.parts)
|
||||
const { response } = await chat.sendMessage(messageContents.parts, { signal })
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
onChunk({
|
||||
text: response.candidates?.[0].content.parts[0].text,
|
||||
@ -218,13 +219,8 @@ export default class GeminiProvider extends BaseProvider {
|
||||
return
|
||||
}
|
||||
|
||||
const lastUserMessage = userMessages.findLast((m) => m.role === 'user')
|
||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal })
|
||||
let time_first_token_millsec = 0
|
||||
|
||||
const processStream = async (stream: GenerateContentStreamResult) => {
|
||||
for await (const chunk of stream.stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
@ -297,7 +293,6 @@ export default class GeminiProvider extends BaseProvider {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
await processStream(userMessagesStream).finally(cleanup)
|
||||
}
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import { setGenerating } from '@renderer/store/runtime'
|
||||
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { formatMessageError } from '@renderer/utils/error'
|
||||
import { formatMessageError, isAbortError } from '@renderer/utils/error'
|
||||
import { cloneDeep, findLast, isEmpty } from 'lodash'
|
||||
|
||||
import AiProvider from '../providers/AiProvider'
|
||||
@ -116,12 +116,18 @@ export async function fetchChatCompletion({
|
||||
// Set metrics.completion_tokens
|
||||
if (message.metrics && message?.usage?.completion_tokens) {
|
||||
if (!message.metrics?.completion_tokens) {
|
||||
message.metrics.completion_tokens = message.usage.completion_tokens
|
||||
message = {
|
||||
...message,
|
||||
metrics: {
|
||||
...message.metrics,
|
||||
completion_tokens: message.usage.completion_tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
console.log('error', error)
|
||||
if (isAbortError(error)) return
|
||||
message.status = 'error'
|
||||
message.error = formatMessageError(error)
|
||||
}
|
||||
|
||||
@ -345,7 +345,6 @@ export const sendMessage =
|
||||
onResponse: async (msg) => {
|
||||
// 允许在回调外维护一个最新的消息状态,每次都更新这个对象,但只通过节流函数分发到Redux
|
||||
const updateMessage = { ...msg, status: msg.status || 'pending', content: msg.content || '' }
|
||||
// 创建节流函数,限制Redux更新频率
|
||||
// 使用节流函数更新Redux
|
||||
throttledDispatch(
|
||||
assistant,
|
||||
|
||||
@ -62,3 +62,30 @@ export function formatMessageError(error: any): Record<string, any> {
|
||||
export function getErrorMessage(error: any): string {
|
||||
return error?.message || error?.toString() || ''
|
||||
}
|
||||
|
||||
export const isAbortError = (error: any): boolean => {
|
||||
// 检查错误消息
|
||||
if (error?.message === 'Request was aborted.') {
|
||||
return true
|
||||
}
|
||||
|
||||
// 检查是否为 DOMException 类型的中止错误
|
||||
if (error instanceof DOMException && error.name === 'AbortError') {
|
||||
return true
|
||||
}
|
||||
console.log(
|
||||
typeof error === 'object',
|
||||
error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason')
|
||||
)
|
||||
// 检查 OpenAI 特定的错误结构
|
||||
if (
|
||||
(error &&
|
||||
typeof error === 'object' &&
|
||||
(error.message === 'Request was aborted.' || error?.message?.includes('signal is aborted without reason'))) ||
|
||||
error.stack?.includes('OpenAI.makeRequest')
|
||||
) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user