feat(websearch): improve web search enablement logic

This commit is contained in:
suyao 2025-04-09 01:51:37 +08:00 committed by 亢奋猫
parent 9689f00214
commit aa73025568
4 changed files with 32 additions and 19 deletions

View File

@ -130,6 +130,7 @@ import XirangModelLogoDark from '@renderer/assets/images/models/xirang_dark.png'
import YiModelLogo from '@renderer/assets/images/models/yi.png' import YiModelLogo from '@renderer/assets/images/models/yi.png'
import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png' import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import WebSearchService from '@renderer/services/WebSearchService'
import { Assistant, Model } from '@renderer/types' import { Assistant, Model } from '@renderer/types'
import OpenAI from 'openai' import OpenAI from 'openai'
@ -2270,6 +2271,9 @@ export function isGenerateImageModel(model: Model): boolean {
} }
export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record<string, any> { export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record<string, any> {
if (WebSearchService.isWebSearchEnabled() && WebSearchService.isOverwriteEnabled()) {
return {}
}
if (isWebSearchModel(model)) { if (isWebSearchModel(model)) {
if (assistant.enableWebSearch) { if (assistant.enableWebSearch) {
const webSearchTools = getWebSearchTools(model) const webSearchTools = getWebSearchTools(model)

View File

@ -777,20 +777,33 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
}) })
} }
const onEnableWebSearch = () => { const showWebSearchEnableModal = () => {
if (!isWebSearchModel(model)) { window.modal.confirm({
if (!WebSearchService.isWebSearchEnabled()) { title: t('chat.input.web_search.enable'),
window.modal.confirm({ content: t('chat.input.web_search.enable_content'),
title: t('chat.input.web_search.enable'), centered: true,
content: t('chat.input.web_search.enable_content'), okText: t('chat.input.web_search.button.ok'),
centered: true, onOk: () => {
okText: t('chat.input.web_search.button.ok'), navigate('/settings/web-search')
onOk: () => {
navigate('/settings/web-search')
}
})
return
} }
})
}
const shouldShowEnableModal = () => {
// 网络搜索功能是否未启用
const webSearchNotEnabled = !WebSearchService.isWebSearchEnabled()
// 非网络搜索模型:仅当网络搜索功能未启用时显示启用提示
if (!isWebSearchModel(model)) {
return webSearchNotEnabled
}
// 网络搜索模型:当允许覆盖但网络搜索功能未启用时显示启用提示
return WebSearchService.isOverwriteEnabled() && webSearchNotEnabled
}
const onEnableWebSearch = () => {
if (shouldShowEnableModal()) {
showWebSearchEnableModal()
return
} }
updateAssistant({ ...assistant, enableWebSearch: !assistant.enableWebSearch }) updateAssistant({ ...assistant, enableWebSearch: !assistant.enableWebSearch })

View File

@ -58,11 +58,7 @@ export async function fetchChatCompletion({
// Search web // Search web
if (WebSearchService.isWebSearchEnabled() && assistant.enableWebSearch && assistant.model) { if (WebSearchService.isWebSearchEnabled() && assistant.enableWebSearch && assistant.model) {
let webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model) const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
if (WebSearchService.isOverwriteEnabled()) {
webSearchParams = {}
}
if (isEmpty(webSearchParams) && !isOpenAIWebSearch(assistant.model)) { if (isEmpty(webSearchParams) && !isOpenAIWebSearch(assistant.model)) {
const lastMessage = findLast(messages, (m) => m.role === 'user') const lastMessage = findLast(messages, (m) => m.role === 'user')
const lastAnswer = findLast(messages, (m) => m.role === 'assistant') const lastAnswer = findLast(messages, (m) => m.role === 'assistant')

View File

@ -41,7 +41,7 @@ const initialState: WebSearchState = {
maxResults: 5, maxResults: 5,
excludeDomains: [], excludeDomains: [],
enhanceMode: false, enhanceMode: false,
overwrite: true overwrite: false
} }
const websearchSlice = createSlice({ const websearchSlice = createSlice({