fix: messages pause bug (#3343)
* refactor: Simplify message resend logic and enhance abort controller handling - Updated MessageMenubar to streamline message resend functionality. - Improved abort controller management in BaseProvider and related services. - Adjusted sendMessage to handle both single and multiple assistant messages. - Enhanced logging for better debugging and tracking of message flow. * feat: Enhance message handling and queue management - Updated Inputbar to include mentions in dispatched messages. - Introduced appendMessage action to manage message insertion at specific positions in the state. - Improved sendMessage logic to handle mentions and maintain message order. - Refactored getTopicQueue to accept options for better queue configuration. * refactor: Improve abort handling and message operations - Refactored useMessageOperations to streamline message pausing logic. - Enhanced abort controller in BaseProvider to handle abort events more effectively. - Updated OpenAIProvider to utilize new abort handling mechanism. - Adjusted fetchChatCompletion to set message status based on abort conditions. - Improved message dispatching in sendMessage for better queue management. * refactor: Enhance signal promise handling in BaseProvider and OpenAIProvider - Updated signal handling in BaseProvider to use a structured signalPromise object for better clarity and management. - Adjusted error handling in OpenAIProvider to correctly catch and throw errors from the signalPromise. - Improved overall abort handling logic to ensure robust message operations. * fix:lint
This commit is contained in:
parent
aa6ecb4814
commit
8faececa4c
@ -158,34 +158,35 @@ export function useMessageOperations(topic: Topic) {
|
|||||||
/**
|
/**
|
||||||
* 暂停消息生成
|
* 暂停消息生成
|
||||||
*/
|
*/
|
||||||
const pauseMessage = useCallback(
|
// const pauseMessage = useCallback(
|
||||||
// 存的是用户消息的id,也就是助手消息的askId
|
// // 存的是用户消息的id,也就是助手消息的askId
|
||||||
async (message: Message) => {
|
// async (message: Message) => {
|
||||||
// 1. 调用 abort
|
// // 1. 调用 abort
|
||||||
message.askId && abortCompletion(message.askId)
|
|
||||||
|
|
||||||
// 2. 更新消息状态
|
// // 2. 更新消息状态,
|
||||||
await editMessage(message.id, { status: 'paused', content: message.content })
|
// // await editMessage(message.id, { status: 'paused', content: message.content })
|
||||||
|
|
||||||
// 3.更改loading状态
|
// // 3.更改loading状态
|
||||||
dispatch(setTopicLoading({ topicId: message.topicId, loading: false }))
|
// dispatch(setTopicLoading({ topicId: message.topicId, loading: false }))
|
||||||
|
|
||||||
// 4. 清理流式消息
|
// // 4. 清理流式消息
|
||||||
clearStreamMessageAction(message.id)
|
// // clearStreamMessageAction(message.id)
|
||||||
},
|
// },
|
||||||
[editMessage, dispatch, clearStreamMessageAction]
|
// [editMessage, dispatch, clearStreamMessageAction]
|
||||||
)
|
// )
|
||||||
|
|
||||||
const pauseMessages = useCallback(async () => {
|
const pauseMessages = useCallback(async () => {
|
||||||
|
// 暂停的消息不需要在这更改status,通过catch判断abort错误之后设置message.status
|
||||||
const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id]
|
const streamMessages = store.getState().messages.streamMessagesByTopic[topic.id]
|
||||||
|
if (!streamMessages) return
|
||||||
|
// 不需要重复暂停
|
||||||
|
const askIds = [...new Set(Object.values(streamMessages).map((m) => m?.askId))]
|
||||||
|
|
||||||
if (streamMessages) {
|
for (const askId of askIds) {
|
||||||
const streamMessagesList = Object.values(streamMessages).filter((msg) => msg?.askId && msg?.id)
|
askId && abortCompletion(askId)
|
||||||
for (const message of streamMessagesList) {
|
|
||||||
message && (await pauseMessage(message))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}, [pauseMessage, topic.id])
|
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
|
||||||
|
}, [topic.id, dispatch])
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 恢复/重发消息
|
* 恢复/重发消息
|
||||||
@ -213,7 +214,7 @@ export function useMessageOperations(topic: Topic) {
|
|||||||
clearStreamMessage: clearStreamMessageAction,
|
clearStreamMessage: clearStreamMessageAction,
|
||||||
createNewContext,
|
createNewContext,
|
||||||
clearTopicMessages: clearTopicMessagesAction,
|
clearTopicMessages: clearTopicMessagesAction,
|
||||||
pauseMessage,
|
// pauseMessage,
|
||||||
pauseMessages,
|
pauseMessages,
|
||||||
resumeMessage
|
resumeMessage
|
||||||
}
|
}
|
||||||
|
|||||||
@ -175,7 +175,11 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
|||||||
userMessage.usage = await estimateMessageUsage(userMessage)
|
userMessage.usage = await estimateMessageUsage(userMessage)
|
||||||
currentMessageId.current = userMessage.id
|
currentMessageId.current = userMessage.id
|
||||||
|
|
||||||
dispatch(_sendMessage(userMessage, assistant, topic))
|
dispatch(
|
||||||
|
_sendMessage(userMessage, assistant, topic, {
|
||||||
|
mentions: mentionModels
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
// Clear input
|
// Clear input
|
||||||
setText('')
|
setText('')
|
||||||
|
|||||||
@ -92,10 +92,10 @@ const MessageMenubar: FC<Props> = (props) => {
|
|||||||
|
|
||||||
// Resend all grouped messages
|
// Resend all grouped messages
|
||||||
if (!isEmpty(groupdMessages)) {
|
if (!isEmpty(groupdMessages)) {
|
||||||
for (const assistantMessage of groupdMessages) {
|
// for (const assistantMessage of groupdMessages) {
|
||||||
const _model = assistantMessage.model || assistantModel
|
// const _model = assistantMessage.model || assistantModel
|
||||||
await resendMessage({ ...assistantMessage, model: _model }, assistant)
|
await resendMessage(message, assistant)
|
||||||
}
|
// }
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import { isReasoningModel } from '@renderer/config/models'
|
|||||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
|
||||||
import {
|
import {
|
||||||
filterContextMessages,
|
filterContextMessages,
|
||||||
filterEmptyMessages,
|
filterEmptyMessages,
|
||||||
@ -241,13 +240,13 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
const toolCalls: ToolUseBlock[] = []
|
const toolCalls: ToolUseBlock[] = []
|
||||||
let hasThinkingContent = false
|
let hasThinkingContent = false
|
||||||
const stream = this.sdk.messages
|
this.sdk.messages
|
||||||
.stream({ ...body, stream: true }, { signal })
|
.stream({ ...body, stream: true }, { signal })
|
||||||
.on('text', (text) => {
|
.on('text', (text) => {
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
// if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||||
stream.controller.abort()
|
// stream.controller.abort()
|
||||||
return resolve()
|
// return resolve()
|
||||||
}
|
// }
|
||||||
|
|
||||||
if (time_first_token_millsec == 0) {
|
if (time_first_token_millsec == 0) {
|
||||||
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||||
@ -357,15 +356,13 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
.on('error', (error) => reject(error))
|
.on('error', (error) => reject(error))
|
||||||
|
.on('abort', () => {
|
||||||
|
reject(new Error('Request was aborted.'))
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
await processStream(body, 0)
|
await processStream(body, 0).finally(cleanup)
|
||||||
.catch((error) => {
|
|
||||||
// 不加这个错误抛不出来
|
|
||||||
throw error
|
|
||||||
})
|
|
||||||
.finally(cleanup)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -160,24 +160,45 @@ export default abstract class BaseProvider {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected createAbortController(messageId?: string) {
|
protected createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||||
const abortController = new AbortController()
|
const abortController = new AbortController()
|
||||||
|
const abortFn = () => abortController.abort()
|
||||||
|
|
||||||
if (messageId) {
|
if (messageId) {
|
||||||
addAbortController(messageId, () => abortController.abort())
|
addAbortController(messageId, abortFn)
|
||||||
}
|
}
|
||||||
|
|
||||||
const cleanup = () => {
|
const cleanup = () => {
|
||||||
if (messageId) {
|
if (messageId) {
|
||||||
removeAbortController(messageId)
|
signalPromise.resolve?.(undefined)
|
||||||
|
removeAbortController(messageId, abortFn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const signalPromise: {
|
||||||
|
resolve: (value: unknown) => void
|
||||||
|
promise: Promise<unknown>
|
||||||
|
} = {
|
||||||
|
resolve: () => {},
|
||||||
|
promise: Promise.resolve()
|
||||||
|
}
|
||||||
|
|
||||||
abortController.signal.addEventListener('abort', () => {
|
if (isAddEventListener) {
|
||||||
// 兼容
|
signalPromise.promise = new Promise((resolve, reject) => {
|
||||||
cleanup()
|
signalPromise.resolve = resolve
|
||||||
})
|
if (abortController.signal.aborted) {
|
||||||
|
reject(new Error('Request was aborted.'))
|
||||||
|
}
|
||||||
|
// 捕获abort事件,有些abort事件必须
|
||||||
|
abortController.signal.addEventListener('abort', () => {
|
||||||
|
reject(new Error('Request was aborted.'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
abortController,
|
||||||
|
cleanup,
|
||||||
|
signalPromise
|
||||||
|
}
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
abortController,
|
abortController,
|
||||||
cleanup
|
cleanup
|
||||||
|
|||||||
@ -326,7 +326,7 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
})
|
})
|
||||||
const newChat = geminiModel.startChat({ history })
|
const newChat = geminiModel.startChat({ history })
|
||||||
const newStream = await newChat.sendMessageStream(fcRespParts, { signal })
|
const newStream = await newChat.sendMessageStream(fcRespParts, { signal })
|
||||||
await processStream(newStream, idx + 1).finally(cleanup)
|
await processStream(newStream, idx + 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -357,7 +357,6 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const userMessages: ChatCompletionMessageParam[] = []
|
const userMessages: ChatCompletionMessageParam[] = []
|
||||||
|
|
||||||
const _messages = filterUserRoleStartMessages(
|
const _messages = filterUserRoleStartMessages(
|
||||||
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
||||||
)
|
)
|
||||||
@ -414,7 +413,7 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
let time_first_content_millsec = 0
|
let time_first_content_millsec = 0
|
||||||
const start_time_millsec = new Date().getTime()
|
const start_time_millsec = new Date().getTime()
|
||||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true)
|
||||||
const { signal } = abortController
|
const { signal } = abortController
|
||||||
|
|
||||||
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
|
mcpTools = filterMCPTools(mcpTools, lastUserMessage?.enabledMCPs)
|
||||||
@ -425,7 +424,6 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
) as ChatCompletionMessageParam[]
|
) as ChatCompletionMessageParam[]
|
||||||
|
|
||||||
const toolResponses: MCPToolResponse[] = []
|
const toolResponses: MCPToolResponse[] = []
|
||||||
|
|
||||||
const processStream = async (stream: any, idx: number) => {
|
const processStream = async (stream: any, idx: number) => {
|
||||||
if (!isSupportStreamOutput()) {
|
if (!isSupportStreamOutput()) {
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
@ -593,6 +591,10 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
)
|
)
|
||||||
|
|
||||||
await processStream(stream, 0).finally(cleanup)
|
await processStream(stream, 0).finally(cleanup)
|
||||||
|
// 捕获signal的错误
|
||||||
|
await signalPromise?.promise?.catch((error) => {
|
||||||
|
throw error
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -34,12 +34,6 @@ export async function fetchChatCompletion({
|
|||||||
const webSearchProvider = WebSearchService.getWebSearchProvider()
|
const webSearchProvider = WebSearchService.getWebSearchProvider()
|
||||||
const AI = new AiProvider(provider)
|
const AI = new AiProvider(provider)
|
||||||
|
|
||||||
// store.dispatch(setGenerating(true))
|
|
||||||
|
|
||||||
// onResponse({ ...message })
|
|
||||||
|
|
||||||
// addAbortController(message.askId ?? message.id)
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let _messages: Message[] = []
|
let _messages: Message[] = []
|
||||||
let isFirstChunk = true
|
let isFirstChunk = true
|
||||||
@ -70,7 +64,6 @@ export async function fetchChatCompletion({
|
|||||||
}
|
}
|
||||||
|
|
||||||
const allMCPTools = await window.api.mcp.listTools()
|
const allMCPTools = await window.api.mcp.listTools()
|
||||||
|
|
||||||
await AI.completions({
|
await AI.completions({
|
||||||
messages: filterUsefulMessages(messages),
|
messages: filterUsefulMessages(messages),
|
||||||
assistant,
|
assistant,
|
||||||
@ -127,9 +120,12 @@ export async function fetchChatCompletion({
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
if (isAbortError(error)) return
|
if (isAbortError(error)) {
|
||||||
message.status = 'error'
|
message.status = 'paused'
|
||||||
message.error = formatMessageError(error)
|
} else {
|
||||||
|
message.status = 'error'
|
||||||
|
message.error = formatMessageError(error)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Emit chat completion event
|
// Emit chat completion event
|
||||||
|
|||||||
@ -53,7 +53,7 @@ export function filterEmptyMessages(messages: Message[]): Message[] {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function filterUsefulMessages(messages: Message[]): Message[] {
|
export function filterUsefulMessages(messages: Message[]): Message[] {
|
||||||
const _messages = messages
|
const _messages = [...messages]
|
||||||
const groupedMessages = getGroupedMessages(messages)
|
const groupedMessages = getGroupedMessages(messages)
|
||||||
|
|
||||||
Object.entries(groupedMessages).forEach(([key, messages]) => {
|
Object.entries(groupedMessages).forEach(([key, messages]) => {
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
|||||||
import { getAssistantMessage, resetAssistantMessage } from '@renderer/services/MessagesService'
|
import { getAssistantMessage, resetAssistantMessage } from '@renderer/services/MessagesService'
|
||||||
import type { AppDispatch, RootState } from '@renderer/store'
|
import type { AppDispatch, RootState } from '@renderer/store'
|
||||||
import type { Assistant, Message, Topic } from '@renderer/types'
|
import type { Assistant, Message, Topic } from '@renderer/types'
|
||||||
|
import { Model } from '@renderer/types'
|
||||||
import { clearTopicQueue, getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
|
import { clearTopicQueue, getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
|
||||||
import { throttle } from 'lodash'
|
import { throttle } from 'lodash'
|
||||||
|
|
||||||
@ -105,6 +106,29 @@ const messagesSlice = createSlice({
|
|||||||
state.messagesByTopic[topicId].push(messages)
|
state.messagesByTopic[topicId].push(messages)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
appendMessage: (
|
||||||
|
state,
|
||||||
|
action: PayloadAction<{ topicId: string; messages: Message | Message[]; position?: number }>
|
||||||
|
) => {
|
||||||
|
const { topicId, messages, position } = action.payload
|
||||||
|
if (!state.messagesByTopic[topicId]) {
|
||||||
|
state.messagesByTopic[topicId] = []
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保消息数组存在并且拿到引用
|
||||||
|
const messagesList = state.messagesByTopic[topicId]
|
||||||
|
|
||||||
|
// 要插入的消息
|
||||||
|
const messagesToInsert = Array.isArray(messages) ? messages : [messages]
|
||||||
|
|
||||||
|
if (position !== undefined && position >= 0 && position <= messagesList.length) {
|
||||||
|
// 如果指定了位置,在特定位置插入消息
|
||||||
|
messagesList.splice(position, 0, ...messagesToInsert)
|
||||||
|
} else {
|
||||||
|
// 否则默认添加到末尾
|
||||||
|
messagesList.push(...messagesToInsert)
|
||||||
|
}
|
||||||
|
},
|
||||||
updateMessage: (
|
updateMessage: (
|
||||||
state,
|
state,
|
||||||
action: PayloadAction<{ topicId: string; messageId: string; updates: Partial<Message> }>
|
action: PayloadAction<{ topicId: string; messageId: string; updates: Partial<Message> }>
|
||||||
@ -233,8 +257,9 @@ export const sendMessage =
|
|||||||
assistant: Assistant,
|
assistant: Assistant,
|
||||||
topic: Topic,
|
topic: Topic,
|
||||||
options?: {
|
options?: {
|
||||||
resendAssistantMessage?: Message
|
resendAssistantMessage?: Message | Message[]
|
||||||
isMentionModel?: boolean
|
isMentionModel?: boolean
|
||||||
|
mentions?: Model[]
|
||||||
}
|
}
|
||||||
) =>
|
) =>
|
||||||
async (dispatch: AppDispatch, getState: () => RootState) => {
|
async (dispatch: AppDispatch, getState: () => RootState) => {
|
||||||
@ -255,17 +280,27 @@ export const sendMessage =
|
|||||||
if (options?.resendAssistantMessage) {
|
if (options?.resendAssistantMessage) {
|
||||||
// 直接使用传入的助手消息,进行重置
|
// 直接使用传入的助手消息,进行重置
|
||||||
const messageToReset = options.resendAssistantMessage
|
const messageToReset = options.resendAssistantMessage
|
||||||
const { model, id } = messageToReset
|
if (Array.isArray(messageToReset)) {
|
||||||
const resetMessage = resetAssistantMessage(messageToReset, model)
|
assistantMessages = messageToReset.map((m) => {
|
||||||
// 更新状态
|
const { model, id } = m
|
||||||
dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage }))
|
const resetMessage = resetAssistantMessage(m, model)
|
||||||
// 使用重置后的消息
|
// 更新状态
|
||||||
assistantMessages.push(resetMessage)
|
dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage }))
|
||||||
|
// 使用重置后的消息
|
||||||
|
return resetMessage
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const { model, id } = messageToReset
|
||||||
|
const resetMessage = resetAssistantMessage(messageToReset, model)
|
||||||
|
// 更新状态
|
||||||
|
dispatch(updateMessage({ topicId: topic.id, messageId: id, updates: resetMessage }))
|
||||||
|
// 使用重置后的消息
|
||||||
|
assistantMessages.push(resetMessage)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 不是重发情况
|
// 为每个被 mention 的模型创建一个助手消息
|
||||||
if (userMessage.mentions?.length) {
|
if (options?.mentions?.length) {
|
||||||
// 为每个被 mention 的模型创建一个助手消息
|
assistantMessages = options?.mentions.map((m) => {
|
||||||
assistantMessages = userMessage.mentions.map((m) => {
|
|
||||||
const assistantMessage = getAssistantMessage({ assistant: { ...assistant, model: m }, topic })
|
const assistantMessage = getAssistantMessage({ assistant: { ...assistant, model: m }, topic })
|
||||||
assistantMessage.model = m
|
assistantMessage.model = m
|
||||||
assistantMessage.askId = userMessage.id
|
assistantMessage.askId = userMessage.id
|
||||||
@ -280,19 +315,36 @@ export const sendMessage =
|
|||||||
assistantMessages.push(assistantMessage)
|
assistantMessages.push(assistantMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 获取当前消息列表
|
||||||
|
const currentMessages = getState().messages.messagesByTopic[topic.id]
|
||||||
|
|
||||||
|
// 最后一个具有相同askId的助手消息,在其后插入
|
||||||
|
let position: number | undefined
|
||||||
|
if (options?.isMentionModel) {
|
||||||
|
const lastAssistantIndex = currentMessages.findLastIndex(
|
||||||
|
(m) => m.role === 'assistant' && m.askId === userMessage.id
|
||||||
|
)
|
||||||
|
if (lastAssistantIndex !== -1) {
|
||||||
|
position = lastAssistantIndex + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
dispatch(
|
dispatch(
|
||||||
addMessage({
|
appendMessage({
|
||||||
topicId: topic.id,
|
topicId: topic.id,
|
||||||
messages: !options?.isMentionModel ? [userMessage, ...assistantMessages] : assistantMessages
|
messages: !options?.isMentionModel ? [userMessage, ...assistantMessages] : assistantMessages,
|
||||||
|
position
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
for (const assistantMessage of assistantMessages) {
|
||||||
|
// for of会收到await 影响,在暂停的时候会因为异步的原因有概率拿不到数据
|
||||||
|
dispatch(setStreamMessage({ topicId: topic.id, message: assistantMessage }))
|
||||||
|
}
|
||||||
const queue = getTopicQueue(topic.id)
|
const queue = getTopicQueue(topic.id)
|
||||||
|
|
||||||
for (const assistantMessage of assistantMessages) {
|
for (const assistantMessage of assistantMessages) {
|
||||||
// Set as stream message instead of adding to messages
|
// Set as stream message instead of adding to messages
|
||||||
dispatch(setStreamMessage({ topicId: topic.id, message: assistantMessage }))
|
|
||||||
|
|
||||||
// Sync user message with database
|
// Sync user message with database
|
||||||
const state = getState()
|
const state = getState()
|
||||||
@ -303,7 +355,7 @@ export const sendMessage =
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 保证请求有序,防止请求静态,限制并发数量
|
// 保证请求有序,防止请求静态,限制并发数量
|
||||||
await queue.add(async () => {
|
queue.add(async () => {
|
||||||
try {
|
try {
|
||||||
const messages = getState().messages.messagesByTopic[topic.id]
|
const messages = getState().messages.messagesByTopic[topic.id]
|
||||||
if (!messages) {
|
if (!messages) {
|
||||||
@ -324,13 +376,30 @@ export const sendMessage =
|
|||||||
|
|
||||||
// 节流
|
// 节流
|
||||||
const throttledDispatch = throttle(handleResponseMessageUpdate, 100, { trailing: true }) // 100ms的节流时间应足够平衡用户体验和性能
|
const throttledDispatch = throttle(handleResponseMessageUpdate, 100, { trailing: true }) // 100ms的节流时间应足够平衡用户体验和性能
|
||||||
|
// 寻找当前正在处理的消息在消息列表中的位置
|
||||||
|
// const messageIndex = messages.findIndex((m) => m.id === assistantMessage.id)
|
||||||
|
const handleMessages = (): Message[] => {
|
||||||
|
// 找到对应的用户消息位置
|
||||||
|
const userMessageIndex = messages.findIndex((m) => m.id === assistantMessage.askId)
|
||||||
|
|
||||||
const messageIndex = messages.findIndex((m) => m.id === assistantMessage.id)
|
if (userMessageIndex !== -1) {
|
||||||
|
// 先截取到用户消息为止的所有消息,再进行过滤
|
||||||
|
const messagesUpToUser = messages.slice(0, userMessageIndex + 1)
|
||||||
|
return messagesUpToUser.filter((m) => !m.status?.includes('ing'))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果找不到对应的用户消息,使用原有逻辑
|
||||||
|
// 按理说不会找不到 先注释掉看看
|
||||||
|
// if (messageIndex !== -1) {
|
||||||
|
// const messagesUpToAssistant = messages.slice(0, messageIndex)
|
||||||
|
// return messagesUpToAssistant.filter((m) => !m.status?.includes('ing'))
|
||||||
|
// }
|
||||||
|
// 没有找到消息索引的情况,过滤所有消息
|
||||||
|
return messages.filter((m) => !m.status?.includes('ing'))
|
||||||
|
}
|
||||||
await fetchChatCompletion({
|
await fetchChatCompletion({
|
||||||
message: { ...assistantMessage },
|
message: { ...assistantMessage },
|
||||||
messages: messages
|
messages: handleMessages(),
|
||||||
.filter((m) => !m.status?.includes('ing'))
|
|
||||||
.slice(0, messageIndex !== -1 ? messageIndex : undefined),
|
|
||||||
assistant: assistantWithModel,
|
assistant: assistantWithModel,
|
||||||
onResponse: async (msg) => {
|
onResponse: async (msg) => {
|
||||||
// 允许在回调外维护一个最新的消息状态,每次都更新这个对象,但只通过节流函数分发到Redux
|
// 允许在回调外维护一个最新的消息状态,每次都更新这个对象,但只通过节流函数分发到Redux
|
||||||
@ -362,13 +431,14 @@ export const sendMessage =
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
// 等待所有请求完成,设置loading
|
|
||||||
await queue.onIdle()
|
|
||||||
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
console.error('Error in sendMessage:', error)
|
console.error('Error in sendMessage:', error)
|
||||||
dispatch(setError(error.message))
|
dispatch(setError(error.message))
|
||||||
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
|
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
|
||||||
|
} finally {
|
||||||
|
// 等待所有请求完成,设置loading
|
||||||
|
await waitForTopicQueue(topic.id)
|
||||||
|
dispatch(setTopicLoading({ topicId: topic.id, loading: false }))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -385,7 +455,7 @@ export const resendMessage =
|
|||||||
// 如果是用户消息,直接重发
|
// 如果是用户消息,直接重发
|
||||||
if (message.role === 'user') {
|
if (message.role === 'user') {
|
||||||
// 查找此用户消息对应的助手消息
|
// 查找此用户消息对应的助手消息
|
||||||
const assistantMessage = topicMessages.find((m) => m.role === 'assistant' && m.askId === message.id)
|
const assistantMessage = topicMessages.filter((m) => m.role === 'assistant' && m.askId === message.id)
|
||||||
return dispatch(
|
return dispatch(
|
||||||
sendMessage(message, assistant, topic, {
|
sendMessage(message, assistant, topic, {
|
||||||
resendAssistantMessage: assistantMessage,
|
resendAssistantMessage: assistantMessage,
|
||||||
@ -408,6 +478,7 @@ export const resendMessage =
|
|||||||
return dispatch(sendMessage(userMessage, assistant, topic, { isMentionModel }))
|
return dispatch(sendMessage(userMessage, assistant, topic, { isMentionModel }))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.log('assistantMessage', message)
|
||||||
dispatch(
|
dispatch(
|
||||||
sendMessage(userMessage, assistant, topic, {
|
sendMessage(userMessage, assistant, topic, {
|
||||||
resendAssistantMessage: message
|
resendAssistantMessage: message
|
||||||
@ -521,7 +592,8 @@ export const {
|
|||||||
loadTopicMessages,
|
loadTopicMessages,
|
||||||
setStreamMessage,
|
setStreamMessage,
|
||||||
commitStreamMessage,
|
commitStreamMessage,
|
||||||
clearStreamMessage
|
clearStreamMessage,
|
||||||
|
appendMessage
|
||||||
} = messagesSlice.actions
|
} = messagesSlice.actions
|
||||||
|
|
||||||
export default messagesSlice.reducer
|
export default messagesSlice.reducer
|
||||||
|
|||||||
@ -1,25 +1,22 @@
|
|||||||
export const abortMap = new Map<string, () => void>()
|
export const abortMap = new Map<string, (() => void)[]>()
|
||||||
|
|
||||||
export const addAbortController = (id: string, abortFn: () => void) => {
|
export const addAbortController = (id: string, abortFn: () => void) => {
|
||||||
let callback = abortFn
|
abortMap.set(id, [...(abortMap.get(id) || []), abortFn])
|
||||||
const existingCallback = abortMap.get(id)
|
|
||||||
if (existingCallback) {
|
|
||||||
callback = () => {
|
|
||||||
existingCallback?.()
|
|
||||||
abortFn()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
abortMap.set(id, callback)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export const removeAbortController = (id: string) => {
|
export const removeAbortController = (id: string, abortFn: () => void) => {
|
||||||
abortMap.delete(id)
|
const callbackArr = abortMap.get(id)
|
||||||
|
if (abortFn) {
|
||||||
|
callbackArr?.splice(callbackArr?.indexOf(abortFn), 1)
|
||||||
|
} else abortMap.delete(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
export const abortCompletion = (id: string) => {
|
export const abortCompletion = (id: string) => {
|
||||||
const abortFn = abortMap.get(id)
|
const abortFns = abortMap.get(id)
|
||||||
if (abortFn) {
|
if (abortFns?.length) {
|
||||||
abortFn()
|
for (const fn of [...abortFns]) {
|
||||||
removeAbortController(id)
|
fn()
|
||||||
|
removeAbortController(id, fn)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,13 +8,9 @@ const requestQueues: { [topicId: string]: PQueue } = {}
|
|||||||
* @param topicId The ID of the topic
|
* @param topicId The ID of the topic
|
||||||
* @returns A PQueue instance for the topic
|
* @returns A PQueue instance for the topic
|
||||||
*/
|
*/
|
||||||
export const getTopicQueue = (topicId: string): PQueue => {
|
export const getTopicQueue = (topicId: string, options = {}): PQueue => {
|
||||||
if (!requestQueues[topicId]) {
|
if (!requestQueues[topicId]) {
|
||||||
requestQueues[topicId] = new PQueue({
|
requestQueues[topicId] = new PQueue(options)
|
||||||
concurrency: 4,
|
|
||||||
timeout: 1000 * 60 * 5, // 5 minutes
|
|
||||||
throwOnTimeout: false
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
return requestQueues[topicId]
|
return requestQueues[topicId]
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user