mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-07-06 05:55:28 +08:00
fix: limit builtin web search usage (#14466)
### What this PR does Before this PR: Built-in web search could be called repeatedly during a single assistant response, producing multiple large search/citation payloads and very high token usage. After this PR: Built-in web search is bounded for a single response: prepared queries are deduplicated and capped, repeated tool executions reuse cached results, later tool-loop steps cannot call the built-in web search tool again, and returned source URLs are shortened to their origin. <!-- (optional, in `fixes #<issue number>(, fixes #<issue_number>, ...)` format, will close the issue(s) when PR gets merged)*: --> Fixes #14465 ### Why we need it and why it was done in this way The repeated built-in web search calls can cause excessive token consumption and slow or expensive responses. The fix keeps the hotfix scope narrow by changing only the built-in web search path and the orchestration hook that exposes it to AI SDK tool loops. The following tradeoffs were made: The default behavior now favors bounded search cost over repeated autonomous searches in one assistant response. If one search pass is insufficient, the assistant should answer with the available sources or ask the user for a follow-up. The following alternatives were considered: - Lowering the global max tool-call setting: rejected because it would affect all tools, not just built-in web search. - Allowing multiple web searches with a larger budget: rejected for now because the reported issue is runaway token usage. Links to places where the discussion took place: #14465 ### Breaking changes None. ### Special notes for your reviewer This PR targets `main` as a minimal `hotfix/*` branch because the current behavior can consume a very large number of tokens in a single response. ### Checklist This checklist is not enforcing, but it's a reminder of items that could be relevant to every PR. Approvers are expected to review this list. - [x] PR: The PR description is expressive enough and will help future contributors - [x] Code: [Write code that humans can understand](https://en.wikiquote.org/wiki/Martin_Fowler#code-for-humans) and [Keep it simple](https://en.wikipedia.org/wiki/KISS_principle) - [ ] Refactor: You have [left the code cleaner than you found it (Boy Scout Rule)](https://learning.oreilly.com/library/view/97-things-every/9780596809515/ch08.html) - [x] Upgrade: Impact of this change on upgrade flows was considered and addressed if required - [x] Documentation: A [user-guide update](https://docs.cherry-ai.com) was considered and is present (link) or not required. Check this only when the PR introduces or changes a user-facing feature or behavior. - [x] Self-review: I have reviewed my own code (e.g., via [`/gh-pr-review`](/.claude/skills/gh-pr-review/SKILL.md), `gh pr diff`, or GitHub UI) before requesting review from others ### Release note <!-- Write your release note: 1. Enter your extended release note in the below block. If the PR requires additional action from users switching to the new release, include the string "action required". 2. If no release note is required, just write "NONE". 3. Only include user-facing changes (new features, bug fixes visible to users, UI changes, behavior changes). For CI, maintenance, internal refactoring, build tooling, or other non-user-facing work, write "NONE". --> ```release-note Fixed built-in web search repeatedly running within a single assistant response and consuming excessive tokens. ``` --------- Signed-off-by: kangfenmao <kangfenmao@qq.com>
This commit is contained in:
@@ -33,7 +33,7 @@ import { isEmpty } from 'lodash'
|
||||
import { MemoryProcessor } from '../../services/MemoryProcessor'
|
||||
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
|
||||
import { memorySearchTool } from '../tools/MemorySearchTool'
|
||||
import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
|
||||
import { BUILTIN_WEB_SEARCH_TOOL_NAME, webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
|
||||
|
||||
const logger = loggerService.withContext('SearchOrchestrationPlugin')
|
||||
|
||||
@@ -328,11 +328,28 @@ export const searchOrchestrationPlugin = (
|
||||
if (needsSearch) {
|
||||
// onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
|
||||
// logger.info('🌐 Adding web search tool with pre-extracted keywords')
|
||||
params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords(
|
||||
params.tools[BUILTIN_WEB_SEARCH_TOOL_NAME] = webSearchToolWithPreExtractedKeywords(
|
||||
assistant.webSearchProviderId,
|
||||
analysisResult.websearch,
|
||||
context.requestId
|
||||
)
|
||||
|
||||
const prepareStep = params.prepareStep
|
||||
params.prepareStep = async (options) => {
|
||||
const stepConfig = await prepareStep?.(options)
|
||||
const hasWebSearchCall = options.steps.some((step) =>
|
||||
step.toolCalls.some((toolCall) => toolCall.toolName === BUILTIN_WEB_SEARCH_TOOL_NAME)
|
||||
)
|
||||
|
||||
return hasWebSearchCall
|
||||
? {
|
||||
...stepConfig,
|
||||
activeTools: (stepConfig?.activeTools ?? Object.keys(params.tools!)).filter(
|
||||
(toolName) => toolName !== BUILTIN_WEB_SEARCH_TOOL_NAME
|
||||
)
|
||||
}
|
||||
: stepConfig
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,9 +2,35 @@ import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import type { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
|
||||
import type { ExtractResults } from '@renderer/utils/extract'
|
||||
import { getUrlOriginOrFallback } from '@renderer/utils/url'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
export const BUILTIN_WEB_SEARCH_TOOL_NAME = 'builtin_web_search'
|
||||
|
||||
const MAX_BUILTIN_WEB_SEARCH_QUERIES = 3
|
||||
|
||||
function normalizeWebSearchQueries(questions: string[]): string[] {
|
||||
if (questions[0] === 'not_needed') {
|
||||
return ['not_needed']
|
||||
}
|
||||
|
||||
const seen = new Set<string>()
|
||||
|
||||
return questions
|
||||
.map((question) => question.trim())
|
||||
.filter((question) => question.length > 0)
|
||||
.filter((question) => {
|
||||
const key = question.toLocaleLowerCase()
|
||||
if (seen.has(key)) {
|
||||
return false
|
||||
}
|
||||
seen.add(key)
|
||||
return true
|
||||
})
|
||||
.slice(0, MAX_BUILTIN_WEB_SEARCH_QUERIES)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用预提取关键词的网络搜索工具
|
||||
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
||||
@@ -18,6 +44,7 @@ export const webSearchToolWithPreExtractedKeywords = (
|
||||
requestId: string
|
||||
) => {
|
||||
const webSearchProvider = WebSearchService.getWebSearchProvider(webSearchProviderId)
|
||||
let cachedSearchResultsPromise: Promise<WebSearchProviderResponse> | undefined
|
||||
|
||||
return tool({
|
||||
description: `Web search tool for finding current information, news, and real-time data from the internet.
|
||||
@@ -40,23 +67,23 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
||||
}),
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
if (cachedSearchResultsPromise) {
|
||||
return cachedSearchResultsPromise
|
||||
}
|
||||
|
||||
let finalQueries = normalizeWebSearchQueries(extractedKeywords.question)
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
finalQueries = normalizeWebSearchQueries([cleanContext])
|
||||
}
|
||||
}
|
||||
|
||||
let searchResults: WebSearchProviderResponse = {
|
||||
query: '',
|
||||
results: []
|
||||
}
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return searchResults
|
||||
if (finalQueries.length === 0 || finalQueries[0] === 'not_needed') {
|
||||
return { query: '', results: [] }
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 结构用于 processWebsearch
|
||||
@@ -66,9 +93,13 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
|
||||
return searchResults
|
||||
cachedSearchResultsPromise = WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
try {
|
||||
return await cachedSearchResultsPromise
|
||||
} catch (error) {
|
||||
cachedSearchResultsPromise = undefined
|
||||
throw error
|
||||
}
|
||||
},
|
||||
toModelOutput: ({ output: results }) => {
|
||||
let summary = 'No search needed based on the query analysis.'
|
||||
@@ -80,10 +111,9 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
||||
number: index + 1,
|
||||
title: result.title,
|
||||
content: result.content,
|
||||
url: result.url
|
||||
url: getUrlOriginOrFallback(result.url)
|
||||
}))
|
||||
|
||||
// 🔑 返回引用友好的格式,复用 REFERENCE_PROMPT 逻辑
|
||||
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
|
||||
const fullInstructions = REFERENCE_PROMPT.replace(
|
||||
'{question}',
|
||||
@@ -110,95 +140,5 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
||||
})
|
||||
}
|
||||
|
||||
// export const webSearchToolWithExtraction = (
|
||||
// webSearchProviderId: WebSearchProvider['id'],
|
||||
// requestId: string,
|
||||
// assistant: Assistant
|
||||
// ) => {
|
||||
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
|
||||
// return tool({
|
||||
// name: 'web_search_with_extraction',
|
||||
// description: 'Search the web for information with automatic keyword extraction from user messages',
|
||||
// inputSchema: z.object({
|
||||
// userMessage: z.object({
|
||||
// content: z.string().describe('The main content of the message'),
|
||||
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
// }),
|
||||
// lastAnswer: z.object({
|
||||
// content: z.string().describe('The main content of the message'),
|
||||
// role: z.enum(['user', 'assistant', 'system']).describe('Message role')
|
||||
// })
|
||||
// }),
|
||||
// outputSchema: z.object({
|
||||
// extractedKeywords: z.object({
|
||||
// question: z.array(z.string()),
|
||||
// links: z.array(z.string()).optional()
|
||||
// }),
|
||||
// searchResults: z.array(
|
||||
// z.object({
|
||||
// query: z.string(),
|
||||
// results: WebSearchProviderResult
|
||||
// })
|
||||
// )
|
||||
// }),
|
||||
// execute: async ({ userMessage, lastAnswer }) => {
|
||||
// const lastUserMessage: Message = {
|
||||
// id: requestId,
|
||||
// role: userMessage.role,
|
||||
// assistantId: assistant.id,
|
||||
// topicId: 'temp',
|
||||
// createdAt: new Date().toISOString(),
|
||||
// status: UserMessageStatus.SUCCESS,
|
||||
// blocks: []
|
||||
// }
|
||||
|
||||
// const lastAnswerMessage: Message | undefined = lastAnswer
|
||||
// ? {
|
||||
// id: requestId + '_answer',
|
||||
// role: lastAnswer.role,
|
||||
// assistantId: assistant.id,
|
||||
// topicId: 'temp',
|
||||
// createdAt: new Date().toISOString(),
|
||||
// status: UserMessageStatus.SUCCESS,
|
||||
// blocks: []
|
||||
// }
|
||||
// : undefined
|
||||
|
||||
// const extractResults = await extractSearchKeywords(lastUserMessage, assistant, {
|
||||
// shouldWebSearch: true,
|
||||
// shouldKnowledgeSearch: false,
|
||||
// lastAnswer: lastAnswerMessage
|
||||
// })
|
||||
|
||||
// if (!extractResults?.websearch || extractResults.websearch.question[0] === 'not_needed') {
|
||||
// return 'No search needed or extraction failed'
|
||||
// }
|
||||
|
||||
// const searchQueries = extractResults.websearch.question
|
||||
// const searchResults: Array<{ query: string; results: any }> = []
|
||||
|
||||
// for (const query of searchQueries) {
|
||||
// // 构建单个查询的ExtractResults结构
|
||||
// const queryExtractResults: ExtractResults = {
|
||||
// websearch: {
|
||||
// question: [query],
|
||||
// links: extractResults.websearch.links
|
||||
// }
|
||||
// }
|
||||
// const response = await webSearchService.processWebsearch(queryExtractResults, requestId)
|
||||
// searchResults.push({
|
||||
// query,
|
||||
// results: response
|
||||
// })
|
||||
// }
|
||||
|
||||
// return { extractedKeywords: extractResults.websearch, searchResults }
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// export type WebSearchToolWithExtractionOutput = InferToolOutput<ReturnType<typeof webSearchToolWithExtraction>>
|
||||
|
||||
export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||
export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchToolWithPreExtractedKeywords>>
|
||||
|
||||
104
src/renderer/src/aiCore/tools/__tests__/WebSearchTool.test.ts
Normal file
104
src/renderer/src/aiCore/tools/__tests__/WebSearchTool.test.ts
Normal file
@@ -0,0 +1,104 @@
|
||||
import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { webSearchToolWithPreExtractedKeywords } from '../WebSearchTool'
|
||||
|
||||
vi.mock('@renderer/services/WebSearchService', () => ({
|
||||
default: {
|
||||
getWebSearchProvider: vi.fn(),
|
||||
processWebsearch: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
describe('webSearchToolWithPreExtractedKeywords', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.mocked(WebSearchService.getWebSearchProvider).mockReturnValue({ id: 'tavily' } as any)
|
||||
vi.mocked(WebSearchService.processWebsearch).mockResolvedValue({
|
||||
query: 'first | second',
|
||||
results: [
|
||||
{
|
||||
title: 'Result',
|
||||
content: 'Content',
|
||||
url: 'https://example.com/path?utm_source=newsletter#details'
|
||||
}
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('deduplicates queries, limits them, keeps full URLs in output, and shortens model URLs', async () => {
|
||||
const searchTool = webSearchToolWithPreExtractedKeywords(
|
||||
'tavily',
|
||||
{
|
||||
question: [' first ', 'FIRST', 'second', 'third', 'fourth']
|
||||
},
|
||||
'request-1'
|
||||
) as any
|
||||
|
||||
const firstResult = await searchTool.execute({})
|
||||
const secondResult = await searchTool.execute({ additionalContext: 'new context' })
|
||||
|
||||
expect(WebSearchService.processWebsearch).toHaveBeenCalledTimes(1)
|
||||
expect(WebSearchService.processWebsearch).toHaveBeenCalledWith(
|
||||
{ id: 'tavily' },
|
||||
{
|
||||
websearch: {
|
||||
question: ['first', 'second', 'third'],
|
||||
links: undefined
|
||||
}
|
||||
},
|
||||
'request-1'
|
||||
)
|
||||
expect(firstResult.results[0].url).toBe('https://example.com/path?utm_source=newsletter#details')
|
||||
expect(secondResult).toBe(firstResult)
|
||||
|
||||
const modelOutput = searchTool.toModelOutput({ output: firstResult })
|
||||
const modelText = modelOutput.value.map((part: { text: string }) => part.text).join('\n')
|
||||
|
||||
expect(modelText).toContain('"url": "https://example.com"')
|
||||
expect(modelText).not.toContain('utm_source')
|
||||
})
|
||||
|
||||
it('reuses the in-flight search request for concurrent executions', async () => {
|
||||
const searchResponse = {
|
||||
query: 'first',
|
||||
results: [
|
||||
{
|
||||
title: 'Result',
|
||||
content: 'Content',
|
||||
url: 'https://example.com/path?utm_source=newsletter#details'
|
||||
}
|
||||
]
|
||||
}
|
||||
vi.mocked(WebSearchService.processWebsearch).mockImplementation(
|
||||
() => new Promise((resolve) => setTimeout(() => resolve(searchResponse), 0))
|
||||
)
|
||||
|
||||
const searchTool = webSearchToolWithPreExtractedKeywords(
|
||||
'tavily',
|
||||
{
|
||||
question: ['first']
|
||||
},
|
||||
'request-1'
|
||||
) as any
|
||||
|
||||
const [firstResult, secondResult] = await Promise.all([
|
||||
searchTool.execute({ additionalContext: 'first context' }),
|
||||
searchTool.execute({ additionalContext: 'second context' })
|
||||
])
|
||||
|
||||
expect(WebSearchService.processWebsearch).toHaveBeenCalledTimes(1)
|
||||
expect(WebSearchService.processWebsearch).toHaveBeenCalledWith(
|
||||
{ id: 'tavily' },
|
||||
{
|
||||
websearch: {
|
||||
question: ['first context'],
|
||||
links: undefined
|
||||
}
|
||||
},
|
||||
'request-1'
|
||||
)
|
||||
expect(firstResult).toBe(searchResponse)
|
||||
expect(secondResult).toBe(searchResponse)
|
||||
})
|
||||
})
|
||||
@@ -1,4 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { BUILTIN_WEB_SEARCH_TOOL_NAME } from '@renderer/aiCore/tools/WebSearchTool'
|
||||
import type { AppDispatch } from '@renderer/store'
|
||||
import store from '@renderer/store'
|
||||
import { toolPermissionsActions } from '@renderer/store/toolPermissions'
|
||||
@@ -156,7 +157,7 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
|
||||
}
|
||||
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
|
||||
// Handle citation block creation for web search results
|
||||
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response) {
|
||||
if (toolResponse.tool.name === BUILTIN_WEB_SEARCH_TOOL_NAME && toolResponse.response) {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{
|
||||
|
||||
17
src/renderer/src/utils/__tests__/url.test.ts
Normal file
17
src/renderer/src/utils/__tests__/url.test.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { getUrlOriginOrFallback } from '../url'
|
||||
|
||||
describe('url utils', () => {
|
||||
it('returns only the origin for valid urls', () => {
|
||||
expect(getUrlOriginOrFallback('https://example.com/path?utm_source=newsletter#details')).toBe('https://example.com')
|
||||
})
|
||||
|
||||
it('preserves ports in the origin', () => {
|
||||
expect(getUrlOriginOrFallback('https://example.com:8443/path')).toBe('https://example.com:8443')
|
||||
})
|
||||
|
||||
it('returns the original value for invalid urls', () => {
|
||||
expect(getUrlOriginOrFallback('not a url')).toBe('not a url')
|
||||
})
|
||||
})
|
||||
@@ -222,3 +222,4 @@ export * from './match'
|
||||
export * from './naming'
|
||||
export * from './sort'
|
||||
export * from './style'
|
||||
export * from './url'
|
||||
|
||||
7
src/renderer/src/utils/url.ts
Normal file
7
src/renderer/src/utils/url.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export function getUrlOriginOrFallback(url: string): string {
|
||||
try {
|
||||
return new URL(url).origin
|
||||
} catch {
|
||||
return url
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user