fix: messages pause bug (#3343)

* refactor: Simplify message resend logic and enhance abort controller handling

- Updated MessageMenubar to streamline message resend functionality.
- Improved abort controller management in BaseProvider and related services.
- Adjusted sendMessage to handle both single and multiple assistant messages.
- Enhanced logging for better debugging and tracking of message flow.

* feat: Enhance message handling and queue management

- Updated Inputbar to include mentions in dispatched messages.
- Introduced appendMessage action to manage message insertion at specific positions in the state.
- Improved sendMessage logic to handle mentions and maintain message order.
- Refactored getTopicQueue to accept options for better queue configuration.

* refactor: Improve abort handling and message operations

- Refactored useMessageOperations to streamline message pausing logic.
- Enhanced abort controller in BaseProvider to handle abort events more effectively.
- Updated OpenAIProvider to utilize new abort handling mechanism.
- Adjusted fetchChatCompletion to set message status based on abort conditions.
- Improved message dispatching in sendMessage for better queue management.

* refactor: Enhance signal promise handling in BaseProvider and OpenAIProvider

- Updated signal handling in BaseProvider to use a structured signalPromise object for better clarity and management.
- Adjusted error handling in OpenAIProvider to correctly catch and throw errors from the signalPromise.
- Improved overall abort handling logic to ensure robust message operations.

* fix:lint
This commit is contained in:
MyPrototypeWhat 2025-03-14 17:57:33 +08:00 committed by GitHub
parent aa6ecb4814
commit 8faececa4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 194 additions and 108 deletions

View File

@ -158,34 +158,35 @@ export function useMessageOperations(topic: Topic) {
/**
*
*/
const pauseMessage = useCallback(
// 存的是用户消息的id也就是助手消息的askId
async (message: Message) => {
// 1. 调用 abort
message.askId && abortCompletion(message.askId)
// const pauseMessage = useCallback(
// // 存的是用户消息的id也就是助手消息的askId
// async (message: Message) => {
// // 1. 调用 abort
// 2. 更新消息状态
await editMessage(message.id, { status: 'paused', content: message.content })
// // 2. 更新消息状态,
// // await editMessage(message.id, { status: 'paused', content: message.content })
// 3.更改loading状态
dispatch(setTopicLoading({ topicId: message.topicId, loading: false }))
// // 3.更改loading状态
// dispatch(setTopicLoading({ topicId: message.topicId, loading: false }))
// 4. 清理流式消息
clearStreamMessageAction(message.id)
},
[editMessage, dispatch, clearStreamMessageAction]
)
// // 4. 清理流式消息
// // clearStreamMessageAction(message.id)
// },
// [editMessage, dispatch, clearStreamMessageAction]
// )
const pauseMessages = useCallback(async () => {
// 暂停的消息不需要在这更改status,通过catch判断abort错误之后设置message.status
const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id]
if (!streamMessages) return
// 不需要重复暂停
const askIds = [...new Set(Object.values(streamMessages).map((m) => m?.askId))]
if (streamMessages) {
const streamMessagesList = Object.values(streamMessages).filter((msg) => msg?.askId && msg?.id)
for (const message of streamMessagesList) {
message && (await pauseMessage(message))
}
for (const askId of askIds) {
askId && abortCompletion(askId)
}
}, [pauseMessage, topic.id])
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
}, [topic.id, dispatch])
/**
* /
@ -213,7 +214,7 @@ export function useMessageOperations(topic: Topic) {
clearStreamMessage: clearStreamMessageAction,
createNewContext,
clearTopicMessages: clearTopicMessagesAction,
pauseMessage,
// pauseMessage,
pauseMessages,
resumeMessage
}

View File

@ -175,7 +175,11 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
userMessage.usage = await estimateMessageUsage(userMessage)
currentMessageId.current = userMessage.id
dispatch(_sendMessage(userMessage, assistant, topic))
dispatch(
_sendMessage(userMessage, assistant, topic, {
mentions: mentionModels
})
)
// Clear input
setText('')

View File

@ -92,10 +92,10 @@ const MessageMenubar: FC<Props> = (props) => {
// Resend all grouped messages
if (!isEmpty(groupdMessages)) {
for (const assistantMessage of groupdMessages) {
const _model = assistantMessage.model || assistantModel
await resendMessage({ ...assistantMessage, model: _model }, assistant)
}
// for (const assistantMessage of groupdMessages) {
// const _model = assistantMessage.model || assistantModel
await resendMessage(message, assistant)
// }
return
}

View File

@ -10,7 +10,6 @@ import { isReasoningModel } from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import {
filterContextMessages,
filterEmptyMessages,
@ -241,13 +240,13 @@ export default class AnthropicProvider extends BaseProvider {
return new Promise<void>((resolve, reject) => {
const toolCalls: ToolUseBlock[] = []
let hasThinkingContent = false
const stream = this.sdk.messages
this.sdk.messages
.stream({ ...body, stream: true }, { signal })
.on('text', (text) => {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
stream.controller.abort()
return resolve()
}
// if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
// stream.controller.abort()
// return resolve()
// }
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime() - start_time_millsec
@ -357,15 +356,13 @@ export default class AnthropicProvider extends BaseProvider {
resolve()
})
.on('error', (error) => reject(error))
.on('abort', () => {
reject(new Error('Request was aborted.'))
})
})
}
await processStream(body, 0)
.catch((error) => {
// 不加这个错误抛不出来
throw error
})
.finally(cleanup)
await processStream(body, 0).finally(cleanup)
}
/**

View File

@ -160,24 +160,45 @@ export default abstract class BaseProvider {
)
}
protected createAbortController(messageId?: string) {
protected createAbortController(messageId?: string, isAddEventListener?: boolean) {
const abortController = new AbortController()
const abortFn = () => abortController.abort()
if (messageId) {
addAbortController(messageId, () => abortController.abort())
addAbortController(messageId, abortFn)
}
const cleanup = () => {
if (messageId) {
removeAbortController(messageId)
signalPromise.resolve?.(undefined)
removeAbortController(messageId, abortFn)
}
}
const signalPromise: {
resolve: (value: unknown) => void
promise: Promise<unknown>
} = {
resolve: () => {},
promise: Promise.resolve()
}
abortController.signal.addEventListener('abort', () => {
// 兼容
cleanup()
})
if (isAddEventListener) {
signalPromise.promise = new Promise((resolve, reject) => {
signalPromise.resolve = resolve
if (abortController.signal.aborted) {
reject(new Error('Request was aborted.'))
}
// 捕获abort事件,有些abort事件必须
abortController.signal.addEventListener('abort', () => {
reject(new Error('Request was aborted.'))
})
})
return {
abortController,
cleanup,
signalPromise
}
}
return {
abortController,
cleanup

View File

@ -326,7 +326,7 @@ export default class GeminiProvider extends BaseProvider {
})
const newChat = geminiModel.startChat({ history })
const newStream = await newChat.sendMessageStream(fcRespParts, { signal })
await processStream(newStream, idx + 1).finally(cleanup)
await processStream(newStream, idx + 1)
}
}

View File

@ -357,7 +357,6 @@ export default class OpenAIProvider extends BaseProvider {
}
const userMessages: ChatCompletionMessageParam[] = []
const _messages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
)
@ -414,7 +413,7 @@ export default class OpenAIProvider extends BaseProvider {
let time_first_content_millsec = 0
const start_time_millsec = new Date().getTime()
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true)
const { signal } = abortController
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
@ -425,7 +424,6 @@ export default class OpenAIProvider extends BaseProvider {
) as ChatCompletionMessageParam[]
const toolResponses: MCPToolResponse[] = []
const processStream = async (stream: any, idx: number) => {
if (!isSupportStreamOutput()) {
const time_completion_millsec = new Date().getTime() - start_time_millsec
@ -593,6 +591,10 @@ export default class OpenAIProvider extends BaseProvider {
)
await processStream(stream, 0).finally(cleanup)
// 捕获signal的错误
await signalPromise?.promise?.catch((error) => {
throw error
})
}
/**

View File

@ -34,12 +34,6 @@ export async function fetchChatCompletion({
const webSearchProvider = WebSearchService.getWebSearchProvider()
const AI = new AiProvider(provider)
// store.dispatch(setGenerating(true))
// onResponse({ ...message })
// addAbortController(message.askId ?? message.id)
try {
let _messages: Message[] = []
let isFirstChunk = true
@ -70,7 +64,6 @@ export async function fetchChatCompletion({
}
const allMCPTools = await window.api.mcp.listTools()
await AI.completions({
messages: filterUsefulMessages(messages),
assistant,
@ -127,9 +120,12 @@ export async function fetchChatCompletion({
}
}
} catch (error: any) {
if (isAbortError(error)) return
message.status = 'error'
message.error = formatMessageError(error)
if (isAbortError(error)) {
message.status = 'paused'
} else {
message.status = 'error'
message.error = formatMessageError(error)
}
}
// Emit chat completion event

View File

@ -53,7 +53,7 @@ export function filterEmptyMessages(messages: Message[]): Message[] {
}
export function filterUsefulMessages(messages: Message[]): Message[] {
const _messages = messages
const _messages = [...messages]
const groupedMessages = getGroupedMessages(messages)
Object.entries(groupedMessages).forEach(([key, messages]) => {

View File

@ -6,6 +6,7 @@ import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import { getAssistantMessage, resetAssistantMessage } from '@renderer/services/MessagesService'
import type { AppDispatch, RootState } from '@renderer/store'
import type { Assistant, Message, Topic } from '@renderer/types'
import { Model } from '@renderer/types'
import { clearTopicQueue, getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
import { throttle } from 'lodash'
@ -105,6 +106,29 @@ const messagesSlice = createSlice({
state.messagesByTopic[topicId].push(messages)
}
},
appendMessage: (
state,
action: PayloadAction<{ topicId: string; messages: Message | Message[]; position?: number }>
) => {
const { topicId, messages, position } = action.payload
if (!state.messagesByTopic[topicId]) {
state.messagesByTopic[topicId] = []
}
// 确保消息数组存在并且拿到引用
const messagesList = state.messagesByTopic[topicId]
// 要插入的消息
const messagesToInsert = Array.isArray(messages) ? messages : [messages]
if (position !== undefined && position >= 0 && position <= messagesList.length) {
// 如果指定了位置,在特定位置插入消息
messagesList.splice(position, 0, ...messagesToInsert)
} else {
// 否则默认添加到末尾
messagesList.push(...messagesToInsert)
}
},
updateMessage: (
state,
action: PayloadAction<{ topicId: string; messageId: string; updates: Partial<Message> }>
@ -233,8 +257,9 @@ export const sendMessage =
assistant: Assistant,
topic: Topic,
options?: {
resendAssistantMessage?: Message
resendAssistantMessage?: Message | Message[]
isMentionModel?: boolean
mentions?: Model[]
}
) =>
async (dispatch: AppDispatch, getState: () => RootState) => {
@ -255,17 +280,27 @@ export const sendMessage =
if (options?.resendAssistantMessage) {
// 直接使用传入的助手消息,进行重置
const messageToReset = options.resendAssistantMessage
const { model, id } = messageToReset
const resetMessage = resetAssistantMessage(messageToReset, model)
// 更新状态
dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage }))
// 使用重置后的消息
assistantMessages.push(resetMessage)
if (Array.isArray(messageToReset)) {
assistantMessages = messageToReset.map((m) => {
const { model, id } = m
const resetMessage = resetAssistantMessage(m, model)
// 更新状态
dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage }))
// 使用重置后的消息
return resetMessage
})
} else {
const { model, id } = messageToReset
const resetMessage = resetAssistantMessage(messageToReset, model)
// 更新状态
dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage }))
// 使用重置后的消息
assistantMessages.push(resetMessage)
}
} else {
// 不是重发情况
if (userMessage.mentions?.length) {
// 为每个被 mention 的模型创建一个助手消息
assistantMessages = userMessage.mentions.map((m) => {
// 为每个被 mention 的模型创建一个助手消息
if (options?.mentions?.length) {
assistantMessages = options?.mentions.map((m) => {
const assistantMessage = getAssistantMessage({ assistant: { ...assistant, model: m }, topic })
assistantMessage.model = m
assistantMessage.askId = userMessage.id
@ -280,19 +315,36 @@ export const sendMessage =
assistantMessages.push(assistantMessage)
}
// 获取当前消息列表
const currentMessages = getState().messages.messagesByTopic[topic.id]
// 最后一个具有相同askId的助手消息在其后插入
let position: number | undefined
if (options?.isMentionModel) {
const lastAssistantIndex = currentMessages.findLastIndex(
(m) => m.role === 'assistant' && m.askId === userMessage.id
)
if (lastAssistantIndex !== -1) {
position = lastAssistantIndex + 1
}
}
dispatch(
addMessage({
appendMessage({
topicId: topic.id,
messages: !options?.isMentionModel ? [userMessage, ...assistantMessages] : assistantMessages
messages: !options?.isMentionModel ? [userMessage, ...assistantMessages] : assistantMessages,
position
})
)
}
for (const assistantMessage of assistantMessages) {
// for of会收到await 影响,在暂停的时候会因为异步的原因有概率拿不到数据
dispatch(setStreamMessage({ topicId: topic.id, message: assistantMessage }))
}
const queue = getTopicQueue(topic.id)
for (const assistantMessage of assistantMessages) {
// Set as stream message instead of adding to messages
dispatch(setStreamMessage({ topicId: topic.id, message: assistantMessage }))
// Sync user message with database
const state = getState()
@ -303,7 +355,7 @@ export const sendMessage =
}
// 保证请求有序,防止请求静态,限制并发数量
await queue.add(async () => {
queue.add(async () => {
try {
const messages = getState().messages.messagesByTopic[topic.id]
if (!messages) {
@ -324,13 +376,30 @@ export const sendMessage =
// 节流
const throttledDispatch = throttle(handleResponseMessageUpdate, 100, { trailing: true }) // 100ms的节流时间应足够平衡用户体验和性能
// 寻找当前正在处理的消息在消息列表中的位置
// const messageIndex = messages.findIndex((m) => m.id === assistantMessage.id)
const handleMessages = (): Message[] => {
// 找到对应的用户消息位置
const userMessageIndex = messages.findIndex((m) => m.id === assistantMessage.askId)
const messageIndex = messages.findIndex((m) => m.id === assistantMessage.id)
if (userMessageIndex !== -1) {
// 先截取到用户消息为止的所有消息,再进行过滤
const messagesUpToUser = messages.slice(0, userMessageIndex + 1)
return messagesUpToUser.filter((m) => !m.status?.includes('ing'))
}
// 如果找不到对应的用户消息,使用原有逻辑
// 按理说不会找不到 先注释掉看看
// if (messageIndex !== -1) {
// const messagesUpToAssistant = messages.slice(0, messageIndex)
// return messagesUpToAssistant.filter((m) => !m.status?.includes('ing'))
// }
// 没有找到消息索引的情况,过滤所有消息
return messages.filter((m) => !m.status?.includes('ing'))
}
await fetchChatCompletion({
message: { ...assistantMessage },
messages: messages
.filter((m) => !m.status?.includes('ing'))
.slice(0, messageIndex !== -1 ? messageIndex : undefined),
messages: handleMessages(),
assistant: assistantWithModel,
onResponse: async (msg) => {
// 允许在回调外维护一个最新的消息状态每次都更新这个对象但只通过节流函数分发到Redux
@ -362,13 +431,14 @@ export const sendMessage =
}
})
}
// 等待所有请求完成,设置loading
await queue.onIdle()
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
} catch (error: any) {
console.error('Error in sendMessage:', error)
dispatch(setError(error.message))
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
} finally {
// 等待所有请求完成,设置loading
await waitForTopicQueue(topic.id)
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
}
}
@ -385,7 +455,7 @@ export const resendMessage =
// 如果是用户消息,直接重发
if (message.role === 'user') {
// 查找此用户消息对应的助手消息
const assistantMessage = topicMessages.find((m) => m.role === 'assistant' && m.askId === message.id)
const assistantMessage = topicMessages.filter((m) => m.role === 'assistant' && m.askId === message.id)
return dispatch(
sendMessage(message, assistant, topic, {
resendAssistantMessage: assistantMessage,
@ -408,6 +478,7 @@ export const resendMessage =
return dispatch(sendMessage(userMessage, assistant, topic, { isMentionModel }))
}
console.log('assistantMessage', message)
dispatch(
sendMessage(userMessage, assistant, topic, {
resendAssistantMessage: message
@ -521,7 +592,8 @@ export const {
loadTopicMessages,
setStreamMessage,
commitStreamMessage,
clearStreamMessage
clearStreamMessage,
appendMessage
} = messagesSlice.actions
export default messagesSlice.reducer

View File

@ -1,25 +1,22 @@
export const abortMap = new Map<string, () => void>()
export const abortMap = new Map<string, (() => void)[]>()
export const addAbortController = (id: string, abortFn: () => void) => {
let callback = abortFn
const existingCallback = abortMap.get(id)
if (existingCallback) {
callback = () => {
existingCallback?.()
abortFn()
}
}
abortMap.set(id, callback)
abortMap.set(id, [...(abortMap.get(id) || []), abortFn])
}
export const removeAbortController = (id: string) => {
abortMap.delete(id)
export const removeAbortController = (id: string, abortFn: () => void) => {
const callbackArr = abortMap.get(id)
if (abortFn) {
callbackArr?.splice(callbackArr?.indexOf(abortFn), 1)
} else abortMap.delete(id)
}
export const abortCompletion = (id: string) => {
const abortFn = abortMap.get(id)
if (abortFn) {
abortFn()
removeAbortController(id)
const abortFns = abortMap.get(id)
if (abortFns?.length) {
for (const fn of [...abortFns]) {
fn()
removeAbortController(id, fn)
}
}
}

View File

@ -8,13 +8,9 @@ const requestQueues: { [topicId: string]: PQueue } = {}
* @param topicId The ID of the topic
* @returns A PQueue instance for the topic
*/
export const getTopicQueue = (topicId: string): PQueue => {
export const getTopicQueue = (topicId: string, options = {}): PQueue => {
if (!requestQueues[topicId]) {
requestQueues[topicId] = new PQueue({
concurrency: 4,
timeout: 1000 * 60 * 5, // 5 minutes
throwOnTimeout: false
})
requestQueues[topicId] = new PQueue(options)
}
return requestQueues[topicId]
}