feat: Add message completion abort functionality
This commit is contained in:
parent
241cb0c0d8
commit
4c9bd02f8e
@ -24,6 +24,7 @@ import FileManager from '@renderer/services/FileManager'
|
|||||||
import { estimateTextTokens as estimateTxtTokens } from '@renderer/services/TokenService'
|
import { estimateTextTokens as estimateTxtTokens } from '@renderer/services/TokenService'
|
||||||
import { translateText } from '@renderer/services/TranslateService'
|
import { translateText } from '@renderer/services/TranslateService'
|
||||||
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
||||||
|
import { abortCompletion } from '@renderer/store/abortController'
|
||||||
import { setGenerating, setSearching } from '@renderer/store/runtime'
|
import { setGenerating, setSearching } from '@renderer/store/runtime'
|
||||||
import { Assistant, FileType, KnowledgeBase, Message, Model, Topic } from '@renderer/types'
|
import { Assistant, FileType, KnowledgeBase, Message, Model, Topic } from '@renderer/types'
|
||||||
import { classNames, delay, getFileExtension, uuid } from '@renderer/utils'
|
import { classNames, delay, getFileExtension, uuid } from '@renderer/utils'
|
||||||
@ -85,7 +86,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic }) => {
|
|||||||
const [selectedKnowledgeBases, setSelectedKnowledgeBases] = useState<KnowledgeBase[]>([])
|
const [selectedKnowledgeBases, setSelectedKnowledgeBases] = useState<KnowledgeBase[]>([])
|
||||||
const [mentionModels, setMentionModels] = useState<Model[]>([])
|
const [mentionModels, setMentionModels] = useState<Model[]>([])
|
||||||
const [isMentionPopupOpen, setIsMentionPopupOpen] = useState(false)
|
const [isMentionPopupOpen, setIsMentionPopupOpen] = useState(false)
|
||||||
|
const currentMessageId = useRef<string>()
|
||||||
const isVision = useMemo(() => isVisionModel(model), [model])
|
const isVision = useMemo(() => isVisionModel(model), [model])
|
||||||
const supportExts = useMemo(() => [...textExts, ...documentExts, ...(isVision ? imageExts : [])], [isVision])
|
const supportExts = useMemo(() => [...textExts, ...documentExts, ...(isVision ? imageExts : [])], [isVision])
|
||||||
|
|
||||||
@ -133,7 +134,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic }) => {
|
|||||||
if (mentionModels.length > 0) {
|
if (mentionModels.length > 0) {
|
||||||
message.mentions = mentionModels
|
message.mentions = mentionModels
|
||||||
}
|
}
|
||||||
|
currentMessageId.current = message.id
|
||||||
EventEmitter.emit(EVENT_NAMES.SEND_MESSAGE, message)
|
EventEmitter.emit(EVENT_NAMES.SEND_MESSAGE, message)
|
||||||
|
|
||||||
setText('')
|
setText('')
|
||||||
@ -274,6 +275,9 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic }) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const onPause = () => {
|
const onPause = () => {
|
||||||
|
if (currentMessageId.current) {
|
||||||
|
abortCompletion(currentMessageId.current)
|
||||||
|
}
|
||||||
window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, true)
|
window.keyv.set(EVENT_NAMES.CHAT_COMPLETION_PAUSED, true)
|
||||||
store.dispatch(setGenerating(false))
|
store.dispatch(setGenerating(false))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import i18n from '@renderer/i18n'
|
|||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||||
|
import { addAbortController, removeAbortController } from '@renderer/store/abortController'
|
||||||
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import { first, flatten, sum, takeRight } from 'lodash'
|
import { first, flatten, sum, takeRight } from 'lodash'
|
||||||
@ -13,7 +14,6 @@ import OpenAI from 'openai'
|
|||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
|
||||||
export default class AnthropicProvider extends BaseProvider {
|
export default class AnthropicProvider extends BaseProvider {
|
||||||
private sdk: Anthropic
|
private sdk: Anthropic
|
||||||
|
|
||||||
@ -107,10 +107,16 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
const abortController = new AbortController()
|
||||||
|
const { signal } = abortController
|
||||||
|
// 获取最后一条用户消息的 ID 作为 askId
|
||||||
|
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||||
|
if (lastUserMessage?.id) {
|
||||||
|
addAbortController(lastUserMessage.id, () => abortController.abort())
|
||||||
|
}
|
||||||
return new Promise<void>((resolve, reject) => {
|
return new Promise<void>((resolve, reject) => {
|
||||||
const stream = this.sdk.messages
|
const stream = this.sdk.messages
|
||||||
.stream({ ...body, stream: true })
|
.stream({ ...body, stream: true }, { signal })
|
||||||
.on('text', (text) => {
|
.on('text', (text) => {
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||||
stream.controller.abort()
|
stream.controller.abort()
|
||||||
@ -146,6 +152,10 @@ export default class AnthropicProvider extends BaseProvider {
|
|||||||
resolve()
|
resolve()
|
||||||
})
|
})
|
||||||
.on('error', (error) => reject(error))
|
.on('error', (error) => reject(error))
|
||||||
|
}).finally(() => {
|
||||||
|
if (lastUserMessage?.id) {
|
||||||
|
removeAbortController(lastUserMessage.id)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import i18n from '@renderer/i18n'
|
|||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||||
|
import { addAbortController, removeAbortController } from '@renderer/store/abortController'
|
||||||
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileType, FileTypes, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import axios from 'axios'
|
import axios from 'axios'
|
||||||
@ -24,7 +25,6 @@ import OpenAI from 'openai'
|
|||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
|
||||||
export default class GeminiProvider extends BaseProvider {
|
export default class GeminiProvider extends BaseProvider {
|
||||||
private sdk: GoogleGenerativeAI
|
private sdk: GoogleGenerativeAI
|
||||||
private requestOptions: RequestOptions
|
private requestOptions: RequestOptions
|
||||||
@ -204,7 +204,19 @@ export default class GeminiProvider extends BaseProvider {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const userMessagesStream = await chat.sendMessageStream(messageContents.parts)
|
const abortController = new AbortController()
|
||||||
|
const { signal } = abortController
|
||||||
|
// 获取最后一条用户消息的 ID 作为 askId
|
||||||
|
const lastUserMessage = userMessages.findLast((m) => m.role === 'user')
|
||||||
|
if (lastUserMessage?.id) {
|
||||||
|
addAbortController(lastUserMessage.id, () => abortController.abort())
|
||||||
|
}
|
||||||
|
|
||||||
|
const userMessagesStream = await chat.sendMessageStream(messageContents.parts, { signal }).finally(() => {
|
||||||
|
if (lastUserMessage?.id) {
|
||||||
|
removeAbortController(lastUserMessage.id)
|
||||||
|
}
|
||||||
|
})
|
||||||
let time_first_token_millsec = 0
|
let time_first_token_millsec = 0
|
||||||
|
|
||||||
for await (const chunk of userMessagesStream.stream) {
|
for await (const chunk of userMessagesStream.stream) {
|
||||||
|
|||||||
@ -4,6 +4,7 @@ import i18n from '@renderer/i18n'
|
|||||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||||
import { filterContextMessages } from '@renderer/services/MessagesService'
|
import { filterContextMessages } from '@renderer/services/MessagesService'
|
||||||
|
import { addAbortController, removeAbortController } from '@renderer/store/abortController'
|
||||||
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, FileTypes, GenerateImageParams, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { removeSpecialCharacters } from '@renderer/utils'
|
import { removeSpecialCharacters } from '@renderer/utils'
|
||||||
import { takeRight } from 'lodash'
|
import { takeRight } from 'lodash'
|
||||||
@ -213,21 +214,40 @@ export default class OpenAIProvider extends BaseProvider {
|
|||||||
let time_first_token_millsec = 0
|
let time_first_token_millsec = 0
|
||||||
let time_first_content_millsec = 0
|
let time_first_content_millsec = 0
|
||||||
const start_time_millsec = new Date().getTime()
|
const start_time_millsec = new Date().getTime()
|
||||||
|
const abortController = new AbortController()
|
||||||
|
const { signal } = abortController
|
||||||
|
|
||||||
// @ts-ignore key is not typed
|
// 获取最后一条用户消息的 ID 作为 askId
|
||||||
const stream = await this.sdk.chat.completions.create({
|
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||||
model: model.id,
|
if (lastUserMessage?.id) {
|
||||||
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
|
addAbortController(lastUserMessage.id, () => abortController.abort())
|
||||||
temperature: this.getTemperature(assistant, model),
|
}
|
||||||
top_p: this.getTopP(assistant, model),
|
|
||||||
max_tokens: maxTokens,
|
const stream = await this.sdk.chat.completions
|
||||||
keep_alive: this.keepAliveTime,
|
// @ts-ignore key is not typed
|
||||||
stream: isSupportStreamOutput(),
|
.create(
|
||||||
...this.getReasoningEffort(assistant, model),
|
{
|
||||||
...getOpenAIWebSearchParams(assistant, model),
|
model: model.id,
|
||||||
...this.getProviderSpecificParameters(assistant, model),
|
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
|
||||||
...this.getCustomParameters(assistant)
|
temperature: this.getTemperature(assistant, model),
|
||||||
})
|
top_p: this.getTopP(assistant, model),
|
||||||
|
max_tokens: maxTokens,
|
||||||
|
keep_alive: this.keepAliveTime,
|
||||||
|
stream: isSupportStreamOutput(),
|
||||||
|
...this.getReasoningEffort(assistant, model),
|
||||||
|
...getOpenAIWebSearchParams(assistant, model),
|
||||||
|
...this.getProviderSpecificParameters(assistant, model),
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
signal
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.finally(() => {
|
||||||
|
if (lastUserMessage?.id) {
|
||||||
|
removeAbortController(lastUserMessage.id)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
if (!isSupportStreamOutput()) {
|
if (!isSupportStreamOutput()) {
|
||||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import i18n from '@renderer/i18n'
|
import i18n from '@renderer/i18n'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
|
import { addAbortController } from '@renderer/store/abortController'
|
||||||
import { setGenerating } from '@renderer/store/runtime'
|
import { setGenerating } from '@renderer/store/runtime'
|
||||||
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||||||
import { formatMessageError } from '@renderer/utils/error'
|
import { formatMessageError } from '@renderer/utils/error'
|
||||||
@ -16,7 +17,6 @@ import {
|
|||||||
import { EVENT_NAMES, EventEmitter } from './EventService'
|
import { EVENT_NAMES, EventEmitter } from './EventService'
|
||||||
import { filterMessages, filterUsefulMessages } from './MessagesService'
|
import { filterMessages, filterUsefulMessages } from './MessagesService'
|
||||||
import { estimateMessagesUsage } from './TokenService'
|
import { estimateMessagesUsage } from './TokenService'
|
||||||
|
|
||||||
export async function fetchChatCompletion({
|
export async function fetchChatCompletion({
|
||||||
message,
|
message,
|
||||||
messages,
|
messages,
|
||||||
@ -37,18 +37,14 @@ export async function fetchChatCompletion({
|
|||||||
|
|
||||||
onResponse({ ...message })
|
onResponse({ ...message })
|
||||||
|
|
||||||
// Handle paused state
|
const pauseFn = (message: Message) => {
|
||||||
let paused = false
|
message.status = 'paused'
|
||||||
const timer = setInterval(() => {
|
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
store.dispatch(setGenerating(false))
|
||||||
paused = true
|
onResponse({ ...message, status: 'paused' })
|
||||||
message.status = 'paused'
|
}
|
||||||
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
|
||||||
store.dispatch(setGenerating(false))
|
addAbortController(message.askId ?? message.id, pauseFn.bind(null, message))
|
||||||
onResponse({ ...message, status: 'paused' })
|
|
||||||
clearInterval(timer)
|
|
||||||
}
|
|
||||||
}, 1000)
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
let _messages: Message[] = []
|
let _messages: Message[] = []
|
||||||
@ -97,12 +93,6 @@ export async function fetchChatCompletion({
|
|||||||
message.error = formatMessageError(error)
|
message.error = formatMessageError(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
timer && clearInterval(timer)
|
|
||||||
|
|
||||||
if (paused) {
|
|
||||||
return message
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update message status
|
// Update message status
|
||||||
message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : message.status
|
message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : message.status
|
||||||
|
|
||||||
|
|||||||
25
src/renderer/src/store/abortController.ts
Normal file
25
src/renderer/src/store/abortController.ts
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
export const abortMap = new Map<string, () => void>()
|
||||||
|
|
||||||
|
export const addAbortController = (messageId: string, abortFn: () => void) => {
|
||||||
|
let callback = abortFn
|
||||||
|
const existingCallback = abortMap.get(messageId)
|
||||||
|
if (existingCallback) {
|
||||||
|
callback = () => {
|
||||||
|
existingCallback?.()
|
||||||
|
abortFn()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
abortMap.set(messageId, callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const removeAbortController = (messageId: string) => {
|
||||||
|
abortMap.delete(messageId)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const abortCompletion = (messageId: string) => {
|
||||||
|
const abortFn = abortMap.get(messageId)
|
||||||
|
if (abortFn) {
|
||||||
|
abortFn()
|
||||||
|
removeAbortController(messageId)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user